allora_core/patterns/aggregator.rs
1//! Aggregator pattern: collects messages sharing a correlation header until a configured
2//! completion size is reached, then emits a concatenated outbound text message.
3//!
4//! # Behavior
5//! * Messages are grouped by the value of a specified correlation header (e.g. "corr").
6//! * When the number of messages in a group reaches `completion_size`, all messages are removed
7//! from the internal store and concatenated (text-only) into `Exchange.out_msg`.
8//! * Only groups where every message has a UTF-8 text payload (`Payload::Text`) are aggregated.
9//! Mixed payload kinds (bytes, JSON, empty) or non-text groups are ignored (no outbound message).
10//! * Once a group completes it is removed, allowing subsequent batches with the same key.
11//!
12//! # Limitations
13//! * Non-text payloads are ignored for aggregation; bytes-only or JSON-only groups will not produce
14//! an `out_msg`. Future enhancements could support bytes joining or JSON array creation.
15//! * Concurrency relies on a single `Mutex<HashMap<..>>`; high contention scenarios may warrant
16//! a sharded map or lock-free structure.
17//! * No time-based or predicate-based completion—only size threshold is supported.
18//!
19//! # Example
20//! ```rust
21//! use allora_core::{patterns::aggregator::Aggregator, route::Route, Exchange, Message};
22//! let route = Route::new().add(Aggregator::new("corr", 2)).build();
23//! let mut ex1 = Exchange::new(Message::from_text("A"));
24//! ex1.in_msg.set_header("corr", "grp");
25//! let rt = tokio::runtime::Runtime::new().unwrap();
26//! rt.block_on(async { route.run(&mut ex1).await.unwrap(); });
27//! assert!(ex1.out_msg.is_none());
28//! let mut ex2 = Exchange::new(Message::from_text("B"));
29//! ex2.in_msg.set_header("corr", "grp");
30//! rt.block_on(async { route.run(&mut ex2).await.unwrap(); });
31//! assert_eq!(ex2.out_msg.unwrap().body_text(), Some("AB"));
32//! ```
33
34use crate::{error::Result, message::Message, processor::Processor, Exchange};
35use std::collections::HashMap;
36use std::fmt::Debug;
37use std::sync::{Arc, Mutex};
38
39#[derive(Debug, Clone)]
40pub struct Aggregator {
41 store: Arc<Mutex<HashMap<String, Vec<Message>>>>,
42 correlation_header: String,
43 completion_size: usize,
44}
45
46impl Aggregator {
47 /// Create a new `Aggregator`.
48 ///
49 /// * `correlation_header` – header key used to group messages (e.g. "corr" or "correlation_id").
50 /// * `completion_size` – number of messages required before aggregation triggers.
51 pub fn new<H: Into<String>>(correlation_header: H, completion_size: usize) -> Self {
52 Self {
53 store: Arc::new(Mutex::new(HashMap::new())),
54 correlation_header: correlation_header.into(),
55 completion_size,
56 }
57 }
58
59 /// Clear all stored partial groups. Intended for test isolation; not usually needed in production.
60 pub fn clear_store(&self) {
61 let mut guard = self.store.lock().unwrap();
62 guard.clear();
63 }
64
65 fn add(&self, exchange: &Exchange) -> Option<Vec<Message>> {
66 let key_opt = exchange.in_msg.header(&self.correlation_header);
67 let key = match key_opt {
68 Some(k) => k,
69 None => return None,
70 };
71 let mut guard = self.store.lock().unwrap();
72 let bucket = guard.entry(key.to_string()).or_default();
73 bucket.push(exchange.in_msg.clone());
74 if bucket.len() >= self.completion_size {
75 let completed = bucket.clone();
76 guard.remove(key);
77 Some(completed)
78 } else {
79 None
80 }
81 }
82}
83
84#[async_trait::async_trait]
85impl Processor for Aggregator {
86 async fn process(&self, exchange: &mut Exchange) -> Result<()> {
87 if let Some(completed) = self.add(exchange) {
88 if completed.iter().all(|m| m.body_text().is_some()) {
89 let concat = completed
90 .iter()
91 .map(|m| m.body_text().unwrap())
92 .collect::<String>();
93 exchange.out_msg = Some(Message::from_text(concat));
94 }
95 }
96 Ok(())
97 }
98}