use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use kafka_protocol::{
indexmap::IndexMap,
messages::{
fetch_request::{FetchPartition, FetchTopic},
list_offsets_request::{ListOffsetsPartition, ListOffsetsTopic},
produce_request::PartitionProduceData,
FetchRequest, ListOffsetsRequest, ProduceRequest, ResponseKind,
},
protocol::StrBytes,
records::{Compression, Record, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions, TimestampType},
};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::{CancellationToken, DropGuard};
use crate::clitask::{ClientTask, ClusterMeta, Msg};
use crate::error::ClientError;
use crate::{broker::BrokerResponse, error::ClientResult};
pub const DEFAULT_TIMEOUT: i32 = 10 * 1000;
pub type MessageHeaders = IndexMap<StrBytes, Option<Bytes>>;
#[derive(Clone)]
pub struct Client {
_tx: mpsc::Sender<Msg>,
cluster: ClusterMeta,
_shutdown: Arc<DropGuard>,
}
impl Client {
pub fn new(seed_list: Vec<String>) -> Self {
let (tx, rx) = mpsc::channel(1_000);
let shutdown = CancellationToken::new();
let task = ClientTask::new(seed_list, rx, shutdown.clone());
let topics = task.cluster.clone();
tokio::spawn(task.run());
Self {
_tx: tx,
cluster: topics,
_shutdown: Arc::new(shutdown.drop_guard()),
}
}
pub async fn list_offsets(&self, topic: StrBytes, ptn: i32, pos: ListOffsetsPosition) -> ClientResult<i64> {
let mut cluster = self.cluster.load();
if !*cluster.bootstrap.borrow() {
let mut sig = cluster.bootstrap.clone();
let _ = sig.wait_for(|val| *val).await; cluster = self.cluster.load();
}
let topic_ptns = cluster.topics.get(&topic).ok_or(ClientError::UnknownTopic(topic.to_string()))?;
let broker = topic_ptns.get(&ptn).ok_or(ClientError::UnknownPartition(topic.to_string(), ptn))?;
let uid = uuid::Uuid::new_v4();
let mut req = ListOffsetsRequest::default();
let mut req_topic = ListOffsetsTopic::default();
req_topic.name = topic.clone().into();
let mut req_ptn = ListOffsetsPartition::default();
req_ptn.partition_index = ptn;
req_ptn.timestamp = match pos {
ListOffsetsPosition::Earliest => -2,
ListOffsetsPosition::Latest => -1,
ListOffsetsPosition::Timestamp(val) => val,
};
req_topic.partitions.push(req_ptn);
req.topics.push(req_topic);
let (tx, rx) = oneshot::channel();
broker.conn.list_offsets(uid, req, tx).await;
let res = rx.await;
let offset = res
.map_err(|_| ClientError::Other("response channel dropped by broker, which should never happen".into()))?
.result
.map_err(ClientError::BrokerError)
.and_then(|(_, res)| {
if let ResponseKind::ListOffsetsResponse(res) = res {
Ok(res)
} else {
Err(ClientError::MalformedResponse)
}
})
.and_then(|res| {
res.topics
.iter()
.find(|topic_res| topic_res.name.0 == topic)
.and_then(|topic_res| topic_res.partitions.iter().find(|ptn_res| ptn_res.partition_index == ptn).map(|ptn_res| ptn_res.offset))
.ok_or(ClientError::MalformedResponse)
})?;
Ok(offset)
}
pub async fn fetch(&self, topic: StrBytes, ptn: i32, start: i64) -> ClientResult<Option<Vec<Record>>> {
let mut cluster = self.cluster.load();
if !*cluster.bootstrap.borrow() {
let mut sig = cluster.bootstrap.clone();
let _ = sig.wait_for(|val| *val).await; cluster = self.cluster.load();
}
let topic_ptns = cluster.topics.get(&topic).ok_or(ClientError::UnknownTopic(topic.to_string()))?;
let broker = topic_ptns.get(&ptn).ok_or(ClientError::UnknownPartition(topic.to_string(), ptn))?;
let uid = uuid::Uuid::new_v4();
let mut req = FetchRequest::default();
req.max_bytes = 1024i32.pow(2);
req.max_wait_ms = 10_000;
let mut req_topic = FetchTopic::default();
req_topic.topic = topic.clone().into();
let mut req_ptn = FetchPartition::default();
req_ptn.partition = ptn;
req_ptn.partition_max_bytes = 1024i32.pow(2);
req_ptn.fetch_offset = start;
req_topic.partitions.push(req_ptn);
req.topics.push(req_topic);
tracing::debug!("about to send request: {:?}", req);
let (tx, rx) = oneshot::channel();
broker.conn.fetch(uid, req, tx).await;
let res = rx.await;
let batch_opt = res
.map_err(|_| ClientError::Other("response channel dropped by broker, which should never happen".into()))?
.result
.map_err(ClientError::BrokerError)
.and_then(|(_, res)| {
tracing::debug!("res: {:?}", res);
if let ResponseKind::FetchResponse(res) = res {
Ok(res)
} else {
Err(ClientError::MalformedResponse)
}
})
.and_then(|res| {
res.responses
.iter()
.find(|topic_res| topic_res.topic.0 == topic)
.and_then(|topic_res| topic_res.partitions.iter().find(|ptn_res| ptn_res.partition_index == ptn).map(|ptn_res| ptn_res.records.clone()))
.ok_or(ClientError::MalformedResponse)
})?;
let Some(mut batch) = batch_opt else { return Ok(None) };
let records = RecordBatchDecoder::decode(&mut batch).map_err(|_| ClientError::MalformedResponse)?;
Ok(Some(records))
}
pub fn topic_producer(&self, topic: &str, acks: Acks, timeout_ms: Option<i32>, compression: Option<Compression>) -> TopicProducer {
let (tx, rx) = mpsc::unbounded_channel();
let compression = compression.unwrap_or(Compression::None);
let encode_opts = RecordEncodeOptions { version: 2, compression };
TopicProducer {
_client: self.clone(),
tx,
rx,
cluster: self.cluster.clone(),
topic: StrBytes::from_string(topic.into()),
acks,
timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT),
encode_opts,
buf: BytesMut::with_capacity(1024 * 1024),
batch_buf: Vec::with_capacity(1024),
last_ptn: -1,
}
}
}
pub struct Message {
pub key: Option<Bytes>,
pub value: Option<Bytes>,
pub headers: MessageHeaders,
}
impl Message {
pub fn new(key: Option<Bytes>, value: Option<Bytes>, headers: MessageHeaders) -> Self {
Self { key, value, headers }
}
}
#[derive(Clone, Copy)]
#[repr(i16)]
pub enum Acks {
All = -1,
None = 0,
Leader = 1,
}
pub struct TopicProducer {
_client: Client,
tx: mpsc::UnboundedSender<BrokerResponse>,
rx: mpsc::UnboundedReceiver<BrokerResponse>,
cluster: ClusterMeta,
pub topic: StrBytes,
acks: Acks,
timeout_ms: i32,
encode_opts: RecordEncodeOptions,
buf: BytesMut,
batch_buf: Vec<Record>,
last_ptn: i32,
}
impl TopicProducer {
pub async fn produce(&mut self, messages: &[Message]) -> ClientResult<(i64, i64)> {
if messages.is_empty() {
return Err(ClientError::ProducerMessagesEmpty);
}
self.batch_buf.clear(); let mut cluster = self.cluster.load();
if !*cluster.bootstrap.borrow() {
let mut sig = cluster.bootstrap.clone();
let _ = sig.wait_for(|val| *val).await; cluster = self.cluster.load();
}
let Some(topic_ptns) = cluster.topics.get(&self.topic) else {
return Err(ClientError::UnknownTopic(self.topic.to_string()));
};
let Some((sticky_ptn, sticky_broker)) = topic_ptns
.range((self.last_ptn + 1)..)
.next()
.or_else(|| topic_ptns.range(..).next())
.map(|(key, val)| (*key, val.clone()))
else {
return Err(ClientError::NoPartitionsAvailable(self.topic.to_string()));
};
self.last_ptn = sticky_ptn;
let timestamp = chrono::Utc::now().timestamp();
for msg in messages.iter() {
self.batch_buf.push(Record {
transactional: false,
control: false,
partition_leader_epoch: 0,
producer_id: 0,
producer_epoch: 0,
timestamp,
timestamp_type: TimestampType::Creation,
offset: 0,
sequence: 0,
key: msg.key.clone(),
value: msg.value.clone(),
headers: msg.headers.clone(),
});
}
let size = self.batch_buf.iter().fold(0usize, |mut acc, record| {
acc += 21; if let Some(key) = record.key.as_ref() {
acc += key.len();
}
if let Some(val) = record.value.as_ref() {
acc += val.len();
}
for (k, v) in record.headers.iter() {
acc += 4 + k.len() + 4 + v.as_ref().map(|v| v.len()).unwrap_or(0);
}
acc
});
self.buf.reserve(size);
let res = RecordBatchEncoder::encode(&mut self.buf, self.batch_buf.iter(), &self.encode_opts).map_err(|err| ClientError::EncodingError(format!("{:?}", err)));
self.batch_buf.clear();
res?;
let mut req = ProduceRequest::default();
req.acks = self.acks as i16;
req.timeout_ms = self.timeout_ms;
let topic = req.topic_data.entry(self.topic.clone().into()).or_default();
let mut ptn_data = PartitionProduceData::default();
ptn_data.index = sticky_ptn;
ptn_data.records = Some(self.buf.split().freeze());
topic.partition_data.push(ptn_data);
let uid = uuid::Uuid::new_v4();
sticky_broker.conn.produce(uid, req, self.tx.clone()).await;
let res = loop {
let Some(res) = self.rx.recv().await else {
unreachable!("both ends of channel are heald, receiving None should not be possible")
};
if res.id == uid {
break res;
}
};
res.result
.map_err(ClientError::BrokerError)
.and_then(|res| {
if let ResponseKind::ProduceResponse(inner) = res.1 {
Ok(inner)
} else {
tracing::error!("expected broker to return a ProduceResponse, got: {:?}", res.1);
Err(ClientError::MalformedResponse)
}
})
.and_then(|res| {
res.responses
.iter()
.find(|topic| topic.0 .0 == self.topic)
.and_then(|val| {
val.1.partition_responses.first().map(|val| {
debug_assert!(messages.len() > 0, "messages len should always be validated at start of function");
let last_offset = val.base_offset + (messages.len() - 1) as i64;
(val.base_offset, last_offset)
})
})
.ok_or(ClientError::MalformedResponse)
})
}
}
pub enum ListOffsetsPosition {
Earliest,
Latest,
Timestamp(i64),
}