use crate::codec::{CodecBuilder, Direction, kafka::KafkaCodecBuilder};
use crate::connection::SinkConnection;
use crate::frame::kafka::{KafkaFrame, RequestBody, ResponseBody};
use crate::frame::{Frame, MessageType};
use crate::message::Messages;
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::{
ChainState, Transform, TransformBuilder, TransformContextBuilder, TransformContextConfig,
};
use crate::transforms::{DownChainProtocol, TransformConfig, UpChainProtocol};
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::timeout;
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct KafkaSinkSingleConfig {
pub destination_port: u16,
pub connect_timeout_ms: u64,
pub read_timeout: Option<u64>,
pub tls: Option<TlsConnectorConfig>,
}
const NAME: &str = "KafkaSinkSingle";
#[typetag::serde(name = "KafkaSinkSingle")]
#[async_trait(?Send)]
impl TransformConfig for KafkaSinkSingleConfig {
async fn get_builder(
&self,
transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
let tls = self.tls.as_ref().map(TlsConnector::new).transpose()?;
Ok(Box::new(KafkaSinkSingleBuilder::new(
self.destination_port,
transform_context.chain_name,
self.connect_timeout_ms,
self.read_timeout,
tls,
)))
}
fn up_chain_protocol(&self) -> UpChainProtocol {
UpChainProtocol::MustBeOneOf(vec![MessageType::Kafka])
}
fn down_chain_protocol(&self) -> DownChainProtocol {
DownChainProtocol::Terminating
}
}
struct KafkaSinkSingleBuilder {
address_port: u16,
connect_timeout: Duration,
read_timeout: Option<Duration>,
tls: Option<TlsConnector>,
}
impl KafkaSinkSingleBuilder {
fn new(
address_port: u16,
_chain_name: String,
connect_timeout_ms: u64,
timeout: Option<u64>,
tls: Option<TlsConnector>,
) -> KafkaSinkSingleBuilder {
let receive_timeout = timeout.map(Duration::from_secs);
KafkaSinkSingleBuilder {
address_port,
connect_timeout: Duration::from_millis(connect_timeout_ms),
read_timeout: receive_timeout,
tls,
}
}
}
impl TransformBuilder for KafkaSinkSingleBuilder {
fn build(&self, transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(KafkaSinkSingle {
connection: None,
address_port: self.address_port,
connect_timeout: self.connect_timeout,
tls: self.tls.clone(),
read_timeout: self.read_timeout,
force_run_chain: transform_context.force_run_chain,
})
}
fn get_name(&self) -> &'static str {
NAME
}
fn is_terminating(&self) -> bool {
true
}
}
struct KafkaSinkSingle {
address_port: u16,
connection: Option<SinkConnection>,
connect_timeout: Duration,
read_timeout: Option<Duration>,
tls: Option<TlsConnector>,
force_run_chain: Arc<Notify>,
}
#[async_trait]
impl Transform for KafkaSinkSingle {
fn get_name(&self) -> &'static str {
NAME
}
async fn transform<'shorter, 'longer: 'shorter>(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
) -> Result<Messages> {
if self.connection.is_none() {
let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkSingle".to_owned());
let address = (chain_state.local_addr.ip(), self.address_port);
self.connection = Some(
SinkConnection::new(
address,
codec,
&self.tls,
self.connect_timeout,
self.force_run_chain.clone(),
self.read_timeout,
)
.await?,
);
}
let mut responses = vec![];
if chain_state.requests.is_empty() {
self.connection
.as_mut()
.unwrap()
.try_recv_into(&mut responses)?;
} else {
for request in &mut chain_state.requests {
if let Some(Frame::Kafka(KafkaFrame::Request {
body: RequestBody::LeaderAndIsr(leader_and_isr),
..
})) = request.frame()
{
for leader in &mut leader_and_isr.live_leaders {
leader.port = self.address_port as i32;
}
request.invalidate_cache();
}
}
let connection = self.connection.as_mut().unwrap();
let requests_count = chain_state.requests.len();
connection.send(std::mem::take(&mut chain_state.requests))?;
while responses.len() < requests_count {
if let Some(read_timeout) = self.read_timeout {
timeout(read_timeout, connection.recv_into(&mut responses)).await?
} else {
connection.recv_into(&mut responses).await
}?;
}
}
for response in &mut responses {
let port = chain_state.local_addr.port() as i32;
match response.frame() {
Some(Frame::Kafka(KafkaFrame::Response {
body: ResponseBody::FindCoordinator(find_coordinator),
version,
..
})) => {
if *version <= 3 {
find_coordinator.port = port;
} else {
for coordinator in &mut find_coordinator.coordinators {
coordinator.port = port;
}
}
response.invalidate_cache();
}
Some(Frame::Kafka(KafkaFrame::Response {
body: ResponseBody::Metadata(metadata),
..
})) => {
for broker in &mut metadata.brokers {
broker.port = port;
}
response.invalidate_cache();
}
Some(Frame::Kafka(KafkaFrame::Response {
body: ResponseBody::DescribeCluster(describe_cluster),
..
})) => {
for broker in &mut describe_cluster.brokers {
broker.port = port;
}
response.invalidate_cache();
}
_ => {}
}
}
Ok(responses)
}
}