Skip to main content

oxigdal_streaming/transformations/
aggregate.rs

1//! Aggregate operations for streaming data.
2
3use crate::core::stream::StreamElement;
4use crate::error::Result;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Type alias for keyed state storage (key -> value bytes).
11type KeyedState = HashMap<Option<Vec<u8>>, Vec<u8>>;
12
13/// Trait for aggregate functions.
14pub trait AggregateFunction: Send + Sync {
15    /// Create initial accumulator.
16    fn create_accumulator(&self) -> Vec<u8>;
17
18    /// Add a value to the accumulator.
19    fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8>;
20
21    /// Get the result from the accumulator.
22    fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8>;
23
24    /// Merge two accumulators.
25    fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8>;
26}
27
28/// Aggregate operator.
29pub struct AggregateOperator<F>
30where
31    F: AggregateFunction,
32{
33    aggregate_fn: Arc<F>,
34    state: Arc<RwLock<KeyedState>>,
35}
36
37impl<F> AggregateOperator<F>
38where
39    F: AggregateFunction,
40{
41    /// Create a new aggregate operator.
42    pub fn new(aggregate_fn: F) -> Self {
43        Self {
44            aggregate_fn: Arc::new(aggregate_fn),
45            state: Arc::new(RwLock::new(HashMap::new())),
46        }
47    }
48
49    /// Process an element.
50    pub async fn process(&self, element: StreamElement) -> Result<StreamElement> {
51        let mut state = self.state.write().await;
52
53        let key = element.key.clone();
54        let current = state
55            .entry(key.clone())
56            .or_insert_with(|| self.aggregate_fn.create_accumulator());
57
58        let updated = self.aggregate_fn.add(current.clone(), element.data);
59        *current = updated.clone();
60
61        let result = self.aggregate_fn.get_result(updated);
62
63        Ok(StreamElement {
64            data: result,
65            event_time: element.event_time,
66            processing_time: element.processing_time,
67            key,
68            metadata: element.metadata,
69        })
70    }
71
72    /// Get the current result for a key.
73    pub async fn get_result(&self, key: Option<Vec<u8>>) -> Vec<u8> {
74        let state = self.state.read().await;
75        state
76            .get(&key)
77            .map(|acc| self.aggregate_fn.get_result(acc.clone()))
78            .unwrap_or_else(|| {
79                self.aggregate_fn
80                    .get_result(self.aggregate_fn.create_accumulator())
81            })
82    }
83
84    /// Clear all state.
85    pub async fn clear(&self) {
86        self.state.write().await.clear();
87    }
88}
89
90/// Count aggregate.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct CountAggregate;
93
94impl AggregateFunction for CountAggregate {
95    fn create_accumulator(&self) -> Vec<u8> {
96        0i64.to_le_bytes().to_vec()
97    }
98
99    fn add(&self, accumulator: Vec<u8>, _value: Vec<u8>) -> Vec<u8> {
100        let count = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
101        (count + 1).to_le_bytes().to_vec()
102    }
103
104    fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
105        accumulator
106    }
107
108    fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
109        let count1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
110        let count2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
111        (count1 + count2).to_le_bytes().to_vec()
112    }
113}
114
115/// Sum aggregate.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SumAggregate;
118
119impl AggregateFunction for SumAggregate {
120    fn create_accumulator(&self) -> Vec<u8> {
121        0i64.to_le_bytes().to_vec()
122    }
123
124    fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
125        let acc = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
126        let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
127        (acc + val).to_le_bytes().to_vec()
128    }
129
130    fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
131        accumulator
132    }
133
134    fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
135        let sum1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
136        let sum2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
137        (sum1 + sum2).to_le_bytes().to_vec()
138    }
139}
140
141/// Average aggregate (stores sum and count).
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AvgAggregate;
144
145impl AggregateFunction for AvgAggregate {
146    fn create_accumulator(&self) -> Vec<u8> {
147        let mut acc = Vec::new();
148        acc.extend_from_slice(&0i64.to_le_bytes());
149        acc.extend_from_slice(&0i64.to_le_bytes());
150        acc
151    }
152
153    fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
154        let sum = i64::from_le_bytes(accumulator[0..8].try_into().unwrap_or([0; 8]));
155        let count = i64::from_le_bytes(accumulator[8..16].try_into().unwrap_or([0; 8]));
156        let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
157
158        let mut result = Vec::new();
159        result.extend_from_slice(&(sum + val).to_le_bytes());
160        result.extend_from_slice(&(count + 1).to_le_bytes());
161        result
162    }
163
164    fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
165        let sum = i64::from_le_bytes(accumulator[0..8].try_into().unwrap_or([0; 8]));
166        let count = i64::from_le_bytes(accumulator[8..16].try_into().unwrap_or([0; 8]));
167
168        if count == 0 {
169            0i64.to_le_bytes().to_vec()
170        } else {
171            (sum / count).to_le_bytes().to_vec()
172        }
173    }
174
175    fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
176        let sum1 = i64::from_le_bytes(acc1[0..8].try_into().unwrap_or([0; 8]));
177        let count1 = i64::from_le_bytes(acc1[8..16].try_into().unwrap_or([0; 8]));
178        let sum2 = i64::from_le_bytes(acc2[0..8].try_into().unwrap_or([0; 8]));
179        let count2 = i64::from_le_bytes(acc2[8..16].try_into().unwrap_or([0; 8]));
180
181        let mut result = Vec::new();
182        result.extend_from_slice(&(sum1 + sum2).to_le_bytes());
183        result.extend_from_slice(&(count1 + count2).to_le_bytes());
184        result
185    }
186}
187
188/// Min aggregate.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct MinAggregate;
191
192impl AggregateFunction for MinAggregate {
193    fn create_accumulator(&self) -> Vec<u8> {
194        i64::MAX.to_le_bytes().to_vec()
195    }
196
197    fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
198        let acc = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
199        let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
200        acc.min(val).to_le_bytes().to_vec()
201    }
202
203    fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
204        accumulator
205    }
206
207    fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
208        let min1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
209        let min2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
210        min1.min(min2).to_le_bytes().to_vec()
211    }
212}
213
214/// Max aggregate.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct MaxAggregate;
217
218impl AggregateFunction for MaxAggregate {
219    fn create_accumulator(&self) -> Vec<u8> {
220        i64::MIN.to_le_bytes().to_vec()
221    }
222
223    fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
224        let acc = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
225        let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
226        acc.max(val).to_le_bytes().to_vec()
227    }
228
229    fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
230        accumulator
231    }
232
233    fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
234        let max1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
235        let max2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
236        max1.max(max2).to_le_bytes().to_vec()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use chrono::Utc;
244
245    #[tokio::test]
246    async fn test_count_aggregate() {
247        let operator = AggregateOperator::new(CountAggregate);
248
249        for i in 0..5 {
250            let elem = StreamElement::new(vec![i], Utc::now());
251            operator
252                .process(elem)
253                .await
254                .expect("aggregate processing should succeed in test");
255        }
256
257        let result = operator.get_result(None).await;
258        let count = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
259        assert_eq!(count, 5);
260    }
261
262    #[tokio::test]
263    async fn test_sum_aggregate() {
264        let operator = AggregateOperator::new(SumAggregate);
265
266        for i in 1..=5 {
267            let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
268            operator
269                .process(elem)
270                .await
271                .expect("aggregate processing should succeed in test");
272        }
273
274        let result = operator.get_result(None).await;
275        let sum = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
276        assert_eq!(sum, 15);
277    }
278
279    #[tokio::test]
280    async fn test_avg_aggregate() {
281        let operator = AggregateOperator::new(AvgAggregate);
282
283        for i in 1..=5 {
284            let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
285            operator
286                .process(elem)
287                .await
288                .expect("aggregate processing should succeed in test");
289        }
290
291        let result = operator.get_result(None).await;
292        let avg = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
293        assert_eq!(avg, 3);
294    }
295
296    #[tokio::test]
297    async fn test_min_aggregate() {
298        let operator = AggregateOperator::new(MinAggregate);
299
300        for i in [5, 2, 8, 1, 9] {
301            let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
302            operator
303                .process(elem)
304                .await
305                .expect("aggregate processing should succeed in test");
306        }
307
308        let result = operator.get_result(None).await;
309        let min = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
310        assert_eq!(min, 1);
311    }
312
313    #[tokio::test]
314    async fn test_max_aggregate() {
315        let operator = AggregateOperator::new(MaxAggregate);
316
317        for i in [5, 2, 8, 1, 9] {
318            let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
319            operator
320                .process(elem)
321                .await
322                .expect("aggregate processing should succeed in test");
323        }
324
325        let result = operator.get_result(None).await;
326        let max = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
327        assert_eq!(max, 9);
328    }
329}