use std::collections::HashMap;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info};
use crate::backends::rabbitmq::client::RabbitMqClient;
use crate::backends::rabbitmq::consumer_group::{ConsumerGroup, ConsumerGroupConfig};
use crate::backends::rabbitmq::topology::RabbitMqTopologyDeclarer;
use crate::consumer::{HandlerTimeoutConfig, resolve_handler_timeout};
use crate::consumer_supervisor::ShutdownTally;
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::metrics;
use crate::topic::{SequencedTopic, Topic};
pub struct ConsumerGroupRegistry {
groups: HashMap<String, ConsumerGroup>,
client: RabbitMqClient,
pub(super) default_handler_timeout: Option<Duration>,
}
impl ConsumerGroupRegistry {
pub fn new(client: RabbitMqClient) -> Self {
Self {
groups: HashMap::new(),
client,
default_handler_timeout: None,
}
}
pub fn with_default_handler_timeout(mut self, timeout: Duration) -> Self {
assert!(
!timeout.is_zero(),
"default_handler_timeout must be positive"
);
self.default_handler_timeout = Some(timeout);
self
}
pub async fn register<T, H>(
&mut self,
config: ConsumerGroupConfig,
handler_factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: Topic + 'static,
H: MessageHandler<T> + 'static,
{
let mut config = config;
config.handler_timeout = HandlerTimeoutConfig::Set(resolve_handler_timeout(
config.handler_timeout,
self.default_handler_timeout,
));
let topology = T::topology();
let name = topology.queue().to_string();
if self.groups.contains_key(&name) {
metrics::record_backend_error(
metrics::BackendLabel::RabbitMq,
metrics::BackendErrorKind::Topology,
);
return Err(ShoveError::Topology(format!(
"consumer group '{name}' is already registered"
)));
}
let channel = self.client.create_channel().await?;
let declarer = RabbitMqTopologyDeclarer::new(channel);
declarer.declare(topology).await?;
info!(group = %name, "registering consumer group");
let group_token = self.client.shutdown_token().child_token();
let group = ConsumerGroup::new::<T, H>(
name.clone(),
name.clone(),
config,
self.client.clone(),
group_token,
handler_factory,
ctx,
);
self.groups.insert(name, group);
Ok(())
}
pub async fn register_fifo<T, H>(
&mut self,
config: ConsumerGroupConfig,
handler_factory: impl Fn() -> H + Send + Sync + 'static,
ctx: H::Context,
) -> Result<()>
where
T: SequencedTopic + 'static,
H: MessageHandler<T> + 'static,
{
let mut config = config;
config.handler_timeout = HandlerTimeoutConfig::Set(resolve_handler_timeout(
config.handler_timeout,
self.default_handler_timeout,
));
let topology = T::topology();
let name = topology.queue().to_string();
if self.groups.contains_key(&name) {
return Err(ShoveError::Topology(format!(
"consumer group '{name}' is already registered"
)));
}
let channel = self.client.create_channel().await?;
let declarer = RabbitMqTopologyDeclarer::new(channel);
declarer.declare(topology).await?;
info!(group = %name, "registering FIFO consumer group");
let group_token = self.client.shutdown_token().child_token();
let group = ConsumerGroup::new_fifo::<T, H>(
name.clone(),
self.client.clone(),
config,
group_token,
handler_factory,
ctx,
);
self.groups.insert(name, group);
Ok(())
}
pub fn start_all(&mut self) {
info!(count = self.groups.len(), "starting all consumer groups");
for group in self.groups.values_mut() {
group.start();
}
}
pub fn groups(&self) -> &HashMap<String, ConsumerGroup> {
&self.groups
}
pub fn groups_mut(&mut self) -> &mut HashMap<String, ConsumerGroup> {
&mut self.groups
}
pub fn client_shutdown_token(&self) -> CancellationToken {
self.client.shutdown_token()
}
pub async fn shutdown_all(&mut self) {
let _ = self.shutdown_all_with_tally().await;
}
pub(crate) async fn shutdown_all_with_tally(&mut self) -> ShutdownTally {
info!(
count = self.groups.len(),
"shutting down all consumer groups"
);
let mut tally = ShutdownTally::default();
for group in self.groups.values_mut() {
tally.add(group.shutdown_with_tally().await);
}
debug!(
errors = tally.errors,
panics = tally.panics,
"all consumer groups shut down"
);
tally
}
}