oxigdal_streaming/transformations/
aggregate.rs1use crate::core::stream::StreamElement;
4use crate::error::Result;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10type KeyedState = HashMap<Option<Vec<u8>>, Vec<u8>>;
12
13pub trait AggregateFunction: Send + Sync {
15 fn create_accumulator(&self) -> Vec<u8>;
17
18 fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8>;
20
21 fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8>;
23
24 fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8>;
26}
27
28pub struct AggregateOperator<F>
30where
31 F: AggregateFunction,
32{
33 aggregate_fn: Arc<F>,
34 state: Arc<RwLock<KeyedState>>,
35}
36
37impl<F> AggregateOperator<F>
38where
39 F: AggregateFunction,
40{
41 pub fn new(aggregate_fn: F) -> Self {
43 Self {
44 aggregate_fn: Arc::new(aggregate_fn),
45 state: Arc::new(RwLock::new(HashMap::new())),
46 }
47 }
48
49 pub async fn process(&self, element: StreamElement) -> Result<StreamElement> {
51 let mut state = self.state.write().await;
52
53 let key = element.key.clone();
54 let current = state
55 .entry(key.clone())
56 .or_insert_with(|| self.aggregate_fn.create_accumulator());
57
58 let updated = self.aggregate_fn.add(current.clone(), element.data);
59 *current = updated.clone();
60
61 let result = self.aggregate_fn.get_result(updated);
62
63 Ok(StreamElement {
64 data: result,
65 event_time: element.event_time,
66 processing_time: element.processing_time,
67 key,
68 metadata: element.metadata,
69 })
70 }
71
72 pub async fn get_result(&self, key: Option<Vec<u8>>) -> Vec<u8> {
74 let state = self.state.read().await;
75 state
76 .get(&key)
77 .map(|acc| self.aggregate_fn.get_result(acc.clone()))
78 .unwrap_or_else(|| {
79 self.aggregate_fn
80 .get_result(self.aggregate_fn.create_accumulator())
81 })
82 }
83
84 pub async fn clear(&self) {
86 self.state.write().await.clear();
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct CountAggregate;
93
94impl AggregateFunction for CountAggregate {
95 fn create_accumulator(&self) -> Vec<u8> {
96 0i64.to_le_bytes().to_vec()
97 }
98
99 fn add(&self, accumulator: Vec<u8>, _value: Vec<u8>) -> Vec<u8> {
100 let count = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
101 (count + 1).to_le_bytes().to_vec()
102 }
103
104 fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
105 accumulator
106 }
107
108 fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
109 let count1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
110 let count2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
111 (count1 + count2).to_le_bytes().to_vec()
112 }
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SumAggregate;
118
119impl AggregateFunction for SumAggregate {
120 fn create_accumulator(&self) -> Vec<u8> {
121 0i64.to_le_bytes().to_vec()
122 }
123
124 fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
125 let acc = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
126 let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
127 (acc + val).to_le_bytes().to_vec()
128 }
129
130 fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
131 accumulator
132 }
133
134 fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
135 let sum1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
136 let sum2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
137 (sum1 + sum2).to_le_bytes().to_vec()
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AvgAggregate;
144
145impl AggregateFunction for AvgAggregate {
146 fn create_accumulator(&self) -> Vec<u8> {
147 let mut acc = Vec::new();
148 acc.extend_from_slice(&0i64.to_le_bytes());
149 acc.extend_from_slice(&0i64.to_le_bytes());
150 acc
151 }
152
153 fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
154 let sum = i64::from_le_bytes(accumulator[0..8].try_into().unwrap_or([0; 8]));
155 let count = i64::from_le_bytes(accumulator[8..16].try_into().unwrap_or([0; 8]));
156 let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
157
158 let mut result = Vec::new();
159 result.extend_from_slice(&(sum + val).to_le_bytes());
160 result.extend_from_slice(&(count + 1).to_le_bytes());
161 result
162 }
163
164 fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
165 let sum = i64::from_le_bytes(accumulator[0..8].try_into().unwrap_or([0; 8]));
166 let count = i64::from_le_bytes(accumulator[8..16].try_into().unwrap_or([0; 8]));
167
168 if count == 0 {
169 0i64.to_le_bytes().to_vec()
170 } else {
171 (sum / count).to_le_bytes().to_vec()
172 }
173 }
174
175 fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
176 let sum1 = i64::from_le_bytes(acc1[0..8].try_into().unwrap_or([0; 8]));
177 let count1 = i64::from_le_bytes(acc1[8..16].try_into().unwrap_or([0; 8]));
178 let sum2 = i64::from_le_bytes(acc2[0..8].try_into().unwrap_or([0; 8]));
179 let count2 = i64::from_le_bytes(acc2[8..16].try_into().unwrap_or([0; 8]));
180
181 let mut result = Vec::new();
182 result.extend_from_slice(&(sum1 + sum2).to_le_bytes());
183 result.extend_from_slice(&(count1 + count2).to_le_bytes());
184 result
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct MinAggregate;
191
192impl AggregateFunction for MinAggregate {
193 fn create_accumulator(&self) -> Vec<u8> {
194 i64::MAX.to_le_bytes().to_vec()
195 }
196
197 fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
198 let acc = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
199 let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
200 acc.min(val).to_le_bytes().to_vec()
201 }
202
203 fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
204 accumulator
205 }
206
207 fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
208 let min1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
209 let min2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
210 min1.min(min2).to_le_bytes().to_vec()
211 }
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct MaxAggregate;
217
218impl AggregateFunction for MaxAggregate {
219 fn create_accumulator(&self) -> Vec<u8> {
220 i64::MIN.to_le_bytes().to_vec()
221 }
222
223 fn add(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
224 let acc = i64::from_le_bytes(accumulator.try_into().unwrap_or([0; 8]));
225 let val = i64::from_le_bytes(value.try_into().unwrap_or([0; 8]));
226 acc.max(val).to_le_bytes().to_vec()
227 }
228
229 fn get_result(&self, accumulator: Vec<u8>) -> Vec<u8> {
230 accumulator
231 }
232
233 fn merge(&self, acc1: Vec<u8>, acc2: Vec<u8>) -> Vec<u8> {
234 let max1 = i64::from_le_bytes(acc1.try_into().unwrap_or([0; 8]));
235 let max2 = i64::from_le_bytes(acc2.try_into().unwrap_or([0; 8]));
236 max1.max(max2).to_le_bytes().to_vec()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use chrono::Utc;
244
245 #[tokio::test]
246 async fn test_count_aggregate() {
247 let operator = AggregateOperator::new(CountAggregate);
248
249 for i in 0..5 {
250 let elem = StreamElement::new(vec![i], Utc::now());
251 operator
252 .process(elem)
253 .await
254 .expect("aggregate processing should succeed in test");
255 }
256
257 let result = operator.get_result(None).await;
258 let count = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
259 assert_eq!(count, 5);
260 }
261
262 #[tokio::test]
263 async fn test_sum_aggregate() {
264 let operator = AggregateOperator::new(SumAggregate);
265
266 for i in 1..=5 {
267 let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
268 operator
269 .process(elem)
270 .await
271 .expect("aggregate processing should succeed in test");
272 }
273
274 let result = operator.get_result(None).await;
275 let sum = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
276 assert_eq!(sum, 15);
277 }
278
279 #[tokio::test]
280 async fn test_avg_aggregate() {
281 let operator = AggregateOperator::new(AvgAggregate);
282
283 for i in 1..=5 {
284 let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
285 operator
286 .process(elem)
287 .await
288 .expect("aggregate processing should succeed in test");
289 }
290
291 let result = operator.get_result(None).await;
292 let avg = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
293 assert_eq!(avg, 3);
294 }
295
296 #[tokio::test]
297 async fn test_min_aggregate() {
298 let operator = AggregateOperator::new(MinAggregate);
299
300 for i in [5, 2, 8, 1, 9] {
301 let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
302 operator
303 .process(elem)
304 .await
305 .expect("aggregate processing should succeed in test");
306 }
307
308 let result = operator.get_result(None).await;
309 let min = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
310 assert_eq!(min, 1);
311 }
312
313 #[tokio::test]
314 async fn test_max_aggregate() {
315 let operator = AggregateOperator::new(MaxAggregate);
316
317 for i in [5, 2, 8, 1, 9] {
318 let elem = StreamElement::new((i as i64).to_le_bytes().to_vec(), Utc::now());
319 operator
320 .process(elem)
321 .await
322 .expect("aggregate processing should succeed in test");
323 }
324
325 let result = operator.get_result(None).await;
326 let max = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
327 assert_eq!(max, 9);
328 }
329}