allora_core/patterns/
aggregator.rs

1//! Aggregator pattern: collects messages sharing a correlation header until a configured
2//! completion size is reached, then emits a concatenated outbound text message.
3//!
4//! # Behavior
5//! * Messages are grouped by the value of a specified correlation header (e.g. "corr").
6//! * When the number of messages in a group reaches `completion_size`, all messages are removed
7//!   from the internal store and concatenated (text-only) into `Exchange.out_msg`.
8//! * Only groups where every message has a UTF-8 text payload (`Payload::Text`) are aggregated.
9//!   Mixed payload kinds (bytes, JSON, empty) or non-text groups are ignored (no outbound message).
10//! * Once a group completes it is removed, allowing subsequent batches with the same key.
11//!
12//! # Limitations
13//! * Non-text payloads are ignored for aggregation; bytes-only or JSON-only groups will not produce
14//!   an `out_msg`. Future enhancements could support bytes joining or JSON array creation.
15//! * Concurrency relies on a single `Mutex<HashMap<..>>`; high contention scenarios may warrant
16//!   a sharded map or lock-free structure.
17//! * No time-based or predicate-based completion—only size threshold is supported.
18//!
19//! # Example
20//! ```rust
21//! use allora_core::{patterns::aggregator::Aggregator, route::Route, Exchange, Message};
22//! let route = Route::new().add(Aggregator::new("corr", 2)).build();
23//! let mut ex1 = Exchange::new(Message::from_text("A"));
24//! ex1.in_msg.set_header("corr", "grp");
25//! let rt = tokio::runtime::Runtime::new().unwrap();
26//! rt.block_on(async { route.run(&mut ex1).await.unwrap(); });
27//! assert!(ex1.out_msg.is_none());
28//! let mut ex2 = Exchange::new(Message::from_text("B"));
29//! ex2.in_msg.set_header("corr", "grp");
30//! rt.block_on(async { route.run(&mut ex2).await.unwrap(); });
31//! assert_eq!(ex2.out_msg.unwrap().body_text(), Some("AB"));
32//! ```
33
34use crate::{error::Result, message::Message, processor::Processor, Exchange};
35use std::collections::HashMap;
36use std::fmt::Debug;
37use std::sync::{Arc, Mutex};
38
39#[derive(Debug, Clone)]
40pub struct Aggregator {
41    store: Arc<Mutex<HashMap<String, Vec<Message>>>>,
42    correlation_header: String,
43    completion_size: usize,
44}
45
46impl Aggregator {
47    /// Create a new `Aggregator`.
48    ///
49    /// * `correlation_header` – header key used to group messages (e.g. "corr" or "correlation_id").
50    /// * `completion_size` – number of messages required before aggregation triggers.
51    pub fn new<H: Into<String>>(correlation_header: H, completion_size: usize) -> Self {
52        Self {
53            store: Arc::new(Mutex::new(HashMap::new())),
54            correlation_header: correlation_header.into(),
55            completion_size,
56        }
57    }
58
59    /// Clear all stored partial groups. Intended for test isolation; not usually needed in production.
60    pub fn clear_store(&self) {
61        let mut guard = self.store.lock().unwrap();
62        guard.clear();
63    }
64
65    fn add(&self, exchange: &Exchange) -> Option<Vec<Message>> {
66        let key_opt = exchange.in_msg.header(&self.correlation_header);
67        let key = match key_opt {
68            Some(k) => k,
69            None => return None,
70        };
71        let mut guard = self.store.lock().unwrap();
72        let bucket = guard.entry(key.to_string()).or_default();
73        bucket.push(exchange.in_msg.clone());
74        if bucket.len() >= self.completion_size {
75            let completed = bucket.clone();
76            guard.remove(key);
77            Some(completed)
78        } else {
79            None
80        }
81    }
82}
83
84#[async_trait::async_trait]
85impl Processor for Aggregator {
86    async fn process(&self, exchange: &mut Exchange) -> Result<()> {
87        if let Some(completed) = self.add(exchange) {
88            if completed.iter().all(|m| m.body_text().is_some()) {
89                let concat = completed
90                    .iter()
91                    .map(|m| m.body_text().unwrap())
92                    .collect::<String>();
93                exchange.out_msg = Some(Message::from_text(concat));
94            }
95        }
96        Ok(())
97    }
98}