1use std::marker::PhantomData;
4use std::os::raw::c_void;
5use std::pin::Pin;
6use std::ptr;
7use std::sync::{Arc, Mutex};
8use std::task::{Context, Poll, Waker};
9use std::time::Duration;
10
11use crate::log::trace;
12use futures_channel::oneshot;
13use futures_util::future::{self, Either, FutureExt};
14use futures_util::pin_mut;
15use futures_util::stream::{Stream, StreamExt};
16use slab::Slab;
17
18use rdkafka_sys as rdsys;
19use rdkafka_sys::types::*;
20
21use crate::client::{Client, EventPollResult, NativeQueue};
22use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
23use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue};
24use crate::consumer::{
25 CommitMode, Consumer, ConsumerContext, ConsumerGroupMetadata, DefaultConsumerContext,
26 RebalanceProtocol,
27};
28use crate::error::{KafkaError, KafkaResult};
29use crate::groups::GroupList;
30use crate::message::BorrowedMessage;
31use crate::metadata::Metadata;
32use crate::topic_partition_list::{Offset, TopicPartitionList};
33use crate::util::{AsyncRuntime, DefaultRuntime, Timeout};
34
35unsafe extern "C" fn native_message_queue_nonempty_cb(_: *mut RDKafka, opaque_ptr: *mut c_void) {
36 let wakers = &*(opaque_ptr as *const WakerSlab);
37 wakers.wake_all();
38}
39
40unsafe fn enable_nonempty_callback(queue: &NativeQueue, wakers: &Arc<WakerSlab>) {
41 rdsys::rd_kafka_queue_cb_event_enable(
42 queue.ptr(),
43 Some(native_message_queue_nonempty_cb),
44 Arc::as_ptr(wakers) as *mut c_void,
45 )
46}
47
48unsafe fn disable_nonempty_callback(queue: &NativeQueue) {
49 rdsys::rd_kafka_queue_cb_event_enable(queue.ptr(), None, ptr::null_mut())
50}
51
52struct WakerSlab {
53 wakers: Mutex<Slab<Option<Waker>>>,
54}
55
56impl WakerSlab {
57 fn new() -> WakerSlab {
58 WakerSlab {
59 wakers: Mutex::new(Slab::new()),
60 }
61 }
62
63 fn wake_all(&self) {
64 let mut wakers = self.wakers.lock().unwrap();
65 for (_, waker) in wakers.iter_mut() {
66 if let Some(waker) = waker.take() {
67 waker.wake();
68 }
69 }
70 }
71
72 fn register(&self) -> usize {
73 let mut wakers = self.wakers.lock().expect("lock poisoned");
74 wakers.insert(None)
75 }
76
77 fn unregister(&self, slot: usize) {
78 let mut wakers = self.wakers.lock().expect("lock poisoned");
79 wakers.remove(slot);
80 }
81
82 fn set_waker(&self, slot: usize, waker: Waker) {
83 let mut wakers = self.wakers.lock().expect("lock poisoned");
84 wakers[slot] = Some(waker);
85 }
86}
87
88pub struct MessageStream<'a, C: ConsumerContext> {
92 wakers: &'a WakerSlab,
93 consumer: &'a BaseConsumer<C>,
94 partition_queue: Option<&'a NativeQueue>,
95 slot: usize,
96}
97
98impl<'a, C: ConsumerContext> MessageStream<'a, C> {
99 fn new(wakers: &'a WakerSlab, consumer: &'a BaseConsumer<C>) -> MessageStream<'a, C> {
100 Self::new_with_optional_partition_queue(wakers, consumer, None)
101 }
102
103 fn new_with_partition_queue(
104 wakers: &'a WakerSlab,
105 consumer: &'a BaseConsumer<C>,
106 partition_queue: &'a NativeQueue,
107 ) -> MessageStream<'a, C> {
108 Self::new_with_optional_partition_queue(wakers, consumer, Some(partition_queue))
109 }
110
111 fn new_with_optional_partition_queue(
112 wakers: &'a WakerSlab,
113 consumer: &'a BaseConsumer<C>,
114 partition_queue: Option<&'a NativeQueue>,
115 ) -> MessageStream<'a, C> {
116 let slot = wakers.register();
117 MessageStream {
118 wakers,
119 consumer,
120 partition_queue,
121 slot,
122 }
123 }
124
125 fn poll(&self) -> EventPollResult<KafkaResult<BorrowedMessage<'a>>> {
126 if let Some(queue) = self.partition_queue {
127 self.consumer.poll_queue(queue, Duration::ZERO)
128 } else {
129 self.consumer
130 .poll_queue(self.consumer.get_queue(), Duration::ZERO)
131 }
132 }
133}
134
135impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {
136 type Item = KafkaResult<BorrowedMessage<'a>>;
137
138 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
139 match self.poll() {
140 EventPollResult::Event(message) => {
141 Poll::Ready(Some(message))
144 }
145 EventPollResult::EventConsumed => {
146 cx.waker().wake_by_ref();
148 Poll::Pending
149 }
150 EventPollResult::None => {
151 self.wakers.set_waker(self.slot, cx.waker().clone());
156
157 match self.poll() {
162 EventPollResult::Event(message) => Poll::Ready(Some(message)),
163 EventPollResult::EventConsumed => {
164 cx.waker().wake_by_ref();
166 Poll::Pending
167 }
168 EventPollResult::None => Poll::Pending,
169 }
170 }
171 }
172 }
173}
174
175impl<C: ConsumerContext> Drop for MessageStream<'_, C> {
176 fn drop(&mut self) {
177 self.wakers.unregister(self.slot);
178 }
179}
180
181#[must_use = "Consumer polling thread will stop immediately if unused"]
196pub struct StreamConsumer<C = DefaultConsumerContext, R = DefaultRuntime>
197where
198 C: ConsumerContext,
199{
200 base: Arc<BaseConsumer<C>>,
201 wakers: Arc<WakerSlab>,
202 _shutdown_trigger: oneshot::Sender<()>,
203 _runtime: PhantomData<R>,
204}
205
206impl<R> FromClientConfig for StreamConsumer<DefaultConsumerContext, R>
207where
208 R: AsyncRuntime,
209{
210 fn from_config(config: &ClientConfig) -> KafkaResult<Self> {
211 StreamConsumer::from_config_and_context(config, DefaultConsumerContext)
212 }
213}
214
215impl<C, R> FromClientConfigAndContext<C> for StreamConsumer<C, R>
217where
218 C: ConsumerContext + 'static,
219 R: AsyncRuntime,
220{
221 fn from_config_and_context(config: &ClientConfig, context: C) -> KafkaResult<Self> {
222 let native_config = config.create_native_config()?;
223 let poll_interval = {
224 let millis: u64 = native_config
225 .get("max.poll.interval.ms")?
226 .trim_end_matches(char::from(0))
227 .parse()
228 .expect("librdkafka validated config value is valid u64");
229 Duration::from_millis(millis)
230 };
231
232 let base = Arc::new(BaseConsumer::new(config, native_config, context)?);
233 let native_ptr = base.client().native_ptr() as usize;
234
235 let wakers = Arc::new(WakerSlab::new());
236 unsafe { enable_nonempty_callback(base.get_queue(), &wakers) }
237
238 let (shutdown_trigger, shutdown_tripwire) = oneshot::channel();
248 let mut shutdown_tripwire = shutdown_tripwire.fuse();
249 R::spawn({
250 let wakers = wakers.clone();
251 async move {
252 trace!("Starting stream consumer wake loop: 0x{:x}", native_ptr);
253 loop {
254 let delay = R::delay_for(poll_interval / 2).fuse();
255 pin_mut!(delay);
256 match future::select(&mut delay, &mut shutdown_tripwire).await {
257 Either::Left(_) => wakers.wake_all(),
258 Either::Right(_) => break,
259 }
260 }
261 trace!("Shut down stream consumer wake loop: 0x{:x}", native_ptr);
262 }
263 });
264
265 Ok(StreamConsumer {
266 base,
267 wakers,
268 _shutdown_trigger: shutdown_trigger,
269 _runtime: PhantomData,
270 })
271 }
272}
273
274impl<C, R> StreamConsumer<C, R>
275where
276 C: ConsumerContext + 'static,
277{
278 pub fn stream(&self) -> MessageStream<'_, C> {
291 MessageStream::new(&self.wakers, &self.base)
292 }
293
294 pub async fn recv(&self) -> Result<BorrowedMessage<'_>, KafkaError> {
315 self.stream()
316 .next()
317 .await
318 .expect("kafka streams never terminate")
319 }
320
321 pub fn split_partition_queue(
356 self: &Arc<Self>,
357 topic: &str,
358 partition: i32,
359 ) -> Option<StreamPartitionQueue<C, R>> {
360 self.base
361 .split_partition_queue(topic, partition)
362 .map(|queue| {
363 let wakers = Arc::new(WakerSlab::new());
364 unsafe { enable_nonempty_callback(&queue.queue, &wakers) };
365 StreamPartitionQueue {
366 queue,
367 wakers,
368 _consumer: self.clone(),
369 }
370 })
371 }
372}
373
374impl<C, R> Consumer<C> for StreamConsumer<C, R>
375where
376 C: ConsumerContext,
377{
378 fn client(&self) -> &Client<C> {
379 self.base.client()
380 }
381
382 fn group_metadata(&self) -> Option<ConsumerGroupMetadata> {
383 self.base.group_metadata()
384 }
385
386 fn subscribe(&self, topics: &[&str]) -> KafkaResult<()> {
387 self.base.subscribe(topics)
388 }
389
390 fn unsubscribe(&self) {
391 self.base.unsubscribe();
392 }
393
394 fn assign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
395 self.base.assign(assignment)
396 }
397
398 fn unassign(&self) -> KafkaResult<()> {
399 self.base.unassign()
400 }
401
402 fn incremental_assign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
403 self.base.incremental_assign(assignment)
404 }
405
406 fn incremental_unassign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
407 self.base.incremental_unassign(assignment)
408 }
409
410 fn assignment_lost(&self) -> bool {
411 self.base.assignment_lost()
412 }
413
414 fn seek<T: Into<Timeout>>(
415 &self,
416 topic: &str,
417 partition: i32,
418 offset: Offset,
419 timeout: T,
420 ) -> KafkaResult<()> {
421 self.base.seek(topic, partition, offset, timeout)
422 }
423
424 fn seek_partitions<T: Into<Timeout>>(
425 &self,
426 topic_partition_list: TopicPartitionList,
427 timeout: T,
428 ) -> KafkaResult<TopicPartitionList> {
429 self.base.seek_partitions(topic_partition_list, timeout)
430 }
431
432 fn commit(
433 &self,
434 topic_partition_list: &TopicPartitionList,
435 mode: CommitMode,
436 ) -> KafkaResult<()> {
437 self.base.commit(topic_partition_list, mode)
438 }
439
440 fn commit_consumer_state(&self, mode: CommitMode) -> KafkaResult<()> {
441 self.base.commit_consumer_state(mode)
442 }
443
444 fn commit_message(&self, message: &BorrowedMessage<'_>, mode: CommitMode) -> KafkaResult<()> {
445 self.base.commit_message(message, mode)
446 }
447
448 fn store_offset(&self, topic: &str, partition: i32, offset: i64) -> KafkaResult<()> {
449 self.base.store_offset(topic, partition, offset)
450 }
451
452 fn store_offset_from_message(&self, message: &BorrowedMessage<'_>) -> KafkaResult<()> {
453 self.base.store_offset_from_message(message)
454 }
455
456 fn store_offsets(&self, tpl: &TopicPartitionList) -> KafkaResult<()> {
457 self.base.store_offsets(tpl)
458 }
459
460 fn subscription(&self) -> KafkaResult<TopicPartitionList> {
461 self.base.subscription()
462 }
463
464 fn assignment(&self) -> KafkaResult<TopicPartitionList> {
465 self.base.assignment()
466 }
467
468 fn committed<T>(&self, timeout: T) -> KafkaResult<TopicPartitionList>
469 where
470 T: Into<Timeout>,
471 Self: Sized,
472 {
473 self.base.committed(timeout)
474 }
475
476 fn committed_offsets<T>(
477 &self,
478 tpl: TopicPartitionList,
479 timeout: T,
480 ) -> KafkaResult<TopicPartitionList>
481 where
482 T: Into<Timeout>,
483 {
484 self.base.committed_offsets(tpl, timeout)
485 }
486
487 fn offsets_for_timestamp<T>(
488 &self,
489 timestamp: i64,
490 timeout: T,
491 ) -> KafkaResult<TopicPartitionList>
492 where
493 T: Into<Timeout>,
494 Self: Sized,
495 {
496 self.base.offsets_for_timestamp(timestamp, timeout)
497 }
498
499 fn offsets_for_times<T>(
500 &self,
501 timestamps: TopicPartitionList,
502 timeout: T,
503 ) -> KafkaResult<TopicPartitionList>
504 where
505 T: Into<Timeout>,
506 Self: Sized,
507 {
508 self.base.offsets_for_times(timestamps, timeout)
509 }
510
511 fn position(&self) -> KafkaResult<TopicPartitionList> {
512 self.base.position()
513 }
514
515 fn fetch_metadata<T>(&self, topic: Option<&str>, timeout: T) -> KafkaResult<Metadata>
516 where
517 T: Into<Timeout>,
518 Self: Sized,
519 {
520 self.base.fetch_metadata(topic, timeout)
521 }
522
523 fn fetch_watermarks<T>(
524 &self,
525 topic: &str,
526 partition: i32,
527 timeout: T,
528 ) -> KafkaResult<(i64, i64)>
529 where
530 T: Into<Timeout>,
531 Self: Sized,
532 {
533 self.base.fetch_watermarks(topic, partition, timeout)
534 }
535
536 fn fetch_group_list<T>(&self, group: Option<&str>, timeout: T) -> KafkaResult<GroupList>
537 where
538 T: Into<Timeout>,
539 Self: Sized,
540 {
541 self.base.fetch_group_list(group, timeout)
542 }
543
544 fn pause(&self, partitions: &TopicPartitionList) -> KafkaResult<()> {
545 self.base.pause(partitions)
546 }
547
548 fn resume(&self, partitions: &TopicPartitionList) -> KafkaResult<()> {
549 self.base.resume(partitions)
550 }
551
552 fn rebalance_protocol(&self) -> RebalanceProtocol {
553 self.base.rebalance_protocol()
554 }
555}
556
557pub struct StreamPartitionQueue<C, R = DefaultRuntime>
562where
563 C: ConsumerContext,
564{
565 queue: PartitionQueue<C>,
566 wakers: Arc<WakerSlab>,
567 _consumer: Arc<StreamConsumer<C, R>>,
568}
569
570impl<C, R> StreamPartitionQueue<C, R>
571where
572 C: ConsumerContext,
573{
574 pub fn stream(&self) -> MessageStream<'_, C> {
587 MessageStream::new_with_partition_queue(
588 &self.wakers,
589 &self._consumer.base,
590 &self.queue.queue,
591 )
592 }
593
594 pub async fn recv(&self) -> Result<BorrowedMessage<'_>, KafkaError> {
618 self.stream()
619 .next()
620 .await
621 .expect("kafka streams never terminate")
622 }
623}
624
625impl<C, R> Drop for StreamPartitionQueue<C, R>
626where
627 C: ConsumerContext,
628{
629 fn drop(&mut self) {
630 unsafe { disable_nonempty_callback(&self.queue.queue) }
631 }
632}