Skip to main content

rio_rs/state/
mod.rs

1#![doc = include_str!("README.md")]
2
3use std::ops::Deref;
4
5use crate::errors::LoadStateError;
6use crate::registry::IdentifiableType;
7use crate::{ServiceObject, WithId};
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11
12#[cfg(feature = "local")]
13pub mod local;
14
15#[cfg(feature = "redis")]
16pub mod redis;
17
18// #[cfg(feature = "sql")]
19// pub mod sql;
20
21#[cfg(feature = "sqlite")]
22pub mod sqlite;
23
24#[cfg(feature = "postgres")]
25pub mod postgres;
26
27/// Trait to define how to get and set states in and out of an object
28///
29/// One need to implement this trait for each state a object holds
30#[async_trait]
31pub trait State<T> {
32    fn get_state(&self) -> &T;
33    fn set_state(&mut self, value: T);
34
35    async fn load<S: StateLoader<T> + Sync + Send>(
36        &self,
37        state_loader: &S,
38        object_kind: &str,
39        object_id: &str,
40        state_type: &str,
41    ) -> Result<T, LoadStateError> {
42        state_loader.load(object_kind, object_id, state_type).await
43    }
44}
45
46/// The `StateLoader` defines an interface to load serialized state from a source
47///
48/// **important** This trait is not responsible for serializing it back to its
49/// original type
50///
51/// TODO use a reader type instead of String on load fn
52#[async_trait]
53pub trait StateLoader<T>: Send + Sync {
54    /// <div class="warning">
55    /// TODO
56    ///
57    /// This here can't be used right now as it depends on the type `T` when
58    /// called.
59    ///
60    /// This makes its usage quite cluncky (if not impossible) as you need to invoke it
61    /// with a target of this loader
62    /// </div>
63    async fn prepare(&self) {}
64
65    async fn load(
66        &self,
67        object_kind: &str,
68        object_id: &str,
69        state_type: &str,
70    ) -> Result<T, LoadStateError>;
71}
72
73/// Auto implement [StateLoader] for every type that derefs to a [StateLoader]
74///
75/// This way you can create a wrapper for a [StateLoader] and it will automatically
76/// get this implementation
77#[async_trait]
78impl<O, T, S> StateLoader<O> for T
79where
80    T: Deref<Target = S> + Send + Sync,
81    S: StateLoader<O>,
82    O: DeserializeOwned,
83{
84    async fn load(
85        &self,
86        object_kind: &str,
87        object_id: &str,
88        state_type: &str,
89    ) -> Result<O, LoadStateError> {
90        self.deref().load(object_kind, object_id, state_type).await
91    }
92}
93
94/// The `StateSave` defines an interface to save serialized data into a persistence
95/// backend (memory, sql server, etc)
96///
97/// **important** This trait is not responsible for serializing the state from
98/// its original type
99///
100/// TODO it sucks this needs to be generic over T (the type it is persisting)
101///      because then we are forced to do like `StateSaver::<TestState1>::prepare`
102#[async_trait]
103pub trait StateSaver<T>: Sync + Send {
104    async fn prepare(&self) {}
105
106    async fn save(
107        &self,
108        object_kind: &str,
109        object_id: &str,
110        state_type: &str,
111        data: &T,
112    ) -> Result<(), LoadStateError>;
113}
114
115/// Auto implement [StateSaver] for every type that derefs to a StateSaver
116///
117/// This way you can create a wrapper for a [StateSaver] and it will automatically
118/// get this implementation
119#[async_trait]
120impl<O, T, S> StateSaver<O> for T
121where
122    T: Deref<Target = S> + Send + Sync,
123    S: StateSaver<O>,
124    O: Serialize + Send + Sync,
125{
126    async fn save(
127        &self,
128        object_kind: &str,
129        object_id: &str,
130        state_type: &str,
131        data: &O,
132    ) -> Result<(), LoadStateError> {
133        self.deref()
134            .save(object_kind, object_id, state_type, data)
135            .await
136    }
137}
138
139/// Reponsible for managing states for a specific object
140///
141/// With this trait one can load/save individual states from an orig (Self) object
142#[async_trait]
143pub trait ObjectStateManager {
144    /// Load the state from the backend, deserialize it, and map it into the
145    /// right state
146    async fn load_state<T, S>(&mut self, state_loader: &S) -> Result<(), LoadStateError>
147    where
148        T: IdentifiableType + Serialize + DeserializeOwned + Default, // neends default cause of trait State
149        S: StateLoader<T> + Send + Sync,
150        Self: State<T> + IdentifiableType + WithId + Send + Sync,
151    {
152        let object_kind = Self::user_defined_type_id();
153        let object_id = self.id();
154        let state_type = T::user_defined_type_id();
155        let data: T = self
156            .load(state_loader, object_kind, object_id, state_type)
157            .await
158            .or(Err(LoadStateError::ObjectNotFound))?;
159
160        self.set_state(data);
161        Ok(())
162    }
163
164    /// Serialize the data out of the state and save it to the backend provided
165    async fn save_state<T, S>(&self, state_saver: &S) -> Result<(), LoadStateError>
166    where
167        T: IdentifiableType + Serialize + DeserializeOwned + Sync + Default, // Needs default cause of trait State
168        S: StateSaver<T>,
169        Self: State<T> + IdentifiableType + WithId + Send + Sync,
170    {
171        let object_kind = Self::user_defined_type_id();
172        let object_id = self.id();
173
174        let state_type = T::user_defined_type_id();
175        let state_value: &T = self.get_state();
176        state_saver
177            .save(object_kind, object_id, state_type, state_value)
178            .await?;
179        Ok(())
180    }
181}
182
183// If an struct implements ServiceObject, it gets ObjectStateManager out of the box
184impl<T> ObjectStateManager for T where T: ServiceObject {}
185
186#[cfg(test)]
187mod test {
188    use super::*;
189    use rio_macros::{ManagedState, TypeName, WithId};
190    use serde::Deserialize;
191
192    type TestResult = Result<(), Box<dyn std::error::Error>>;
193
194    #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
195    struct PersonState {
196        name: String,
197        age: u8,
198    }
199
200    impl IdentifiableType for Option<PersonState> {
201        fn instance_type_id(&self) -> &'static str {
202            "OptionPersonState"
203        }
204    }
205
206    #[derive(Default, Debug, Serialize, Deserialize, PartialEq, TypeName)]
207    #[rio_path = "crate"]
208    struct LegalPersonState {
209        legal_name: String,
210        id_document: String,
211    }
212
213    #[tokio::test]
214    async fn sanity_check() -> TestResult {
215        let local_state = local::LocalState::new();
216        let state = PersonState {
217            name: "Foo".to_string(),
218            age: 21,
219        };
220        local_state.save("a", "1", "PersonState", &state).await?;
221        let new_state: PersonState = local_state.load("a", "1", "PersonState").await?;
222        assert_eq!(state, new_state);
223        Ok(())
224    }
225
226    #[tokio::test]
227    async fn model_call() -> TestResult {
228        #[derive(Debug, Default, WithId, TypeName, ManagedState)]
229        #[rio_path = "crate"]
230        struct Person {
231            id: String,
232            #[managed_state]
233            person_state: Option<PersonState>,
234            #[managed_state]
235            legal_state: LegalPersonState,
236        }
237        impl ObjectStateManager for Person {}
238
239        impl Person {
240            async fn load_all_states(
241                &mut self,
242                state_loader: &local::LocalState,
243            ) -> Result<(), LoadStateError> {
244                self.load_state::<Option<PersonState>, _>(state_loader)
245                    .await?;
246                self.load_state::<LegalPersonState, _>(state_loader).await?;
247                Ok(())
248            }
249
250            async fn save_all_states(
251                &mut self,
252                state_saver: &local::LocalState,
253            ) -> Result<(), LoadStateError> {
254                // WUT?
255                self.save_state::<Option<PersonState>, _>(state_saver)
256                    .await?;
257                self.save_state::<LegalPersonState, _>(state_saver).await?;
258                Ok(())
259            }
260        }
261
262        let local_state = local::LocalState::new();
263
264        {
265            let mut person = Person::default();
266            person.person_state = Some(PersonState {
267                name: "Foo".to_string(),
268                age: 22,
269            });
270            person.legal_state = LegalPersonState {
271                legal_name: "Foo Bla".to_string(),
272                id_document: "123.123.123-12".to_string(),
273            };
274            person.save_all_states(&local_state).await?;
275        }
276        {
277            let mut person = Person::default();
278            person.load_all_states(&local_state).await?;
279            assert!(person.person_state.is_some());
280            assert_eq!(&person.legal_state.legal_name, "Foo Bla");
281        }
282        Ok(())
283    }
284}