sage_runtime/persistence/
mod.rs1#[cfg(any(
19 feature = "persistence-sqlite",
20 feature = "persistence-postgres",
21 feature = "persistence-file"
22))]
23mod backends;
24
25#[cfg(feature = "persistence-sqlite")]
26pub use backends::SyncSqliteStore;
27#[cfg(feature = "persistence-postgres")]
28pub use backends::SyncPostgresStore;
29#[cfg(feature = "persistence-file")]
30pub use backends::SyncFileStore;
31
32use serde::{de::DeserializeOwned, Serialize};
33use std::collections::HashMap;
34use std::sync::{Arc, RwLock};
35
36pub trait CheckpointStore: Send + Sync {
41 fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value);
43
44 fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value>;
46
47 fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value>;
49
50 fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>);
52
53 fn exists_sync(&self, agent_key: &str) -> bool;
55}
56
57#[derive(Default)]
59pub struct MemoryCheckpointStore {
60 data: RwLock<HashMap<String, HashMap<String, serde_json::Value>>>,
61}
62
63impl MemoryCheckpointStore {
64 pub fn new() -> Self {
65 Self::default()
66 }
67}
68
69impl CheckpointStore for MemoryCheckpointStore {
70 fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value) {
71 let mut data = self.data.write().unwrap();
72 data.entry(agent_key.to_string())
73 .or_default()
74 .insert(field.to_string(), value);
75 }
76
77 fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value> {
78 self.data
79 .read()
80 .unwrap()
81 .get(agent_key)
82 .and_then(|fields| fields.get(field).cloned())
83 }
84
85 fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value> {
86 self.data
87 .read()
88 .unwrap()
89 .get(agent_key)
90 .cloned()
91 .unwrap_or_default()
92 }
93
94 fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>) {
95 let mut data = self.data.write().unwrap();
96 data.insert(agent_key.to_string(), fields.clone());
97 }
98
99 fn exists_sync(&self, agent_key: &str) -> bool {
100 self.data.read().unwrap().contains_key(agent_key)
101 }
102}
103
104pub struct Persisted<T> {
109 value: RwLock<T>,
110 store: Arc<dyn CheckpointStore>,
111 agent_key: String,
112 field_name: String,
113}
114
115impl<T: Clone + Serialize + DeserializeOwned + Default + Send> Persisted<T> {
116 pub fn new(
118 store: Arc<dyn CheckpointStore>,
119 agent_key: impl Into<String>,
120 field_name: impl Into<String>,
121 ) -> Self {
122 let agent_key = agent_key.into();
123 let field_name = field_name.into();
124
125 let value = store
127 .load_sync(&agent_key, &field_name)
128 .and_then(|v| serde_json::from_value(v).ok())
129 .unwrap_or_default();
130
131 Self {
132 value: RwLock::new(value),
133 store,
134 agent_key,
135 field_name,
136 }
137 }
138
139 pub fn with_initial(
141 store: Arc<dyn CheckpointStore>,
142 agent_key: impl Into<String>,
143 field_name: impl Into<String>,
144 initial: T,
145 ) -> Self {
146 let agent_key = agent_key.into();
147 let field_name = field_name.into();
148
149 let value = store
151 .load_sync(&agent_key, &field_name)
152 .and_then(|v| serde_json::from_value(v).ok())
153 .unwrap_or(initial);
154
155 Self {
156 value: RwLock::new(value),
157 store,
158 agent_key,
159 field_name,
160 }
161 }
162
163 pub fn get(&self) -> T {
165 self.value.read().unwrap().clone()
166 }
167
168 pub fn set(&self, new_value: T) {
170 *self.value.write().unwrap() = new_value.clone();
171 if let Ok(json) = serde_json::to_value(&new_value) {
172 self.store.save_sync(&self.agent_key, &self.field_name, json);
173 }
174 }
175
176 pub fn checkpoint(&self) {
178 let value = self.value.read().unwrap().clone();
179 if let Ok(json) = serde_json::to_value(&value) {
180 self.store.save_sync(&self.agent_key, &self.field_name, json);
181 }
182 }
183}
184
185pub fn agent_checkpoint_key(agent_name: &str, beliefs: &serde_json::Value) -> String {
187 use std::collections::hash_map::DefaultHasher;
188 use std::hash::{Hash, Hasher};
189
190 let mut hasher = DefaultHasher::new();
191 agent_name.hash(&mut hasher);
192 beliefs.to_string().hash(&mut hasher);
193 format!("{}_{:016x}", agent_name, hasher.finish())
194}
195
196pub fn checkpoint_all<S: CheckpointStore + ?Sized>(
198 store: &S,
199 agent_key: &str,
200 fields: Vec<(&str, serde_json::Value)>,
201) {
202 let map: HashMap<String, serde_json::Value> = fields
203 .into_iter()
204 .map(|(k, v)| (k.to_string(), v))
205 .collect();
206 store.save_all_sync(agent_key, &map);
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 fn make_store() -> Arc<dyn CheckpointStore> {
214 Arc::new(MemoryCheckpointStore::new())
215 }
216
217 #[test]
218 fn memory_store_save_load() {
219 let store = MemoryCheckpointStore::new();
220 store.save_sync("agent1", "count", serde_json::json!(42));
221
222 let loaded = store.load_sync("agent1", "count");
223 assert_eq!(loaded, Some(serde_json::json!(42)));
224 }
225
226 #[test]
227 fn persisted_field_loads_from_checkpoint() {
228 let store = make_store();
229 store.save_sync("agent1", "count", serde_json::json!(100));
230
231 let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
232 assert_eq!(field.get(), 100);
233 }
234
235 #[test]
236 fn persisted_field_defaults_when_no_checkpoint() {
237 let store = make_store();
238 let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
239 assert_eq!(field.get(), 0); }
241
242 #[test]
243 fn persisted_field_auto_checkpoints_on_set() {
244 let store = make_store();
245 let field: Persisted<i64> = Persisted::new(Arc::clone(&store), "agent1", "count");
246
247 field.set(42);
248
249 let loaded = store.load_sync("agent1", "count");
251 assert_eq!(loaded, Some(serde_json::json!(42)));
252 }
253
254 #[test]
255 fn checkpoint_key_varies_with_beliefs() {
256 let key1 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 1}));
257 let key2 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 2}));
258 assert_ne!(key1, key2);
259 }
260}