use crate::{error::Result, message::Message, message::Payload, processor::Processor, Exchange};
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub trait CompletionCondition: Send + Sync {
fn is_complete(&self, group: &[Message], first_seen: Instant) -> bool;
}
#[derive(Debug, Clone, Copy)]
pub struct BySize(pub usize);
impl CompletionCondition for BySize {
fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
group.len() >= self.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct ByTimeout(pub Duration);
impl CompletionCondition for ByTimeout {
fn is_complete(&self, _group: &[Message], first_seen: Instant) -> bool {
first_seen.elapsed() >= self.0
}
}
pub struct ByPredicate<F: Fn(&[Message]) -> bool + Send + Sync>(pub F);
impl<F> CompletionCondition for ByPredicate<F>
where
F: Fn(&[Message]) -> bool + Send + Sync,
{
fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
(self.0)(group)
}
}
pub struct ByWeight<F: Fn(&Message) -> u64 + Send + Sync> {
pub weight: F,
pub threshold: u64,
}
impl<F> CompletionCondition for ByWeight<F>
where
F: Fn(&Message) -> u64 + Send + Sync,
{
fn is_complete(&self, group: &[Message], _first_seen: Instant) -> bool {
group.iter().map(|m| (self.weight)(m)).sum::<u64>() >= self.threshold
}
}
pub trait AggregationStrategy: Send + Sync {
fn combine(&self, group: Vec<Message>) -> Option<Message>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ConcatText;
impl AggregationStrategy for ConcatText {
fn combine(&self, group: Vec<Message>) -> Option<Message> {
if !group.iter().all(|m| m.body_text().is_some()) {
return None;
}
let concat: String = group.iter().map(|m| m.body_text().unwrap()).collect();
Some(Message::from_text(concat))
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct JsonArray;
impl AggregationStrategy for JsonArray {
fn combine(&self, group: Vec<Message>) -> Option<Message> {
let arr: Vec<serde_json::Value> = group
.into_iter()
.map(|m| match m.payload {
Payload::Text(s) => serde_json::Value::String(s),
Payload::Bytes(b) => {
serde_json::Value::Array(b.into_iter().map(serde_json::Value::from).collect())
}
Payload::Json(v) => v,
Payload::Empty => serde_json::Value::Null,
})
.collect();
Some(Message::new(Payload::Json(serde_json::Value::Array(arr))))
}
}
pub trait GroupStore: Send + Sync {
fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant);
fn take(&self, key: &str) -> Option<Vec<Message>>;
fn clear(&self);
}
struct InMemoryGroup {
messages: Vec<Message>,
first_seen: Instant,
}
#[derive(Default)]
pub struct InMemoryGroupStore {
inner: Mutex<HashMap<String, InMemoryGroup>>,
}
impl InMemoryGroupStore {
pub fn new() -> Self {
Self::default()
}
}
impl GroupStore for InMemoryGroupStore {
fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant) {
let mut guard = self.inner.lock().unwrap();
let entry = guard
.entry(key.to_string())
.or_insert_with(|| InMemoryGroup {
messages: Vec::new(),
first_seen: Instant::now(),
});
entry.messages.push(msg);
(entry.messages.clone(), entry.first_seen)
}
fn take(&self, key: &str) -> Option<Vec<Message>> {
let mut guard = self.inner.lock().unwrap();
guard.remove(key).map(|g| g.messages)
}
fn clear(&self) {
let mut guard = self.inner.lock().unwrap();
guard.clear();
}
}
#[derive(Clone)]
pub struct Aggregator {
correlation_header: String,
completion: Arc<dyn CompletionCondition>,
strategy: Arc<dyn AggregationStrategy>,
store: Arc<dyn GroupStore>,
}
impl fmt::Debug for Aggregator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Aggregator")
.field("correlation_header", &self.correlation_header)
.finish_non_exhaustive()
}
}
impl Aggregator {
pub fn new<H: Into<String>>(correlation_header: H, completion_size: usize) -> Self {
Self::with_completion(correlation_header, Arc::new(BySize(completion_size)))
}
pub fn with_completion<H: Into<String>>(
correlation_header: H,
completion: Arc<dyn CompletionCondition>,
) -> Self {
Self {
correlation_header: correlation_header.into(),
completion,
strategy: Arc::new(ConcatText),
store: Arc::new(InMemoryGroupStore::new()),
}
}
pub fn weighted<H, F>(correlation_header: H, weight: F, threshold: u64) -> Self
where
H: Into<String>,
F: Fn(&Message) -> u64 + Send + Sync + 'static,
{
Self::with_completion(correlation_header, Arc::new(ByWeight { weight, threshold }))
}
pub fn timed<H: Into<String>>(correlation_header: H, dur: Duration) -> Self {
Self::with_completion(correlation_header, Arc::new(ByTimeout(dur)))
}
pub fn when<H, F>(correlation_header: H, predicate: F) -> Self
where
H: Into<String>,
F: Fn(&[Message]) -> bool + Send + Sync + 'static,
{
Self::with_completion(correlation_header, Arc::new(ByPredicate(predicate)))
}
pub fn with_strategy(mut self, strategy: Arc<dyn AggregationStrategy>) -> Self {
self.strategy = strategy;
self
}
pub fn with_store(mut self, store: Arc<dyn GroupStore>) -> Self {
self.store = store;
self
}
pub fn clear_store(&self) {
self.store.clear();
}
}
#[async_trait::async_trait]
impl Processor for Aggregator {
async fn process(&self, exchange: &mut Exchange) -> Result<()> {
let key = match exchange.in_msg.header(&self.correlation_header) {
Some(k) => k.to_string(),
None => return Ok(()),
};
let (group, first_seen) = self.store.append(&key, exchange.in_msg.clone());
if self.completion.is_complete(&group, first_seen) {
if let Some(completed) = self.store.take(&key) {
if let Some(out) = self.strategy.combine(completed) {
exchange.out_msg = Some(out);
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{Exchange, Message, Payload};
use crate::route::Route;
use std::sync::atomic::{AtomicUsize, Ordering};
fn run(route: &Route, exchange: &mut Exchange) {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(route.run(exchange))
.unwrap();
}
fn ex_with(header: &str, key: &str, msg: Message) -> Exchange {
let mut e = Exchange::new(msg);
e.in_msg.set_header(header, key);
e
}
#[test]
fn back_compat_size_two_concats_ab() {
let route = Route::new().add(Aggregator::new("corr", 2)).build();
let mut ex1 = ex_with("corr", "g", Message::from_text("A"));
run(&route, &mut ex1);
assert!(ex1.out_msg.is_none());
let mut ex2 = ex_with("corr", "g", Message::from_text("B"));
run(&route, &mut ex2);
assert_eq!(ex2.out_msg.unwrap().body_text(), Some("AB"));
}
#[test]
fn back_compat_three_messages() {
let route = Route::new().add(Aggregator::new("corr", 3)).build();
let mut last = None;
for s in ["A", "B", "C"] {
let mut ex = ex_with("corr", "123", Message::from_text(s));
run(&route, &mut ex);
last = Some(ex);
}
assert_eq!(last.unwrap().out_msg.unwrap().body_text(), Some("ABC"));
}
#[test]
fn ignores_messages_without_correlation_header() {
let route = Route::new().add(Aggregator::new("corr", 2)).build();
for s in ["A", "B"] {
let mut ex = Exchange::new(Message::from_text(s));
run(&route, &mut ex);
assert!(ex.out_msg.is_none());
}
}
#[test]
fn aggregates_multiple_batches_for_same_key() {
let route = Route::new().add(Aggregator::new("corr", 2)).build();
let mut ex1 = ex_with("corr", "same", Message::from_text("A"));
run(&route, &mut ex1);
assert!(ex1.out_msg.is_none());
let mut ex2 = ex_with("corr", "same", Message::from_text("B"));
run(&route, &mut ex2);
assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("AB"));
let mut ex3 = ex_with("corr", "same", Message::from_text("C"));
run(&route, &mut ex3);
assert!(ex3.out_msg.is_none());
let mut ex4 = ex_with("corr", "same", Message::from_text("D"));
run(&route, &mut ex4);
assert_eq!(ex4.out_msg.as_ref().unwrap().body_text(), Some("CD"));
}
#[test]
fn concat_text_non_text_group_emits_nothing() {
let route = Route::new().add(Aggregator::new("corr", 2)).build();
let mut ex1 = ex_with("corr", "m", Message::new(Payload::Bytes(vec![0, 1])));
run(&route, &mut ex1);
let mut ex2 = ex_with("corr", "m", Message::new(Payload::Bytes(vec![2, 3])));
run(&route, &mut ex2);
assert!(ex2.out_msg.is_none());
}
#[test]
fn clear_store_resets_groups() {
let agg = Aggregator::new("corr", 2);
let route = Route::new().add(agg.clone()).build();
let mut ex1 = ex_with("corr", "x", Message::from_text("A"));
run(&route, &mut ex1);
agg.clear_store();
let mut ex2 = ex_with("corr", "x", Message::from_text("B"));
run(&route, &mut ex2);
assert!(
ex2.out_msg.is_none(),
"clear_store should reset the group; B should be the first of a new batch"
);
}
#[test]
fn by_weight_completes_at_threshold() {
let threshold: u64 = 7;
let route = Route::new()
.add(Aggregator::weighted(
"block",
|m: &Message| {
m.header("voting_power")
.and_then(|s| s.parse().ok())
.unwrap_or(0)
},
threshold,
))
.build();
for (vp, expect_out) in [(3u64, false), (3, false), (4, true)] {
let mut ex = Exchange::new(Message::from_text(format!("vote-vp{vp}")));
ex.in_msg.set_header("block", "h=42");
ex.in_msg.set_header("voting_power", vp.to_string());
run(&route, &mut ex);
assert_eq!(
ex.out_msg.is_some(),
expect_out,
"vp={vp}: expected out_msg={expect_out}"
);
}
}
#[test]
fn by_weight_fires_exactly_at_threshold_boundary() {
let route = Route::new()
.add(Aggregator::weighted(
"block",
|m: &Message| {
m.header("voting_power")
.and_then(|s| s.parse().ok())
.unwrap_or(0)
},
6,
))
.build();
let mut ex1 = Exchange::new(Message::from_text("a"));
ex1.in_msg.set_header("block", "h=1");
ex1.in_msg.set_header("voting_power", "3");
run(&route, &mut ex1);
assert!(ex1.out_msg.is_none());
let mut ex2 = Exchange::new(Message::from_text("b"));
ex2.in_msg.set_header("block", "h=1");
ex2.in_msg.set_header("voting_power", "3");
run(&route, &mut ex2);
assert!(ex2.out_msg.is_some(), "sum=6, threshold=6: should fire");
}
#[test]
fn by_weight_isolated_per_key() {
let route = Route::new()
.add(Aggregator::weighted(
"block",
|m: &Message| {
m.header("voting_power")
.and_then(|s| s.parse().ok())
.unwrap_or(0)
},
4,
))
.build();
for (block, vp, expect) in [
("A", 2, false),
("B", 1, false),
("A", 2, true),
("B", 1, false),
] {
let mut ex = Exchange::new(Message::from_text("v"));
ex.in_msg.set_header("block", block);
ex.in_msg.set_header("voting_power", vp.to_string());
run(&route, &mut ex);
assert_eq!(ex.out_msg.is_some(), expect, "block={block} vp={vp}");
}
}
#[test]
fn by_predicate_completes() {
let route = Route::new()
.add(Aggregator::when("corr", |g: &[Message]| {
g.iter().any(|m| m.body_text() == Some("STOP"))
}))
.build();
let mut ex1 = ex_with("corr", "x", Message::from_text("go"));
run(&route, &mut ex1);
assert!(ex1.out_msg.is_none());
let mut ex2 = ex_with("corr", "x", Message::from_text("STOP"));
run(&route, &mut ex2);
assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("goSTOP"));
}
#[test]
fn by_timeout_lazy_completes_on_next_arrival() {
let route = Route::new()
.add(Aggregator::timed("corr", Duration::from_millis(40)))
.build();
let mut ex1 = ex_with("corr", "t", Message::from_text("A"));
run(&route, &mut ex1);
assert!(ex1.out_msg.is_none(), "first message: deadline not reached");
let mut ex2 = ex_with("corr", "t", Message::from_text("B"));
run(&route, &mut ex2);
assert!(ex2.out_msg.is_none(), "B arrived too soon");
std::thread::sleep(Duration::from_millis(60));
let mut ex3 = ex_with("corr", "t", Message::from_text("C"));
run(&route, &mut ex3);
assert_eq!(ex3.out_msg.as_ref().unwrap().body_text(), Some("ABC"));
}
#[test]
fn json_array_strategy_emits_array_of_mixed_payloads() {
let route = Route::new()
.add(Aggregator::new("corr", 4).with_strategy(Arc::new(JsonArray)))
.build();
let mut ex1 = ex_with("corr", "j", Message::from_text("hi"));
run(&route, &mut ex1);
let mut ex2 = ex_with("corr", "j", Message::new(Payload::Bytes(vec![1, 2])));
run(&route, &mut ex2);
let mut ex3 = ex_with(
"corr",
"j",
Message::new(Payload::Json(serde_json::json!({"k": "v"}))),
);
run(&route, &mut ex3);
let mut ex4 = ex_with("corr", "j", Message::new(Payload::Empty));
run(&route, &mut ex4);
let out = ex4
.out_msg
.expect("JsonArray must always emit on completion");
let Payload::Json(serde_json::Value::Array(arr)) = out.payload else {
panic!("JsonArray strategy must emit Payload::Json(Array)");
};
assert_eq!(arr.len(), 4);
assert_eq!(arr[0], serde_json::Value::String("hi".into()));
assert_eq!(arr[1], serde_json::json!([1, 2]));
assert_eq!(arr[2], serde_json::json!({"k": "v"}));
assert_eq!(arr[3], serde_json::Value::Null);
}
struct CountingStore {
inner: InMemoryGroupStore,
appends: AtomicUsize,
takes: AtomicUsize,
}
impl CountingStore {
fn new() -> Self {
Self {
inner: InMemoryGroupStore::new(),
appends: AtomicUsize::new(0),
takes: AtomicUsize::new(0),
}
}
}
impl GroupStore for CountingStore {
fn append(&self, key: &str, msg: Message) -> (Vec<Message>, Instant) {
self.appends.fetch_add(1, Ordering::SeqCst);
self.inner.append(key, msg)
}
fn take(&self, key: &str) -> Option<Vec<Message>> {
self.takes.fetch_add(1, Ordering::SeqCst);
self.inner.take(key)
}
fn clear(&self) {
self.inner.clear();
}
}
#[test]
fn custom_group_store_is_used() {
let store = Arc::new(CountingStore::new());
let route = Route::new()
.add(Aggregator::new("corr", 2).with_store(store.clone()))
.build();
let mut ex1 = ex_with("corr", "k", Message::from_text("A"));
run(&route, &mut ex1);
let mut ex2 = ex_with("corr", "k", Message::from_text("B"));
run(&route, &mut ex2);
assert_eq!(ex2.out_msg.as_ref().unwrap().body_text(), Some("AB"));
assert_eq!(store.appends.load(Ordering::SeqCst), 2);
assert_eq!(store.takes.load(Ordering::SeqCst), 1);
}
}