1use async_trait::async_trait;
17use coreon_core::{Exchange, Processor, Result};
18use std::{
19 collections::HashMap,
20 sync::{Arc, OnceLock},
21 time::{Duration, Instant},
22};
23use tokio::sync::Mutex;
24use tracing::{debug, warn};
25
26pub type CorrelationFn = dyn Fn(&Exchange) -> Option<String> + Send + Sync;
27pub type CombineFn = dyn Fn(Option<Exchange>, Exchange) -> Exchange + Send + Sync;
28
29struct GroupState {
30 combined: Exchange,
31 count: usize,
32 first_seen: Instant,
33}
34
35struct Inner {
36 correlation: Arc<CorrelationFn>,
37 combine: Arc<CombineFn>,
38 completion_size: Option<usize>,
39 completion_timeout: Option<Duration>,
40 sweep_interval: Duration,
41 groups: Mutex<HashMap<String, GroupState>>,
42 downstream: Arc<dyn Processor>,
43}
44
45pub struct Aggregator {
46 inner: Arc<Inner>,
47 eviction_started: OnceLock<()>,
48}
49
50#[derive(Default)]
51pub struct AggregatorConfig {
52 pub completion_size: Option<usize>,
53 pub completion_timeout: Option<Duration>,
54 pub sweep_interval: Option<Duration>,
57}
58
59impl Aggregator {
60 pub fn new<C, M>(
61 correlation: C,
62 combine: M,
63 config: AggregatorConfig,
64 downstream: Arc<dyn Processor>,
65 ) -> Arc<Self>
66 where
67 C: Fn(&Exchange) -> Option<String> + Send + Sync + 'static,
68 M: Fn(Option<Exchange>, Exchange) -> Exchange + Send + Sync + 'static,
69 {
70 Self::from_arcs(Arc::new(correlation), Arc::new(combine), config, downstream)
71 }
72
73 pub fn from_arcs(
76 correlation: Arc<CorrelationFn>,
77 combine: Arc<CombineFn>,
78 config: AggregatorConfig,
79 downstream: Arc<dyn Processor>,
80 ) -> Arc<Self> {
81 let sweep = config
82 .sweep_interval
83 .or_else(|| config.completion_timeout.map(|t| t / 4))
84 .unwrap_or_else(|| Duration::from_millis(100))
85 .max(Duration::from_millis(50));
86 Arc::new(Self {
87 inner: Arc::new(Inner {
88 correlation,
89 combine,
90 completion_size: config.completion_size,
91 completion_timeout: config.completion_timeout,
92 sweep_interval: sweep,
93 groups: Mutex::new(HashMap::new()),
94 downstream,
95 }),
96 eviction_started: OnceLock::new(),
97 })
98 }
99
100 fn ensure_eviction_started(&self) {
101 if self.inner.completion_timeout.is_none() {
103 return;
104 }
105 self.eviction_started.get_or_init(|| {
106 let inner = self.inner.clone();
107 let sweep = inner.sweep_interval;
108 let timeout = inner
109 .completion_timeout
110 .expect("timeout checked above");
111 tokio::spawn(async move {
112 let mut ticker = tokio::time::interval(sweep);
113 ticker.tick().await; loop {
115 ticker.tick().await;
116 sweep_expired(&inner, timeout).await;
117 }
118 });
119 });
120 }
121}
122
123async fn sweep_expired(inner: &Arc<Inner>, timeout: Duration) {
124 let now = Instant::now();
125 let expired: Vec<(String, Exchange)> = {
126 let mut groups = inner.groups.lock().await;
127 let keys: Vec<String> = groups
128 .iter()
129 .filter(|(_, g)| now.duration_since(g.first_seen) >= timeout)
130 .map(|(k, _)| k.clone())
131 .collect();
132 keys.into_iter()
133 .filter_map(|k| groups.remove(&k).map(|g| (k, g.combined)))
134 .collect()
135 };
136 for (key, combined) in expired {
137 debug!(key = %key, "aggregator: timeout flush");
138 let mut ex = combined;
139 if let Err(e) = inner.downstream.process(&mut ex).await {
140 warn!(key = %key, error = %e, "aggregator: downstream error on flush");
141 }
142 }
143}
144
145#[async_trait]
146impl Processor for Aggregator {
147 async fn process(&self, exchange: &mut Exchange) -> Result<()> {
148 self.ensure_eviction_started();
149
150 let key = match (self.inner.correlation)(exchange) {
151 Some(k) => k,
152 None => {
153 return Ok(());
156 }
157 };
158
159 let incoming = std::mem::take(exchange);
162
163 let completed = {
165 let mut groups = self.inner.groups.lock().await;
166 let state = groups.remove(&key);
167 let (prev_count, prev_first, prev_combined) = match state {
168 Some(s) => (s.count, s.first_seen, Some(s.combined)),
169 None => (0, Instant::now(), None),
170 };
171 let combined = (self.inner.combine)(prev_combined, incoming);
172 let new_state = GroupState {
173 combined,
174 count: prev_count + 1,
175 first_seen: prev_first,
176 };
177 if matches!(self.inner.completion_size, Some(max) if new_state.count >= max) {
179 Some(new_state.combined)
180 } else {
181 groups.insert(key.clone(), new_state);
182 None
183 }
184 };
185
186 if let Some(mut combined) = completed {
187 debug!(key = %key, "aggregator: size flush");
188 self.inner.downstream.process(&mut combined).await?;
189 }
190 Ok(())
191 }
192}