1use crate::{error::Result, message::Message, message::Payload, processor::Processor, Exchange};
74use std::collections::HashMap;
75use std::fmt;
76use std::sync::{Arc, Mutex};
77use std::time::{Duration, Instant};
78
79pub trait CompletionCondition: Send + Sync {
88 fn is_complete(&self, group: &[Message], first_seen: Instant) -> bool;
91}
92
93#[derive(Debug, Clone, Copy)]
95pub struct BySize(pub usize);
96
97impl CompletionCondition for BySize {
98 fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
99 group.len() >= self.0
100 }
101}
102
103#[derive(Debug, Clone, Copy)]
107pub struct ByTimeout(pub Duration);
108
109impl CompletionCondition for ByTimeout {
110 fn is_complete(&self, _group: &[Message], first_seen: Instant) -> bool {
111 first_seen.elapsed() >= self.0
112 }
113}
114
115pub struct ByPredicate<F: Fn(&[Message]) -> bool + Send + Sync>(pub F);
117
118impl<F> CompletionCondition for ByPredicate<F>
119where
120 F: Fn(&[Message]) -> bool + Send + Sync,
121{
122 fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
123 (self.0)(group)
124 }
125}
126
127pub struct ByWeight<F: Fn(&Message) -> u64 + Send + Sync> {
132 pub weight: F,
133 pub threshold: u64,
134}
135
136impl<F> CompletionCondition for ByWeight<F>
137where
138 F: Fn(&Message) -> u64 + Send + Sync,
139{
140 fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
141 group.iter().map(|m| (self.weight)(m)).sum::<u64>() >= self.threshold
142 }
143}
144
145pub trait AggregationStrategy: Send + Sync {
152 fn combine(&self, group: Vec<Message>) -> Option<Message>;
153}
154
155#[derive(Debug, Clone, Copy, Default)]
158pub struct ConcatText;
159
160impl AggregationStrategy for ConcatText {
161 fn combine(&self, group: Vec<Message>) -> Option<Message> {
162 if !group.iter().all(|m| m.body_text().is_some()) {
163 return None;
164 }
165 let concat: String = group.iter().map(|m| m.body_text().unwrap()).collect();
166 Some(Message::from_text(concat))
167 }
168}
169
170#[derive(Debug, Clone, Copy, Default)]
177pub struct JsonArray;
178
179impl AggregationStrategy for JsonArray {
180 fn combine(&self, group: Vec<Message>) -> Option<Message> {
181 let arr: Vec<serde_json::Value> = group
182 .into_iter()
183 .map(|m| match m.payload {
184 Payload::Text(s) => serde_json::Value::String(s),
185 Payload::Bytes(b) => {
186 serde_json::Value::Array(b.into_iter().map(serde_json::Value::from).collect())
187 }
188 Payload::Json(v) => v,
189 Payload::Empty => serde_json::Value::Null,
190 })
191 .collect();
192 Some(Message::new(Payload::Json(serde_json::Value::Array(arr))))
193 }
194}
195
196pub trait GroupStore: Send + Sync {
203 fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant);
206 fn take(&self, key: &str) -> Option<Vec<Message>>;
208 fn clear(&self);
210}
211
212struct InMemoryGroup {
213 messages: Vec<Message>,
214 first_seen: Instant,
215}
216
217#[derive(Default)]
219pub struct InMemoryGroupStore {
220 inner: Mutex<HashMap<String, InMemoryGroup>>,
221}
222
223impl InMemoryGroupStore {
224 pub fn new() -> Self {
225 Self::default()
226 }
227}
228
229impl GroupStore for InMemoryGroupStore {
230 fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant) {
231 let mut guard = self.inner.lock().unwrap();
232 let entry = guard
233 .entry(key.to_string())
234 .or_insert_with(|| InMemoryGroup {
235 messages: Vec::new(),
236 first_seen: Instant::now(),
237 });
238 entry.messages.push(msg);
239 (entry.messages.clone(), entry.first_seen)
240 }
241 fn take(&self, key: &str) -> Option<Vec<Message>> {
242 let mut guard = self.inner.lock().unwrap();
243 guard.remove(key).map(|g| g.messages)
244 }
245 fn clear(&self) {
246 let mut guard = self.inner.lock().unwrap();
247 guard.clear();
248 }
249}
250
251#[derive(Clone)]
258pub struct Aggregator {
259 correlation_header: String,
260 completion: Arc<dyn CompletionCondition>,
261 strategy: Arc<dyn AggregationStrategy>,
262 store: Arc<dyn GroupStore>,
263}
264
265impl fmt::Debug for Aggregator {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 f.debug_struct("Aggregator")
268 .field("correlation_header", &self.correlation_header)
269 .finish_non_exhaustive()
270 }
271}
272
273impl Aggregator {
274 pub fn new<H: Into<String>>(correlation_header: H, completion_size: usize) -> Self {
276 Self::with_completion(correlation_header, Arc::new(BySize(completion_size)))
277 }
278
279 pub fn with_completion<H: Into<String>>(
283 correlation_header: H,
284 completion: Arc<dyn CompletionCondition>,
285 ) -> Self {
286 Self {
287 correlation_header: correlation_header.into(),
288 completion,
289 strategy: Arc::new(ConcatText),
290 store: Arc::new(InMemoryGroupStore::new()),
291 }
292 }
293
294 pub fn weighted<H, F>(correlation_header: H, weight: F, threshold: u64) -> Self
296 where
297 H: Into<String>,
298 F: Fn(&Message) -> u64 + Send + Sync + 'static,
299 {
300 Self::with_completion(correlation_header, Arc::new(ByWeight { weight, threshold }))
301 }
302
303 pub fn timed<H: Into<String>>(correlation_header: H, dur: Duration) -> Self {
305 Self::with_completion(correlation_header, Arc::new(ByTimeout(dur)))
306 }
307
308 pub fn when<H, F>(correlation_header: H, predicate: F) -> Self
310 where
311 H: Into<String>,
312 F: Fn(&[Message]) -> bool + Send + Sync + 'static,
313 {
314 Self::with_completion(correlation_header, Arc::new(ByPredicate(predicate)))
315 }
316
317 pub fn with_strategy(mut self, strategy: Arc<dyn AggregationStrategy>) -> Self {
319 self.strategy = strategy;
320 self
321 }
322
323 pub fn with_store(mut self, store: Arc<dyn GroupStore>) -> Self {
325 self.store = store;
326 self
327 }
328
329 pub fn clear_store(&self) {
331 self.store.clear();
332 }
333}
334
335#[async_trait::async_trait]
336impl Processor for Aggregator {
337 async fn process(&self, exchange: &mut Exchange) -> Result<()> {
338 let key = match exchange.in_msg.header(&self.correlation_header) {
339 Some(k) => k.to_string(),
340 None => return Ok(()),
341 };
342 let (group, first_seen) = self.store.append(&key, exchange.in_msg.clone());
343 if self.completion.is_complete(&group, first_seen) {
344 if let Some(completed) = self.store.take(&key) {
345 if let Some(out) = self.strategy.combine(completed) {
346 exchange.out_msg = Some(out);
347 }
348 }
349 }
350 Ok(())
351 }
352}
353
354#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::message::{Exchange, Message, Payload};
362 use crate::route::Route;
363 use std::sync::atomic::{AtomicUsize, Ordering};
364
365 fn run(route: &Route, exchange: &mut Exchange) {
367 tokio::runtime::Runtime::new()
368 .unwrap()
369 .block_on(route.run(exchange))
370 .unwrap();
371 }
372
373 fn ex_with(header: &str, key: &str, msg: Message) -> Exchange {
374 let mut e = Exchange::new(msg);
375 e.in_msg.set_header(header, key);
376 e
377 }
378
379 #[test]
382 fn back_compat_size_two_concats_ab() {
383 let route = Route::new().add(Aggregator::new("corr", 2)).build();
384 let mut ex1 = ex_with("corr", "g", Message::from_text("A"));
385 run(&route, &mut ex1);
386 assert!(ex1.out_msg.is_none());
387 let mut ex2 = ex_with("corr", "g", Message::from_text("B"));
388 run(&route, &mut ex2);
389 assert_eq!(ex2.out_msg.unwrap().body_text(), Some("AB"));
390 }
391
392 #[test]
393 fn back_compat_three_messages() {
394 let route = Route::new().add(Aggregator::new("corr", 3)).build();
395 let mut last = None;
396 for s in ["A", "B", "C"] {
397 let mut ex = ex_with("corr", "123", Message::from_text(s));
398 run(&route, &mut ex);
399 last = Some(ex);
400 }
401 assert_eq!(last.unwrap().out_msg.unwrap().body_text(), Some("ABC"));
402 }
403
404 #[test]
405 fn ignores_messages_without_correlation_header() {
406 let route = Route::new().add(Aggregator::new("corr", 2)).build();
407 for s in ["A", "B"] {
408 let mut ex = Exchange::new(Message::from_text(s));
409 run(&route, &mut ex);
410 assert!(ex.out_msg.is_none());
411 }
412 }
413
414 #[test]
415 fn aggregates_multiple_batches_for_same_key() {
416 let route = Route::new().add(Aggregator::new("corr", 2)).build();
417 let mut ex1 = ex_with("corr", "same", Message::from_text("A"));
419 run(&route, &mut ex1);
420 assert!(ex1.out_msg.is_none());
421 let mut ex2 = ex_with("corr", "same", Message::from_text("B"));
422 run(&route, &mut ex2);
423 assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("AB"));
424 let mut ex3 = ex_with("corr", "same", Message::from_text("C"));
426 run(&route, &mut ex3);
427 assert!(ex3.out_msg.is_none());
428 let mut ex4 = ex_with("corr", "same", Message::from_text("D"));
429 run(&route, &mut ex4);
430 assert_eq!(ex4.out_msg.as_ref().unwrap().body_text(), Some("CD"));
431 }
432
433 #[test]
434 fn concat_text_non_text_group_emits_nothing() {
435 let route = Route::new().add(Aggregator::new("corr", 2)).build();
437 let mut ex1 = ex_with("corr", "m", Message::new(Payload::Bytes(vec![0, 1])));
438 run(&route, &mut ex1);
439 let mut ex2 = ex_with("corr", "m", Message::new(Payload::Bytes(vec![2, 3])));
440 run(&route, &mut ex2);
441 assert!(ex2.out_msg.is_none());
442 }
443
444 #[test]
445 fn clear_store_resets_groups() {
446 let agg = Aggregator::new("corr", 2);
447 let route = Route::new().add(agg.clone()).build();
448 let mut ex1 = ex_with("corr", "x", Message::from_text("A"));
449 run(&route, &mut ex1);
450 agg.clear_store();
451 let mut ex2 = ex_with("corr", "x", Message::from_text("B"));
452 run(&route, &mut ex2);
453 assert!(
454 ex2.out_msg.is_none(),
455 "clear_store should reset the group; B should be the first of a new batch"
456 );
457 }
458
459 #[test]
462 fn by_weight_completes_at_threshold() {
463 let threshold: u64 = 7;
466 let route = Route::new()
467 .add(Aggregator::weighted(
468 "block",
469 |m: &Message| {
470 m.header("voting_power")
471 .and_then(|s| s.parse().ok())
472 .unwrap_or(0)
473 },
474 threshold,
475 ))
476 .build();
477
478 for (vp, expect_out) in [(3u64, false), (3, false), (4, true)] {
479 let mut ex = Exchange::new(Message::from_text(format!("vote-vp{vp}")));
480 ex.in_msg.set_header("block", "h=42");
481 ex.in_msg.set_header("voting_power", vp.to_string());
482 run(&route, &mut ex);
483 assert_eq!(
484 ex.out_msg.is_some(),
485 expect_out,
486 "vp={vp}: expected out_msg={expect_out}"
487 );
488 }
489 }
490
491 #[test]
492 fn by_weight_fires_exactly_at_threshold_boundary() {
493 let route = Route::new()
495 .add(Aggregator::weighted(
496 "block",
497 |m: &Message| {
498 m.header("voting_power")
499 .and_then(|s| s.parse().ok())
500 .unwrap_or(0)
501 },
502 6,
503 ))
504 .build();
505 let mut ex1 = Exchange::new(Message::from_text("a"));
507 ex1.in_msg.set_header("block", "h=1");
508 ex1.in_msg.set_header("voting_power", "3");
509 run(&route, &mut ex1);
510 assert!(ex1.out_msg.is_none());
511 let mut ex2 = Exchange::new(Message::from_text("b"));
512 ex2.in_msg.set_header("block", "h=1");
513 ex2.in_msg.set_header("voting_power", "3");
514 run(&route, &mut ex2);
515 assert!(ex2.out_msg.is_some(), "sum=6, threshold=6: should fire");
516 }
517
518 #[test]
519 fn by_weight_isolated_per_key() {
520 let route = Route::new()
521 .add(Aggregator::weighted(
522 "block",
523 |m: &Message| {
524 m.header("voting_power")
525 .and_then(|s| s.parse().ok())
526 .unwrap_or(0)
527 },
528 4,
529 ))
530 .build();
531
532 for (block, vp, expect) in [
534 ("A", 2, false),
535 ("B", 1, false),
536 ("A", 2, true),
537 ("B", 1, false),
538 ] {
539 let mut ex = Exchange::new(Message::from_text("v"));
540 ex.in_msg.set_header("block", block);
541 ex.in_msg.set_header("voting_power", vp.to_string());
542 run(&route, &mut ex);
543 assert_eq!(ex.out_msg.is_some(), expect, "block={block} vp={vp}");
544 }
545 }
546
547 #[test]
548 fn by_predicate_completes() {
549 let route = Route::new()
551 .add(Aggregator::when("corr", |g: &[Message]| {
552 g.iter().any(|m| m.body_text() == Some("STOP"))
553 }))
554 .build();
555 let mut ex1 = ex_with("corr", "x", Message::from_text("go"));
556 run(&route, &mut ex1);
557 assert!(ex1.out_msg.is_none());
558 let mut ex2 = ex_with("corr", "x", Message::from_text("STOP"));
559 run(&route, &mut ex2);
560 assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("goSTOP"));
561 }
562
563 #[test]
564 fn by_timeout_lazy_completes_on_next_arrival() {
565 let route = Route::new()
567 .add(Aggregator::timed("corr", Duration::from_millis(40)))
568 .build();
569 let mut ex1 = ex_with("corr", "t", Message::from_text("A"));
570 run(&route, &mut ex1);
571 assert!(ex1.out_msg.is_none(), "first message: deadline not reached");
572 let mut ex2 = ex_with("corr", "t", Message::from_text("B"));
574 run(&route, &mut ex2);
575 assert!(ex2.out_msg.is_none(), "B arrived too soon");
576 std::thread::sleep(Duration::from_millis(60));
578 let mut ex3 = ex_with("corr", "t", Message::from_text("C"));
579 run(&route, &mut ex3);
580 assert_eq!(ex3.out_msg.as_ref().unwrap().body_text(), Some("ABC"));
581 }
582
583 #[test]
586 fn json_array_strategy_emits_array_of_mixed_payloads() {
587 let route = Route::new()
588 .add(Aggregator::new("corr", 4).with_strategy(Arc::new(JsonArray)))
589 .build();
590 let mut ex1 = ex_with("corr", "j", Message::from_text("hi"));
591 run(&route, &mut ex1);
592 let mut ex2 = ex_with("corr", "j", Message::new(Payload::Bytes(vec![1, 2])));
593 run(&route, &mut ex2);
594 let mut ex3 = ex_with(
595 "corr",
596 "j",
597 Message::new(Payload::Json(serde_json::json!({"k": "v"}))),
598 );
599 run(&route, &mut ex3);
600 let mut ex4 = ex_with("corr", "j", Message::new(Payload::Empty));
601 run(&route, &mut ex4);
602
603 let out = ex4
604 .out_msg
605 .expect("JsonArray must always emit on completion");
606 let Payload::Json(serde_json::Value::Array(arr)) = out.payload else {
607 panic!("JsonArray strategy must emit Payload::Json(Array)");
608 };
609 assert_eq!(arr.len(), 4);
610 assert_eq!(arr[0], serde_json::Value::String("hi".into()));
611 assert_eq!(arr[1], serde_json::json!([1, 2]));
612 assert_eq!(arr[2], serde_json::json!({"k": "v"}));
613 assert_eq!(arr[3], serde_json::Value::Null);
614 }
615
616 struct CountingStore {
620 inner: InMemoryGroupStore,
621 appends: AtomicUsize,
622 takes: AtomicUsize,
623 }
624 impl CountingStore {
625 fn new() -> Self {
626 Self {
627 inner: InMemoryGroupStore::new(),
628 appends: AtomicUsize::new(0),
629 takes: AtomicUsize::new(0),
630 }
631 }
632 }
633 impl GroupStore for CountingStore {
634 fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant) {
635 self.appends.fetch_add(1, Ordering::SeqCst);
636 self.inner.append(key, msg)
637 }
638 fn take(&self, key: &str) -> Option<Vec<Message>> {
639 self.takes.fetch_add(1, Ordering::SeqCst);
640 self.inner.take(key)
641 }
642 fn clear(&self) {
643 self.inner.clear();
644 }
645 }
646
647 #[test]
648 fn custom_group_store_is_used() {
649 let store = Arc::new(CountingStore::new());
650 let route = Route::new()
651 .add(Aggregator::new("corr", 2).with_store(store.clone()))
652 .build();
653 let mut ex1 = ex_with("corr", "k", Message::from_text("A"));
654 run(&route, &mut ex1);
655 let mut ex2 = ex_with("corr", "k", Message::from_text("B"));
656 run(&route, &mut ex2);
657 assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("AB"));
658 assert_eq!(store.appends.load(Ordering::SeqCst), 2);
659 assert_eq!(store.takes.load(Ordering::SeqCst), 1);
660 }
661}