1use std::{sync::Arc, time::Duration};
4
5use bytes::{Bytes, BytesMut};
6use kafka_protocol::{
7    indexmap::IndexMap,
8    messages::{
9        create_topics_request::CreateTopicsRequest,
10        create_topics_response::CreateTopicsResponse,
11        delete_topics_request::DeleteTopicsRequest,
12        delete_topics_response::DeleteTopicsResponse,
13        fetch_request::{FetchPartition, FetchTopic},
14        list_offsets_request::{ListOffsetsPartition, ListOffsetsTopic},
15        produce_request::PartitionProduceData,
16        FetchRequest, FindCoordinatorRequest, FindCoordinatorResponse, ListOffsetsRequest, MetadataResponse, ProduceRequest, ResponseHeader, ResponseKind,
17    },
18    protocol::StrBytes,
19    records::{Compression, Record, RecordBatchDecoder, RecordBatchEncoder, RecordEncodeOptions, TimestampType},
20    ResponseError,
21};
22use tokio::sync::{mpsc, oneshot};
23use tokio_util::sync::{CancellationToken, DropGuard};
24
25use crate::clitask::{ClientTask, Cluster, ClusterMeta, MetadataPolicy, Msg};
26use crate::error::ClientError;
27#[cfg(feature = "internal")]
28use crate::internal::InternalClient;
29use crate::{broker::BrokerResponse, error::ClientResult};
30
31pub const DEFAULT_TIMEOUT: i32 = 10 * 1000;
38
39pub type MessageHeaders = IndexMap<StrBytes, Option<Bytes>>;
41
42#[derive(Clone)]
52pub struct Client {
53    pub(crate) tx: mpsc::Sender<Msg>,
55    pub(crate) cluster: ClusterMeta,
59    _shutdown: Arc<DropGuard>,
61}
62
63impl Client {
64    pub fn new(seed_list: Vec<String>) -> Self {
66        let (tx, rx) = mpsc::channel(1_000);
67        let shutdown = CancellationToken::new();
68        let task = ClientTask::new(seed_list, None, MetadataPolicy::default(), rx, false, shutdown.clone());
69        let topics = task.cluster.clone();
70        tokio::spawn(task.run());
71        Self {
72            tx,
73            cluster: topics,
74            _shutdown: Arc::new(shutdown.drop_guard()),
75        }
76    }
77
78    #[cfg(feature = "internal")]
87    pub fn new_internal(seed_list: Vec<String>, block_list: Vec<i32>) -> InternalClient {
88        let (tx, rx) = mpsc::channel(1_000);
89        let shutdown = CancellationToken::new();
90        let task = ClientTask::new(seed_list, Some(block_list), MetadataPolicy::Manual, rx, true, shutdown.clone());
91        let cluster = task.cluster.clone();
92        tokio::spawn(task.run());
93        let cli = Self {
94            tx,
95            cluster,
96            _shutdown: Arc::new(shutdown.drop_guard()),
97        };
98        InternalClient::new(cli)
99    }
100
101    pub async fn get_metadata(&self, broker_id: i32) -> ClientResult<MetadataResponse> {
103        let cluster = self.get_cluster_metadata_cache().await?;
104        let broker = cluster
105            .brokers
106            .get(&broker_id)
107            .ok_or(ClientError::Other("broker does not exist in currently discovered metadata".into()))?;
108
109        let uid = uuid::Uuid::new_v4();
111        let (tx, rx) = oneshot::channel();
112        broker.conn.get_metadata(uid, tx.into(), true).await;
113
114        let res = unpack_broker_response(rx).await.and_then(|(_, res)| {
116            if let ResponseKind::MetadataResponse(res) = res {
117                Ok(res)
118            } else {
119                Err(ClientError::MalformedResponse)
120            }
121        })?;
122
123        Ok(res)
124    }
125
126    pub async fn list_offsets(&self, topic: StrBytes, ptn: i32, pos: ListOffsetsPosition) -> ClientResult<i64> {
128        let cluster = self.get_cluster_metadata_cache().await?;
129
130        let topic_ptns = cluster.topics.get(&topic).ok_or(ClientError::UnknownTopic(topic.to_string()))?;
132        let ptn_meta = topic_ptns.get(&ptn).ok_or(ClientError::UnknownPartition(topic.to_string(), ptn))?;
133        let broker = ptn_meta.leader.clone().ok_or(ClientError::NoPartitionLeader(topic.to_string(), ptn))?;
134
135        let uid = uuid::Uuid::new_v4();
137        let mut req = ListOffsetsRequest::default();
138        let mut req_topic = ListOffsetsTopic::default();
140        req_topic.name = topic.clone().into();
141        let mut req_ptn = ListOffsetsPartition::default();
142        req_ptn.partition_index = ptn;
143        req_ptn.timestamp = match pos {
144            ListOffsetsPosition::Earliest => -2,
145            ListOffsetsPosition::Latest => -1,
146            ListOffsetsPosition::Timestamp(val) => val,
147        };
148        req_topic.partitions.push(req_ptn);
149        req.topics.push(req_topic);
150
151        let (tx, rx) = oneshot::channel();
153        broker.conn.list_offsets(uid, req, tx).await;
154
155        let offset = unpack_broker_response(rx)
158            .await
159            .and_then(|(_, res)| {
160                if let ResponseKind::ListOffsetsResponse(res) = res {
161                    Ok(res)
162                } else {
163                    Err(ClientError::MalformedResponse)
164                }
165            })
166            .and_then(|res| {
167                res.topics
168                    .iter()
169                    .find(|topic_res| topic_res.name.0 == topic)
170                    .and_then(|topic_res| topic_res.partitions.iter().find(|ptn_res| ptn_res.partition_index == ptn).map(|ptn_res| ptn_res.offset))
171                    .ok_or(ClientError::MalformedResponse)
172            })?;
173
174        Ok(offset)
175    }
176
177    pub async fn fetch(&self, topic: StrBytes, ptn: i32, start: i64) -> ClientResult<Option<Vec<Record>>> {
179        let cluster = self.get_cluster_metadata_cache().await?;
180
181        let topic_ptns = cluster.topics.get(&topic).ok_or(ClientError::UnknownTopic(topic.to_string()))?;
183        let ptn_meta = topic_ptns.get(&ptn).ok_or(ClientError::UnknownPartition(topic.to_string(), ptn))?;
184        let broker = ptn_meta.leader.clone().ok_or(ClientError::NoPartitionLeader(topic.to_string(), ptn))?;
185
186        let uid = uuid::Uuid::new_v4();
188        let mut req = FetchRequest::default();
189        req.max_bytes = 1024i32.pow(2);
190        req.max_wait_ms = 10_000;
191        let mut req_topic = FetchTopic::default();
193        req_topic.topic = topic.clone().into();
194        let mut req_ptn = FetchPartition::default();
195        req_ptn.partition = ptn;
196        req_ptn.partition_max_bytes = 1024i32.pow(2);
197        req_ptn.fetch_offset = start;
198        req_topic.partitions.push(req_ptn);
199        req.topics.push(req_topic);
200        tracing::debug!("about to send request: {:?}", req);
201
202        let (tx, rx) = oneshot::channel();
204        broker.conn.fetch(uid, req, tx).await;
205
206        let batch_opt = unpack_broker_response(rx)
209            .await
210            .and_then(|(_, res)| {
211                tracing::debug!("res: {:?}", res);
212                if let ResponseKind::FetchResponse(res) = res {
213                    Ok(res)
214                } else {
215                    Err(ClientError::MalformedResponse)
216                }
217            })
218            .and_then(|res| {
219                res.responses
220                    .iter()
221                    .find(|topic_res| topic_res.topic.0 == topic)
222                    .and_then(|topic_res| topic_res.partitions.iter().find(|ptn_res| ptn_res.partition_index == ptn).map(|ptn_res| ptn_res.records.clone()))
223                    .ok_or(ClientError::MalformedResponse)
224            })?;
225
226        let Some(mut batch) = batch_opt else { return Ok(None) };
228        let records = RecordBatchDecoder::decode(&mut batch).map_err(|_| ClientError::MalformedResponse)?;
229
230        Ok(Some(records))
231    }
232
233    pub(crate) async fn get_cluster_metadata_cache(&self) -> ClientResult<Arc<Cluster>> {
238        let mut cluster = self.cluster.load();
239        if !*cluster.bootstrap.borrow() {
240            let mut sig = cluster.bootstrap.clone();
241            let _ = tokio::time::timeout(Duration::from_secs(10), sig.wait_for(|val| *val))
242                .await
243                .map_err(|_err| ClientError::ClusterMetadataTimeout)?
244                .map_err(|_err| ClientError::ClusterMetadataTimeout)?;
245            cluster = self.cluster.load();
246        }
247        Ok(cluster.clone())
248    }
249
250    pub async fn find_coordinator(&self, key: StrBytes, key_type: i8, broker_id: Option<i32>) -> ClientResult<FindCoordinatorResponse> {
252        let cluster = self.get_cluster_metadata_cache().await?;
253
254        let broker = broker_id
256            .and_then(|id| cluster.brokers.get(&id).cloned())
257            .or_else(|| {
258                cluster.brokers.first_key_value().map(|(_, broker)| broker.clone())
260            })
261            .ok_or_else(|| ClientError::NoBrokerFound)?;
262
263        let uid = uuid::Uuid::new_v4();
265        let mut req = FindCoordinatorRequest::default();
266        req.key = key;
267        req.key_type = key_type;
268
269        let (tx, rx) = oneshot::channel();
271        broker.conn.find_coordinator(uid, req, tx).await;
272
273        unpack_broker_response(rx)
275            .await
276            .and_then(|(_, res)| {
277                tracing::debug!("res: {:?}", res);
278                if let ResponseKind::FindCoordinatorResponse(res) = res {
279                    Ok(res)
280                } else {
281                    Err(ClientError::MalformedResponse)
282                }
283            })
284            .and_then(|res| {
285                if res.error_code != 0 {
287                    return Err(ClientError::ResponseError(res.error_code, ResponseError::try_from_code(res.error_code), res.error_message));
288                }
289                Ok(res)
290            })
291    }
292
293    pub fn topic_producer(&self, topic: &str, acks: Acks, timeout_ms: Option<i32>, compression: Option<Compression>) -> TopicProducer {
295        let (tx, rx) = mpsc::unbounded_channel();
296        let compression = compression.unwrap_or(Compression::None);
297        let encode_opts = RecordEncodeOptions { version: 2, compression };
298        TopicProducer {
299            _client: self.clone(),
300            tx,
301            rx,
302            cluster: self.cluster.clone(),
303            topic: StrBytes::from_string(topic.into()),
304            acks,
305            timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT),
306            encode_opts,
307            buf: BytesMut::with_capacity(1024 * 1024),
308            batch_buf: Vec::with_capacity(1024),
309            last_ptn: -1,
310        }
311    }
312
313    pub fn admin(&self) -> Admin {
315        Admin { _client: self.clone() }
316    }
317}
318
319pub struct Message {
321    pub key: Option<Bytes>,
323    pub value: Option<Bytes>,
325    pub headers: MessageHeaders,
327}
328
329impl Message {
330    pub fn new(key: Option<Bytes>, value: Option<Bytes>, headers: MessageHeaders) -> Self {
332        Self { key, value, headers }
333    }
334}
335
336#[derive(Clone, Copy)]
338#[repr(i16)]
339pub enum Acks {
340    All = -1,
342    None = 0,
344    Leader = 1,
346}
347
348pub struct TopicProducer {
350    _client: Client,
352    tx: mpsc::UnboundedSender<BrokerResponse>,
354    rx: mpsc::UnboundedReceiver<BrokerResponse>,
356    cluster: ClusterMeta,
360    pub topic: StrBytes,
362    acks: Acks,
364    timeout_ms: i32,
366    encode_opts: RecordEncodeOptions,
368
369    buf: BytesMut,
371    batch_buf: Vec<Record>,
373    last_ptn: i32,
375}
376
377impl TopicProducer {
378    pub async fn produce(&mut self, messages: &[Message]) -> ClientResult<(i64, i64)> {
380        if messages.is_empty() {
384            return Err(ClientError::ProducerMessagesEmpty);
385        }
386        self.batch_buf.clear(); let mut cluster = self.cluster.load();
388        if !*cluster.bootstrap.borrow() {
389            let mut sig = cluster.bootstrap.clone();
390            let _ = sig.wait_for(|val| *val).await; cluster = self.cluster.load();
392        }
393        let Some(topic_ptns) = cluster.topics.get(&self.topic) else {
394            return Err(ClientError::UnknownTopic(self.topic.to_string()));
395        };
396
397        let Some((sticky_ptn, sticky_broker)) = topic_ptns
399            .range((self.last_ptn + 1)..)
400            .filter_map(|(ptn, meta)| meta.leader.clone().map(|leader| (ptn, leader)))
401            .next()
402            .or_else(|| topic_ptns.range(..).filter_map(|(ptn, meta)| meta.leader.clone().map(|leader| (ptn, leader))).next())
403            .map(|(key, val)| (*key, val.clone()))
404        else {
405            return Err(ClientError::NoPartitionsAvailable(self.topic.to_string()));
406        };
407        self.last_ptn = sticky_ptn;
408
409        let timestamp = chrono::Utc::now().timestamp_millis();
411        for msg in messages.iter() {
412            self.batch_buf.push(Record {
413                transactional: false,
414                control: false,
415                partition_leader_epoch: 0,
416                producer_id: 0,
417                producer_epoch: 0,
418                timestamp,
419                timestamp_type: TimestampType::Creation,
420                offset: 0,
421                sequence: 0,
422                key: msg.key.clone(),
423                value: msg.value.clone(),
424                headers: msg.headers.clone(),
425            });
426        }
427
428        let res = RecordBatchEncoder::encode(&mut self.buf, self.batch_buf.iter(), &self.encode_opts).map_err(|err| ClientError::EncodingError(format!("{:?}", err)));
430        self.batch_buf.clear();
431        res?;
432
433        let mut req = ProduceRequest::default();
435        req.acks = self.acks as i16;
436        req.timeout_ms = self.timeout_ms;
437        let topic = req.topic_data.entry(self.topic.clone().into()).or_default();
438        let mut ptn_data = PartitionProduceData::default();
439        ptn_data.index = sticky_ptn;
440        ptn_data.records = Some(self.buf.split().freeze());
441        topic.partition_data.push(ptn_data);
442
443        let uid = uuid::Uuid::new_v4();
445        sticky_broker.conn.produce(uid, req, self.tx.clone()).await;
446        let res = loop {
447            let Some(res) = self.rx.recv().await else {
448                unreachable!("both ends of channel are held, receiving None should not be possible")
449            };
450            if res.id == uid {
451                break res;
452            }
453        };
454
455        res.result
458            .map_err(ClientError::BrokerError)
459            .and_then(|res| {
460                if let ResponseKind::ProduceResponse(inner) = res.1 {
462                    Ok(inner)
463                } else {
464                    tracing::error!("expected broker to return a ProduceResponse, got: {:?}", res.1);
465                    Err(ClientError::MalformedResponse)
466                }
467            })
468            .and_then(|res| {
469                res.responses
471                    .iter()
472                    .find(|topic| topic.0 .0 == self.topic)
473                    .and_then(|val| {
474                        val.1.partition_responses.first().map(|val| {
475                            debug_assert!(!messages.is_empty(), "messages len should always be validated at start of function");
476                            let last_offset = val.base_offset + (messages.len() - 1) as i64;
477                            (val.base_offset, last_offset)
478                        })
479                    })
480                    .ok_or(ClientError::MalformedResponse)
481            })
482    }
483}
484
485pub enum ListOffsetsPosition {
487    Earliest,
489    Latest,
491    Timestamp(i64),
493}
494
495pub(crate) async fn unpack_broker_response(rx: oneshot::Receiver<BrokerResponse>) -> ClientResult<(ResponseHeader, ResponseKind)> {
500    rx.await
501        .map_err(|_| ClientError::Other("response channel dropped by broker, which should never happen".into()))?
502        .result
503        .map_err(ClientError::BrokerError)
504}
505
506pub struct Admin {
507    _client: Client,
509}
510
511impl Admin {
512    pub async fn create_topics(&self, request: CreateTopicsRequest) -> ClientResult<CreateTopicsResponse> {
514        if request.topics.is_empty() {
515            return Err(ClientError::NoTopicsSpecified);
516        }
517        let cluster = self._client.get_cluster_metadata_cache().await?;
518        let (tx, rx) = oneshot::channel();
519
520        return if let Some(leader) = &cluster.controller {
521            let uid = uuid::Uuid::new_v4();
522            leader.conn.create_topics(uid, request, tx).await;
523            unpack_broker_response(rx).await.and_then(|(_, res)| {
524                if let ResponseKind::CreateTopicsResponse(inner) = res {
525                    Ok(inner)
526                } else {
527                    Err(ClientError::MalformedResponse)
528                }
529            })
530        } else {
531            Err(ClientError::NoControllerFound)
532        };
533    }
534
535    pub async fn delete_topics(&self, request: DeleteTopicsRequest) -> ClientResult<DeleteTopicsResponse> {
537        if request.topics.is_empty() {
538            return Err(ClientError::NoTopicsSpecified);
539        }
540        let cluster = self._client.get_cluster_metadata_cache().await?;
541        let (tx, rx) = oneshot::channel();
542
543        return if let Some(leader) = &cluster.controller {
544            let uid = uuid::Uuid::new_v4();
545            leader.conn.delete_topics(uid, request, tx).await;
546            unpack_broker_response(rx).await.and_then(|(_, res)| {
547                if let ResponseKind::DeleteTopicsResponse(inner) = res {
548                    Ok(inner)
549                } else {
550                    Err(ClientError::MalformedResponse)
551                }
552            })
553        } else {
554            Err(ClientError::NoControllerFound)
555        };
556    }
557}