1use super::{StateLoader, StateSaver};
2use crate::errors::LoadStateError;
3use async_trait::async_trait;
4use dashmap::DashMap;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7
8#[derive(Debug, Default)]
12pub struct LocalState {
13 data: DashMap<(String, String, String), String>,
14}
15
16impl LocalState {
17 pub fn new() -> LocalState {
18 LocalState {
19 data: DashMap::new(),
20 }
21 }
22}
23
24#[async_trait]
25impl<T: DeserializeOwned> StateLoader<T> for LocalState {
26 async fn load(
27 &self,
28 object_kind: &str,
29 object_id: &str,
30 state_type: &str,
31 ) -> Result<T, LoadStateError> {
32 let object_kind = object_kind.to_string();
33 let object_id = object_id.to_string();
34 let state_type = state_type.to_string();
35 let k = (object_kind, object_id, state_type);
36
37 if let Some(x) = self.data.get(&k) {
38 Ok(serde_json::from_str(&x).expect("TODO"))
39 } else {
40 Err(LoadStateError::ObjectNotFound)
41 }
42 }
43}
44
45#[async_trait]
46impl<T: Serialize + Send + Sync> StateSaver<T> for LocalState {
47 async fn save(
48 &self,
49 object_kind: &str,
50 object_id: &str,
51 state_type: &str,
52 data: &T,
53 ) -> Result<(), LoadStateError> {
54 let object_kind = object_kind.to_string();
55 let object_id = object_id.to_string();
56 let state_type = state_type.to_string();
57 let k = (object_kind, object_id, state_type);
58 self.data
59 .insert(k, serde_json::to_string(&data).expect("TODO"));
60 Ok(())
61 }
62}
63
64#[cfg(test)]
65mod test {
66 use rio_macros::TypeName;
67 use serde::Deserialize;
68
69 use super::*;
70
71 type TestResult = Result<(), Box<dyn std::error::Error>>;
72
73 #[derive(TypeName, Debug, Serialize, Deserialize, PartialEq)]
74 #[rio_path = "crate"]
75 struct TestState {
76 name: String,
77 }
78
79 #[tokio::test]
80 async fn sanity_check() -> TestResult {
81 let local_state = LocalState::new();
82 let state = TestState {
83 name: "Foo".to_string(),
84 };
85 local_state.save("a", "1", "TestState", &state).await?;
86 let new_state: TestState = local_state.load("a", "1", "TestState").await?;
87 assert_eq!(state, new_state);
88 Ok(())
89 }
90}