oxigdal_streaming/transformations/
reduce.rs1use crate::core::stream::StreamElement;
4use crate::error::Result;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9type KeyedState = HashMap<Option<Vec<u8>>, Vec<u8>>;
11
12pub trait ReduceFunction: Send + Sync {
14 fn reduce(&self, value1: Vec<u8>, value2: Vec<u8>) -> Vec<u8>;
16}
17
18pub struct ReduceOperator<F>
20where
21 F: ReduceFunction,
22{
23 reduce_fn: Arc<F>,
24 state: Arc<RwLock<KeyedState>>,
25}
26
27impl<F> ReduceOperator<F>
28where
29 F: ReduceFunction,
30{
31 pub fn new(reduce_fn: F) -> Self {
33 Self {
34 reduce_fn: Arc::new(reduce_fn),
35 state: Arc::new(RwLock::new(HashMap::new())),
36 }
37 }
38
39 pub async fn process(&self, element: StreamElement) -> Result<Option<StreamElement>> {
41 let mut state = self.state.write().await;
42
43 let key = element.key.clone();
44 let current = state.entry(key.clone()).or_insert_with(Vec::new);
45
46 if current.is_empty() {
47 *current = element.data;
48 Ok(None)
49 } else {
50 let reduced = self.reduce_fn.reduce(current.clone(), element.data);
51 *current = reduced.clone();
52
53 Ok(Some(StreamElement {
54 data: reduced,
55 event_time: element.event_time,
56 processing_time: element.processing_time,
57 key,
58 metadata: element.metadata,
59 }))
60 }
61 }
62
63 pub async fn get_state(&self, key: Option<Vec<u8>>) -> Option<Vec<u8>> {
65 self.state.read().await.get(&key).cloned()
66 }
67
68 pub async fn clear(&self) {
70 self.state.write().await.clear();
71 }
72}
73
74pub trait FoldFunction: Send + Sync {
76 fn fold(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8>;
78}
79
80pub struct FoldOperator<F>
82where
83 F: FoldFunction,
84{
85 fold_fn: Arc<F>,
86 initial_value: Vec<u8>,
87 state: Arc<RwLock<KeyedState>>,
88}
89
90impl<F> FoldOperator<F>
91where
92 F: FoldFunction,
93{
94 pub fn new(fold_fn: F, initial_value: Vec<u8>) -> Self {
96 Self {
97 fold_fn: Arc::new(fold_fn),
98 initial_value,
99 state: Arc::new(RwLock::new(HashMap::new())),
100 }
101 }
102
103 pub async fn process(&self, element: StreamElement) -> Result<StreamElement> {
105 let mut state = self.state.write().await;
106
107 let key = element.key.clone();
108 let current = state
109 .entry(key.clone())
110 .or_insert_with(|| self.initial_value.clone());
111
112 let folded = self.fold_fn.fold(current.clone(), element.data);
113 *current = folded.clone();
114
115 Ok(StreamElement {
116 data: folded,
117 event_time: element.event_time,
118 processing_time: element.processing_time,
119 key,
120 metadata: element.metadata,
121 })
122 }
123
124 pub async fn get_state(&self, key: Option<Vec<u8>>) -> Vec<u8> {
126 self.state
127 .read()
128 .await
129 .get(&key)
130 .cloned()
131 .unwrap_or_else(|| self.initial_value.clone())
132 }
133
134 pub async fn clear(&self) {
136 self.state.write().await.clear();
137 }
138}
139
140pub struct ScanOperator<F>
142where
143 F: FoldFunction,
144{
145 fold_fn: Arc<F>,
146 initial_value: Vec<u8>,
147 state: Arc<RwLock<KeyedState>>,
148}
149
150impl<F> ScanOperator<F>
151where
152 F: FoldFunction,
153{
154 pub fn new(fold_fn: F, initial_value: Vec<u8>) -> Self {
156 Self {
157 fold_fn: Arc::new(fold_fn),
158 initial_value,
159 state: Arc::new(RwLock::new(HashMap::new())),
160 }
161 }
162
163 pub async fn process(&self, element: StreamElement) -> Result<Vec<StreamElement>> {
165 let mut state = self.state.write().await;
166
167 let key = element.key.clone();
168 let current = state
169 .entry(key.clone())
170 .or_insert_with(|| self.initial_value.clone());
171
172 let scanned = self.fold_fn.fold(current.clone(), element.data);
173 *current = scanned.clone();
174
175 Ok(vec![StreamElement {
176 data: scanned,
177 event_time: element.event_time,
178 processing_time: element.processing_time,
179 key,
180 metadata: element.metadata,
181 }])
182 }
183
184 pub async fn clear(&self) {
186 self.state.write().await.clear();
187 }
188}
189
190pub struct SumReduce;
192
193impl ReduceFunction for SumReduce {
194 fn reduce(&self, value1: Vec<u8>, value2: Vec<u8>) -> Vec<u8> {
195 let v1 = i64::from_le_bytes(value1.try_into().unwrap_or([0; 8]));
196 let v2 = i64::from_le_bytes(value2.try_into().unwrap_or([0; 8]));
197 (v1 + v2).to_le_bytes().to_vec()
198 }
199}
200
201pub struct ConcatFold;
203
204impl FoldFunction for ConcatFold {
205 fn fold(&self, mut accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
206 accumulator.extend(value);
207 accumulator
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use chrono::Utc;
215
216 #[tokio::test]
217 async fn test_reduce_operator() {
218 let operator = ReduceOperator::new(SumReduce);
219
220 let elem1 = StreamElement::new(5i64.to_le_bytes().to_vec(), Utc::now());
221 let elem2 = StreamElement::new(3i64.to_le_bytes().to_vec(), Utc::now());
222
223 let result1 = operator
224 .process(elem1)
225 .await
226 .expect("Failed to process first element in reduce operator test");
227 assert!(result1.is_none());
228
229 let result2 = operator
230 .process(elem2)
231 .await
232 .expect("Failed to process second element in reduce operator test");
233 assert!(result2.is_some());
234
235 let value = i64::from_le_bytes(
236 result2
237 .expect("Result2 should contain a value after reduce operation")
238 .data
239 .try_into()
240 .unwrap_or([0; 8]),
241 );
242 assert_eq!(value, 8);
243 }
244
245 #[tokio::test]
246 async fn test_fold_operator() {
247 let operator = FoldOperator::new(ConcatFold, vec![]);
248
249 let elem1 = StreamElement::new(vec![1, 2], Utc::now());
250 let elem2 = StreamElement::new(vec![3, 4], Utc::now());
251
252 let result1 = operator
253 .process(elem1)
254 .await
255 .expect("Failed to process first element in fold operator test");
256 assert_eq!(result1.data, vec![1, 2]);
257
258 let result2 = operator
259 .process(elem2)
260 .await
261 .expect("Failed to process second element in fold operator test");
262 assert_eq!(result2.data, vec![1, 2, 3, 4]);
263 }
264
265 #[tokio::test]
266 async fn test_scan_operator() {
267 let operator = ScanOperator::new(ConcatFold, vec![]);
268
269 let elem1 = StreamElement::new(vec![1, 2], Utc::now());
270 let elem2 = StreamElement::new(vec![3, 4], Utc::now());
271
272 let results1 = operator
273 .process(elem1)
274 .await
275 .expect("Failed to process first element in scan operator test");
276 assert_eq!(results1.len(), 1);
277 assert_eq!(results1[0].data, vec![1, 2]);
278
279 let results2 = operator
280 .process(elem2)
281 .await
282 .expect("Failed to process second element in scan operator test");
283 assert_eq!(results2.len(), 1);
284 assert_eq!(results2[0].data, vec![1, 2, 3, 4]);
285 }
286
287 #[tokio::test]
288 async fn test_reduce_with_keys() {
289 let operator = ReduceOperator::new(SumReduce);
290
291 let elem1 = StreamElement::new(5i64.to_le_bytes().to_vec(), Utc::now()).with_key(vec![1]);
292 let elem2 = StreamElement::new(3i64.to_le_bytes().to_vec(), Utc::now()).with_key(vec![1]);
293 let elem3 = StreamElement::new(10i64.to_le_bytes().to_vec(), Utc::now()).with_key(vec![2]);
294
295 operator
296 .process(elem1)
297 .await
298 .expect("Failed to process first keyed element");
299 operator
300 .process(elem2)
301 .await
302 .expect("Failed to process second keyed element");
303 operator
304 .process(elem3)
305 .await
306 .expect("Failed to process third keyed element");
307
308 let state1 = operator
309 .get_state(Some(vec![1]))
310 .await
311 .expect("Failed to get state for key [1]");
312 let value1 = i64::from_le_bytes(state1.try_into().unwrap_or([0; 8]));
313 assert_eq!(value1, 8);
314
315 let state2 = operator
316 .get_state(Some(vec![2]))
317 .await
318 .expect("Failed to get state for key [2]");
319 let value2 = i64::from_le_bytes(state2.try_into().unwrap_or([0; 8]));
320 assert_eq!(value2, 10);
321 }
322}