oxigdal_streaming/state/
operator_state.rs1use crate::error::{Result, StreamingError};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub trait OperatorState: Send + Sync {
10 fn snapshot(&self) -> impl std::future::Future<Output = Result<Vec<u8>>> + Send;
12
13 fn restore(&self, snapshot: &[u8]) -> impl std::future::Future<Output = Result<()>> + Send;
15}
16
17pub struct BroadcastState {
19 state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
20}
21
22impl BroadcastState {
23 pub fn new() -> Self {
25 Self {
26 state: Arc::new(RwLock::new(HashMap::new())),
27 }
28 }
29
30 pub async fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
32 self.state.read().await.get(key).cloned()
33 }
34
35 pub async fn put(&self, key: Vec<u8>, value: Vec<u8>) {
37 self.state.write().await.insert(key, value);
38 }
39
40 pub async fn remove(&self, key: &[u8]) {
42 self.state.write().await.remove(key);
43 }
44
45 pub async fn contains(&self, key: &[u8]) -> bool {
47 self.state.read().await.contains_key(key)
48 }
49
50 pub async fn clear(&self) {
52 self.state.write().await.clear();
53 }
54
55 pub async fn keys(&self) -> Vec<Vec<u8>> {
57 self.state.read().await.keys().cloned().collect()
58 }
59}
60
61impl Default for BroadcastState {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl OperatorState for BroadcastState {
68 async fn snapshot(&self) -> Result<Vec<u8>> {
69 let state = self.state.read().await;
70 oxicode::encode_to_vec(&*state)
72 .map_err(|e| StreamingError::SerializationError(e.to_string()))
73 }
74
75 async fn restore(&self, snapshot: &[u8]) -> Result<()> {
76 let (restored, _): (HashMap<Vec<u8>, Vec<u8>>, _) = oxicode::decode_from_slice(snapshot)
77 .map_err(|e| StreamingError::SerializationError(e.to_string()))?;
78 *self.state.write().await = restored;
79 Ok(())
80 }
81}
82
83pub struct UnionListState {
85 state: Arc<RwLock<Vec<Vec<u8>>>>,
86}
87
88impl UnionListState {
89 pub fn new() -> Self {
91 Self {
92 state: Arc::new(RwLock::new(Vec::new())),
93 }
94 }
95
96 pub async fn get(&self) -> Vec<Vec<u8>> {
98 self.state.read().await.clone()
99 }
100
101 pub async fn add(&self, value: Vec<u8>) {
103 self.state.write().await.push(value);
104 }
105
106 pub async fn add_all(&self, values: Vec<Vec<u8>>) {
108 self.state.write().await.extend(values);
109 }
110
111 pub async fn update(&self, values: Vec<Vec<u8>>) {
113 *self.state.write().await = values;
114 }
115
116 pub async fn clear(&self) {
118 self.state.write().await.clear();
119 }
120
121 pub async fn len(&self) -> usize {
123 self.state.read().await.len()
124 }
125
126 pub async fn is_empty(&self) -> bool {
128 self.state.read().await.is_empty()
129 }
130}
131
132impl Default for UnionListState {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl OperatorState for UnionListState {
139 async fn snapshot(&self) -> Result<Vec<u8>> {
140 let state = self.state.read().await;
141 Ok(serde_json::to_vec(&*state)?)
142 }
143
144 async fn restore(&self, snapshot: &[u8]) -> Result<()> {
145 let restored: Vec<Vec<u8>> = serde_json::from_slice(snapshot)?;
146 *self.state.write().await = restored;
147 Ok(())
148 }
149}
150
151pub trait ListCheckpointed {
153 fn snapshot_state(&self) -> impl std::future::Future<Output = Vec<Vec<u8>>> + Send;
155
156 fn restore_state(
158 &self,
159 state: Vec<Vec<u8>>,
160 ) -> impl std::future::Future<Output = Result<()>> + Send;
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[tokio::test]
168 async fn test_broadcast_state() {
169 let state = BroadcastState::new();
170
171 state.put(vec![1], vec![42]).await;
172 assert_eq!(state.get(&[1]).await, Some(vec![42]));
173
174 assert!(state.contains(&[1]).await);
175 assert!(!state.contains(&[2]).await);
176
177 state.remove(&[1]).await;
178 assert_eq!(state.get(&[1]).await, None);
179 }
180
181 #[tokio::test]
182 async fn test_broadcast_state_snapshot() {
183 let state = BroadcastState::new();
184
185 state.put(vec![1], vec![42]).await;
186 state.put(vec![2], vec![43]).await;
187
188 let snapshot = state
189 .snapshot()
190 .await
191 .expect("Failed to create snapshot of broadcast state");
192
193 let state2 = BroadcastState::new();
194 state2
195 .restore(&snapshot)
196 .await
197 .expect("Failed to restore broadcast state from snapshot");
198
199 assert_eq!(state2.get(&[1]).await, Some(vec![42]));
200 assert_eq!(state2.get(&[2]).await, Some(vec![43]));
201 }
202
203 #[tokio::test]
204 async fn test_union_list_state() {
205 let state = UnionListState::new();
206
207 state.add(vec![1]).await;
208 state.add(vec![2]).await;
209 state.add(vec![3]).await;
210
211 let values = state.get().await;
212 assert_eq!(values, vec![vec![1], vec![2], vec![3]]);
213
214 assert_eq!(state.len().await, 3);
215 assert!(!state.is_empty().await);
216 }
217
218 #[tokio::test]
219 async fn test_union_list_state_snapshot() {
220 let state = UnionListState::new();
221
222 state.add(vec![1]).await;
223 state.add(vec![2]).await;
224
225 let snapshot = state
226 .snapshot()
227 .await
228 .expect("Failed to create snapshot of union list state");
229
230 let state2 = UnionListState::new();
231 state2
232 .restore(&snapshot)
233 .await
234 .expect("Failed to restore union list state from snapshot");
235
236 assert_eq!(state2.get().await, vec![vec![1], vec![2]]);
237 }
238}