awaken_runtime/state/
persistence.rs1use std::sync::Arc;
2
3use crate::state::KeyScope;
4use awaken_contract::{StateError, UnknownKeyPolicy};
5
6use super::{PersistedState, StateMap, StateStore};
7
8impl StateStore {
9 pub fn export_persisted(&self) -> Result<PersistedState, StateError> {
10 let registry = self.registry.lock();
11 let state = self.inner.read();
12 let mut extensions = std::collections::HashMap::new();
13
14 for reg in registry.keys_by_type.values() {
15 if !reg.options.persistent {
16 continue;
17 }
18
19 if let Some(json) = (reg.export)(state.ext.as_ref()).map_err(|err| match err {
20 StateError::KeyEncode { key, message } => StateError::KeyEncode { key, message },
21 other => StateError::KeyEncode {
22 key: reg.key.clone(),
23 message: other.to_string(),
24 },
25 })? {
26 extensions.insert(reg.key.clone(), json);
27 }
28 }
29
30 Ok(PersistedState {
31 revision: state.revision,
32 extensions,
33 })
34 }
35
36 pub fn restore_persisted(
37 &self,
38 persisted: PersistedState,
39 unknown_policy: UnknownKeyPolicy,
40 ) -> Result<(), StateError> {
41 let registry = self.registry.lock();
42 let mut next_ext = StateMap::default();
43
44 for (key, json) in persisted.extensions {
45 let Some(reg) = registry.keys_by_name.get(&key) else {
46 match unknown_policy {
47 UnknownKeyPolicy::Error => return Err(StateError::UnknownKey { key }),
48 UnknownKeyPolicy::Skip => continue,
49 }
50 };
51
52 (reg.import)(&mut next_ext, json).map_err(|err| match err {
53 StateError::KeyDecode { key, message } => StateError::KeyDecode { key, message },
54 other => StateError::KeyDecode {
55 key: reg.key.clone(),
56 message: other.to_string(),
57 },
58 })?;
59 }
60
61 let mut state = self.inner.write();
62 state.ext = Arc::new(next_ext);
63 state.revision = persisted.revision;
64 Ok(())
65 }
66
67 pub fn restore_thread_scoped(
71 &self,
72 persisted: PersistedState,
73 unknown_policy: UnknownKeyPolicy,
74 ) -> Result<(), StateError> {
75 let registry = self.registry.lock();
76 let mut state = self.inner.write();
77 let ext = Arc::make_mut(&mut state.ext);
78
79 for (key, json) in persisted.extensions {
80 let Some(reg) = registry.keys_by_name.get(&key) else {
81 match unknown_policy {
82 UnknownKeyPolicy::Error => return Err(StateError::UnknownKey { key }),
83 UnknownKeyPolicy::Skip => continue,
84 }
85 };
86
87 if reg.scope != KeyScope::Thread {
88 continue;
89 }
90
91 (reg.import)(ext, json).map_err(|err| match err {
92 StateError::KeyDecode { key, message } => StateError::KeyDecode { key, message },
93 other => StateError::KeyDecode {
94 key: reg.key.clone(),
95 message: other.to_string(),
96 },
97 })?;
98 }
99
100 Ok(())
101 }
102
103 pub fn clear_run_scoped(&self) {
105 let registry = self.registry.lock();
106 let mut state = self.inner.write();
107 let ext = Arc::make_mut(&mut state.ext);
108
109 for reg in registry.keys_by_type.values() {
110 if reg.scope == KeyScope::Run {
111 (reg.clear)(ext);
112 }
113 }
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
121 use crate::state::{StateKey, StateKeyOptions};
122 use awaken_contract::UnknownKeyPolicy;
123
124 struct PersistentCounter;
125
126 impl StateKey for PersistentCounter {
127 const KEY: &'static str = "test.persist_counter";
128 type Value = i64;
129 type Update = i64;
130
131 fn apply(value: &mut Self::Value, update: Self::Update) {
132 *value += update;
133 }
134 }
135
136 struct TransientFlag;
137
138 impl StateKey for TransientFlag {
139 const KEY: &'static str = "test.transient_flag";
140 type Value = bool;
141 type Update = bool;
142
143 fn apply(value: &mut Self::Value, update: Self::Update) {
144 *value = update;
145 }
146 }
147
148 struct PersistenceTestPlugin;
149
150 impl Plugin for PersistenceTestPlugin {
151 fn descriptor(&self) -> PluginDescriptor {
152 PluginDescriptor {
153 name: "persistence-test-plugin",
154 }
155 }
156
157 fn register(
158 &self,
159 registrar: &mut PluginRegistrar,
160 ) -> Result<(), awaken_contract::StateError> {
161 registrar.register_key::<PersistentCounter>(StateKeyOptions {
162 persistent: true,
163 ..Default::default()
164 })?;
165 registrar.register_key::<TransientFlag>(StateKeyOptions {
166 persistent: false,
167 ..Default::default()
168 })?;
169 Ok(())
170 }
171 }
172
173 #[test]
174 fn export_import_roundtrip() {
175 let store = StateStore::new();
176 store.install_plugin(PersistenceTestPlugin).unwrap();
177
178 let mut batch = store.begin_mutation();
179 batch.update::<PersistentCounter>(42);
180 store.commit(batch).unwrap();
181
182 let exported = store.export_persisted().unwrap();
183
184 let store2 = StateStore::new();
186 store2.install_plugin(PersistenceTestPlugin).unwrap();
187 store2
188 .restore_persisted(exported, UnknownKeyPolicy::Error)
189 .unwrap();
190
191 let val = store2.read::<PersistentCounter>().unwrap();
192 assert_eq!(val, 42);
193 }
194
195 #[test]
196 fn export_skips_non_persistent_keys() {
197 let store = StateStore::new();
198 store.install_plugin(PersistenceTestPlugin).unwrap();
199
200 let mut batch = store.begin_mutation();
201 batch.update::<PersistentCounter>(10);
202 batch.update::<TransientFlag>(true);
203 store.commit(batch).unwrap();
204
205 let exported = store.export_persisted().unwrap();
206
207 assert!(
209 exported.extensions.contains_key(PersistentCounter::KEY),
210 "persistent key should be exported"
211 );
212 assert!(
213 !exported.extensions.contains_key(TransientFlag::KEY),
214 "non-persistent key should NOT be exported"
215 );
216 }
217
218 #[test]
219 fn import_unknown_key_with_skip_policy() {
220 let store = StateStore::new();
221 store.install_plugin(PersistenceTestPlugin).unwrap();
222
223 let mut extensions = std::collections::HashMap::new();
225 extensions.insert("unknown.key".to_string(), serde_json::json!("some_value"));
226 let persisted = PersistedState {
227 revision: 5,
228 extensions,
229 };
230
231 let result = store.restore_persisted(persisted, UnknownKeyPolicy::Skip);
233 assert!(
234 result.is_ok(),
235 "skip policy should not error on unknown keys"
236 );
237 }
238}