Skip to main content

oxigdal_streaming/state/
operator_state.rs

1//! Operator state for stream processing.
2
3use crate::error::{Result, StreamingError};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8/// Operator state trait.
9pub trait OperatorState: Send + Sync {
10    /// Snapshot the state.
11    fn snapshot(&self) -> impl std::future::Future<Output = Result<Vec<u8>>> + Send;
12
13    /// Restore from a snapshot.
14    fn restore(&self, snapshot: &[u8]) -> impl std::future::Future<Output = Result<()>> + Send;
15}
16
17/// Broadcast state (shared across all parallel instances).
18pub struct BroadcastState {
19    state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
20}
21
22impl BroadcastState {
23    /// Create a new broadcast state.
24    pub fn new() -> Self {
25        Self {
26            state: Arc::new(RwLock::new(HashMap::new())),
27        }
28    }
29
30    /// Get a value.
31    pub async fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
32        self.state.read().await.get(key).cloned()
33    }
34
35    /// Put a value.
36    pub async fn put(&self, key: Vec<u8>, value: Vec<u8>) {
37        self.state.write().await.insert(key, value);
38    }
39
40    /// Remove a value.
41    pub async fn remove(&self, key: &[u8]) {
42        self.state.write().await.remove(key);
43    }
44
45    /// Check if a key exists.
46    pub async fn contains(&self, key: &[u8]) -> bool {
47        self.state.read().await.contains_key(key)
48    }
49
50    /// Clear all state.
51    pub async fn clear(&self) {
52        self.state.write().await.clear();
53    }
54
55    /// Get all keys.
56    pub async fn keys(&self) -> Vec<Vec<u8>> {
57        self.state.read().await.keys().cloned().collect()
58    }
59}
60
61impl Default for BroadcastState {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl OperatorState for BroadcastState {
68    async fn snapshot(&self) -> Result<Vec<u8>> {
69        let state = self.state.read().await;
70        // Use oxicode for binary serialization since JSON requires string keys
71        oxicode::encode_to_vec(&*state)
72            .map_err(|e| StreamingError::SerializationError(e.to_string()))
73    }
74
75    async fn restore(&self, snapshot: &[u8]) -> Result<()> {
76        let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
77            .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
78        *self.state.write().await = restored;
79        Ok(())
80    }
81}
82
83/// Union list state (list that is distributed across parallel instances).
84pub struct UnionListState {
85    state: Arc<RwLock<Vec<Vec<u8>>>>,
86}
87
88impl UnionListState {
89    /// Create a new union list state.
90    pub fn new() -> Self {
91        Self {
92            state: Arc::new(RwLock::new(Vec::new())),
93        }
94    }
95
96    /// Get all values.
97    pub async fn get(&self) -> Vec<Vec<u8>> {
98        self.state.read().await.clone()
99    }
100
101    /// Add a value.
102    pub async fn add(&self, value: Vec<u8>) {
103        self.state.write().await.push(value);
104    }
105
106    /// Add multiple values.
107    pub async fn add_all(&self, values: Vec<Vec<u8>>) {
108        self.state.write().await.extend(values);
109    }
110
111    /// Update with new values.
112    pub async fn update(&self, values: Vec<Vec<u8>>) {
113        *self.state.write().await = values;
114    }
115
116    /// Clear all values.
117    pub async fn clear(&self) {
118        self.state.write().await.clear();
119    }
120
121    /// Get the number of values.
122    pub async fn len(&self) -> usize {
123        self.state.read().await.len()
124    }
125
126    /// Check if empty.
127    pub async fn is_empty(&self) -> bool {
128        self.state.read().await.is_empty()
129    }
130}
131
132impl Default for UnionListState {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138impl OperatorState for UnionListState {
139    async fn snapshot(&self) -> Result<Vec<u8>> {
140        let state = self.state.read().await;
141        Ok(serde_json::to_vec(&*state)?)
142    }
143
144    async fn restore(&self, snapshot: &[u8]) -> Result<()> {
145        let restored: Vec<Vec<u8>> = serde_json::from_slice(snapshot)?;
146        *self.state.write().await = restored;
147        Ok(())
148    }
149}
150
151/// Trait for checkpointable list state.
152pub trait ListCheckpointed {
153    /// Get the current state as a list.
154    fn snapshot_state(&self) -> impl std::future::Future<Output = Vec<Vec<u8>>> + Send;
155
156    /// Restore state from a list.
157    fn restore_state(
158        &self,
159        state: Vec<Vec<u8>>,
160    ) -> impl std::future::Future<Output = Result<()>> + Send;
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[tokio::test]
168    async fn test_broadcast_state() {
169        let state = BroadcastState::new();
170
171        state.put(vec![1], vec![42]).await;
172        assert_eq!(state.get(&[1]).await, Some(vec![42]));
173
174        assert!(state.contains(&[1]).await);
175        assert!(!state.contains(&[2]).await);
176
177        state.remove(&[1]).await;
178        assert_eq!(state.get(&[1]).await, None);
179    }
180
181    #[tokio::test]
182    async fn test_broadcast_state_snapshot() {
183        let state = BroadcastState::new();
184
185        state.put(vec![1], vec![42]).await;
186        state.put(vec![2], vec![43]).await;
187
188        let snapshot = state
189            .snapshot()
190            .await
191            .expect("Failed to create snapshot of broadcast state");
192
193        let state2 = BroadcastState::new();
194        state2
195            .restore(&snapshot)
196            .await
197            .expect("Failed to restore broadcast state from snapshot");
198
199        assert_eq!(state2.get(&[1]).await, Some(vec![42]));
200        assert_eq!(state2.get(&[2]).await, Some(vec![43]));
201    }
202
203    #[tokio::test]
204    async fn test_union_list_state() {
205        let state = UnionListState::new();
206
207        state.add(vec![1]).await;
208        state.add(vec![2]).await;
209        state.add(vec![3]).await;
210
211        let values = state.get().await;
212        assert_eq!(values, vec![vec![1], vec![2], vec![3]]);
213
214        assert_eq!(state.len().await, 3);
215        assert!(!state.is_empty().await);
216    }
217
218    #[tokio::test]
219    async fn test_union_list_state_snapshot() {
220        let state = UnionListState::new();
221
222        state.add(vec![1]).await;
223        state.add(vec![2]).await;
224
225        let snapshot = state
226            .snapshot()
227            .await
228            .expect("Failed to create snapshot of union list state");
229
230        let state2 = UnionListState::new();
231        state2
232            .restore(&snapshot)
233            .await
234            .expect("Failed to restore union list state from snapshot");
235
236        assert_eq!(state2.get().await, vec![vec![1], vec![2]]);
237    }
238}