Skip to main content

oxigdal_streaming/transformations/
reduce.rs

1//! Reduce, fold, and scan operations.
2
3use crate::core::stream::StreamElement;
4use crate::error::Result;
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// Type alias for keyed state storage (key -> value bytes).
10type KeyedState = HashMap<Option<Vec<u8>>, Vec<u8>>;
11
12/// Function for reducing elements.
13pub trait ReduceFunction: Send + Sync {
14    /// Reduce two values into one.
15    fn reduce(&self, value1: Vec<u8>, value2: Vec<u8>) -> Vec<u8>;
16}
17
18/// Reduce operator.
19pub struct ReduceOperator<F>
20where
21    F: ReduceFunction,
22{
23    reduce_fn: Arc<F>,
24    state: Arc<RwLock<KeyedState>>,
25}
26
27impl<F> ReduceOperator<F>
28where
29    F: ReduceFunction,
30{
31    /// Create a new reduce operator.
32    pub fn new(reduce_fn: F) -> Self {
33        Self {
34            reduce_fn: Arc::new(reduce_fn),
35            state: Arc::new(RwLock::new(HashMap::new())),
36        }
37    }
38
39    /// Process an element.
40    pub async fn process(&self, element: StreamElement) -> Result<Option<StreamElement>> {
41        let mut state = self.state.write().await;
42
43        let key = element.key.clone();
44        let current = state.entry(key.clone()).or_insert_with(Vec::new);
45
46        if current.is_empty() {
47            *current = element.data;
48            Ok(None)
49        } else {
50            let reduced = self.reduce_fn.reduce(current.clone(), element.data);
51            *current = reduced.clone();
52
53            Ok(Some(StreamElement {
54                data: reduced,
55                event_time: element.event_time,
56                processing_time: element.processing_time,
57                key,
58                metadata: element.metadata,
59            }))
60        }
61    }
62
63    /// Get the current state for a key.
64    pub async fn get_state(&self, key: Option<Vec<u8>>) -> Option<Vec<u8>> {
65        self.state.read().await.get(&key).cloned()
66    }
67
68    /// Clear all state.
69    pub async fn clear(&self) {
70        self.state.write().await.clear();
71    }
72}
73
74/// Function for folding elements with an accumulator.
75pub trait FoldFunction: Send + Sync {
76    /// Fold a value into the accumulator.
77    fn fold(&self, accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8>;
78}
79
80/// Fold operator.
81pub struct FoldOperator<F>
82where
83    F: FoldFunction,
84{
85    fold_fn: Arc<F>,
86    initial_value: Vec<u8>,
87    state: Arc<RwLock<KeyedState>>,
88}
89
90impl<F> FoldOperator<F>
91where
92    F: FoldFunction,
93{
94    /// Create a new fold operator.
95    pub fn new(fold_fn: F, initial_value: Vec<u8>) -> Self {
96        Self {
97            fold_fn: Arc::new(fold_fn),
98            initial_value,
99            state: Arc::new(RwLock::new(HashMap::new())),
100        }
101    }
102
103    /// Process an element.
104    pub async fn process(&self, element: StreamElement) -> Result<StreamElement> {
105        let mut state = self.state.write().await;
106
107        let key = element.key.clone();
108        let current = state
109            .entry(key.clone())
110            .or_insert_with(|| self.initial_value.clone());
111
112        let folded = self.fold_fn.fold(current.clone(), element.data);
113        *current = folded.clone();
114
115        Ok(StreamElement {
116            data: folded,
117            event_time: element.event_time,
118            processing_time: element.processing_time,
119            key,
120            metadata: element.metadata,
121        })
122    }
123
124    /// Get the current state for a key.
125    pub async fn get_state(&self, key: Option<Vec<u8>>) -> Vec<u8> {
126        self.state
127            .read()
128            .await
129            .get(&key)
130            .cloned()
131            .unwrap_or_else(|| self.initial_value.clone())
132    }
133
134    /// Clear all state.
135    pub async fn clear(&self) {
136        self.state.write().await.clear();
137    }
138}
139
140/// Scan operator (like fold but emits intermediate results).
141pub struct ScanOperator<F>
142where
143    F: FoldFunction,
144{
145    fold_fn: Arc<F>,
146    initial_value: Vec<u8>,
147    state: Arc<RwLock<KeyedState>>,
148}
149
150impl<F> ScanOperator<F>
151where
152    F: FoldFunction,
153{
154    /// Create a new scan operator.
155    pub fn new(fold_fn: F, initial_value: Vec<u8>) -> Self {
156        Self {
157            fold_fn: Arc::new(fold_fn),
158            initial_value,
159            state: Arc::new(RwLock::new(HashMap::new())),
160        }
161    }
162
163    /// Process an element.
164    pub async fn process(&self, element: StreamElement) -> Result<Vec<StreamElement>> {
165        let mut state = self.state.write().await;
166
167        let key = element.key.clone();
168        let current = state
169            .entry(key.clone())
170            .or_insert_with(|| self.initial_value.clone());
171
172        let scanned = self.fold_fn.fold(current.clone(), element.data);
173        *current = scanned.clone();
174
175        Ok(vec![StreamElement {
176            data: scanned,
177            event_time: element.event_time,
178            processing_time: element.processing_time,
179            key,
180            metadata: element.metadata,
181        }])
182    }
183
184    /// Clear all state.
185    pub async fn clear(&self) {
186        self.state.write().await.clear();
187    }
188}
189
190/// Simple sum reduce function.
191pub struct SumReduce;
192
193impl ReduceFunction for SumReduce {
194    fn reduce(&self, value1: Vec<u8>, value2: Vec<u8>) -> Vec<u8> {
195        let v1 = i64::from_le_bytes(value1.try_into().unwrap_or([0; 8]));
196        let v2 = i64::from_le_bytes(value2.try_into().unwrap_or([0; 8]));
197        (v1 + v2).to_le_bytes().to_vec()
198    }
199}
200
201/// Simple concatenation fold function.
202pub struct ConcatFold;
203
204impl FoldFunction for ConcatFold {
205    fn fold(&self, mut accumulator: Vec<u8>, value: Vec<u8>) -> Vec<u8> {
206        accumulator.extend(value);
207        accumulator
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use chrono::Utc;
215
216    #[tokio::test]
217    async fn test_reduce_operator() {
218        let operator = ReduceOperator::new(SumReduce);
219
220        let elem1 = StreamElement::new(5i64.to_le_bytes().to_vec(), Utc::now());
221        let elem2 = StreamElement::new(3i64.to_le_bytes().to_vec(), Utc::now());
222
223        let result1 = operator
224            .process(elem1)
225            .await
226            .expect("Failed to process first element in reduce operator test");
227        assert!(result1.is_none());
228
229        let result2 = operator
230            .process(elem2)
231            .await
232            .expect("Failed to process second element in reduce operator test");
233        assert!(result2.is_some());
234
235        let value = i64::from_le_bytes(
236            result2
237                .expect("Result2 should contain a value after reduce operation")
238                .data
239                .try_into()
240                .unwrap_or([0; 8]),
241        );
242        assert_eq!(value, 8);
243    }
244
245    #[tokio::test]
246    async fn test_fold_operator() {
247        let operator = FoldOperator::new(ConcatFold, vec![]);
248
249        let elem1 = StreamElement::new(vec![1, 2], Utc::now());
250        let elem2 = StreamElement::new(vec![3, 4], Utc::now());
251
252        let result1 = operator
253            .process(elem1)
254            .await
255            .expect("Failed to process first element in fold operator test");
256        assert_eq!(result1.data, vec![1, 2]);
257
258        let result2 = operator
259            .process(elem2)
260            .await
261            .expect("Failed to process second element in fold operator test");
262        assert_eq!(result2.data, vec![1, 2, 3, 4]);
263    }
264
265    #[tokio::test]
266    async fn test_scan_operator() {
267        let operator = ScanOperator::new(ConcatFold, vec![]);
268
269        let elem1 = StreamElement::new(vec![1, 2], Utc::now());
270        let elem2 = StreamElement::new(vec![3, 4], Utc::now());
271
272        let results1 = operator
273            .process(elem1)
274            .await
275            .expect("Failed to process first element in scan operator test");
276        assert_eq!(results1.len(), 1);
277        assert_eq!(results1[0].data, vec![1, 2]);
278
279        let results2 = operator
280            .process(elem2)
281            .await
282            .expect("Failed to process second element in scan operator test");
283        assert_eq!(results2.len(), 1);
284        assert_eq!(results2[0].data, vec![1, 2, 3, 4]);
285    }
286
287    #[tokio::test]
288    async fn test_reduce_with_keys() {
289        let operator = ReduceOperator::new(SumReduce);
290
291        let elem1 = StreamElement::new(5i64.to_le_bytes().to_vec(), Utc::now()).with_key(vec![1]);
292        let elem2 = StreamElement::new(3i64.to_le_bytes().to_vec(), Utc::now()).with_key(vec![1]);
293        let elem3 = StreamElement::new(10i64.to_le_bytes().to_vec(), Utc::now()).with_key(vec![2]);
294
295        operator
296            .process(elem1)
297            .await
298            .expect("Failed to process first keyed element");
299        operator
300            .process(elem2)
301            .await
302            .expect("Failed to process second keyed element");
303        operator
304            .process(elem3)
305            .await
306            .expect("Failed to process third keyed element");
307
308        let state1 = operator
309            .get_state(Some(vec![1]))
310            .await
311            .expect("Failed to get state for key [1]");
312        let value1 = i64::from_le_bytes(state1.try_into().unwrap_or([0; 8]));
313        assert_eq!(value1, 8);
314
315        let state2 = operator
316            .get_state(Some(vec![2]))
317            .await
318            .expect("Failed to get state for key [2]");
319        let value2 = i64::from_le_bytes(state2.try_into().unwrap_or([0; 8]));
320        assert_eq!(value2, 10);
321    }
322}