use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::backend::ConsumerOptionsInner;
use crate::consumer::{HandlerTimeoutConfig, resolve_handler_timeout};
use crate::consumer_supervisor::{SupervisorOutcome, tally_join_result};
use crate::error::Result;
use crate::handler::MessageHandler;
use crate::topic::{SequencedTopic, Topic};
use crate::backend::consumer::ConsumerImpl;
use super::client::RedisClient;
use super::consumer::RedisConsumer;
use super::topology::RedisTopologyDeclarer;
type BoxFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
type TaskFactory = Box<dyn FnOnce() -> BoxFuture + Send>;
#[derive(Debug, Clone)]
pub struct RedisConsumerGroupConfig {
consumer_count: u16,
pub(crate) handler_timeout: HandlerTimeoutConfig,
}
impl RedisConsumerGroupConfig {
pub fn new(consumer_count: u16) -> Self {
Self {
consumer_count: consumer_count.max(1),
handler_timeout: HandlerTimeoutConfig::Inherit,
}
}
pub fn with_handler_timeout(mut self, timeout: Duration) -> Self {
assert!(!timeout.is_zero(), "handler_timeout must be positive");
self.handler_timeout = HandlerTimeoutConfig::Set(timeout);
self
}
pub fn consumer_count(&self) -> u16 {
self.consumer_count
}
pub fn handler_timeout(&self) -> Option<Duration> {
Some(resolve_handler_timeout(self.handler_timeout, None))
}
}
impl Default for RedisConsumerGroupConfig {
fn default() -> Self {
Self::new(1)
}
}
pub struct RedisConsumerGroupRegistry {
client: RedisClient,
tasks: Vec<TaskFactory>,
shutdown: CancellationToken,
pub(super) default_handler_timeout: Option<Duration>,
}
impl RedisConsumerGroupRegistry {
pub fn new(client: RedisClient) -> Self {
Self {
client,
tasks: Vec::new(),
shutdown: CancellationToken::new(),
default_handler_timeout: None,
}
}
pub fn with_default_handler_timeout(mut self, timeout: Duration) -> Self {
assert!(
!timeout.is_zero(),
"default_handler_timeout must be positive"
);
self.default_handler_timeout = Some(timeout);
self
}
pub fn broker_shutdown_token(&self) -> CancellationToken {
self.shutdown.clone()
}
pub async fn register<T, H>(
&mut self,
config: RedisConsumerGroupConfig,
factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: Topic + 'static,
H: MessageHandler<T> + 'static,
{
let topology = T::topology();
let declarer = RedisTopologyDeclarer::new(self.client.clone());
declarer.declare(topology).await?;
let resolved_handler_timeout =
resolve_handler_timeout(config.handler_timeout, self.default_handler_timeout);
let n = config.consumer_count() as usize;
for _ in 0..n {
let client = self.client.clone();
let shutdown = self.shutdown.clone();
let handler = factory();
let ctx = ctx.clone();
let handler_timeout = resolved_handler_timeout;
let task: TaskFactory = Box::new(move || {
Box::pin(async move {
let consumer = RedisConsumer::new(client);
let mut options = ConsumerOptionsInner::defaults_with_shutdown(shutdown);
options.handler_timeout = Some(handler_timeout);
consumer.run::<T, H>(handler, ctx, options).await
})
});
self.tasks.push(task);
}
Ok(())
}
pub async fn register_fifo<T, H>(
&mut self,
config: RedisConsumerGroupConfig,
factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: SequencedTopic + 'static,
H: MessageHandler<T> + 'static,
{
let topology = T::topology();
let declarer = RedisTopologyDeclarer::new(self.client.clone());
declarer.declare(topology).await?;
let resolved_handler_timeout =
resolve_handler_timeout(config.handler_timeout, self.default_handler_timeout);
let n = config.consumer_count() as usize;
for _ in 0..n {
let client = self.client.clone();
let shutdown = self.shutdown.clone();
let handler = factory();
let ctx = ctx.clone();
let handler_timeout = resolved_handler_timeout;
let task: TaskFactory = Box::new(move || {
Box::pin(async move {
let consumer = RedisConsumer::new(client);
let mut options = ConsumerOptionsInner::defaults_with_shutdown(shutdown);
options.handler_timeout = Some(handler_timeout);
consumer.run_fifo::<T, H>(handler, ctx, options).await
})
});
self.tasks.push(task);
}
Ok(())
}
pub fn start_all(&mut self, set: &mut JoinSet<Result<()>>) {
for factory in self.tasks.drain(..) {
set.spawn(factory());
}
}
pub async fn run_until_timeout<S>(
mut self,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
S: Future<Output = ()> + Send + 'static,
{
let mut set: JoinSet<Result<()>> = JoinSet::new();
self.start_all(&mut set);
let shutdown = self.shutdown.clone();
tokio::select! {
_ = shutdown.cancelled() => {}
_ = signal => { shutdown.cancel(); }
}
let mut errors = 0usize;
let mut panics = 0usize;
let drain = async {
while let Some(res) = set.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
};
match tokio::time::timeout(drain_timeout, drain).await {
Ok(()) => SupervisorOutcome {
errors,
panics,
timed_out: false,
},
Err(_) => {
tracing::warn!(
timeout_ms = drain_timeout.as_millis() as u64,
"RedisConsumerGroupRegistry: drain timed out; aborting surviving tasks"
);
set.abort_all();
while let Some(res) = set.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
SupervisorOutcome {
errors,
panics,
timed_out: true,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::consumer::DEFAULT_HANDLER_TIMEOUT;
#[test]
fn config_default_handler_timeout_is_library_default() {
let cfg = RedisConsumerGroupConfig::new(1);
assert_eq!(cfg.handler_timeout(), Some(DEFAULT_HANDLER_TIMEOUT));
}
#[test]
fn with_handler_timeout_round_trips() {
let cfg = RedisConsumerGroupConfig::new(1).with_handler_timeout(Duration::from_secs(7));
assert_eq!(cfg.handler_timeout(), Some(Duration::from_secs(7)));
}
#[test]
fn config_inherit_resolves_to_registry_default_when_set() {
let cfg = RedisConsumerGroupConfig::new(1);
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, Some(Duration::from_secs(45))),
Duration::from_secs(45),
);
}
#[test]
fn with_handler_timeout_beats_registry_default() {
let cfg = RedisConsumerGroupConfig::new(1).with_handler_timeout(Duration::from_secs(5));
assert_eq!(
resolve_handler_timeout(cfg.handler_timeout, Some(Duration::from_secs(45))),
Duration::from_secs(5),
);
}
#[test]
#[should_panic(expected = "handler_timeout must be positive")]
fn with_handler_timeout_zero_panics() {
let _ = RedisConsumerGroupConfig::new(1).with_handler_timeout(Duration::ZERO);
}
#[test]
fn config_consumer_count() {
let cfg = RedisConsumerGroupConfig::new(4);
assert_eq!(cfg.consumer_count(), 4);
}
#[test]
fn config_default_count_is_one() {
let cfg = RedisConsumerGroupConfig::default();
assert_eq!(cfg.consumer_count(), 1);
}
#[test]
fn config_zero_clamped_to_one() {
let cfg = RedisConsumerGroupConfig::new(0);
assert_eq!(cfg.consumer_count(), 1);
}
#[test]
fn config_large_consumer_count() {
let cfg = RedisConsumerGroupConfig::new(u16::MAX);
assert_eq!(cfg.consumer_count(), u16::MAX);
}
#[test]
fn config_builder_chain_consumer_count_accessible() {
let cfg = RedisConsumerGroupConfig::new(8);
assert_eq!(cfg.consumer_count(), 8);
let cloned = cfg.clone();
assert_eq!(cloned.consumer_count(), 8);
}
}