use crate::{
errors::{DanubeError, Result},
retry_manager::RetryManager,
topic_consumer::TopicConsumer,
DanubeClient,
};
use danube_core::message::StreamMessage;
use futures::{future::join_all, StreamExt};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use tracing::{error, info, warn};
const RECEIVE_CHANNEL_BUFFER: usize = 100;
const GRACEFUL_CLOSE_DELAY_MS: u64 = 100;
#[derive(Debug, Clone)]
pub enum SubType {
Exclusive,
Shared,
FailOver,
KeyShared,
}
#[derive(Debug)]
pub struct Consumer {
client: DanubeClient,
topic_name: String,
consumer_name: String,
consumers: HashMap<String, Arc<Mutex<TopicConsumer>>>,
subscription: String,
subscription_type: SubType,
consumer_options: ConsumerOptions,
key_filters: Vec<String>,
shutdown: Arc<AtomicBool>,
task_handles: Vec<JoinHandle<()>>,
}
impl Consumer {
pub(crate) fn new(
client: DanubeClient,
topic_name: String,
consumer_name: String,
subscription: String,
sub_type: Option<SubType>,
consumer_options: ConsumerOptions,
key_filters: Vec<String>,
) -> Self {
let subscription_type = sub_type.unwrap_or(SubType::Shared);
Consumer {
client,
topic_name,
consumer_name,
consumers: HashMap::new(),
subscription,
subscription_type,
consumer_options,
key_filters,
shutdown: Arc::new(AtomicBool::new(false)),
task_handles: Vec::new(),
}
}
pub async fn subscribe(&mut self) -> Result<()> {
let partitions = self
.client
.lookup_service
.topic_partitions(&self.client.uri, &self.topic_name)
.await?;
let mut tasks = Vec::new();
for topic_partition in partitions {
let topic_name = topic_partition.clone();
let consumer_name = self.consumer_name.clone();
let subscription = self.subscription.clone();
let subscription_type = self.subscription_type.clone();
let consumer_options = self.consumer_options.clone();
let key_filters = self.key_filters.clone();
let client = self.client.clone();
let task = tokio::spawn(async move {
let mut topic_consumer = TopicConsumer::new(
client,
topic_name,
consumer_name,
subscription,
Some(subscription_type),
key_filters,
consumer_options,
);
match topic_consumer.subscribe().await {
Ok(_) => Ok(topic_consumer),
Err(e) => Err(e),
}
});
tasks.push(task);
}
let results = join_all(tasks).await;
let mut topic_consumers = HashMap::new();
for result in results {
match result {
Ok(Ok(consumer)) => {
topic_consumers.insert(
consumer.get_topic_name().to_string(),
Arc::new(Mutex::new(consumer)),
);
}
Ok(Err(e)) => return Err(e),
Err(e) => return Err(DanubeError::Unrecoverable(e.to_string())),
}
}
if topic_consumers.is_empty() {
return Err(DanubeError::Unrecoverable(
"No partitions found".to_string(),
));
}
self.consumers.extend(topic_consumers.into_iter());
Ok(())
}
pub async fn receive(&mut self) -> Result<mpsc::Receiver<StreamMessage>> {
let (tx, rx) = mpsc::channel(RECEIVE_CHANNEL_BUFFER);
let retry_manager = RetryManager::new(
self.consumer_options.max_retries,
self.consumer_options.base_backoff_ms,
self.consumer_options.max_backoff_ms,
);
for (_, consumer) in &self.consumers {
let broker_stop = {
let locked = consumer.lock().await;
Arc::clone(&locked.stop_signal)
};
let handle = tokio::spawn(partition_receive_loop(
Arc::clone(consumer),
tx.clone(),
retry_manager.clone(),
self.shutdown.clone(),
broker_stop,
));
self.task_handles.push(handle);
}
Ok(rx)
}
pub async fn ack(&mut self, message: &StreamMessage) -> Result<()> {
let topic_name = message.msg_id.topic_name.clone();
let topic_consumer = self.consumers.get_mut(&topic_name);
if let Some(topic_consumer) = topic_consumer {
let mut topic_consumer = topic_consumer.lock().await;
let _ = topic_consumer
.send_ack(
message.request_id,
message.msg_id.clone(),
&self.subscription,
)
.await?;
}
Ok(())
}
pub async fn nack(
&mut self,
message: &StreamMessage,
delay_ms: Option<u64>,
reason: Option<String>,
) -> Result<()> {
let topic_name = message.msg_id.topic_name.clone();
let topic_consumer = self.consumers.get_mut(&topic_name);
if let Some(topic_consumer) = topic_consumer {
let mut topic_consumer = topic_consumer.lock().await;
let _ = topic_consumer
.send_nack(
message.request_id,
message.msg_id.clone(),
&self.subscription,
delay_ms,
reason,
)
.await?;
}
Ok(())
}
pub async fn close(&mut self) {
self.shutdown.store(true, Ordering::SeqCst);
for (_, topic_consumer) in self.consumers.iter() {
let locked = topic_consumer.lock().await;
locked.stop();
}
for handle in self.task_handles.drain(..) {
handle.abort();
}
tokio::time::sleep(std::time::Duration::from_millis(GRACEFUL_CLOSE_DELAY_MS)).await;
}
}
async fn partition_receive_loop(
consumer: Arc<Mutex<TopicConsumer>>,
tx: mpsc::Sender<StreamMessage>,
retry_manager: RetryManager,
shutdown: Arc<AtomicBool>,
broker_stop: Arc<AtomicBool>,
) {
let mut attempts = 0;
loop {
if shutdown.load(Ordering::SeqCst) {
return;
}
let stream_result = {
let mut locked = consumer.lock().await;
locked.receive().await
};
match stream_result {
Ok(mut stream) => {
attempts = 0;
while !shutdown.load(Ordering::SeqCst) && !broker_stop.load(Ordering::Relaxed) {
match stream.next().await {
Some(Ok(stream_message)) => {
let message: StreamMessage = stream_message.into();
if tx.send(message).await.is_err() {
return; }
}
Some(Err(e)) => {
warn!(error = %e, "error receiving message");
break; }
None => break, }
}
if shutdown.load(Ordering::SeqCst) {
return;
}
if broker_stop.load(Ordering::Relaxed) {
broker_stop.store(false, Ordering::Relaxed);
warn!("broker signaled topic close, triggering resubscription");
match resubscribe(&consumer).await {
Ok(_) => {
info!("resubscription successful after broker close signal");
continue;
}
Err(e) => {
error!(error = ?e, "resubscription failed after broker close signal");
return;
}
}
}
}
Err(ref error) if matches!(error, DanubeError::Unrecoverable(_)) => {
if shutdown.load(Ordering::SeqCst) {
return;
}
warn!(error = ?error, "unrecoverable error, attempting resubscription");
match resubscribe(&consumer).await {
Ok(_) => {
info!("resubscription successful after unrecoverable error");
attempts = 0;
continue;
}
Err(e) => {
error!(error = ?e, "resubscription failed after unrecoverable error");
return;
}
}
}
Err(error) if retry_manager.is_retryable_error(&error) => {
if shutdown.load(Ordering::SeqCst) {
return;
}
attempts += 1;
if attempts > retry_manager.max_retries() {
warn!("max retries exceeded, attempting resubscription");
match resubscribe(&consumer).await {
Ok(_) => {
info!("resubscription successful");
attempts = 0;
continue;
}
Err(e) => {
error!(error = ?e, "resubscription failed");
return;
}
}
}
let backoff = retry_manager.calculate_backoff(attempts - 1);
tokio::time::sleep(backoff).await;
}
Err(error) => {
error!(error = ?error, "non-retryable error in consumer receive");
return;
}
}
}
}
async fn resubscribe(consumer: &Arc<Mutex<TopicConsumer>>) -> Result<()> {
let mut locked = consumer.lock().await;
locked.subscribe().await?;
Ok(())
}
#[derive(Debug, Clone)]
pub struct ConsumerBuilder {
client: DanubeClient,
topic: Option<String>,
consumer_name: Option<String>,
subscription: Option<String>,
subscription_type: Option<SubType>,
consumer_options: ConsumerOptions,
key_filters: Vec<String>,
}
impl ConsumerBuilder {
pub fn new(client: &DanubeClient) -> Self {
ConsumerBuilder {
client: client.clone(),
topic: None,
consumer_name: None,
subscription: None,
subscription_type: None,
consumer_options: ConsumerOptions::default(),
key_filters: Vec::new(),
}
}
pub fn with_topic(mut self, topic: impl Into<String>) -> Self {
self.topic = Some(topic.into());
self
}
pub fn with_consumer_name(mut self, consumer_name: impl Into<String>) -> Self {
self.consumer_name = Some(consumer_name.into());
self
}
pub fn with_subscription(mut self, subscription_name: impl Into<String>) -> Self {
self.subscription = Some(subscription_name.into());
self
}
pub fn with_subscription_type(mut self, subscription_type: SubType) -> Self {
self.subscription_type = Some(subscription_type);
self
}
pub fn with_options(mut self, options: ConsumerOptions) -> Self {
self.consumer_options = options;
self
}
pub fn with_key_filter(mut self, pattern: impl Into<String>) -> Self {
self.key_filters.push(pattern.into());
self
}
pub fn with_key_filters(mut self, patterns: Vec<String>) -> Self {
self.key_filters.extend(patterns);
self
}
pub fn build(self) -> Result<Consumer> {
let topic = self.topic.ok_or_else(|| {
DanubeError::Unrecoverable("topic is required to build a Consumer".into())
})?;
let consumer_name = self.consumer_name.ok_or_else(|| {
DanubeError::Unrecoverable("consumer name is required to build a Consumer".into())
})?;
let subscription = self.subscription.ok_or_else(|| {
DanubeError::Unrecoverable("subscription is required to build a Consumer".into())
})?;
Ok(Consumer::new(
self.client,
topic,
consumer_name,
subscription,
self.subscription_type,
self.consumer_options,
self.key_filters,
))
}
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct ConsumerOptions {
pub max_retries: usize,
pub base_backoff_ms: u64,
pub max_backoff_ms: u64,
}
impl ConsumerOptions {
pub fn new(max_retries: usize, base_backoff_ms: u64, max_backoff_ms: u64) -> Self {
Self {
max_retries,
base_backoff_ms,
max_backoff_ms,
}
}
}