use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use crate::backend::ConsumerOptionsInner;
use crate::consumer_supervisor::{SupervisorOutcome, tally_join_result};
use crate::error::Result;
use crate::handler::MessageHandler;
use crate::topic::{SequencedTopic, Topic};
use crate::backend::consumer::ConsumerImpl;
use super::client::RedisClient;
use super::consumer::RedisConsumer;
use super::topology::RedisTopologyDeclarer;
type BoxFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
type TaskFactory = Box<dyn FnOnce() -> BoxFuture + Send>;
#[derive(Debug, Clone)]
pub struct RedisConsumerGroupConfig {
consumer_count: u16,
}
impl RedisConsumerGroupConfig {
pub fn new(consumer_count: u16) -> Self {
Self {
consumer_count: consumer_count.max(1),
}
}
pub fn consumer_count(&self) -> u16 {
self.consumer_count
}
}
impl Default for RedisConsumerGroupConfig {
fn default() -> Self {
Self::new(1)
}
}
pub struct RedisConsumerGroupRegistry {
client: RedisClient,
tasks: Vec<TaskFactory>,
shutdown: CancellationToken,
}
impl RedisConsumerGroupRegistry {
pub fn new(client: RedisClient) -> Self {
Self {
client,
tasks: Vec::new(),
shutdown: CancellationToken::new(),
}
}
pub fn broker_shutdown_token(&self) -> CancellationToken {
self.shutdown.clone()
}
pub async fn register<T, H>(
&mut self,
config: RedisConsumerGroupConfig,
factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: Topic + 'static,
H: MessageHandler<T> + 'static,
{
let topology = T::topology();
let declarer = RedisTopologyDeclarer::new(self.client.clone());
declarer.declare(topology).await?;
let n = config.consumer_count() as usize;
for _ in 0..n {
let client = self.client.clone();
let shutdown = self.shutdown.clone();
let handler = factory();
let ctx = ctx.clone();
let task: TaskFactory = Box::new(move || {
Box::pin(async move {
let consumer = RedisConsumer::new(client);
let options = ConsumerOptionsInner::defaults_with_shutdown(shutdown);
consumer.run::<T, H>(handler, ctx, options).await
})
});
self.tasks.push(task);
}
Ok(())
}
pub async fn register_fifo<T, H>(
&mut self,
factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: SequencedTopic + 'static,
H: MessageHandler<T> + 'static,
{
let topology = T::topology();
let declarer = RedisTopologyDeclarer::new(self.client.clone());
declarer.declare(topology).await?;
let client = self.client.clone();
let shutdown = self.shutdown.clone();
let handler = factory();
let ctx = ctx.clone();
let task: TaskFactory = Box::new(move || {
Box::pin(async move {
let consumer = RedisConsumer::new(client);
let options = ConsumerOptionsInner::defaults_with_shutdown(shutdown);
consumer.run_fifo::<T, H>(handler, ctx, options).await
})
});
self.tasks.push(task);
Ok(())
}
pub fn start_all(&mut self, set: &mut JoinSet<Result<()>>) {
for factory in self.tasks.drain(..) {
set.spawn(factory());
}
}
pub async fn run_until_timeout<S>(
mut self,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
S: Future<Output = ()> + Send + 'static,
{
let mut set: JoinSet<Result<()>> = JoinSet::new();
self.start_all(&mut set);
let shutdown = self.shutdown.clone();
tokio::select! {
_ = shutdown.cancelled() => {}
_ = signal => { shutdown.cancel(); }
}
let mut errors = 0usize;
let mut panics = 0usize;
let drain = async {
while let Some(res) = set.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
};
match tokio::time::timeout(drain_timeout, drain).await {
Ok(()) => SupervisorOutcome {
errors,
panics,
timed_out: false,
},
Err(_) => {
tracing::warn!(
timeout_ms = drain_timeout.as_millis() as u64,
"RedisConsumerGroupRegistry: drain timed out; aborting surviving tasks"
);
set.abort_all();
while let Some(res) = set.join_next().await {
tally_join_result(res, &mut errors, &mut panics);
}
SupervisorOutcome {
errors,
panics,
timed_out: true,
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_consumer_count() {
let cfg = RedisConsumerGroupConfig::new(4);
assert_eq!(cfg.consumer_count(), 4);
}
#[test]
fn config_default_count_is_one() {
let cfg = RedisConsumerGroupConfig::default();
assert_eq!(cfg.consumer_count(), 1);
}
#[test]
fn config_zero_clamped_to_one() {
let cfg = RedisConsumerGroupConfig::new(0);
assert_eq!(cfg.consumer_count(), 1);
}
#[test]
fn config_large_consumer_count() {
let cfg = RedisConsumerGroupConfig::new(u16::MAX);
assert_eq!(cfg.consumer_count(), u16::MAX);
}
#[test]
fn config_builder_chain_consumer_count_accessible() {
let cfg = RedisConsumerGroupConfig::new(8);
assert_eq!(cfg.consumer_count(), 8);
let cloned = cfg.clone();
assert_eq!(cloned.consumer_count(), 8);
}
}