use crate::{Result, Task};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
type AggregatorFn = Arc<dyn Fn(&str, Vec<Task>) -> Result<Option<Task>> + Send + Sync>;
#[derive(Debug, Clone)]
pub struct AggregatorConfig {
pub max_size: usize,
pub grace_period: Duration,
pub max_delay: Duration,
}
impl Default for AggregatorConfig {
fn default() -> Self {
Self {
max_size: 10,
grace_period: Duration::from_secs(30),
max_delay: Duration::from_secs(300), }
}
}
impl AggregatorConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
#[must_use]
pub fn grace_period(mut self, duration: Duration) -> Self {
self.grace_period = duration;
self
}
#[must_use]
pub fn max_delay(mut self, duration: Duration) -> Self {
self.max_delay = duration;
self
}
}
pub trait Aggregator: Send + Sync {
fn aggregate(&self, group: &str, tasks: Vec<Task>) -> Result<Option<Task>>;
}
pub struct GroupAggregatorFunc {
config: AggregatorConfig,
f: AggregatorFn,
}
impl GroupAggregatorFunc {
pub fn new<F>(f: F) -> Self
where
F: Fn(&str, Vec<Task>) -> Result<Option<Task>> + Send + Sync + 'static,
{
Self {
config: AggregatorConfig::default(),
f: Arc::new(f),
}
}
pub fn with_config<F>(f: F, config: AggregatorConfig) -> Self
where
F: Fn(&str, Vec<Task>) -> Result<Option<Task>> + Send + Sync + 'static,
{
Self {
config,
f: Arc::new(f),
}
}
pub fn config(&self) -> &AggregatorConfig {
&self.config
}
}
impl Aggregator for GroupAggregatorFunc {
fn aggregate(&self, group: &str, tasks: Vec<Task>) -> Result<Option<Task>> {
(self.f)(group, tasks)
}
}
pub struct FirstTaskAggregator;
impl Aggregator for FirstTaskAggregator {
fn aggregate(&self, _group: &str, mut tasks: Vec<Task>) -> Result<Option<Task>> {
Ok(tasks.pop())
}
}
pub struct AggregatorManager {
aggregators: HashMap<String, Arc<dyn Aggregator>>,
default_config: AggregatorConfig,
}
impl AggregatorManager {
pub fn new() -> Self {
Self {
aggregators: HashMap::new(),
default_config: AggregatorConfig::default(),
}
}
pub fn register(&mut self, group: &str, aggregator: Arc<dyn Aggregator>) {
self.aggregators.insert(group.to_string(), aggregator);
}
pub fn register_fn<F>(&mut self, group: &str, f: F)
where
F: Fn(&str, Vec<Task>) -> Result<Option<Task>> + Send + Sync + 'static,
{
let aggregator = GroupAggregatorFunc::new(f);
self.aggregators.insert(group.to_string(), Arc::new(aggregator));
}
pub fn get(&self, group: &str) -> Option<&Arc<dyn Aggregator>> {
self.aggregators.get(group)
}
pub fn has(&self, group: &str) -> bool {
self.aggregators.contains_key(group)
}
pub fn default_config(&self) -> &AggregatorConfig {
&self.default_config
}
pub fn set_default_config(&mut self, config: AggregatorConfig) {
self.default_config = config;
}
}
impl Default for AggregatorManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aggregator_config() {
let config = AggregatorConfig::new()
.max_size(20)
.grace_period(Duration::from_secs(60))
.max_delay(Duration::from_secs(600));
assert_eq!(config.max_size, 20);
assert_eq!(config.grace_period, Duration::from_secs(60));
assert_eq!(config.max_delay, Duration::from_secs(600));
}
#[test]
fn test_aggregator_manager() {
let mut manager = AggregatorManager::new();
manager.register_fn("test_group", |_, tasks| {
Ok(tasks.into_iter().next())
});
assert!(manager.has("test_group"));
assert!(!manager.has("unknown_group"));
}
#[test]
fn test_first_task_aggregator() {
let aggregator = FirstTaskAggregator;
let task1 = Task {
id: "task1".to_string(),
task_type: "test".to_string(),
queue: "default".to_string(),
payload: vec![1, 2, 3],
options: Default::default(),
status: Default::default(),
created_at: 0,
enqueued_at: None,
processed_at: None,
retry_cnt: 0,
last_error: None,
};
let task2 = Task {
id: "task2".to_string(),
..task1.clone()
};
let result = aggregator.aggregate("test", vec![task1, task2]).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().id, "task2");
}
#[test]
fn test_group_aggregator_func() {
let aggregator = GroupAggregatorFunc::new(|_group, tasks| {
Ok(Some(Task::builder("batch").queue("default").payload(&tasks.len())?.build()?))
});
let task = Task {
id: "task1".to_string(),
task_type: "test".to_string(),
queue: "default".to_string(),
payload: vec![1, 2, 3],
options: Default::default(),
status: Default::default(),
created_at: 0,
enqueued_at: None,
processed_at: None,
retry_cnt: 0,
last_error: None,
};
let result = aggregator.aggregate("test", vec![task.clone(), task]).unwrap();
assert!(result.is_some());
}
}