Skip to main content

rio_rs/state/
local.rs

1use super::{StateLoader, StateSaver};
2use crate::errors::LoadStateError;
3use async_trait::async_trait;
4use dashmap::DashMap;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8/// `LocalState` is a state provider for testing purposes
9///
10/// It stores all the serialized states into a single `DashMap`
11#[derive(Debug, Default)]
12pub struct LocalState {
13    data: DashMap<(String, String, String), String>,
14}
15
16impl LocalState {
17    pub fn new() -> LocalState {
18        LocalState {
19            data: DashMap::new(),
20        }
21    }
22}
23
24#[async_trait]
25impl<T: DeserializeOwned> StateLoader<T> for LocalState {
26    async fn load(
27        &self,
28        object_kind: &str,
29        object_id: &str,
30        state_type: &str,
31    ) -> Result<T, LoadStateError> {
32        let object_kind = object_kind.to_string();
33        let object_id = object_id.to_string();
34        let state_type = state_type.to_string();
35        let k = (object_kind, object_id, state_type);
36
37        if let Some(x) = self.data.get(&k) {
38            Ok(serde_json::from_str(&x).expect("TODO"))
39        } else {
40            Err(LoadStateError::ObjectNotFound)
41        }
42    }
43}
44
45#[async_trait]
46impl<T: Serialize + Send + Sync> StateSaver<T> for LocalState {
47    async fn save(
48        &self,
49        object_kind: &str,
50        object_id: &str,
51        state_type: &str,
52        data: &T,
53    ) -> Result<(), LoadStateError> {
54        let object_kind = object_kind.to_string();
55        let object_id = object_id.to_string();
56        let state_type = state_type.to_string();
57        let k = (object_kind, object_id, state_type);
58        self.data
59            .insert(k, serde_json::to_string(&data).expect("TODO"));
60        Ok(())
61    }
62}
63
64#[cfg(test)]
65mod test {
66    use rio_macros::TypeName;
67    use serde::Deserialize;
68
69    use super::*;
70
71    type TestResult = Result<(), Box<dyn std::error::Error>>;
72
73    #[derive(TypeName, Debug, Serialize, Deserialize, PartialEq)]
74    #[rio_path = "crate"]
75    struct TestState {
76        name: String,
77    }
78
79    #[tokio::test]
80    async fn sanity_check() -> TestResult {
81        let local_state = LocalState::new();
82        let state = TestState {
83            name: "Foo".to_string(),
84        };
85        local_state.save("a", "1", "TestState", &state).await?;
86        let new_state: TestState = local_state.load("a", "1", "TestState").await?;
87        assert_eq!(state, new_state);
88        Ok(())
89    }
90}