use crate::base::Broker;
use crate::components::ComponentLifecycle;
use crate::error::Result;
use crate::task::Task;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
pub trait GroupAggregator: Send + Sync {
fn aggregate(&self, group: &str, tasks: Vec<Task>) -> Result<Task>;
}
pub struct GroupAggregatorFunc<F> {
func: F,
}
impl<F> GroupAggregatorFunc<F>
where
F: Fn(&str, Vec<Task>) -> Result<Task> + Send + Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
impl<F> GroupAggregator for GroupAggregatorFunc<F>
where
F: Fn(&str, Vec<Task>) -> Result<Task> + Send + Sync,
{
fn aggregate(&self, group: &str, tasks: Vec<Task>) -> Result<Task> {
(self.func)(group, tasks)
}
}
pub struct AggregatorConfig {
pub interval: Duration,
pub queues: Vec<String>,
pub grace_period: Duration,
pub max_delay: Option<Duration>,
pub max_size: Option<usize>,
pub group_aggregator: Option<Arc<dyn GroupAggregator>>,
}
impl Default for AggregatorConfig {
fn default() -> Self {
Self {
interval: Duration::from_secs(5),
queues: vec!["default".to_string()],
grace_period: Duration::from_secs(60),
max_delay: None,
max_size: None,
group_aggregator: None,
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] struct GroupInfo {
pub group: String,
pub queue: String,
pub set_id: Option<String>,
pub task_count: usize,
}
pub struct Aggregator {
broker: Arc<dyn Broker>,
config: AggregatorConfig,
done: Arc<AtomicBool>,
#[allow(dead_code)] groups: Arc<RwLock<HashMap<String, GroupInfo>>>,
}
impl Aggregator {
pub fn new(broker: Arc<dyn Broker>, config: AggregatorConfig) -> Self {
Self {
broker,
config,
done: Arc::new(AtomicBool::new(false)),
groups: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn start(self: Arc<Self>) -> JoinHandle<()> {
tracing::info!("starting aggregator");
tokio::spawn(async move {
let mut interval = tokio::time::interval(self.config.interval);
loop {
interval.tick().await;
if self.done.load(Ordering::Relaxed) {
tracing::debug!("Aggregator: shutting down");
break;
}
if let Err(e) = self.aggregate().await {
tracing::error!("Aggregator error: {}", e);
}
}
})
}
async fn aggregate(&self) -> Result<()> {
for queue in &self.config.queues {
let groups = self.broker.list_groups(queue).await?;
for group in groups {
tracing::debug!("Aggregator: found group in queue {}: {:?}", queue, group);
if let Ok(Some(set_id)) = self
.broker
.aggregation_check(
queue,
&group,
self.config.grace_period,
self.config.max_delay.unwrap_or(Duration::from_secs(30)),
self.config.max_size.unwrap_or(10),
)
.await
{
tracing::debug!(
"Aggregator: found aggregation set ready for processing: queue={}, set_id={}",
queue,
set_id
);
match self
.broker
.read_aggregation_set(queue, &group, &set_id)
.await
{
Ok(task_messages) => {
let task_count = task_messages.len();
tracing::info!(
"Aggregator: processing {} tasks from aggregation set {} in queue {}",
task_count,
set_id,
queue
);
if let Some(aggregator) = &self.config.group_aggregator {
let mut tasks = Vec::new();
for task_msg in task_messages {
match Task::new_with_headers(
&task_msg.r#type,
&task_msg.payload,
task_msg.headers,
) {
Ok(task) => tasks.push(task),
Err(e) => {
tracing::warn!("Aggregator: failed to create task from message: {}", e);
}
}
}
if !tasks.is_empty() {
match aggregator.aggregate(&group, tasks) {
Ok(aggregated_task) => {
tracing::info!(
"Aggregator: aggregated {} tasks into task type '{}' for group '{}'",
task_count,
aggregated_task.get_type(),
group
);
let mut enqueue_task = aggregated_task.with_queue(queue);
if enqueue_task.options.group.is_none() {
enqueue_task = enqueue_task.with_group(&group);
}
if let Err(e) = self.broker.enqueue(&enqueue_task).await {
tracing::error!("Aggregator: failed to enqueue aggregated task: {}", e);
} else {
tracing::debug!(
"Aggregator: successfully enqueued aggregated task to queue '{}'",
queue
);
}
}
Err(e) => {
tracing::error!(
"Aggregator: failed to aggregate tasks for group '{}': {}",
group,
e
);
}
}
}
} else {
tracing::debug!(
"Aggregator: no GroupAggregator configured, tasks read but not processed"
);
}
}
Err(e) => {
tracing::warn!(
"Aggregator: failed to read aggregation set {}: {}",
set_id,
e
);
}
}
if let Err(e) = self
.broker
.delete_aggregation_set(queue, &group, &set_id)
.await
{
tracing::warn!(
"Aggregator: failed to close aggregation set {}: {}",
set_id,
e
);
}
}
}
}
Ok(())
}
pub fn shutdown(&self) {
self.done.store(true, Ordering::Relaxed);
}
pub fn is_done(&self) -> bool {
self.done.load(Ordering::Relaxed)
}
}
impl ComponentLifecycle for Aggregator {
fn start(self: Arc<Self>) -> JoinHandle<()> {
Aggregator::start(self)
}
fn shutdown(&self) {
Aggregator::shutdown(self)
}
fn is_done(&self) -> bool {
Aggregator::is_done(self)
}
}
#[cfg(feature = "default")]
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aggregator_config_default() {
let config = AggregatorConfig::default();
assert_eq!(config.interval, Duration::from_secs(5));
assert_eq!(config.queues, vec!["default".to_string()]);
assert_eq!(config.grace_period, Duration::from_secs(60));
assert_eq!(config.max_delay, None);
assert_eq!(config.max_size, None);
assert!(config.group_aggregator.is_none());
}
#[tokio::test]
async fn test_aggregator_shutdown() {
use crate::backend::{RedisBroker, RedisConnectionType};
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let broker = Arc::new(RedisBroker::new(redis_connection_config).await.unwrap());
let config = AggregatorConfig::default();
let aggregator = Aggregator::new(broker, config);
assert!(!aggregator.is_done());
aggregator.shutdown();
assert!(aggregator.is_done());
}
#[test]
fn test_group_aggregator_func() {
let aggregator = GroupAggregatorFunc::new(|group: &str, tasks: Vec<Task>| {
assert_eq!(group, "test-group");
assert_eq!(tasks.len(), 3);
Task::new("batch:process", b"aggregated")
});
let tasks = vec![
Task::new("task1", b"payload1").unwrap(),
Task::new("task2", b"payload2").unwrap(),
Task::new("task3", b"payload3").unwrap(),
];
let result = aggregator.aggregate("test-group", tasks);
assert!(result.is_ok());
let aggregated = result.unwrap();
assert_eq!(aggregated.get_type(), "batch:process");
assert_eq!(aggregated.get_payload(), b"aggregated");
}
#[tokio::test]
async fn test_group_aggregator_with_config() {
use crate::backend::{RedisBroker, RedisConnectionType};
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let broker = Arc::new(RedisBroker::new(redis_connection_config).await.unwrap());
let aggregator = Arc::new(GroupAggregatorFunc::new(|group: &str, tasks: Vec<Task>| {
let combined = format!("Aggregated {} tasks from group {}", tasks.len(), group);
Task::new("batch:process", combined.as_bytes())
}));
let config = AggregatorConfig {
interval: Duration::from_secs(5),
queues: vec!["default".to_string()],
grace_period: Duration::from_secs(60),
max_delay: None,
max_size: None,
group_aggregator: Some(aggregator),
};
assert!(config.group_aggregator.is_some());
let aggregator = Aggregator::new(broker, config);
assert!(!aggregator.is_done());
}
}