Skip to main content

oxigdal_streaming/state/
keyed_state.rs

1//! Keyed state for stream processing.
2
3use crate::error::Result;
4use crate::state::backend::StateBackend;
5use std::sync::Arc;
6
7/// Keyed state trait.
8pub trait KeyedState: Send + Sync {
9    /// Get the state key.
10    fn key(&self) -> &[u8];
11
12    /// Clear the state.
13    fn clear(&self) -> impl std::future::Future<Output = Result<()>> + Send;
14}
15
16/// Value state (stores a single value per key).
17pub struct ValueState<B>
18where
19    B: StateBackend,
20{
21    backend: Arc<B>,
22    namespace: String,
23    key: Vec<u8>,
24}
25
26impl<B> ValueState<B>
27where
28    B: StateBackend,
29{
30    /// Create a new value state.
31    pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>) -> Self {
32        Self {
33            backend,
34            namespace,
35            key,
36        }
37    }
38
39    /// Get the value.
40    pub async fn get(&self) -> Result<Option<Vec<u8>>> {
41        let state_key = self.make_state_key();
42        self.backend.get(&state_key).await
43    }
44
45    /// Set the value.
46    pub async fn set(&self, value: Vec<u8>) -> Result<()> {
47        let state_key = self.make_state_key();
48        self.backend.put(&state_key, &value).await
49    }
50
51    /// Update the value using a function.
52    pub async fn update<F>(&self, f: F) -> Result<()>
53    where
54        F: FnOnce(Option<Vec<u8>>) -> Vec<u8>,
55    {
56        let current = self.get().await?;
57        let new_value = f(current);
58        self.set(new_value).await
59    }
60
61    fn make_state_key(&self) -> Vec<u8> {
62        let mut state_key = Vec::new();
63        state_key.extend_from_slice(self.namespace.as_bytes());
64        state_key.push(b':');
65        state_key.extend_from_slice(&self.key);
66        state_key
67    }
68}
69
70impl<B> KeyedState for ValueState<B>
71where
72    B: StateBackend,
73{
74    fn key(&self) -> &[u8] {
75        &self.key
76    }
77
78    async fn clear(&self) -> Result<()> {
79        let state_key = self.make_state_key();
80        self.backend.delete(&state_key).await
81    }
82}
83
84/// List state (stores a list of values per key).
85pub struct ListState<B>
86where
87    B: StateBackend,
88{
89    backend: Arc<B>,
90    namespace: String,
91    key: Vec<u8>,
92}
93
94impl<B> ListState<B>
95where
96    B: StateBackend,
97{
98    /// Create a new list state.
99    pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>) -> Self {
100        Self {
101            backend,
102            namespace,
103            key,
104        }
105    }
106
107    /// Get all values in the list.
108    pub async fn get(&self) -> Result<Vec<Vec<u8>>> {
109        let state_key = self.make_state_key();
110        if let Some(data) = self.backend.get(&state_key).await? {
111            Ok(serde_json::from_slice(&data)?)
112        } else {
113            Ok(Vec::new())
114        }
115    }
116
117    /// Add a value to the list.
118    pub async fn add(&self, value: Vec<u8>) -> Result<()> {
119        let mut list = self.get().await?;
120        list.push(value);
121        self.set_list(list).await
122    }
123
124    /// Add multiple values to the list.
125    pub async fn add_all(&self, values: Vec<Vec<u8>>) -> Result<()> {
126        let mut list = self.get().await?;
127        list.extend(values);
128        self.set_list(list).await
129    }
130
131    /// Update the entire list.
132    pub async fn update(&self, values: Vec<Vec<u8>>) -> Result<()> {
133        self.set_list(values).await
134    }
135
136    fn set_list(&self, list: Vec<Vec<u8>>) -> impl std::future::Future<Output = Result<()>> + Send {
137        let state_key = self.make_state_key();
138        let backend = self.backend.clone();
139        async move {
140            let data = serde_json::to_vec(&list)?;
141            backend.put(&state_key, &data).await
142        }
143    }
144
145    fn make_state_key(&self) -> Vec<u8> {
146        let mut state_key = Vec::new();
147        state_key.extend_from_slice(self.namespace.as_bytes());
148        state_key.push(b':');
149        state_key.extend_from_slice(&self.key);
150        state_key
151    }
152}
153
154impl<B> KeyedState for ListState<B>
155where
156    B: StateBackend,
157{
158    fn key(&self) -> &[u8] {
159        &self.key
160    }
161
162    async fn clear(&self) -> Result<()> {
163        let state_key = self.make_state_key();
164        self.backend.delete(&state_key).await
165    }
166}
167
168/// Map state (stores key-value pairs per key).
169pub struct MapState<B>
170where
171    B: StateBackend,
172{
173    backend: Arc<B>,
174    namespace: String,
175    key: Vec<u8>,
176}
177
178impl<B> MapState<B>
179where
180    B: StateBackend,
181{
182    /// Create a new map state.
183    pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>) -> Self {
184        Self {
185            backend,
186            namespace,
187            key,
188        }
189    }
190
191    /// Get a value from the map.
192    pub async fn get(&self, map_key: &[u8]) -> Result<Option<Vec<u8>>> {
193        let state_key = self.make_state_key(map_key);
194        self.backend.get(&state_key).await
195    }
196
197    /// Put a value into the map.
198    pub async fn put(&self, map_key: &[u8], value: Vec<u8>) -> Result<()> {
199        let state_key = self.make_state_key(map_key);
200        self.backend.put(&state_key, &value).await
201    }
202
203    /// Remove a key from the map.
204    pub async fn remove(&self, map_key: &[u8]) -> Result<()> {
205        let state_key = self.make_state_key(map_key);
206        self.backend.delete(&state_key).await
207    }
208
209    /// Check if the map contains a key.
210    pub async fn contains(&self, map_key: &[u8]) -> Result<bool> {
211        let state_key = self.make_state_key(map_key);
212        self.backend.contains(&state_key).await
213    }
214
215    fn make_state_key(&self, map_key: &[u8]) -> Vec<u8> {
216        let mut state_key = Vec::new();
217        state_key.extend_from_slice(self.namespace.as_bytes());
218        state_key.push(b':');
219        state_key.extend_from_slice(&self.key);
220        state_key.push(b':');
221        state_key.extend_from_slice(map_key);
222        state_key
223    }
224}
225
226impl<B> KeyedState for MapState<B>
227where
228    B: StateBackend,
229{
230    fn key(&self) -> &[u8] {
231        &self.key
232    }
233
234    async fn clear(&self) -> Result<()> {
235        Ok(())
236    }
237}
238
239/// Reducing state.
240pub struct ReducingState<B, F>
241where
242    B: StateBackend,
243    F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
244{
245    value_state: ValueState<B>,
246    reduce_fn: Arc<F>,
247}
248
249impl<B, F> ReducingState<B, F>
250where
251    B: StateBackend,
252    F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
253{
254    /// Create a new reducing state.
255    pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>, reduce_fn: F) -> Self {
256        Self {
257            value_state: ValueState::new(backend, namespace, key),
258            reduce_fn: Arc::new(reduce_fn),
259        }
260    }
261
262    /// Get the reduced value.
263    pub async fn get(&self) -> Result<Option<Vec<u8>>> {
264        self.value_state.get().await
265    }
266
267    /// Add a value (will be reduced with existing value).
268    pub async fn add(&self, value: Vec<u8>) -> Result<()> {
269        let reduce_fn = self.reduce_fn.clone();
270        self.value_state
271            .update(move |current| {
272                if let Some(existing) = current {
273                    reduce_fn(existing, value)
274                } else {
275                    value
276                }
277            })
278            .await
279    }
280}
281
282impl<B, F> KeyedState for ReducingState<B, F>
283where
284    B: StateBackend,
285    F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
286{
287    fn key(&self) -> &[u8] {
288        self.value_state.key()
289    }
290
291    async fn clear(&self) -> Result<()> {
292        self.value_state.clear().await
293    }
294}
295
296/// Aggregating state.
297pub struct AggregatingState<B, F>
298where
299    B: StateBackend,
300    F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
301{
302    value_state: ValueState<B>,
303    aggregate_fn: Arc<F>,
304}
305
306impl<B, F> AggregatingState<B, F>
307where
308    B: StateBackend,
309    F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
310{
311    /// Create a new aggregating state.
312    pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>, aggregate_fn: F) -> Self {
313        Self {
314            value_state: ValueState::new(backend, namespace, key),
315            aggregate_fn: Arc::new(aggregate_fn),
316        }
317    }
318
319    /// Get the aggregated value.
320    pub async fn get(&self) -> Result<Option<Vec<u8>>> {
321        self.value_state.get().await
322    }
323
324    /// Add a value (will be aggregated with existing value).
325    pub async fn add(&self, value: Vec<u8>) -> Result<()> {
326        let aggregate_fn = self.aggregate_fn.clone();
327        self.value_state
328            .update(move |current| {
329                if let Some(existing) = current {
330                    aggregate_fn(existing, value)
331                } else {
332                    value
333                }
334            })
335            .await
336    }
337}
338
339impl<B, F> KeyedState for AggregatingState<B, F>
340where
341    B: StateBackend,
342    F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
343{
344    fn key(&self) -> &[u8] {
345        self.value_state.key()
346    }
347
348    async fn clear(&self) -> Result<()> {
349        self.value_state.clear().await
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::state::backend::MemoryStateBackend;
357
358    #[tokio::test]
359    async fn test_value_state() {
360        let backend = Arc::new(MemoryStateBackend::new());
361        let state = ValueState::new(backend, "test".to_string(), vec![1]);
362
363        state
364            .set(vec![42])
365            .await
366            .expect("Failed to set value in value state");
367        let value = state
368            .get()
369            .await
370            .expect("Failed to get value from value state");
371        assert_eq!(value, Some(vec![42]));
372
373        state.clear().await.expect("Failed to clear value state");
374        let value = state.get().await.expect("Failed to get value after clear");
375        assert_eq!(value, None);
376    }
377
378    #[tokio::test]
379    async fn test_list_state() {
380        let backend = Arc::new(MemoryStateBackend::new());
381        let state = ListState::new(backend, "test".to_string(), vec![1]);
382
383        state
384            .add(vec![1])
385            .await
386            .expect("Failed to add first item to list state");
387        state
388            .add(vec![2])
389            .await
390            .expect("Failed to add second item to list state");
391        state
392            .add(vec![3])
393            .await
394            .expect("Failed to add third item to list state");
395
396        let list = state
397            .get()
398            .await
399            .expect("Failed to get list from list state");
400        assert_eq!(list, vec![vec![1], vec![2], vec![3]]);
401    }
402
403    #[tokio::test]
404    async fn test_map_state() {
405        let backend = Arc::new(MemoryStateBackend::new());
406        let state = MapState::new(backend, "test".to_string(), vec![1]);
407
408        state
409            .put(b"key1", vec![1])
410            .await
411            .expect("Failed to put key1 in map state");
412        state
413            .put(b"key2", vec![2])
414            .await
415            .expect("Failed to put key2 in map state");
416
417        assert_eq!(
418            state
419                .get(b"key1")
420                .await
421                .expect("Failed to get key1 from map state"),
422            Some(vec![1])
423        );
424        assert_eq!(
425            state
426                .get(b"key2")
427                .await
428                .expect("Failed to get key2 from map state"),
429            Some(vec![2])
430        );
431
432        assert!(
433            state
434                .contains(b"key1")
435                .await
436                .expect("Failed to check if map contains key1")
437        );
438
439        state
440            .remove(b"key1")
441            .await
442            .expect("Failed to remove key1 from map state");
443        assert!(
444            !state
445                .contains(b"key1")
446                .await
447                .expect("Failed to check if map contains key1 after removal")
448        );
449    }
450
451    #[tokio::test]
452    async fn test_reducing_state() {
453        let backend = Arc::new(MemoryStateBackend::new());
454        let state = ReducingState::new(backend, "test".to_string(), vec![1], |a, b| {
455            let v1 = i64::from_le_bytes(a.try_into().unwrap_or([0; 8]));
456            let v2 = i64::from_le_bytes(b.try_into().unwrap_or([0; 8]));
457            (v1 + v2).to_le_bytes().to_vec()
458        });
459
460        state
461            .add(5i64.to_le_bytes().to_vec())
462            .await
463            .expect("Failed to add first value to reducing state");
464        state
465            .add(3i64.to_le_bytes().to_vec())
466            .await
467            .expect("Failed to add second value to reducing state");
468
469        let result = state
470            .get()
471            .await
472            .expect("Failed to get value from reducing state")
473            .expect("Expected Some value from reducing state");
474        let value = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
475        assert_eq!(value, 8);
476    }
477}