Skip to main content

awaken_runtime/state/
persistence.rs

1use std::sync::Arc;
2
3use crate::state::KeyScope;
4use awaken_contract::{StateError, UnknownKeyPolicy};
5
6use super::{PersistedState, StateMap, StateStore};
7
8impl StateStore {
9    pub fn export_persisted(&self) -> Result<PersistedState, StateError> {
10        let registry = self.registry.lock();
11        let state = self.inner.read();
12        let mut extensions = std::collections::HashMap::new();
13
14        for reg in registry.keys_by_type.values() {
15            if !reg.options.persistent {
16                continue;
17            }
18
19            if let Some(json) = (reg.export)(state.ext.as_ref()).map_err(|err| match err {
20                StateError::KeyEncode { key, message } => StateError::KeyEncode { key, message },
21                other => StateError::KeyEncode {
22                    key: reg.key.clone(),
23                    message: other.to_string(),
24                },
25            })? {
26                extensions.insert(reg.key.clone(), json);
27            }
28        }
29
30        Ok(PersistedState {
31            revision: state.revision,
32            extensions,
33        })
34    }
35
36    pub fn restore_persisted(
37        &self,
38        persisted: PersistedState,
39        unknown_policy: UnknownKeyPolicy,
40    ) -> Result<(), StateError> {
41        let registry = self.registry.lock();
42        let mut next_ext = StateMap::default();
43
44        for (key, json) in persisted.extensions {
45            let Some(reg) = registry.keys_by_name.get(&key) else {
46                match unknown_policy {
47                    UnknownKeyPolicy::Error => return Err(StateError::UnknownKey { key }),
48                    UnknownKeyPolicy::Skip => continue,
49                }
50            };
51
52            (reg.import)(&mut next_ext, json).map_err(|err| match err {
53                StateError::KeyDecode { key, message } => StateError::KeyDecode { key, message },
54                other => StateError::KeyDecode {
55                    key: reg.key.clone(),
56                    message: other.to_string(),
57                },
58            })?;
59        }
60
61        let mut state = self.inner.write();
62        state.ext = Arc::new(next_ext);
63        state.revision = persisted.revision;
64        Ok(())
65    }
66
67    /// Restore only `Thread`-scoped keys from a persisted state snapshot.
68    ///
69    /// Run-scoped keys in `persisted` are ignored. Unknown keys follow `unknown_policy`.
70    pub fn restore_thread_scoped(
71        &self,
72        persisted: PersistedState,
73        unknown_policy: UnknownKeyPolicy,
74    ) -> Result<(), StateError> {
75        let registry = self.registry.lock();
76        let mut state = self.inner.write();
77        let ext = Arc::make_mut(&mut state.ext);
78
79        for (key, json) in persisted.extensions {
80            let Some(reg) = registry.keys_by_name.get(&key) else {
81                match unknown_policy {
82                    UnknownKeyPolicy::Error => return Err(StateError::UnknownKey { key }),
83                    UnknownKeyPolicy::Skip => continue,
84                }
85            };
86
87            if reg.scope != KeyScope::Thread {
88                continue;
89            }
90
91            (reg.import)(ext, json).map_err(|err| match err {
92                StateError::KeyDecode { key, message } => StateError::KeyDecode { key, message },
93                other => StateError::KeyDecode {
94                    key: reg.key.clone(),
95                    message: other.to_string(),
96                },
97            })?;
98        }
99
100        Ok(())
101    }
102
103    /// Clear all `Run`-scoped keys, preserving `Thread`-scoped keys.
104    pub fn clear_run_scoped(&self) {
105        let registry = self.registry.lock();
106        let mut state = self.inner.write();
107        let ext = Arc::make_mut(&mut state.ext);
108
109        for reg in registry.keys_by_type.values() {
110            if reg.scope == KeyScope::Run {
111                (reg.clear)(ext);
112            }
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::plugins::{Plugin, PluginDescriptor, PluginRegistrar};
121    use crate::state::{StateKey, StateKeyOptions};
122    use awaken_contract::UnknownKeyPolicy;
123
124    struct PersistentCounter;
125
126    impl StateKey for PersistentCounter {
127        const KEY: &'static str = "test.persist_counter";
128        type Value = i64;
129        type Update = i64;
130
131        fn apply(value: &mut Self::Value, update: Self::Update) {
132            *value += update;
133        }
134    }
135
136    struct TransientFlag;
137
138    impl StateKey for TransientFlag {
139        const KEY: &'static str = "test.transient_flag";
140        type Value = bool;
141        type Update = bool;
142
143        fn apply(value: &mut Self::Value, update: Self::Update) {
144            *value = update;
145        }
146    }
147
148    struct PersistenceTestPlugin;
149
150    impl Plugin for PersistenceTestPlugin {
151        fn descriptor(&self) -> PluginDescriptor {
152            PluginDescriptor {
153                name: "persistence-test-plugin",
154            }
155        }
156
157        fn register(
158            &self,
159            registrar: &mut PluginRegistrar,
160        ) -> Result<(), awaken_contract::StateError> {
161            registrar.register_key::<PersistentCounter>(StateKeyOptions {
162                persistent: true,
163                ..Default::default()
164            })?;
165            registrar.register_key::<TransientFlag>(StateKeyOptions {
166                persistent: false,
167                ..Default::default()
168            })?;
169            Ok(())
170        }
171    }
172
173    #[test]
174    fn export_import_roundtrip() {
175        let store = StateStore::new();
176        store.install_plugin(PersistenceTestPlugin).unwrap();
177
178        let mut batch = store.begin_mutation();
179        batch.update::<PersistentCounter>(42);
180        store.commit(batch).unwrap();
181
182        let exported = store.export_persisted().unwrap();
183
184        // Create a new store, install same plugin, restore
185        let store2 = StateStore::new();
186        store2.install_plugin(PersistenceTestPlugin).unwrap();
187        store2
188            .restore_persisted(exported, UnknownKeyPolicy::Error)
189            .unwrap();
190
191        let val = store2.read::<PersistentCounter>().unwrap();
192        assert_eq!(val, 42);
193    }
194
195    #[test]
196    fn export_skips_non_persistent_keys() {
197        let store = StateStore::new();
198        store.install_plugin(PersistenceTestPlugin).unwrap();
199
200        let mut batch = store.begin_mutation();
201        batch.update::<PersistentCounter>(10);
202        batch.update::<TransientFlag>(true);
203        store.commit(batch).unwrap();
204
205        let exported = store.export_persisted().unwrap();
206
207        // Only the persistent key should be in the export
208        assert!(
209            exported.extensions.contains_key(PersistentCounter::KEY),
210            "persistent key should be exported"
211        );
212        assert!(
213            !exported.extensions.contains_key(TransientFlag::KEY),
214            "non-persistent key should NOT be exported"
215        );
216    }
217
218    #[test]
219    fn import_unknown_key_with_skip_policy() {
220        let store = StateStore::new();
221        store.install_plugin(PersistenceTestPlugin).unwrap();
222
223        // Build a PersistedState with an unknown key
224        let mut extensions = std::collections::HashMap::new();
225        extensions.insert("unknown.key".to_string(), serde_json::json!("some_value"));
226        let persisted = PersistedState {
227            revision: 5,
228            extensions,
229        };
230
231        // Should succeed with Skip policy
232        let result = store.restore_persisted(persisted, UnknownKeyPolicy::Skip);
233        assert!(
234            result.is_ok(),
235            "skip policy should not error on unknown keys"
236        );
237    }
238}