#![cfg(feature = "rabbitmq")]
use std::ops::RangeInclusive;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use crate::backend::ConsumerOptionsInner as ConsumerOptions;
use crate::backends::rabbitmq::client::RabbitMqClient;
use crate::backends::rabbitmq::consumer::RabbitMqConsumer;
use crate::consumer::{HandlerTimeoutConfig, resolve_handler_timeout};
use crate::consumer_supervisor::ShutdownTally;
use crate::handler::MessageHandler;
use crate::topic::{SequencedTopic, Topic};
use crate::{DEFAULT_MAX_MESSAGE_SIZE, DEFAULT_MAX_PENDING_PER_KEY};
type Spawner = Arc<dyn Fn(ConsumerOptions) -> JoinHandle<()> + Send + Sync>;
#[derive(Clone)]
pub struct ConsumerGroupConfig {
pub(crate) prefetch_count: u16,
pub(crate) min_consumers: u16,
pub(crate) max_consumers: u16,
pub(crate) max_retries: u32,
pub(crate) handler_timeout: HandlerTimeoutConfig,
pub(crate) concurrent_processing: bool,
pub(crate) max_pending_per_key: Option<usize>,
pub(crate) max_message_size: Option<usize>,
}
impl ConsumerGroupConfig {
pub fn new(range: RangeInclusive<u16>) -> Self {
let min = *range.start();
let max = *range.end();
assert!(
min <= max,
"min_consumers ({min}) must be <= max_consumers ({max})"
);
Self {
prefetch_count: 10,
min_consumers: min,
max_consumers: max,
max_retries: 10,
handler_timeout: HandlerTimeoutConfig::Inherit,
concurrent_processing: false,
max_pending_per_key: Some(DEFAULT_MAX_PENDING_PER_KEY),
max_message_size: Some(DEFAULT_MAX_MESSAGE_SIZE),
}
}
pub fn with_prefetch_count(mut self, prefetch_count: u16) -> Self {
self.prefetch_count = prefetch_count;
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_handler_timeout(mut self, timeout: Duration) -> Self {
assert!(!timeout.is_zero(), "handler_timeout must be positive");
self.handler_timeout = HandlerTimeoutConfig::Set(timeout);
self
}
pub fn prefetch_count(&self) -> u16 {
self.prefetch_count
}
pub fn min_consumers(&self) -> u16 {
self.min_consumers
}
pub fn max_consumers(&self) -> u16 {
self.max_consumers
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn handler_timeout(&self) -> Option<Duration> {
Some(resolve_handler_timeout(self.handler_timeout, None))
}
pub fn with_concurrent_processing(mut self, concurrent: bool) -> Self {
self.concurrent_processing = concurrent;
self
}
pub fn concurrent_processing(&self) -> bool {
self.concurrent_processing
}
}
impl Default for ConsumerGroupConfig {
fn default() -> Self {
Self::new(1..=4)
}
}
pub struct ConsumerGroup {
name: String,
queue: String,
config: ConsumerGroupConfig,
spawner: Spawner,
consumers: Vec<(CancellationToken, Arc<AtomicBool>, JoinHandle<()>)>,
group_token: CancellationToken,
error_count: Arc<AtomicUsize>,
panic_count: Arc<AtomicUsize>,
}
impl ConsumerGroup {
pub fn new<T, H>(
name: impl Into<String>,
queue: impl Into<String>,
config: ConsumerGroupConfig,
client: RabbitMqClient,
group_token: CancellationToken,
handler_factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Self
where
T: Topic + 'static,
H: MessageHandler<T> + 'static,
{
let concurrent = config.concurrent_processing;
let error_count = Arc::new(AtomicUsize::new(0));
let ec_for_spawner = error_count.clone();
let spawner: Spawner = Arc::new(move |options: ConsumerOptions| {
let handler = handler_factory();
let consumer = RabbitMqConsumer::new(client.clone());
let options = if concurrent {
options
} else {
ConsumerOptions {
prefetch_count: 1,
..options
}
};
let ec = ec_for_spawner.clone();
let ctx = ctx.clone();
tokio::spawn(async move {
let result = consumer.run_with_inner::<T, H>(handler, ctx, options).await;
if let Err(e) = result {
ec.fetch_add(1, Ordering::Relaxed);
tracing::error!("consumer task exited with error: {e}");
}
})
});
Self {
name: name.into(),
queue: queue.into(),
consumers: Vec::with_capacity(config.max_consumers as usize),
config,
spawner,
group_token,
error_count,
panic_count: Arc::new(AtomicUsize::new(0)),
}
}
pub fn new_fifo<T, H>(
queue: impl Into<String>,
client: RabbitMqClient,
mut config: ConsumerGroupConfig,
group_token: CancellationToken,
handler_factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Self
where
T: SequencedTopic + 'static,
H: MessageHandler<T> + 'static,
{
let error_count = Arc::new(AtomicUsize::new(0));
let panic_count = Arc::new(AtomicUsize::new(0));
let ec_for_spawner = error_count.clone();
let pc_for_spawner = panic_count.clone();
config.min_consumers = 1;
config.max_consumers = 1;
let spawner: Spawner = Arc::new(move |options: ConsumerOptions| {
let handler = handler_factory();
let consumer = RabbitMqConsumer::new(client.clone());
let ec = ec_for_spawner.clone();
let pc = pc_for_spawner.clone();
let ctx = ctx.clone();
tokio::spawn(async move {
let handles = match consumer.spawn_fifo_shards::<T, H>(handler, ctx, options) {
Ok(h) => h,
Err(e) => {
ec.fetch_add(1, Ordering::Relaxed);
tracing::error!("FIFO registration failed: {e}");
return;
}
};
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
ec.fetch_add(1, Ordering::Relaxed);
tracing::error!("sequenced shard exited with error: {e}");
}
Err(e) if e.is_cancelled() => {}
Err(e) => {
pc.fetch_add(1, Ordering::Relaxed);
tracing::error!("sequenced shard panicked: {e}");
}
}
}
})
});
let queue_str: String = queue.into();
Self {
name: queue_str.clone(),
queue: queue_str,
consumers: Vec::with_capacity(1),
config,
spawner,
group_token,
error_count,
panic_count,
}
}
pub fn start(&mut self) {
let target = self.config.min_consumers as usize;
info!(
group = %self.name,
queue = %self.queue,
initial_consumers = target,
"starting consumer group"
);
for _ in 0..target {
self.spawn_one();
}
}
pub fn scale_up(&mut self) -> bool {
if self.consumers.len() >= self.config.max_consumers as usize {
debug!(group = %self.name, max = self.config.max_consumers, "scale_up rejected: at max capacity");
return false;
}
self.spawn_one();
info!(
group = %self.name,
consumers = self.consumers.len(),
"scaled up: spawned new consumer"
);
true
}
pub fn scale_down(&mut self) -> bool {
if self.consumers.len() <= self.config.min_consumers as usize {
debug!(group = %self.name, min = self.config.min_consumers, "scale_down rejected: at min capacity");
return false;
}
let idle_index = self
.consumers
.iter()
.rposition(|(_, processing, _)| !processing.load(Ordering::Relaxed));
let Some(index) = idle_index else {
warn!(group = %self.name, "scale_down rejected: all consumers are busy");
return false;
};
let (token, _, _handle) = self.consumers.swap_remove(index);
token.cancel();
info!(
group = %self.name,
consumers = self.consumers.len(),
"scaled down: cancelled an idle consumer"
);
true
}
pub fn active_consumers(&self) -> usize {
self.consumers.len()
}
pub fn queue(&self) -> &str {
&self.queue
}
pub fn config(&self) -> &ConsumerGroupConfig {
&self.config
}
pub async fn shutdown(&mut self) {
let _ = self.shutdown_with_tally().await;
}
pub(crate) async fn shutdown_with_tally(&mut self) -> ShutdownTally {
info!(group = %self.name, consumers = self.consumers.len(), "shutting down consumer group");
self.group_token.cancel();
let mut panics = 0usize;
for (_token, _processing, handle) in self.consumers.drain(..) {
match handle.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {}
Err(e) => {
tracing::error!(error = %e, group = %self.name, "consumer task panicked");
panics += 1;
}
}
}
let errors = self.error_count.swap(0, Ordering::Relaxed);
let panics = panics + self.panic_count.swap(0, Ordering::Relaxed);
debug!(group = %self.name, errors, panics, "consumer group shutdown complete");
ShutdownTally { errors, panics }
}
fn spawn_one(&mut self) {
let child_token = self.group_token.child_token();
let processing = Arc::new(AtomicBool::new(false));
let mut options = ConsumerOptions::defaults_with_shutdown(child_token.clone());
options.max_retries = self.config.max_retries;
options.prefetch_count = self.config.prefetch_count;
options.processing = processing.clone();
options.handler_timeout = Some(resolve_handler_timeout(self.config.handler_timeout, None));
options.max_pending_per_key = self.config.max_pending_per_key;
options.max_message_size = self.config.max_message_size;
options.consumer_group = Some(Arc::from(self.name.as_str()));
let handle = (self.spawner)(options);
self.consumers.push((child_token, processing, handle));
debug!(group = %self.name, consumer_index = self.consumers.len() - 1, "spawned consumer");
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::consumer::DEFAULT_HANDLER_TIMEOUT;
fn test_group(config: ConsumerGroupConfig) -> ConsumerGroup {
let group_token = CancellationToken::new();
let spawner: Spawner = Arc::new(|options: ConsumerOptions| {
tokio::spawn(async move {
options.shutdown.cancelled().await;
})
});
ConsumerGroup {
name: "test-group".into(),
queue: "test-queue".into(),
consumers: Vec::with_capacity(config.max_consumers as usize),
config,
spawner,
group_token,
error_count: Arc::new(AtomicUsize::new(0)),
panic_count: Arc::new(AtomicUsize::new(0)),
}
}
fn default_config() -> ConsumerGroupConfig {
ConsumerGroupConfig::new(1..=4)
}
#[test]
fn start_spawns_min_consumers() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(ConsumerGroupConfig::new(3..=5));
group.start();
assert_eq!(group.active_consumers(), 3);
group.shutdown().await;
});
}
#[test]
fn start_with_zero_min_spawns_nothing() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(ConsumerGroupConfig::new(0..=4));
group.start();
assert_eq!(group.active_consumers(), 0);
group.shutdown().await;
});
}
#[test]
fn scale_up_adds_one_consumer() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(default_config());
group.start();
assert_eq!(group.active_consumers(), 1);
assert!(group.scale_up());
assert_eq!(group.active_consumers(), 2);
group.shutdown().await;
});
}
#[test]
fn scale_up_rejected_at_max() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(ConsumerGroupConfig::new(2..=2));
group.start();
assert_eq!(group.active_consumers(), 2);
assert!(!group.scale_up());
assert_eq!(group.active_consumers(), 2);
group.shutdown().await;
});
}
#[test]
fn scale_down_removes_one_consumer() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(default_config());
group.start();
group.scale_up();
assert_eq!(group.active_consumers(), 2);
assert!(group.scale_down());
assert_eq!(group.active_consumers(), 1);
group.shutdown().await;
});
}
#[test]
fn scale_down_rejected_at_min() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(default_config());
group.start();
assert_eq!(group.active_consumers(), 1);
assert!(!group.scale_down());
assert_eq!(group.active_consumers(), 1);
group.shutdown().await;
});
}
#[test]
fn scale_down_skips_busy_consumers() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(ConsumerGroupConfig::new(0..=3));
group.scale_up();
group.scale_up();
group.scale_up();
assert_eq!(group.active_consumers(), 3);
for (_, processing, _) in &group.consumers {
processing.store(true, Ordering::Release);
}
assert!(!group.scale_down());
assert_eq!(group.active_consumers(), 3);
group.shutdown().await;
});
}
#[test]
fn scale_down_picks_idle_when_some_busy() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(ConsumerGroupConfig::new(0..=3));
group.scale_up();
group.scale_up();
group.scale_up();
assert_eq!(group.active_consumers(), 3);
group.consumers[0].1.store(true, Ordering::Release);
group.consumers[2].1.store(true, Ordering::Release);
let idle_token_ptr = Arc::as_ptr(&group.consumers[1].1);
assert!(group.scale_down());
assert_eq!(group.active_consumers(), 2);
for (_, processing, _) in &group.consumers {
assert_ne!(Arc::as_ptr(processing), idle_token_ptr);
}
group.shutdown().await;
});
}
#[test]
fn scale_down_cancels_token() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(ConsumerGroupConfig::new(0..=2));
group.scale_up();
group.scale_up();
let doomed_token = group.consumers[1].0.clone();
assert!(!doomed_token.is_cancelled());
group.scale_down();
assert!(doomed_token.is_cancelled());
group.shutdown().await;
});
}
#[tokio::test]
async fn shutdown_cancels_group_token() {
let mut group = test_group(default_config());
let group_token = group.group_token.clone();
group.start();
group.scale_up();
assert!(!group_token.is_cancelled());
group.shutdown().await;
assert!(group_token.is_cancelled());
assert_eq!(group.active_consumers(), 0);
}
#[test]
fn queue_returns_configured_queue() {
let group = test_group(default_config());
assert_eq!(group.queue(), "test-queue");
}
#[test]
fn config_returns_reference() {
let group = test_group(
ConsumerGroupConfig::new(2..=8)
.with_prefetch_count(5)
.with_max_retries(3)
.with_handler_timeout(Duration::from_secs(30)),
);
let config = group.config();
assert_eq!(config.min_consumers(), 2);
assert_eq!(config.max_consumers(), 8);
assert_eq!(config.prefetch_count(), 5);
assert_eq!(config.max_retries(), 3);
assert_eq!(config.handler_timeout(), Some(Duration::from_secs(30)));
}
#[test]
fn new_with_valid_range() {
let config = ConsumerGroupConfig::new(2..=8);
assert_eq!(config.min_consumers(), 2);
assert_eq!(config.max_consumers(), 8);
}
#[test]
fn new_sets_defaults() {
let config = ConsumerGroupConfig::new(1..=4);
assert_eq!(config.prefetch_count(), 10);
assert_eq!(config.max_retries(), 10);
assert_eq!(config.handler_timeout(), Some(DEFAULT_HANDLER_TIMEOUT));
}
#[test]
fn new_with_equal_min_max() {
let config = ConsumerGroupConfig::new(3..=3);
assert_eq!(config.min_consumers(), 3);
assert_eq!(config.max_consumers(), 3);
}
#[test]
#[should_panic]
#[allow(clippy::reversed_empty_ranges)]
fn new_panics_if_min_greater_than_max() {
let _ = ConsumerGroupConfig::new(5..=2);
}
#[test]
fn with_prefetch_count_sets_value() {
let config = ConsumerGroupConfig::new(1..=4).with_prefetch_count(25);
assert_eq!(config.prefetch_count(), 25);
}
#[test]
fn with_max_retries_sets_value() {
let config = ConsumerGroupConfig::new(1..=4).with_max_retries(5);
assert_eq!(config.max_retries(), 5);
}
#[test]
fn with_handler_timeout_sets_value() {
let config = ConsumerGroupConfig::new(1..=4).with_handler_timeout(Duration::from_secs(60));
assert_eq!(config.handler_timeout(), Some(Duration::from_secs(60)));
}
#[test]
fn builder_chaining_sets_all_values() {
let config = ConsumerGroupConfig::new(1..=5)
.with_prefetch_count(20)
.with_max_retries(3)
.with_handler_timeout(Duration::from_secs(30));
assert_eq!(config.min_consumers(), 1);
assert_eq!(config.max_consumers(), 5);
assert_eq!(config.prefetch_count(), 20);
assert_eq!(config.max_retries(), 3);
assert_eq!(config.handler_timeout(), Some(Duration::from_secs(30)));
}
#[test]
fn concurrent_processing_defaults_to_false() {
let config = ConsumerGroupConfig::new(1..=4);
assert!(!config.concurrent_processing());
}
#[test]
fn with_concurrent_processing_sets_value() {
let config = ConsumerGroupConfig::new(1..=4).with_concurrent_processing(true);
assert!(config.concurrent_processing());
}
#[test]
fn with_concurrent_processing_false_explicit() {
let config = ConsumerGroupConfig::new(1..=4)
.with_concurrent_processing(true)
.with_concurrent_processing(false);
assert!(!config.concurrent_processing());
}
#[test]
fn builder_chaining_with_concurrent_processing() {
let config = ConsumerGroupConfig::new(1..=8)
.with_prefetch_count(20)
.with_max_retries(3)
.with_handler_timeout(Duration::from_secs(30))
.with_concurrent_processing(true);
assert_eq!(config.min_consumers(), 1);
assert_eq!(config.max_consumers(), 8);
assert_eq!(config.prefetch_count(), 20);
assert_eq!(config.max_retries(), 3);
assert_eq!(config.handler_timeout(), Some(Duration::from_secs(30)));
assert!(config.concurrent_processing());
}
#[test]
fn spawned_consumers_start_idle() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let mut group = test_group(default_config());
group.scale_up();
let (_, processing, _) = &group.consumers[0];
assert!(!processing.load(Ordering::Acquire));
group.shutdown().await;
});
}
#[test]
fn inherit_config_uses_library_default_with_no_registry_default() {
let cfg = ConsumerGroupConfig::new(1..=4);
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, None),
DEFAULT_HANDLER_TIMEOUT,
);
}
#[test]
fn inherit_config_uses_registry_default_when_set() {
let cfg = ConsumerGroupConfig::new(1..=4);
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, Some(Duration::from_secs(45))),
Duration::from_secs(45),
);
}
#[test]
fn with_handler_timeout_beats_registry_default() {
let cfg = ConsumerGroupConfig::new(1..=4).with_handler_timeout(Duration::from_secs(5));
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, Some(Duration::from_secs(45))),
Duration::from_secs(5),
);
}
#[test]
#[should_panic(expected = "handler_timeout must be positive")]
fn with_handler_timeout_zero_panics() {
let _ = ConsumerGroupConfig::new(1..=4).with_handler_timeout(Duration::ZERO);
}
}