1use crate::error::Result;
4use crate::state::backend::StateBackend;
5use std::sync::Arc;
6
7pub trait KeyedState: Send + Sync {
9 fn key(&self) -> &[u8];
11
12 fn clear(&self) -> impl std::future::Future<Output = Result<()>> + Send;
14}
15
16pub struct ValueState<B>
18where
19 B: StateBackend,
20{
21 backend: Arc<B>,
22 namespace: String,
23 key: Vec<u8>,
24}
25
26impl<B> ValueState<B>
27where
28 B: StateBackend,
29{
30 pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>) -> Self {
32 Self {
33 backend,
34 namespace,
35 key,
36 }
37 }
38
39 pub async fn get(&self) -> Result<Option<Vec<u8>>> {
41 let state_key = self.make_state_key();
42 self.backend.get(&state_key).await
43 }
44
45 pub async fn set(&self, value: Vec<u8>) -> Result<()> {
47 let state_key = self.make_state_key();
48 self.backend.put(&state_key, &value).await
49 }
50
51 pub async fn update<F>(&self, f: F) -> Result<()>
53 where
54 F: FnOnce(Option<Vec<u8>>) -> Vec<u8>,
55 {
56 let current = self.get().await?;
57 let new_value = f(current);
58 self.set(new_value).await
59 }
60
61 fn make_state_key(&self) -> Vec<u8> {
62 let mut state_key = Vec::new();
63 state_key.extend_from_slice(self.namespace.as_bytes());
64 state_key.push(b':');
65 state_key.extend_from_slice(&self.key);
66 state_key
67 }
68}
69
70impl<B> KeyedState for ValueState<B>
71where
72 B: StateBackend,
73{
74 fn key(&self) -> &[u8] {
75 &self.key
76 }
77
78 async fn clear(&self) -> Result<()> {
79 let state_key = self.make_state_key();
80 self.backend.delete(&state_key).await
81 }
82}
83
84pub struct ListState<B>
86where
87 B: StateBackend,
88{
89 backend: Arc<B>,
90 namespace: String,
91 key: Vec<u8>,
92}
93
94impl<B> ListState<B>
95where
96 B: StateBackend,
97{
98 pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>) -> Self {
100 Self {
101 backend,
102 namespace,
103 key,
104 }
105 }
106
107 pub async fn get(&self) -> Result<Vec<Vec<u8>>> {
109 let state_key = self.make_state_key();
110 if let Some(data) = self.backend.get(&state_key).await? {
111 Ok(serde_json::from_slice(&data)?)
112 } else {
113 Ok(Vec::new())
114 }
115 }
116
117 pub async fn add(&self, value: Vec<u8>) -> Result<()> {
119 let mut list = self.get().await?;
120 list.push(value);
121 self.set_list(list).await
122 }
123
124 pub async fn add_all(&self, values: Vec<Vec<u8>>) -> Result<()> {
126 let mut list = self.get().await?;
127 list.extend(values);
128 self.set_list(list).await
129 }
130
131 pub async fn update(&self, values: Vec<Vec<u8>>) -> Result<()> {
133 self.set_list(values).await
134 }
135
136 fn set_list(&self, list: Vec<Vec<u8>>) -> impl std::future::Future<Output = Result<()>> + Send {
137 let state_key = self.make_state_key();
138 let backend = self.backend.clone();
139 async move {
140 let data = serde_json::to_vec(&list)?;
141 backend.put(&state_key, &data).await
142 }
143 }
144
145 fn make_state_key(&self) -> Vec<u8> {
146 let mut state_key = Vec::new();
147 state_key.extend_from_slice(self.namespace.as_bytes());
148 state_key.push(b':');
149 state_key.extend_from_slice(&self.key);
150 state_key
151 }
152}
153
154impl<B> KeyedState for ListState<B>
155where
156 B: StateBackend,
157{
158 fn key(&self) -> &[u8] {
159 &self.key
160 }
161
162 async fn clear(&self) -> Result<()> {
163 let state_key = self.make_state_key();
164 self.backend.delete(&state_key).await
165 }
166}
167
168pub struct MapState<B>
170where
171 B: StateBackend,
172{
173 backend: Arc<B>,
174 namespace: String,
175 key: Vec<u8>,
176}
177
178impl<B> MapState<B>
179where
180 B: StateBackend,
181{
182 pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>) -> Self {
184 Self {
185 backend,
186 namespace,
187 key,
188 }
189 }
190
191 pub async fn get(&self, map_key: &[u8]) -> Result<Option<Vec<u8>>> {
193 let state_key = self.make_state_key(map_key);
194 self.backend.get(&state_key).await
195 }
196
197 pub async fn put(&self, map_key: &[u8], value: Vec<u8>) -> Result<()> {
199 let state_key = self.make_state_key(map_key);
200 self.backend.put(&state_key, &value).await
201 }
202
203 pub async fn remove(&self, map_key: &[u8]) -> Result<()> {
205 let state_key = self.make_state_key(map_key);
206 self.backend.delete(&state_key).await
207 }
208
209 pub async fn contains(&self, map_key: &[u8]) -> Result<bool> {
211 let state_key = self.make_state_key(map_key);
212 self.backend.contains(&state_key).await
213 }
214
215 fn make_state_key(&self, map_key: &[u8]) -> Vec<u8> {
216 let mut state_key = Vec::new();
217 state_key.extend_from_slice(self.namespace.as_bytes());
218 state_key.push(b':');
219 state_key.extend_from_slice(&self.key);
220 state_key.push(b':');
221 state_key.extend_from_slice(map_key);
222 state_key
223 }
224}
225
226impl<B> KeyedState for MapState<B>
227where
228 B: StateBackend,
229{
230 fn key(&self) -> &[u8] {
231 &self.key
232 }
233
234 async fn clear(&self) -> Result<()> {
235 Ok(())
236 }
237}
238
239pub struct ReducingState<B, F>
241where
242 B: StateBackend,
243 F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
244{
245 value_state: ValueState<B>,
246 reduce_fn: Arc<F>,
247}
248
249impl<B, F> ReducingState<B, F>
250where
251 B: StateBackend,
252 F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
253{
254 pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>, reduce_fn: F) -> Self {
256 Self {
257 value_state: ValueState::new(backend, namespace, key),
258 reduce_fn: Arc::new(reduce_fn),
259 }
260 }
261
262 pub async fn get(&self) -> Result<Option<Vec<u8>>> {
264 self.value_state.get().await
265 }
266
267 pub async fn add(&self, value: Vec<u8>) -> Result<()> {
269 let reduce_fn = self.reduce_fn.clone();
270 self.value_state
271 .update(move |current| {
272 if let Some(existing) = current {
273 reduce_fn(existing, value)
274 } else {
275 value
276 }
277 })
278 .await
279 }
280}
281
282impl<B, F> KeyedState for ReducingState<B, F>
283where
284 B: StateBackend,
285 F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
286{
287 fn key(&self) -> &[u8] {
288 self.value_state.key()
289 }
290
291 async fn clear(&self) -> Result<()> {
292 self.value_state.clear().await
293 }
294}
295
296pub struct AggregatingState<B, F>
298where
299 B: StateBackend,
300 F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
301{
302 value_state: ValueState<B>,
303 aggregate_fn: Arc<F>,
304}
305
306impl<B, F> AggregatingState<B, F>
307where
308 B: StateBackend,
309 F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
310{
311 pub fn new(backend: Arc<B>, namespace: String, key: Vec<u8>, aggregate_fn: F) -> Self {
313 Self {
314 value_state: ValueState::new(backend, namespace, key),
315 aggregate_fn: Arc::new(aggregate_fn),
316 }
317 }
318
319 pub async fn get(&self) -> Result<Option<Vec<u8>>> {
321 self.value_state.get().await
322 }
323
324 pub async fn add(&self, value: Vec<u8>) -> Result<()> {
326 let aggregate_fn = self.aggregate_fn.clone();
327 self.value_state
328 .update(move |current| {
329 if let Some(existing) = current {
330 aggregate_fn(existing, value)
331 } else {
332 value
333 }
334 })
335 .await
336 }
337}
338
339impl<B, F> KeyedState for AggregatingState<B, F>
340where
341 B: StateBackend,
342 F: Fn(Vec<u8>, Vec<u8>) -> Vec<u8> + Send + Sync,
343{
344 fn key(&self) -> &[u8] {
345 self.value_state.key()
346 }
347
348 async fn clear(&self) -> Result<()> {
349 self.value_state.clear().await
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::state::backend::MemoryStateBackend;
357
358 #[tokio::test]
359 async fn test_value_state() {
360 let backend = Arc::new(MemoryStateBackend::new());
361 let state = ValueState::new(backend, "test".to_string(), vec![1]);
362
363 state
364 .set(vec![42])
365 .await
366 .expect("Failed to set value in value state");
367 let value = state
368 .get()
369 .await
370 .expect("Failed to get value from value state");
371 assert_eq!(value, Some(vec![42]));
372
373 state.clear().await.expect("Failed to clear value state");
374 let value = state.get().await.expect("Failed to get value after clear");
375 assert_eq!(value, None);
376 }
377
378 #[tokio::test]
379 async fn test_list_state() {
380 let backend = Arc::new(MemoryStateBackend::new());
381 let state = ListState::new(backend, "test".to_string(), vec![1]);
382
383 state
384 .add(vec![1])
385 .await
386 .expect("Failed to add first item to list state");
387 state
388 .add(vec![2])
389 .await
390 .expect("Failed to add second item to list state");
391 state
392 .add(vec![3])
393 .await
394 .expect("Failed to add third item to list state");
395
396 let list = state
397 .get()
398 .await
399 .expect("Failed to get list from list state");
400 assert_eq!(list, vec![vec![1], vec![2], vec![3]]);
401 }
402
403 #[tokio::test]
404 async fn test_map_state() {
405 let backend = Arc::new(MemoryStateBackend::new());
406 let state = MapState::new(backend, "test".to_string(), vec![1]);
407
408 state
409 .put(b"key1", vec![1])
410 .await
411 .expect("Failed to put key1 in map state");
412 state
413 .put(b"key2", vec![2])
414 .await
415 .expect("Failed to put key2 in map state");
416
417 assert_eq!(
418 state
419 .get(b"key1")
420 .await
421 .expect("Failed to get key1 from map state"),
422 Some(vec![1])
423 );
424 assert_eq!(
425 state
426 .get(b"key2")
427 .await
428 .expect("Failed to get key2 from map state"),
429 Some(vec![2])
430 );
431
432 assert!(
433 state
434 .contains(b"key1")
435 .await
436 .expect("Failed to check if map contains key1")
437 );
438
439 state
440 .remove(b"key1")
441 .await
442 .expect("Failed to remove key1 from map state");
443 assert!(
444 !state
445 .contains(b"key1")
446 .await
447 .expect("Failed to check if map contains key1 after removal")
448 );
449 }
450
451 #[tokio::test]
452 async fn test_reducing_state() {
453 let backend = Arc::new(MemoryStateBackend::new());
454 let state = ReducingState::new(backend, "test".to_string(), vec![1], |a, b| {
455 let v1 = i64::from_le_bytes(a.try_into().unwrap_or([0; 8]));
456 let v2 = i64::from_le_bytes(b.try_into().unwrap_or([0; 8]));
457 (v1 + v2).to_le_bytes().to_vec()
458 });
459
460 state
461 .add(5i64.to_le_bytes().to_vec())
462 .await
463 .expect("Failed to add first value to reducing state");
464 state
465 .add(3i64.to_le_bytes().to_vec())
466 .await
467 .expect("Failed to add second value to reducing state");
468
469 let result = state
470 .get()
471 .await
472 .expect("Failed to get value from reducing state")
473 .expect("Expected Some value from reducing state");
474 let value = i64::from_le_bytes(result.try_into().unwrap_or([0; 8]));
475 assert_eq!(value, 8);
476 }
477}