1use crate::{error::Result, message::Message, message::Payload, processor::Processor, Exchange};
74use std::collections::HashMap;
75use std::fmt;
76use std::sync::{Arc, Mutex};
77use std::time::{Duration, Instant};
78
79pub trait CompletionCondition: Send + Sync {
88 fn is_complete(&self, group: &[Message], first_seen: Instant) -> bool;
91}
92
93#[derive(Debug, Clone, Copy)]
95pub struct BySize(pub usize);
96
97impl CompletionCondition for BySize {
98 fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
99 group.len() >= self.0
100 }
101}
102
103#[derive(Debug, Clone, Copy)]
107pub struct ByTimeout(pub Duration);
108
109impl CompletionCondition for ByTimeout {
110 fn is_complete(&self, _group: &[Message], first_seen: Instant) -> bool {
111 first_seen.elapsed() >= self.0
112 }
113}
114
115pub struct ByPredicate<F: Fn(&[Message]) -> bool + Send + Sync>(pub F);
117
118impl<F> CompletionCondition for ByPredicate<F>
119where
120 F: Fn(&[Message]) -> bool + Send + Sync,
121{
122 fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
123 (self.0)(group)
124 }
125}
126
127pub struct ByWeight<F: Fn(&Message) -> u64 + Send + Sync> {
132 pub weight: F,
133 pub threshold: u64,
134}
135
136impl<F> CompletionCondition for ByWeight<F>
137where
138 F: Fn(&Message) -> u64 + Send + Sync,
139{
140 fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
141 group.iter().map(|m| (self.weight)(m)).sum::<u64>() >= self.threshold
142 }
143}
144
145pub trait AggregationStrategy: Send + Sync {
152 fn combine(&self, group: Vec<Message>) -> Option<Message>;
153}
154
155#[derive(Debug, Clone, Copy, Default)]
158pub struct ConcatText;
159
160impl AggregationStrategy for ConcatText {
161 fn combine(&self, group: Vec<Message>) -> Option<Message> {
162 if !group.iter().all(|m| m.body_text().is_some()) {
163 return None;
164 }
165 let concat: String = group.iter().map(|m| m.body_text().unwrap()).collect();
166 Some(Message::from_text(concat))
167 }
168}
169
170#[derive(Debug, Clone, Copy, Default)]
177pub struct JsonArray;
178
179#[derive(Debug, Clone, Copy, Default)]
183pub struct EmitSignal;
184
185impl AggregationStrategy for EmitSignal {
186 fn combine(&self, _group: Vec<Message>) -> Option<Message> {
187 Some(Message::default())
188 }
189}
190
191impl AggregationStrategy for JsonArray {
192 fn combine(&self, group: Vec<Message>) -> Option<Message> {
193 let arr: Vec<serde_json::Value> = group
194 .into_iter()
195 .map(|m| match m.payload {
196 Payload::Text(s) => serde_json::Value::String(s),
197 Payload::Bytes(b) => {
198 serde_json::Value::Array(b.into_iter().map(serde_json::Value::from).collect())
199 }
200 Payload::Json(v) => v,
201 Payload::Empty => serde_json::Value::Null,
202 })
203 .collect();
204 Some(Message::new(Payload::Json(serde_json::Value::Array(arr))))
205 }
206}
207
208pub trait GroupStore: Send + Sync {
215 fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant);
218 fn take(&self, key: &str) -> Option<Vec<Message>>;
220 fn clear(&self);
222}
223
224struct InMemoryGroup {
225 messages: Vec<Message>,
226 first_seen: Instant,
227}
228
229#[derive(Default)]
231pub struct InMemoryGroupStore {
232 inner: Mutex<HashMap<String, InMemoryGroup>>,
233}
234
235impl InMemoryGroupStore {
236 pub fn new() -> Self {
237 Self::default()
238 }
239}
240
241impl GroupStore for InMemoryGroupStore {
242 fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant) {
243 let mut guard = self.inner.lock().unwrap();
244 let entry = guard
245 .entry(key.to_string())
246 .or_insert_with(|| InMemoryGroup {
247 messages: Vec::new(),
248 first_seen: Instant::now(),
249 });
250 entry.messages.push(msg);
251 (entry.messages.clone(), entry.first_seen)
252 }
253 fn take(&self, key: &str) -> Option<Vec<Message>> {
254 let mut guard = self.inner.lock().unwrap();
255 guard.remove(key).map(|g| g.messages)
256 }
257 fn clear(&self) {
258 let mut guard = self.inner.lock().unwrap();
259 guard.clear();
260 }
261}
262
263#[derive(Clone)]
270pub struct Aggregator {
271 correlation_header: String,
272 completion: Arc<dyn CompletionCondition>,
273 strategy: Arc<dyn AggregationStrategy>,
274 store: Arc<dyn GroupStore>,
275}
276
277impl fmt::Debug for Aggregator {
278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
279 f.debug_struct("Aggregator")
280 .field("correlation_header", &self.correlation_header)
281 .finish_non_exhaustive()
282 }
283}
284
285impl Aggregator {
286 pub fn new<H: Into<String>>(correlation_header: H, completion_size: usize) -> Self {
288 Self::with_completion(correlation_header, Arc::new(BySize(completion_size)))
289 }
290
291 pub fn with_completion<H: Into<String>>(
295 correlation_header: H,
296 completion: Arc<dyn CompletionCondition>,
297 ) -> Self {
298 Self {
299 correlation_header: correlation_header.into(),
300 completion,
301 strategy: Arc::new(ConcatText),
302 store: Arc::new(InMemoryGroupStore::new()),
303 }
304 }
305
306 pub fn weighted<H, F>(correlation_header: H, weight: F, threshold: u64) -> Self
308 where
309 H: Into<String>,
310 F: Fn(&Message) -> u64 + Send + Sync + 'static,
311 {
312 Self::with_completion(correlation_header, Arc::new(ByWeight { weight, threshold }))
313 }
314
315 pub fn timed<H: Into<String>>(correlation_header: H, dur: Duration) -> Self {
317 Self::with_completion(correlation_header, Arc::new(ByTimeout(dur)))
318 }
319
320 pub fn when<H, F>(correlation_header: H, predicate: F) -> Self
322 where
323 H: Into<String>,
324 F: Fn(&[Message]) -> bool + Send + Sync + 'static,
325 {
326 Self::with_completion(correlation_header, Arc::new(ByPredicate(predicate)))
327 }
328
329 pub fn with_strategy(mut self, strategy: Arc<dyn AggregationStrategy>) -> Self {
331 self.strategy = strategy;
332 self
333 }
334
335 pub fn with_store(mut self, store: Arc<dyn GroupStore>) -> Self {
337 self.store = store;
338 self
339 }
340
341 pub fn clear_store(&self) {
343 self.store.clear();
344 }
345}
346
347#[async_trait::async_trait]
348impl Processor for Aggregator {
349 async fn process(&self, exchange: &mut Exchange) -> Result<()> {
350 let key = match exchange.in_msg.header(&self.correlation_header) {
351 Some(k) => k.to_string(),
352 None => return Ok(()),
353 };
354 let (group, first_seen) = self.store.append(&key, exchange.in_msg.clone());
355 if self.completion.is_complete(&group, first_seen) {
356 if let Some(completed) = self.store.take(&key) {
357 if let Some(out) = self.strategy.combine(completed) {
358 exchange.out_msg = Some(out);
359 }
360 }
361 }
362 Ok(())
363 }
364}
365
366#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::message::{Exchange, Message, Payload};
374 use crate::route::Route;
375 use std::sync::atomic::{AtomicUsize, Ordering};
376
377 fn run(route: &Route, exchange: &mut Exchange) {
379 tokio::runtime::Runtime::new()
380 .unwrap()
381 .block_on(route.run(exchange))
382 .unwrap();
383 }
384
385 fn ex_with(header: &str, key: &str, msg: Message) -> Exchange {
386 let mut e = Exchange::new(msg);
387 e.in_msg.set_header(header, key);
388 e
389 }
390
391 #[test]
394 fn back_compat_size_two_concats_ab() {
395 let route = Route::new().add(Aggregator::new("corr", 2)).build();
396 let mut ex1 = ex_with("corr", "g", Message::from_text("A"));
397 run(&route, &mut ex1);
398 assert!(ex1.out_msg.is_none());
399 let mut ex2 = ex_with("corr", "g", Message::from_text("B"));
400 run(&route, &mut ex2);
401 assert_eq!(ex2.out_msg.unwrap().body_text(), Some("AB"));
402 }
403
404 #[test]
405 fn back_compat_three_messages() {
406 let route = Route::new().add(Aggregator::new("corr", 3)).build();
407 let mut last = None;
408 for s in ["A", "B", "C"] {
409 let mut ex = ex_with("corr", "123", Message::from_text(s));
410 run(&route, &mut ex);
411 last = Some(ex);
412 }
413 assert_eq!(last.unwrap().out_msg.unwrap().body_text(), Some("ABC"));
414 }
415
416 #[test]
417 fn ignores_messages_without_correlation_header() {
418 let route = Route::new().add(Aggregator::new("corr", 2)).build();
419 for s in ["A", "B"] {
420 let mut ex = Exchange::new(Message::from_text(s));
421 run(&route, &mut ex);
422 assert!(ex.out_msg.is_none());
423 }
424 }
425
426 #[test]
427 fn aggregates_multiple_batches_for_same_key() {
428 let route = Route::new().add(Aggregator::new("corr", 2)).build();
429 let mut ex1 = ex_with("corr", "same", Message::from_text("A"));
431 run(&route, &mut ex1);
432 assert!(ex1.out_msg.is_none());
433 let mut ex2 = ex_with("corr", "same", Message::from_text("B"));
434 run(&route, &mut ex2);
435 assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("AB"));
436 let mut ex3 = ex_with("corr", "same", Message::from_text("C"));
438 run(&route, &mut ex3);
439 assert!(ex3.out_msg.is_none());
440 let mut ex4 = ex_with("corr", "same", Message::from_text("D"));
441 run(&route, &mut ex4);
442 assert_eq!(ex4.out_msg.as_ref().unwrap().body_text(), Some("CD"));
443 }
444
445 #[test]
446 fn concat_text_non_text_group_emits_nothing() {
447 let route = Route::new().add(Aggregator::new("corr", 2)).build();
449 let mut ex1 = ex_with("corr", "m", Message::new(Payload::Bytes(vec![0, 1])));
450 run(&route, &mut ex1);
451 let mut ex2 = ex_with("corr", "m", Message::new(Payload::Bytes(vec![2, 3])));
452 run(&route, &mut ex2);
453 assert!(ex2.out_msg.is_none());
454 }
455
456 #[test]
457 fn clear_store_resets_groups() {
458 let agg = Aggregator::new("corr", 2);
459 let route = Route::new().add(agg.clone()).build();
460 let mut ex1 = ex_with("corr", "x", Message::from_text("A"));
461 run(&route, &mut ex1);
462 agg.clear_store();
463 let mut ex2 = ex_with("corr", "x", Message::from_text("B"));
464 run(&route, &mut ex2);
465 assert!(
466 ex2.out_msg.is_none(),
467 "clear_store should reset the group; B should be the first of a new batch"
468 );
469 }
470
471 #[test]
474 fn by_weight_completes_at_threshold() {
475 let threshold: u64 = 7;
478 let route = Route::new()
479 .add(Aggregator::weighted(
480 "block",
481 |m: &Message| {
482 m.header("voting_power")
483 .and_then(|s| s.parse().ok())
484 .unwrap_or(0)
485 },
486 threshold,
487 ))
488 .build();
489
490 for (vp, expect_out) in [(3u64, false), (3, false), (4, true)] {
491 let mut ex = Exchange::new(Message::from_text(format!("vote-vp{vp}")));
492 ex.in_msg.set_header("block", "h=42");
493 ex.in_msg.set_header("voting_power", vp.to_string());
494 run(&route, &mut ex);
495 assert_eq!(
496 ex.out_msg.is_some(),
497 expect_out,
498 "vp={vp}: expected out_msg={expect_out}"
499 );
500 }
501 }
502
503 #[test]
504 fn by_weight_fires_exactly_at_threshold_boundary() {
505 let route = Route::new()
507 .add(Aggregator::weighted(
508 "block",
509 |m: &Message| {
510 m.header("voting_power")
511 .and_then(|s| s.parse().ok())
512 .unwrap_or(0)
513 },
514 6,
515 ))
516 .build();
517 let mut ex1 = Exchange::new(Message::from_text("a"));
519 ex1.in_msg.set_header("block", "h=1");
520 ex1.in_msg.set_header("voting_power", "3");
521 run(&route, &mut ex1);
522 assert!(ex1.out_msg.is_none());
523 let mut ex2 = Exchange::new(Message::from_text("b"));
524 ex2.in_msg.set_header("block", "h=1");
525 ex2.in_msg.set_header("voting_power", "3");
526 run(&route, &mut ex2);
527 assert!(ex2.out_msg.is_some(), "sum=6, threshold=6: should fire");
528 }
529
530 #[test]
531 fn by_weight_isolated_per_key() {
532 let route = Route::new()
533 .add(Aggregator::weighted(
534 "block",
535 |m: &Message| {
536 m.header("voting_power")
537 .and_then(|s| s.parse().ok())
538 .unwrap_or(0)
539 },
540 4,
541 ))
542 .build();
543
544 for (block, vp, expect) in [
546 ("A", 2, false),
547 ("B", 1, false),
548 ("A", 2, true),
549 ("B", 1, false),
550 ] {
551 let mut ex = Exchange::new(Message::from_text("v"));
552 ex.in_msg.set_header("block", block);
553 ex.in_msg.set_header("voting_power", vp.to_string());
554 run(&route, &mut ex);
555 assert_eq!(ex.out_msg.is_some(), expect, "block={block} vp={vp}");
556 }
557 }
558
559 #[test]
560 fn by_predicate_completes() {
561 let route = Route::new()
563 .add(Aggregator::when("corr", |g: &[Message]| {
564 g.iter().any(|m| m.body_text() == Some("STOP"))
565 }))
566 .build();
567 let mut ex1 = ex_with("corr", "x", Message::from_text("go"));
568 run(&route, &mut ex1);
569 assert!(ex1.out_msg.is_none());
570 let mut ex2 = ex_with("corr", "x", Message::from_text("STOP"));
571 run(&route, &mut ex2);
572 assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("goSTOP"));
573 }
574
575 #[test]
576 fn by_timeout_lazy_completes_on_next_arrival() {
577 let route = Route::new()
579 .add(Aggregator::timed("corr", Duration::from_millis(40)))
580 .build();
581 let mut ex1 = ex_with("corr", "t", Message::from_text("A"));
582 run(&route, &mut ex1);
583 assert!(ex1.out_msg.is_none(), "first message: deadline not reached");
584 let mut ex2 = ex_with("corr", "t", Message::from_text("B"));
586 run(&route, &mut ex2);
587 assert!(ex2.out_msg.is_none(), "B arrived too soon");
588 std::thread::sleep(Duration::from_millis(60));
590 let mut ex3 = ex_with("corr", "t", Message::from_text("C"));
591 run(&route, &mut ex3);
592 assert_eq!(ex3.out_msg.as_ref().unwrap().body_text(), Some("ABC"));
593 }
594
595 #[test]
598 fn json_array_strategy_emits_array_of_mixed_payloads() {
599 let route = Route::new()
600 .add(Aggregator::new("corr", 4).with_strategy(Arc::new(JsonArray)))
601 .build();
602 let mut ex1 = ex_with("corr", "j", Message::from_text("hi"));
603 run(&route, &mut ex1);
604 let mut ex2 = ex_with("corr", "j", Message::new(Payload::Bytes(vec![1, 2])));
605 run(&route, &mut ex2);
606 let mut ex3 = ex_with(
607 "corr",
608 "j",
609 Message::new(Payload::Json(serde_json::json!({"k": "v"}))),
610 );
611 run(&route, &mut ex3);
612 let mut ex4 = ex_with("corr", "j", Message::new(Payload::Empty));
613 run(&route, &mut ex4);
614
615 let out = ex4
616 .out_msg
617 .expect("JsonArray must always emit on completion");
618 let Payload::Json(serde_json::Value::Array(arr)) = out.payload else {
619 panic!("JsonArray strategy must emit Payload::Json(Array)");
620 };
621 assert_eq!(arr.len(), 4);
622 assert_eq!(arr[0], serde_json::Value::String("hi".into()));
623 assert_eq!(arr[1], serde_json::json!([1, 2]));
624 assert_eq!(arr[2], serde_json::json!({"k": "v"}));
625 assert_eq!(arr[3], serde_json::Value::Null);
626 }
627
628 struct CountingStore {
632 inner: InMemoryGroupStore,
633 appends: AtomicUsize,
634 takes: AtomicUsize,
635 }
636 impl CountingStore {
637 fn new() -> Self {
638 Self {
639 inner: InMemoryGroupStore::new(),
640 appends: AtomicUsize::new(0),
641 takes: AtomicUsize::new(0),
642 }
643 }
644 }
645 impl GroupStore for CountingStore {
646 fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant) {
647 self.appends.fetch_add(1, Ordering::SeqCst);
648 self.inner.append(key, msg)
649 }
650 fn take(&self, key: &str) -> Option<Vec<Message>> {
651 self.takes.fetch_add(1, Ordering::SeqCst);
652 self.inner.take(key)
653 }
654 fn clear(&self) {
655 self.inner.clear();
656 }
657 }
658
659 #[test]
660 fn custom_group_store_is_used() {
661 let store = Arc::new(CountingStore::new());
662 let route = Route::new()
663 .add(Aggregator::new("corr", 2).with_store(store.clone()))
664 .build();
665 let mut ex1 = ex_with("corr", "k", Message::from_text("A"));
666 run(&route, &mut ex1);
667 let mut ex2 = ex_with("corr", "k", Message::from_text("B"));
668 run(&route, &mut ex2);
669 assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("AB"));
670 assert_eq!(store.appends.load(Ordering::SeqCst), 2);
671 assert_eq!(store.takes.load(Ordering::SeqCst), 1);
672 }
673}