use std::marker::PhantomData;
use std::os::raw::c_void;
use std::pin::Pin;
use std::ptr;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use crate::log::trace;
use futures_channel::oneshot;
use futures_util::future::{self, Either, FutureExt};
use futures_util::pin_mut;
use futures_util::stream::{Stream, StreamExt};
use slab::Slab;
use rdkafka_sys as rdsys;
use rdkafka_sys::types::*;
use crate::client::{Client, EventPollResult, NativeQueue};
use crate::config::{ClientConfig, FromClientConfig, FromClientConfigAndContext};
use crate::consumer::base_consumer::{BaseConsumer, PartitionQueue};
use crate::consumer::{
CommitMode, Consumer, ConsumerContext, ConsumerGroupMetadata, DefaultConsumerContext,
RebalanceProtocol,
};
use crate::error::{KafkaError, KafkaResult};
use crate::groups::GroupList;
use crate::message::BorrowedMessage;
use crate::metadata::Metadata;
use crate::topic_partition_list::{Offset, TopicPartitionList};
use crate::util::{AsyncRuntime, DefaultRuntime, Timeout};
unsafe extern "C" fn native_message_queue_nonempty_cb(_: *mut RDKafka, opaque_ptr: *mut c_void) {
let wakers = &*(opaque_ptr as *const WakerSlab);
wakers.wake_all();
}
unsafe fn enable_nonempty_callback(queue: &NativeQueue, wakers: &Arc<WakerSlab>) {
rdsys::rd_kafka_queue_cb_event_enable(
queue.ptr(),
Some(native_message_queue_nonempty_cb),
Arc::as_ptr(wakers) as *mut c_void,
)
}
unsafe fn disable_nonempty_callback(queue: &NativeQueue) {
rdsys::rd_kafka_queue_cb_event_enable(queue.ptr(), None, ptr::null_mut())
}
struct WakerSlab {
wakers: Mutex<Slab<Option<Waker>>>,
}
impl WakerSlab {
fn new() -> WakerSlab {
WakerSlab {
wakers: Mutex::new(Slab::new()),
}
}
fn wake_all(&self) {
let mut wakers = self.wakers.lock().unwrap();
for (_, waker) in wakers.iter_mut() {
if let Some(waker) = waker.take() {
waker.wake();
}
}
}
fn register(&self) -> usize {
let mut wakers = self.wakers.lock().expect("lock poisoned");
wakers.insert(None)
}
fn unregister(&self, slot: usize) {
let mut wakers = self.wakers.lock().expect("lock poisoned");
wakers.remove(slot);
}
fn set_waker(&self, slot: usize, waker: Waker) {
let mut wakers = self.wakers.lock().expect("lock poisoned");
wakers[slot] = Some(waker);
}
}
pub struct MessageStream<'a, C: ConsumerContext> {
wakers: &'a WakerSlab,
consumer: &'a BaseConsumer<C>,
partition_queue: Option<&'a NativeQueue>,
slot: usize,
}
impl<'a, C: ConsumerContext> MessageStream<'a, C> {
fn new(wakers: &'a WakerSlab, consumer: &'a BaseConsumer<C>) -> MessageStream<'a, C> {
Self::new_with_optional_partition_queue(wakers, consumer, None)
}
fn new_with_partition_queue(
wakers: &'a WakerSlab,
consumer: &'a BaseConsumer<C>,
partition_queue: &'a NativeQueue,
) -> MessageStream<'a, C> {
Self::new_with_optional_partition_queue(wakers, consumer, Some(partition_queue))
}
fn new_with_optional_partition_queue(
wakers: &'a WakerSlab,
consumer: &'a BaseConsumer<C>,
partition_queue: Option<&'a NativeQueue>,
) -> MessageStream<'a, C> {
let slot = wakers.register();
MessageStream {
wakers,
consumer,
partition_queue,
slot,
}
}
fn poll(&self) -> EventPollResult<KafkaResult<BorrowedMessage<'a>>> {
if let Some(queue) = self.partition_queue {
self.consumer.poll_queue(queue, Duration::ZERO)
} else {
self.consumer
.poll_queue(self.consumer.get_queue(), Duration::ZERO)
}
}
}
impl<'a, C: ConsumerContext> Stream for MessageStream<'a, C> {
type Item = KafkaResult<BorrowedMessage<'a>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.poll() {
EventPollResult::Event(message) => {
Poll::Ready(Some(message))
}
EventPollResult::EventConsumed => {
cx.waker().wake_by_ref();
Poll::Pending
}
EventPollResult::None => {
self.wakers.set_waker(self.slot, cx.waker().clone());
match self.poll() {
EventPollResult::Event(message) => Poll::Ready(Some(message)),
EventPollResult::EventConsumed => {
cx.waker().wake_by_ref();
Poll::Pending
}
EventPollResult::None => Poll::Pending,
}
}
}
}
}
impl<C: ConsumerContext> Drop for MessageStream<'_, C> {
fn drop(&mut self) {
self.wakers.unregister(self.slot);
}
}
#[must_use = "Consumer polling thread will stop immediately if unused"]
pub struct StreamConsumer<C = DefaultConsumerContext, R = DefaultRuntime>
where
C: ConsumerContext,
{
base: Arc<BaseConsumer<C>>,
wakers: Arc<WakerSlab>,
_shutdown_trigger: oneshot::Sender<()>,
_runtime: PhantomData<R>,
}
impl<R> FromClientConfig for StreamConsumer<DefaultConsumerContext, R>
where
R: AsyncRuntime,
{
fn from_config(config: &ClientConfig) -> KafkaResult<Self> {
StreamConsumer::from_config_and_context(config, DefaultConsumerContext)
}
}
impl<C, R> FromClientConfigAndContext<C> for StreamConsumer<C, R>
where
C: ConsumerContext + 'static,
R: AsyncRuntime,
{
fn from_config_and_context(config: &ClientConfig, context: C) -> KafkaResult<Self> {
let native_config = config.create_native_config()?;
let poll_interval = {
let millis: u64 = native_config
.get("max.poll.interval.ms")?
.trim_end_matches(char::from(0))
.parse()
.expect("librdkafka validated config value is valid u64");
Duration::from_millis(millis)
};
let base = Arc::new(BaseConsumer::new(config, native_config, context)?);
let native_ptr = base.client().native_ptr() as usize;
let wakers = Arc::new(WakerSlab::new());
unsafe { enable_nonempty_callback(base.get_queue(), &wakers) }
let (shutdown_trigger, shutdown_tripwire) = oneshot::channel();
let mut shutdown_tripwire = shutdown_tripwire.fuse();
R::spawn({
let wakers = wakers.clone();
async move {
trace!("Starting stream consumer wake loop: 0x{:x}", native_ptr);
loop {
let delay = R::delay_for(poll_interval / 2).fuse();
pin_mut!(delay);
match future::select(&mut delay, &mut shutdown_tripwire).await {
Either::Left(_) => wakers.wake_all(),
Either::Right(_) => break,
}
}
trace!("Shut down stream consumer wake loop: 0x{:x}", native_ptr);
}
});
Ok(StreamConsumer {
base,
wakers,
_shutdown_trigger: shutdown_trigger,
_runtime: PhantomData,
})
}
}
impl<C, R> StreamConsumer<C, R>
where
C: ConsumerContext + 'static,
{
pub fn stream(&self) -> MessageStream<'_, C> {
MessageStream::new(&self.wakers, &self.base)
}
pub async fn recv(&self) -> Result<BorrowedMessage<'_>, KafkaError> {
self.stream()
.next()
.await
.expect("kafka streams never terminate")
}
pub fn split_partition_queue(
self: &Arc<Self>,
topic: &str,
partition: i32,
) -> Option<StreamPartitionQueue<C, R>> {
self.base
.split_partition_queue(topic, partition)
.map(|queue| {
let wakers = Arc::new(WakerSlab::new());
unsafe { enable_nonempty_callback(&queue.queue, &wakers) };
StreamPartitionQueue {
queue,
wakers,
_consumer: self.clone(),
}
})
}
}
impl<C, R> Consumer<C> for StreamConsumer<C, R>
where
C: ConsumerContext,
{
fn client(&self) -> &Client<C> {
self.base.client()
}
fn group_metadata(&self) -> Option<ConsumerGroupMetadata> {
self.base.group_metadata()
}
fn subscribe(&self, topics: &[&str]) -> KafkaResult<()> {
self.base.subscribe(topics)
}
fn unsubscribe(&self) {
self.base.unsubscribe();
}
fn assign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
self.base.assign(assignment)
}
fn unassign(&self) -> KafkaResult<()> {
self.base.unassign()
}
fn incremental_assign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
self.base.incremental_assign(assignment)
}
fn incremental_unassign(&self, assignment: &TopicPartitionList) -> KafkaResult<()> {
self.base.incremental_unassign(assignment)
}
fn assignment_lost(&self) -> bool {
self.base.assignment_lost()
}
fn seek<T: Into<Timeout>>(
&self,
topic: &str,
partition: i32,
offset: Offset,
timeout: T,
) -> KafkaResult<()> {
self.base.seek(topic, partition, offset, timeout)
}
fn seek_partitions<T: Into<Timeout>>(
&self,
topic_partition_list: TopicPartitionList,
timeout: T,
) -> KafkaResult<TopicPartitionList> {
self.base.seek_partitions(topic_partition_list, timeout)
}
fn commit(
&self,
topic_partition_list: &TopicPartitionList,
mode: CommitMode,
) -> KafkaResult<()> {
self.base.commit(topic_partition_list, mode)
}
fn commit_consumer_state(&self, mode: CommitMode) -> KafkaResult<()> {
self.base.commit_consumer_state(mode)
}
fn commit_message(&self, message: &BorrowedMessage<'_>, mode: CommitMode) -> KafkaResult<()> {
self.base.commit_message(message, mode)
}
fn store_offset(&self, topic: &str, partition: i32, offset: i64) -> KafkaResult<()> {
self.base.store_offset(topic, partition, offset)
}
fn store_offset_from_message(&self, message: &BorrowedMessage<'_>) -> KafkaResult<()> {
self.base.store_offset_from_message(message)
}
fn store_offsets(&self, tpl: &TopicPartitionList) -> KafkaResult<()> {
self.base.store_offsets(tpl)
}
fn subscription(&self) -> KafkaResult<TopicPartitionList> {
self.base.subscription()
}
fn assignment(&self) -> KafkaResult<TopicPartitionList> {
self.base.assignment()
}
fn committed<T>(&self, timeout: T) -> KafkaResult<TopicPartitionList>
where
T: Into<Timeout>,
Self: Sized,
{
self.base.committed(timeout)
}
fn committed_offsets<T>(
&self,
tpl: TopicPartitionList,
timeout: T,
) -> KafkaResult<TopicPartitionList>
where
T: Into<Timeout>,
{
self.base.committed_offsets(tpl, timeout)
}
fn offsets_for_timestamp<T>(
&self,
timestamp: i64,
timeout: T,
) -> KafkaResult<TopicPartitionList>
where
T: Into<Timeout>,
Self: Sized,
{
self.base.offsets_for_timestamp(timestamp, timeout)
}
fn offsets_for_times<T>(
&self,
timestamps: TopicPartitionList,
timeout: T,
) -> KafkaResult<TopicPartitionList>
where
T: Into<Timeout>,
Self: Sized,
{
self.base.offsets_for_times(timestamps, timeout)
}
fn position(&self) -> KafkaResult<TopicPartitionList> {
self.base.position()
}
fn fetch_metadata<T>(&self, topic: Option<&str>, timeout: T) -> KafkaResult<Metadata>
where
T: Into<Timeout>,
Self: Sized,
{
self.base.fetch_metadata(topic, timeout)
}
fn fetch_watermarks<T>(
&self,
topic: &str,
partition: i32,
timeout: T,
) -> KafkaResult<(i64, i64)>
where
T: Into<Timeout>,
Self: Sized,
{
self.base.fetch_watermarks(topic, partition, timeout)
}
fn fetch_group_list<T>(&self, group: Option<&str>, timeout: T) -> KafkaResult<GroupList>
where
T: Into<Timeout>,
Self: Sized,
{
self.base.fetch_group_list(group, timeout)
}
fn pause(&self, partitions: &TopicPartitionList) -> KafkaResult<()> {
self.base.pause(partitions)
}
fn resume(&self, partitions: &TopicPartitionList) -> KafkaResult<()> {
self.base.resume(partitions)
}
fn rebalance_protocol(&self) -> RebalanceProtocol {
self.base.rebalance_protocol()
}
}
pub struct StreamPartitionQueue<C, R = DefaultRuntime>
where
C: ConsumerContext,
{
queue: PartitionQueue<C>,
wakers: Arc<WakerSlab>,
_consumer: Arc<StreamConsumer<C, R>>,
}
impl<C, R> StreamPartitionQueue<C, R>
where
C: ConsumerContext,
{
pub fn stream(&self) -> MessageStream<'_, C> {
MessageStream::new_with_partition_queue(
&self.wakers,
&self._consumer.base,
&self.queue.queue,
)
}
pub async fn recv(&self) -> Result<BorrowedMessage<'_>, KafkaError> {
self.stream()
.next()
.await
.expect("kafka streams never terminate")
}
}
impl<C, R> Drop for StreamPartitionQueue<C, R>
where
C: ConsumerContext,
{
fn drop(&mut self) {
unsafe { disable_nonempty_callback(&self.queue.queue) }
}
}