use async_trait::async_trait;
use coreon_core::{Exchange, Processor, Result};
use std::{
collections::HashMap,
sync::{Arc, OnceLock},
time::{Duration, Instant},
};
use tokio::sync::Mutex;
use tracing::{debug, warn};
pub type CorrelationFn = dyn Fn(&Exchange) -> Option<String> + Send + Sync;
pub type CombineFn = dyn Fn(Option<Exchange>, Exchange) -> Exchange + Send + Sync;
struct GroupState {
combined: Exchange,
count: usize,
first_seen: Instant,
}
struct Inner {
correlation: Arc<CorrelationFn>,
combine: Arc<CombineFn>,
completion_size: Option<usize>,
completion_timeout: Option<Duration>,
sweep_interval: Duration,
groups: Mutex<HashMap<String, GroupState>>,
downstream: Arc<dyn Processor>,
}
pub struct Aggregator {
inner: Arc<Inner>,
eviction_started: OnceLock<()>,
}
#[derive(Default)]
pub struct AggregatorConfig {
pub completion_size: Option<usize>,
pub completion_timeout: Option<Duration>,
pub sweep_interval: Option<Duration>,
}
impl Aggregator {
pub fn new<C, M>(
correlation: C,
combine: M,
config: AggregatorConfig,
downstream: Arc<dyn Processor>,
) -> Arc<Self>
where
C: Fn(&Exchange) -> Option<String> + Send + Sync + 'static,
M: Fn(Option<Exchange>, Exchange) -> Exchange + Send + Sync + 'static,
{
Self::from_arcs(Arc::new(correlation), Arc::new(combine), config, downstream)
}
pub fn from_arcs(
correlation: Arc<CorrelationFn>,
combine: Arc<CombineFn>,
config: AggregatorConfig,
downstream: Arc<dyn Processor>,
) -> Arc<Self> {
let sweep = config
.sweep_interval
.or_else(|| config.completion_timeout.map(|t| t / 4))
.unwrap_or_else(|| Duration::from_millis(100))
.max(Duration::from_millis(50));
Arc::new(Self {
inner: Arc::new(Inner {
correlation,
combine,
completion_size: config.completion_size,
completion_timeout: config.completion_timeout,
sweep_interval: sweep,
groups: Mutex::new(HashMap::new()),
downstream,
}),
eviction_started: OnceLock::new(),
})
}
fn ensure_eviction_started(&self) {
if self.inner.completion_timeout.is_none() {
return;
}
self.eviction_started.get_or_init(|| {
let inner = self.inner.clone();
let sweep = inner.sweep_interval;
let timeout = inner
.completion_timeout
.expect("timeout checked above");
tokio::spawn(async move {
let mut ticker = tokio::time::interval(sweep);
ticker.tick().await; loop {
ticker.tick().await;
sweep_expired(&inner, timeout).await;
}
});
});
}
}
async fn sweep_expired(inner: &Arc<Inner>, timeout: Duration) {
let now = Instant::now();
let expired: Vec<(String, Exchange)> = {
let mut groups = inner.groups.lock().await;
let keys: Vec<String> = groups
.iter()
.filter(|(_, g)| now.duration_since(g.first_seen) >= timeout)
.map(|(k, _)| k.clone())
.collect();
keys.into_iter()
.filter_map(|k| groups.remove(&k).map(|g| (k, g.combined)))
.collect()
};
for (key, combined) in expired {
debug!(key = %key, "aggregator: timeout flush");
let mut ex = combined;
if let Err(e) = inner.downstream.process(&mut ex).await {
warn!(key = %key, error = %e, "aggregator: downstream error on flush");
}
}
}
#[async_trait]
impl Processor for Aggregator {
async fn process(&self, exchange: &mut Exchange) -> Result<()> {
self.ensure_eviction_started();
let key = match (self.inner.correlation)(exchange) {
Some(k) => k,
None => {
return Ok(());
}
};
let incoming = std::mem::take(exchange);
let completed = {
let mut groups = self.inner.groups.lock().await;
let state = groups.remove(&key);
let (prev_count, prev_first, prev_combined) = match state {
Some(s) => (s.count, s.first_seen, Some(s.combined)),
None => (0, Instant::now(), None),
};
let combined = (self.inner.combine)(prev_combined, incoming);
let new_state = GroupState {
combined,
count: prev_count + 1,
first_seen: prev_first,
};
if matches!(self.inner.completion_size, Some(max) if new_state.count >= max) {
Some(new_state.combined)
} else {
groups.insert(key.clone(), new_state);
None
}
};
if let Some(mut combined) = completed {
debug!(key = %key, "aggregator: size flush");
self.inner.downstream.process(&mut combined).await?;
}
Ok(())
}
}