1#![doc = include_str!("README.md")]
2
3use std::ops::Deref;
4
5use crate::errors::LoadStateError;
6use crate::registry::IdentifiableType;
7use crate::{ServiceObject, WithId};
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11
12#[cfg(feature = "local")]
13pub mod local;
14
15#[cfg(feature = "redis")]
16pub mod redis;
17
18#[cfg(feature = "sqlite")]
22pub mod sqlite;
23
24#[cfg(feature = "postgres")]
25pub mod postgres;
26
27#[async_trait]
31pub trait State<T> {
32 fn get_state(&self) -> &T;
33 fn set_state(&mut self, value: T);
34
35 async fn load<S: StateLoader<T> + Sync + Send>(
36 &self,
37 state_loader: &S,
38 object_kind: &str,
39 object_id: &str,
40 state_type: &str,
41 ) -> Result<T, LoadStateError> {
42 state_loader.load(object_kind, object_id, state_type).await
43 }
44}
45
46#[async_trait]
53pub trait StateLoader<T>: Send + Sync {
54 async fn prepare(&self) {}
64
65 async fn load(
66 &self,
67 object_kind: &str,
68 object_id: &str,
69 state_type: &str,
70 ) -> Result<T, LoadStateError>;
71}
72
73#[async_trait]
78impl<O, T, S> StateLoader<O> for T
79where
80 T: Deref<Target = S> + Send + Sync,
81 S: StateLoader<O>,
82 O: DeserializeOwned,
83{
84 async fn load(
85 &self,
86 object_kind: &str,
87 object_id: &str,
88 state_type: &str,
89 ) -> Result<O, LoadStateError> {
90 self.deref().load(object_kind, object_id, state_type).await
91 }
92}
93
94#[async_trait]
103pub trait StateSaver<T>: Sync + Send {
104 async fn prepare(&self) {}
105
106 async fn save(
107 &self,
108 object_kind: &str,
109 object_id: &str,
110 state_type: &str,
111 data: &T,
112 ) -> Result<(), LoadStateError>;
113}
114
115#[async_trait]
120impl<O, T, S> StateSaver<O> for T
121where
122 T: Deref<Target = S> + Send + Sync,
123 S: StateSaver<O>,
124 O: Serialize + Send + Sync,
125{
126 async fn save(
127 &self,
128 object_kind: &str,
129 object_id: &str,
130 state_type: &str,
131 data: &O,
132 ) -> Result<(), LoadStateError> {
133 self.deref()
134 .save(object_kind, object_id, state_type, data)
135 .await
136 }
137}
138
139#[async_trait]
143pub trait ObjectStateManager {
144 async fn load_state<T, S>(&mut self, state_loader: &S) -> Result<(), LoadStateError>
147 where
148 T: IdentifiableType + Serialize + DeserializeOwned + Default, S: StateLoader<T> + Send + Sync,
150 Self: State<T> + IdentifiableType + WithId + Send + Sync,
151 {
152 let object_kind = Self::user_defined_type_id();
153 let object_id = self.id();
154 let state_type = T::user_defined_type_id();
155 let data: T = self
156 .load(state_loader, object_kind, object_id, state_type)
157 .await
158 .or(Err(LoadStateError::ObjectNotFound))?;
159
160 self.set_state(data);
161 Ok(())
162 }
163
164 async fn save_state<T, S>(&self, state_saver: &S) -> Result<(), LoadStateError>
166 where
167 T: IdentifiableType + Serialize + DeserializeOwned + Sync + Default, S: StateSaver<T>,
169 Self: State<T> + IdentifiableType + WithId + Send + Sync,
170 {
171 let object_kind = Self::user_defined_type_id();
172 let object_id = self.id();
173
174 let state_type = T::user_defined_type_id();
175 let state_value: &T = self.get_state();
176 state_saver
177 .save(object_kind, object_id, state_type, state_value)
178 .await?;
179 Ok(())
180 }
181}
182
183impl<T> ObjectStateManager for T where T: ServiceObject {}
185
186#[cfg(test)]
187mod test {
188 use super::*;
189 use rio_macros::{ManagedState, TypeName, WithId};
190 use serde::Deserialize;
191
192 type TestResult = Result<(), Box<dyn std::error::Error>>;
193
194 #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
195 struct PersonState {
196 name: String,
197 age: u8,
198 }
199
200 impl IdentifiableType for Option<PersonState> {
201 fn instance_type_id(&self) -> &'static str {
202 "OptionPersonState"
203 }
204 }
205
206 #[derive(Default, Debug, Serialize, Deserialize, PartialEq, TypeName)]
207 #[rio_path = "crate"]
208 struct LegalPersonState {
209 legal_name: String,
210 id_document: String,
211 }
212
213 #[tokio::test]
214 async fn sanity_check() -> TestResult {
215 let local_state = local::LocalState::new();
216 let state = PersonState {
217 name: "Foo".to_string(),
218 age: 21,
219 };
220 local_state.save("a", "1", "PersonState", &state).await?;
221 let new_state: PersonState = local_state.load("a", "1", "PersonState").await?;
222 assert_eq!(state, new_state);
223 Ok(())
224 }
225
226 #[tokio::test]
227 async fn model_call() -> TestResult {
228 #[derive(Debug, Default, WithId, TypeName, ManagedState)]
229 #[rio_path = "crate"]
230 struct Person {
231 id: String,
232 #[managed_state]
233 person_state: Option<PersonState>,
234 #[managed_state]
235 legal_state: LegalPersonState,
236 }
237 impl ObjectStateManager for Person {}
238
239 impl Person {
240 async fn load_all_states(
241 &mut self,
242 state_loader: &local::LocalState,
243 ) -> Result<(), LoadStateError> {
244 self.load_state::<Option<PersonState>, _>(state_loader)
245 .await?;
246 self.load_state::<LegalPersonState, _>(state_loader).await?;
247 Ok(())
248 }
249
250 async fn save_all_states(
251 &mut self,
252 state_saver: &local::LocalState,
253 ) -> Result<(), LoadStateError> {
254 self.save_state::<Option<PersonState>, _>(state_saver)
256 .await?;
257 self.save_state::<LegalPersonState, _>(state_saver).await?;
258 Ok(())
259 }
260 }
261
262 let local_state = local::LocalState::new();
263
264 {
265 let mut person = Person::default();
266 person.person_state = Some(PersonState {
267 name: "Foo".to_string(),
268 age: 22,
269 });
270 person.legal_state = LegalPersonState {
271 legal_name: "Foo Bla".to_string(),
272 id_document: "123.123.123-12".to_string(),
273 };
274 person.save_all_states(&local_state).await?;
275 }
276 {
277 let mut person = Person::default();
278 person.load_all_states(&local_state).await?;
279 assert!(person.person_state.is_some());
280 assert_eq!(&person.legal_state.legal_name, "Foo Bla");
281 }
282 Ok(())
283 }
284}