use std::future::Future;
use std::ops::RangeInclusive;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use crate::backend::ConsumerOptionsInner;
use crate::consumer::{
DEFAULT_MAX_MESSAGE_SIZE, DEFAULT_MAX_PENDING_PER_KEY, HandlerTimeoutConfig,
resolve_handler_timeout,
};
use crate::consumer_supervisor::{ShutdownTally, SupervisorOutcome};
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::topic::{SequencedTopic, Topic};
use crate::backend::consumer::ConsumerImpl;
use super::client::RedisClient;
use super::consumer::RedisConsumer;
use super::reaper::spawn_reaper;
use super::topology::RedisTopologyDeclarer;
pub(crate) type Spawner = Arc<dyn Fn(ConsumerOptionsInner) -> JoinHandle<()> + Send + Sync>;
#[derive(Debug, Clone)]
pub struct RedisConsumerGroupConfig {
prefetch_count: u16,
max_retries: u32,
pub(crate) min_consumers: u16,
pub(crate) max_consumers: u16,
concurrent_processing: bool,
pub(crate) handler_timeout: HandlerTimeoutConfig,
max_pending_per_key: Option<usize>,
max_message_size: Option<usize>,
}
impl RedisConsumerGroupConfig {
pub fn new(range: RangeInclusive<u16>) -> Self {
let min = (*range.start()).max(1);
let max = (*range.end()).max(1);
assert!(
min <= max,
"min_consumers ({min}) must be <= max_consumers ({max})"
);
Self {
prefetch_count: 10,
max_retries: 10,
min_consumers: min,
max_consumers: max,
concurrent_processing: false,
handler_timeout: HandlerTimeoutConfig::Inherit,
max_pending_per_key: Some(DEFAULT_MAX_PENDING_PER_KEY),
max_message_size: Some(DEFAULT_MAX_MESSAGE_SIZE),
}
}
pub fn with_prefetch_count(mut self, prefetch_count: u16) -> Self {
self.prefetch_count = prefetch_count;
self
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_concurrent_processing(mut self, concurrent: bool) -> Self {
self.concurrent_processing = concurrent;
self
}
pub fn with_handler_timeout(mut self, timeout: Duration) -> Self {
assert!(!timeout.is_zero(), "handler_timeout must be positive");
self.handler_timeout = HandlerTimeoutConfig::Set(timeout);
self
}
pub fn with_max_pending_per_key(mut self, limit: usize) -> Self {
self.max_pending_per_key = Some(limit);
self
}
pub fn with_max_message_size(mut self, max: usize) -> Self {
self.max_message_size = Some(max);
self
}
pub fn prefetch_count(&self) -> u16 {
self.prefetch_count
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn min_consumers(&self) -> u16 {
self.min_consumers
}
pub fn max_consumers(&self) -> u16 {
self.max_consumers
}
pub fn concurrent_processing(&self) -> bool {
self.concurrent_processing
}
pub fn handler_timeout(&self) -> Option<Duration> {
Some(resolve_handler_timeout(self.handler_timeout, None))
}
pub fn max_pending_per_key(&self) -> Option<usize> {
self.max_pending_per_key
}
pub fn max_message_size(&self) -> Option<usize> {
self.max_message_size
}
}
impl Default for RedisConsumerGroupConfig {
fn default() -> Self {
Self::new(1..=4)
}
}
pub struct RedisConsumerGroup {
pub(crate) queue: String,
pub(crate) config: RedisConsumerGroupConfig,
pub(crate) spawner: Spawner,
pub(crate) consumers: Vec<(CancellationToken, Arc<AtomicBool>, JoinHandle<()>)>,
pub(crate) group_token: CancellationToken,
pub(crate) error_count: Arc<AtomicUsize>,
pub(crate) panic_count: Arc<AtomicUsize>,
pub(crate) reaper_factory: ReaperFactory,
pub(crate) reaper_handle: Option<JoinHandle<()>>,
}
pub(crate) type ReaperFactory = Arc<dyn Fn() -> JoinHandle<()> + Send + Sync>;
fn build_reaper_factory(
client: &RedisClient,
streams: Vec<String>,
config: &RedisConsumerGroupConfig,
group_token: &CancellationToken,
) -> ReaperFactory {
let group = client.group().to_string();
let handler_timeout = resolve_handler_timeout(config.handler_timeout, None);
let min_idle_ms = handler_timeout.as_millis() as u64;
let interval = Duration::from_millis(min_idle_ms.max(30_000));
let client = client.clone();
let token = group_token.clone();
Arc::new(move || {
spawn_reaper(
client.clone(),
streams.clone(),
group.clone(),
interval,
min_idle_ms,
token.clone(),
)
})
}
impl RedisConsumerGroup {
pub fn new<T, H>(
queue: impl Into<String>,
client: RedisClient,
config: RedisConsumerGroupConfig,
group_token: CancellationToken,
handler_factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Self
where
T: Topic + 'static,
H: MessageHandler<T> + 'static,
{
let concurrent = config.concurrent_processing();
let effective_prefetch = if concurrent {
config.prefetch_count().max(1)
} else {
1
};
let error_count = Arc::new(AtomicUsize::new(0));
let ec_for_spawner = error_count.clone();
let client_for_spawner = client.clone();
let ctx_for_spawner = ctx;
let spawner: Spawner = Arc::new(move |mut options: ConsumerOptionsInner| {
let handler = handler_factory();
let consumer = RedisConsumer::new(client_for_spawner.clone());
let ctx = ctx_for_spawner.clone();
let ec = ec_for_spawner.clone();
options.prefetch_count = effective_prefetch;
tokio::spawn(async move {
let result = if concurrent {
consumer.run_concurrent::<T, H>(handler, ctx, options).await
} else {
<RedisConsumer as ConsumerImpl>::run::<T, H>(&consumer, handler, ctx, options)
.await
};
if let Err(e) = result {
ec.fetch_add(1, Ordering::Relaxed);
tracing::error!("consumer task exited with error: {e}");
}
})
});
let reaper_factory = build_reaper_factory(
&client,
vec![T::topology().queue().to_string()],
&config,
&group_token,
);
Self {
queue: queue.into(),
consumers: Vec::with_capacity(config.max_consumers() as usize),
config,
spawner,
group_token,
error_count,
panic_count: Arc::new(AtomicUsize::new(0)),
reaper_factory,
reaper_handle: None,
}
}
pub fn new_fifo<T, H>(
queue: impl Into<String>,
client: RedisClient,
mut config: RedisConsumerGroupConfig,
group_token: CancellationToken,
handler_factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Self
where
T: SequencedTopic + 'static,
H: MessageHandler<T> + 'static,
{
config.min_consumers = 1;
config.max_consumers = 1;
let prefetch = config.prefetch_count().max(1);
let error_count = Arc::new(AtomicUsize::new(0));
let panic_count = Arc::new(AtomicUsize::new(0));
let ec_for_spawner = error_count.clone();
let pc_for_spawner = panic_count.clone();
let client_for_spawner = client.clone();
let ctx_for_spawner = ctx;
let spawner: Spawner = Arc::new(move |mut options: ConsumerOptionsInner| {
let handler = handler_factory();
let consumer = RedisConsumer::new(client_for_spawner.clone());
let ctx = ctx_for_spawner.clone();
let ec = ec_for_spawner.clone();
let pc = pc_for_spawner.clone();
options.prefetch_count = prefetch;
tokio::spawn(async move {
let handles = match <RedisConsumer as ConsumerImpl>::spawn_fifo_shards::<T, H>(
&consumer, handler, ctx, options,
)
.await
{
Ok(h) => h,
Err(e) => {
ec.fetch_add(1, Ordering::Relaxed);
tracing::error!("FIFO registration failed: {e}");
return;
}
};
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => {
ec.fetch_add(1, Ordering::Relaxed);
tracing::error!("sequenced shard exited with error: {e}");
}
Err(e) if e.is_cancelled() => {}
Err(e) => {
pc.fetch_add(1, Ordering::Relaxed);
tracing::error!("sequenced shard panicked: {e}");
}
}
}
})
});
let topology = T::topology();
let n_shards = topology
.sequencing()
.expect("new_fifo requires sequencing config")
.routing_shards();
let reaper_streams: Vec<String> = (0..n_shards)
.map(|idx| RedisTopologyDeclarer::shard_stream_name(topology.queue(), idx))
.collect();
let reaper_factory = build_reaper_factory(&client, reaper_streams, &config, &group_token);
Self {
queue: queue.into(),
consumers: Vec::with_capacity(1),
config,
spawner,
group_token,
error_count,
panic_count,
reaper_factory,
reaper_handle: None,
}
}
pub fn start(&mut self) {
let target = self.config.min_consumers() as usize;
info!(
group = %self.queue,
queue = %self.queue,
initial_consumers = target,
"starting consumer group"
);
for _ in 0..target {
self.spawn_one();
}
if self.reaper_handle.is_none() {
self.reaper_handle = Some((self.reaper_factory)());
}
}
pub fn scale_up(&mut self) -> bool {
if self.consumers.len() >= self.config.max_consumers() as usize {
debug!(group = %self.queue, max = self.config.max_consumers(), "scale_up rejected: at max capacity");
return false;
}
self.spawn_one();
info!(
group = %self.queue,
consumers = self.consumers.len(),
"scaled up: spawned new consumer"
);
true
}
pub fn scale_down(&mut self) -> bool {
if self.consumers.len() <= self.config.min_consumers() as usize {
debug!(group = %self.queue, min = self.config.min_consumers(), "scale_down rejected: at min capacity");
return false;
}
let idle_index = self
.consumers
.iter()
.rposition(|(_, processing, _)| !processing.load(Ordering::Relaxed));
let Some(index) = idle_index else {
warn!(group = %self.queue, "scale_down rejected: all consumers are busy");
return false;
};
let (token, _, _handle) = self.consumers.swap_remove(index);
token.cancel();
info!(
group = %self.queue,
consumers = self.consumers.len(),
"scaled down: cancelled an idle consumer"
);
true
}
pub fn active_consumers(&self) -> usize {
self.consumers.len()
}
pub fn queue(&self) -> &str {
&self.queue
}
pub fn config(&self) -> &RedisConsumerGroupConfig {
&self.config
}
pub async fn shutdown(&mut self) {
let _ = self.shutdown_with_tally().await;
}
pub(crate) async fn shutdown_with_tally(&mut self) -> ShutdownTally {
let mut tally = ShutdownTally::default();
self.drain_into(&mut tally).await;
debug!(
group = %self.queue,
errors = tally.errors,
panics = tally.panics,
"consumer group shutdown complete"
);
tally
}
pub(crate) async fn drain_into(&mut self, tally: &mut ShutdownTally) {
info!(
group = %self.queue,
consumers = self.consumers.len(),
"shutting down consumer group"
);
self.group_token.cancel();
tally.errors += self.error_count.swap(0, Ordering::Relaxed);
tally.panics += self.panic_count.swap(0, Ordering::Relaxed);
while let Some((_token, _processing, handle)) = self.consumers.pop() {
match handle.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {}
Err(e) => {
tracing::error!(error = %e, group = %self.queue, "consumer task panicked");
tally.panics += 1;
}
}
}
if let Some(h) = self.reaper_handle.take() {
let _ = h.await;
}
tally.errors += self.error_count.swap(0, Ordering::Relaxed);
tally.panics += self.panic_count.swap(0, Ordering::Relaxed);
}
pub(crate) async fn abort_remaining_into(&mut self, tally: &mut ShutdownTally) {
self.group_token.cancel();
for (_token, _processing, handle) in &self.consumers {
handle.abort();
}
if let Some(h) = &self.reaper_handle {
h.abort();
}
while let Some((_token, _processing, handle)) = self.consumers.pop() {
match handle.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {}
Err(e) => {
tracing::error!(
error = %e,
group = %self.queue,
"consumer task panicked during abort escalation"
);
tally.panics += 1;
}
}
}
if let Some(h) = self.reaper_handle.take() {
let _ = h.await;
}
tally.errors += self.error_count.swap(0, Ordering::Relaxed);
tally.panics += self.panic_count.swap(0, Ordering::Relaxed);
}
fn spawn_one(&mut self) {
let child_token = self.group_token.child_token();
let processing = Arc::new(AtomicBool::new(false));
let mut options = ConsumerOptionsInner::defaults_with_shutdown(child_token.clone());
options.prefetch_count = self.config.prefetch_count();
options.max_retries = self.config.max_retries();
options.handler_timeout = Some(resolve_handler_timeout(self.config.handler_timeout, None));
options.processing = processing.clone();
options.consumer_group = Some(Arc::from(self.queue.as_str()));
options.max_pending_per_key = self.config.max_pending_per_key();
options.max_message_size = self.config.max_message_size();
let handle = (self.spawner)(options);
self.consumers.push((child_token, processing, handle));
debug!(group = %self.queue, consumer_index = self.consumers.len() - 1, "spawned consumer");
}
}
pub struct RedisConsumerGroupRegistry {
pub(crate) groups: std::collections::HashMap<String, RedisConsumerGroup>,
client: Option<RedisClient>,
shutdown: CancellationToken,
pub(super) default_handler_timeout: Option<Duration>,
}
impl RedisConsumerGroupRegistry {
pub fn new(client: RedisClient) -> Self {
Self {
groups: std::collections::HashMap::new(),
client: Some(client),
shutdown: CancellationToken::new(),
default_handler_timeout: None,
}
}
#[cfg(test)]
pub(crate) fn from_groups(
groups: std::collections::HashMap<String, RedisConsumerGroup>,
) -> Self {
Self {
groups,
client: None,
shutdown: CancellationToken::new(),
default_handler_timeout: None,
}
}
pub fn with_default_handler_timeout(mut self, timeout: Duration) -> Self {
assert!(
!timeout.is_zero(),
"default_handler_timeout must be positive"
);
self.default_handler_timeout = Some(timeout);
self
}
pub fn broker_shutdown_token(&self) -> CancellationToken {
self.shutdown.clone()
}
pub async fn register<T, H>(
&mut self,
config: RedisConsumerGroupConfig,
factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: Topic + 'static,
H: MessageHandler<T> + 'static,
{
let mut config = config;
config.handler_timeout = HandlerTimeoutConfig::Set(resolve_handler_timeout(
config.handler_timeout,
self.default_handler_timeout,
));
let topology = T::topology();
let name = topology.queue().to_string();
if self.groups.contains_key(&name) {
return Err(ShoveError::Topology(format!(
"consumer group '{name}' is already registered"
)));
}
let client = self.client.as_ref().ok_or_else(|| {
ShoveError::Topology("registry has no client (test-only registry)".into())
})?;
let declarer = RedisTopologyDeclarer::new(client.clone());
declarer.declare(topology).await?;
let group_token = self.shutdown.child_token();
let group = RedisConsumerGroup::new::<T, H>(
name.clone(),
client.clone(),
config,
group_token,
factory,
ctx,
);
self.groups.insert(name, group);
Ok(())
}
pub async fn register_fifo<T, H>(
&mut self,
config: RedisConsumerGroupConfig,
factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: SequencedTopic + 'static,
H: MessageHandler<T> + 'static,
{
if config.concurrent_processing() {
return Err(ShoveError::Topology(format!(
"topic '{}' is sequenced; `concurrent_processing` on a FIFO consumer would \
break per-key ordering. Drop `with_concurrent_processing(true)` or use \
`register` for unsequenced topics.",
T::topology().queue(),
)));
}
let mut config = config;
config.handler_timeout = HandlerTimeoutConfig::Set(resolve_handler_timeout(
config.handler_timeout,
self.default_handler_timeout,
));
let topology = T::topology();
let name = topology.queue().to_string();
if self.groups.contains_key(&name) {
return Err(ShoveError::Topology(format!(
"consumer group '{name}' is already registered"
)));
}
let client = self.client.as_ref().ok_or_else(|| {
ShoveError::Topology("registry has no client (test-only registry)".into())
})?;
let declarer = RedisTopologyDeclarer::new(client.clone());
declarer.declare(topology).await?;
let group_token = self.shutdown.child_token();
let group = RedisConsumerGroup::new_fifo::<T, H>(
name.clone(),
client.clone(),
config,
group_token,
factory,
ctx,
);
self.groups.insert(name, group);
Ok(())
}
pub fn start_all(&mut self) {
info!(count = self.groups.len(), "starting all consumer groups");
for group in self.groups.values_mut() {
group.start();
}
}
pub fn groups(&self) -> &std::collections::HashMap<String, RedisConsumerGroup> {
&self.groups
}
pub fn groups_mut(&mut self) -> &mut std::collections::HashMap<String, RedisConsumerGroup> {
&mut self.groups
}
pub async fn shutdown_all(&mut self) {
let _ = self.shutdown_all_with_tally().await;
}
pub(crate) async fn shutdown_all_with_tally(&mut self) -> ShutdownTally {
let mut tally = ShutdownTally::default();
self.drain_all_into(&mut tally).await;
tally
}
pub(crate) async fn drain_all_into(&mut self, tally: &mut ShutdownTally) {
info!(
count = self.groups.len(),
"shutting down all consumer groups"
);
for group in self.groups.values_mut() {
group.drain_into(tally).await;
}
debug!(
errors = tally.errors,
panics = tally.panics,
"all consumer groups shut down"
);
}
pub(crate) async fn abort_all_remaining_into(&mut self, tally: &mut ShutdownTally) {
for group in self.groups.values_mut() {
group.abort_remaining_into(tally).await;
}
}
pub async fn run_until_timeout<S>(
mut self,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
S: Future<Output = ()> + Send + 'static,
{
self.start_all();
let shutdown = self.shutdown.clone();
let signal_handle = tokio::spawn(signal);
tokio::select! {
_ = shutdown.cancelled() => {}
res = signal_handle => {
let _ = res;
shutdown.cancel();
}
}
let mut tally = ShutdownTally::default();
match tokio::time::timeout(drain_timeout, self.drain_all_into(&mut tally)).await {
Ok(()) => SupervisorOutcome {
errors: tally.errors,
panics: tally.panics,
timed_out: false,
},
Err(_) => {
tracing::warn!(
timeout_ms = drain_timeout.as_millis() as u64,
"drain timeout elapsed; aborting surviving consumer tasks"
);
self.abort_all_remaining_into(&mut tally).await;
SupervisorOutcome {
errors: tally.errors,
panics: tally.panics,
timed_out: true,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::consumer::DEFAULT_HANDLER_TIMEOUT;
#[test]
fn config_default_handler_timeout_is_library_default() {
let cfg = RedisConsumerGroupConfig::new(1..=1);
assert_eq!(cfg.handler_timeout(), Some(DEFAULT_HANDLER_TIMEOUT));
}
#[test]
fn with_handler_timeout_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_handler_timeout(Duration::from_secs(7));
assert_eq!(cfg.handler_timeout(), Some(Duration::from_secs(7)));
}
#[test]
fn config_inherit_resolves_to_registry_default_when_set() {
let cfg = RedisConsumerGroupConfig::new(1..=1);
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, Some(Duration::from_secs(45))),
Duration::from_secs(45),
);
}
#[test]
fn with_handler_timeout_beats_registry_default() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_handler_timeout(Duration::from_secs(5));
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, Some(Duration::from_secs(45))),
Duration::from_secs(5),
);
}
#[test]
#[should_panic(expected = "handler_timeout must be positive")]
fn with_handler_timeout_zero_panics() {
let _ = RedisConsumerGroupConfig::new(1..=1).with_handler_timeout(Duration::ZERO);
}
#[test]
fn config_consumer_range() {
let cfg = RedisConsumerGroupConfig::new(2..=4);
assert_eq!(cfg.min_consumers(), 2);
assert_eq!(cfg.max_consumers(), 4);
}
#[test]
fn config_default_range_is_one_to_four() {
let cfg = RedisConsumerGroupConfig::default();
assert_eq!(cfg.min_consumers(), 1);
assert_eq!(cfg.max_consumers(), 4);
}
#[test]
fn config_zero_clamped_to_one() {
let cfg = RedisConsumerGroupConfig::new(0..=0);
assert_eq!(cfg.min_consumers(), 1);
assert_eq!(cfg.max_consumers(), 1);
}
#[test]
fn config_large_consumer_count() {
let cfg = RedisConsumerGroupConfig::new(u16::MAX..=u16::MAX);
assert_eq!(cfg.max_consumers(), u16::MAX);
}
#[test]
fn config_default_prefetch_count_is_ten() {
let cfg = RedisConsumerGroupConfig::default();
assert_eq!(cfg.prefetch_count(), 10);
}
#[test]
fn config_default_concurrent_processing_is_false() {
let cfg = RedisConsumerGroupConfig::default();
assert!(!cfg.concurrent_processing());
}
#[test]
fn with_prefetch_count_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_prefetch_count(64);
assert_eq!(cfg.prefetch_count(), 64);
}
#[test]
fn default_max_retries_is_ten() {
let cfg = RedisConsumerGroupConfig::default();
assert_eq!(cfg.max_retries(), 10);
}
#[test]
fn with_max_retries_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_max_retries(7);
assert_eq!(cfg.max_retries(), 7);
}
#[test]
fn with_concurrent_processing_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_concurrent_processing(true);
assert!(cfg.concurrent_processing());
}
#[test]
fn config_default_max_pending_per_key_is_library_default() {
use crate::consumer::DEFAULT_MAX_PENDING_PER_KEY;
let cfg = RedisConsumerGroupConfig::new(1..=1);
assert_eq!(cfg.max_pending_per_key(), Some(DEFAULT_MAX_PENDING_PER_KEY));
}
#[test]
fn config_default_max_message_size_is_library_default() {
use crate::consumer::DEFAULT_MAX_MESSAGE_SIZE;
let cfg = RedisConsumerGroupConfig::new(1..=1);
assert_eq!(cfg.max_message_size(), Some(DEFAULT_MAX_MESSAGE_SIZE));
}
#[test]
fn with_max_pending_per_key_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_max_pending_per_key(500);
assert_eq!(cfg.max_pending_per_key(), Some(500));
}
#[test]
fn with_max_message_size_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1..=1).with_max_message_size(1024);
assert_eq!(cfg.max_message_size(), Some(1024));
}
#[test]
fn builder_chain_preserves_all_fields() {
use crate::consumer::{DEFAULT_MAX_MESSAGE_SIZE, DEFAULT_MAX_PENDING_PER_KEY};
let cfg = RedisConsumerGroupConfig::new(2..=8)
.with_prefetch_count(32)
.with_concurrent_processing(true)
.with_handler_timeout(Duration::from_secs(3))
.with_max_pending_per_key(200)
.with_max_message_size(4096);
assert_eq!(cfg.min_consumers(), 2);
assert_eq!(cfg.max_consumers(), 8);
assert_eq!(cfg.prefetch_count(), 32);
assert!(cfg.concurrent_processing());
assert_eq!(cfg.handler_timeout(), Some(Duration::from_secs(3)));
assert_eq!(cfg.max_pending_per_key(), Some(200));
assert_eq!(cfg.max_message_size(), Some(4096));
assert_ne!(Some(200), Some(DEFAULT_MAX_PENDING_PER_KEY));
assert_ne!(Some(4096usize), Some(DEFAULT_MAX_MESSAGE_SIZE));
}
#[test]
#[should_panic(expected = "min_consumers")]
#[allow(clippy::reversed_empty_ranges)]
fn config_min_greater_than_max_panics() {
let _ = RedisConsumerGroupConfig::new(5..=2);
}
mod group {
use super::*;
use crate::backend::ConsumerOptionsInner;
fn default_config() -> RedisConsumerGroupConfig {
RedisConsumerGroupConfig::new(1..=4)
}
fn test_group(config: RedisConsumerGroupConfig) -> RedisConsumerGroup {
let group_token = CancellationToken::new();
let spawner: Spawner = Arc::new(|options: ConsumerOptionsInner| {
tokio::spawn(async move {
options.shutdown.cancelled().await;
})
});
let reaper_factory: ReaperFactory = Arc::new(|| tokio::spawn(async {}));
RedisConsumerGroup {
queue: "test-queue".into(),
consumers: Vec::with_capacity(config.max_consumers() as usize),
config,
spawner,
group_token,
error_count: Arc::new(AtomicUsize::new(0)),
panic_count: Arc::new(AtomicUsize::new(0)),
reaper_factory,
reaper_handle: None,
}
}
#[tokio::test]
async fn spawn_one_threads_max_pending_per_key_to_options() {
use std::sync::Mutex;
let captured: Arc<Mutex<Option<Option<usize>>>> = Arc::new(Mutex::new(None));
let cap = captured.clone();
let config = RedisConsumerGroupConfig::new(1..=1).with_max_pending_per_key(777);
let mut group = test_group(config);
group.spawner = Arc::new(move |options: ConsumerOptionsInner| {
*cap.lock().unwrap() = Some(options.max_pending_per_key);
tokio::spawn(async move { options.shutdown.cancelled().await })
});
group.start();
assert_eq!(*captured.lock().unwrap(), Some(Some(777)));
group.shutdown().await;
}
#[tokio::test]
async fn spawn_one_threads_max_message_size_to_options() {
use std::sync::Mutex;
let captured: Arc<Mutex<Option<Option<usize>>>> = Arc::new(Mutex::new(None));
let cap = captured.clone();
let config = RedisConsumerGroupConfig::new(1..=1).with_max_message_size(2048);
let mut group = test_group(config);
group.spawner = Arc::new(move |options: ConsumerOptionsInner| {
*cap.lock().unwrap() = Some(options.max_message_size);
tokio::spawn(async move { options.shutdown.cancelled().await })
});
group.start();
assert_eq!(*captured.lock().unwrap(), Some(Some(2048)));
group.shutdown().await;
}
#[tokio::test]
async fn start_spawns_min_consumers() {
let mut group = test_group(RedisConsumerGroupConfig::new(3..=5));
group.start();
assert_eq!(group.active_consumers(), 3);
group.shutdown().await;
}
#[tokio::test]
async fn scale_up_adds_one_consumer() {
let mut group = test_group(default_config());
group.start();
assert_eq!(group.active_consumers(), 1);
assert!(group.scale_up());
assert_eq!(group.active_consumers(), 2);
group.shutdown().await;
}
#[tokio::test]
async fn scale_up_rejected_at_max() {
let mut group = test_group(RedisConsumerGroupConfig::new(2..=2));
group.start();
assert_eq!(group.active_consumers(), 2);
assert!(!group.scale_up());
assert_eq!(group.active_consumers(), 2);
group.shutdown().await;
}
#[tokio::test]
async fn scale_down_removes_one_consumer() {
let mut group = test_group(RedisConsumerGroupConfig::new(1..=4));
group.start();
assert!(group.scale_up());
assert_eq!(group.active_consumers(), 2);
assert!(group.scale_down());
assert_eq!(group.active_consumers(), 1);
group.shutdown().await;
}
#[tokio::test]
async fn scale_down_rejected_at_min() {
let mut group = test_group(RedisConsumerGroupConfig::new(1..=4));
group.start();
assert!(!group.scale_down());
assert_eq!(group.active_consumers(), 1);
group.shutdown().await;
}
#[tokio::test]
async fn reaper_is_spawned_exactly_once_per_group() {
let spawn_count = Arc::new(AtomicUsize::new(0));
let counter = spawn_count.clone();
let reaper_factory: ReaperFactory = Arc::new(move || {
counter.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async {})
});
let mut group = test_group(RedisConsumerGroupConfig::new(1..=4));
group.reaper_factory = reaper_factory;
assert_eq!(spawn_count.load(Ordering::Relaxed), 0);
group.start();
assert_eq!(spawn_count.load(Ordering::Relaxed), 1);
group.start();
assert_eq!(
spawn_count.load(Ordering::Relaxed),
1,
"second start() must not respawn the reaper"
);
group.shutdown().await;
}
#[tokio::test]
async fn shutdown_joins_the_reaper_handle() {
let reaper_factory: ReaperFactory = Arc::new(|| tokio::spawn(async {}));
let mut group = test_group(RedisConsumerGroupConfig::new(1..=4));
group.reaper_factory = reaper_factory;
group.start();
assert!(group.reaper_handle.is_some());
group.shutdown().await;
assert!(
group.reaper_handle.is_none(),
"shutdown_with_tally must take() the reaper handle"
);
}
fn hanging_test_group(config: RedisConsumerGroupConfig) -> RedisConsumerGroup {
let mut group = test_group(config);
group.spawner = Arc::new(|_options: ConsumerOptionsInner| {
tokio::spawn(async {
std::future::pending::<()>().await;
})
});
group.reaper_factory = Arc::new(|| {
tokio::spawn(async {
std::future::pending::<()>().await;
})
});
group
}
#[tokio::test]
async fn drain_into_timeout_preserves_atomics_in_tally() {
let mut group = hanging_test_group(RedisConsumerGroupConfig::new(2..=2));
group.start();
assert_eq!(group.active_consumers(), 2);
group.error_count.store(7, Ordering::Relaxed);
group.panic_count.store(2, Ordering::Relaxed);
let mut tally = ShutdownTally::default();
let result =
tokio::time::timeout(Duration::from_millis(50), group.drain_into(&mut tally)).await;
assert!(result.is_err(), "drain must time out on hanging consumers");
assert_eq!(tally.errors, 7);
assert_eq!(tally.panics, 2);
}
#[tokio::test]
async fn abort_remaining_into_kills_hanging_consumers_and_reaper() {
let mut group = hanging_test_group(RedisConsumerGroupConfig::new(2..=2));
group.start();
assert!(group.reaper_handle.is_some());
group.error_count.store(5, Ordering::Relaxed);
group.panic_count.store(1, Ordering::Relaxed);
let mut tally = ShutdownTally::default();
let _ =
tokio::time::timeout(Duration::from_millis(50), group.drain_into(&mut tally)).await;
group.abort_remaining_into(&mut tally).await;
assert_eq!(group.active_consumers(), 0);
assert!(
group.reaper_handle.is_none(),
"abort_remaining_into must take() the reaper handle"
);
assert_eq!(tally.errors, 5);
assert_eq!(tally.panics, 1);
}
#[tokio::test]
async fn scale_down_skips_busy_consumers() {
let mut group = test_group(RedisConsumerGroupConfig::new(1..=4));
let busy_flag = Arc::new(AtomicBool::new(true));
let flag = busy_flag.clone();
let spawner: Spawner = Arc::new(move |options: ConsumerOptionsInner| {
options
.processing
.store(flag.load(Ordering::Relaxed), Ordering::Relaxed);
tokio::spawn(async move {
options.shutdown.cancelled().await;
})
});
group.spawner = spawner;
group.start();
group.scale_up();
assert!(!group.scale_down(), "all consumers reporting busy");
busy_flag.store(false, Ordering::Relaxed);
group.scale_up();
assert!(group.scale_down());
group.shutdown().await;
}
}
}