Skip to main content

sage_runtime/persistence/
mod.rs

1//! Persistence support for @persistent agent beliefs.
2//!
3//! This module provides the runtime support for persistent agent state:
4//! - `CheckpointStore` trait for storage backends
5//! - `Persisted<T>` wrapper for auto-checkpointing fields
6//! - `AgentCheckpoint` for managing agent-level persistence
7//!
8//! # Backends
9//!
10//! The following backends are available via feature flags:
11//! - `persistence-sqlite`: SQLite database (recommended for local development)
12//! - `persistence-postgres`: PostgreSQL (recommended for production)
13//! - `persistence-file`: JSON files (useful for debugging)
14//!
15//! Without any persistence feature, only `MemoryCheckpointStore` is available.
16
17// Sync adapters for async persistence backends
18#[cfg(any(
19    feature = "persistence-sqlite",
20    feature = "persistence-postgres",
21    feature = "persistence-file"
22))]
23mod backends;
24
25#[cfg(feature = "persistence-sqlite")]
26pub use backends::SyncSqliteStore;
27#[cfg(feature = "persistence-postgres")]
28pub use backends::SyncPostgresStore;
29#[cfg(feature = "persistence-file")]
30pub use backends::SyncFileStore;
31
32use serde::{de::DeserializeOwned, Serialize};
33use std::collections::HashMap;
34use std::sync::{Arc, RwLock};
35
36/// A checkpoint store for persisting agent state.
37///
38/// This is a re-export of the trait from sage-persistence, simplified
39/// for use in generated code.
40pub trait CheckpointStore: Send + Sync {
41    /// Save a field value synchronously (blocks on async).
42    fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value);
43
44    /// Load a field value synchronously.
45    fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value>;
46
47    /// Load all fields for an agent.
48    fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value>;
49
50    /// Save all fields atomically.
51    fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>);
52
53    /// Check if any checkpoint exists for an agent.
54    fn exists_sync(&self, agent_key: &str) -> bool;
55}
56
57/// In-memory checkpoint store for testing.
58#[derive(Default)]
59pub struct MemoryCheckpointStore {
60    data: RwLock<HashMap<String, HashMap<String, serde_json::Value>>>,
61}
62
63impl MemoryCheckpointStore {
64    pub fn new() -> Self {
65        Self::default()
66    }
67}
68
69impl CheckpointStore for MemoryCheckpointStore {
70    fn save_sync(&self, agent_key: &str, field: &str, value: serde_json::Value) {
71        let mut data = self.data.write().unwrap();
72        data.entry(agent_key.to_string())
73            .or_default()
74            .insert(field.to_string(), value);
75    }
76
77    fn load_sync(&self, agent_key: &str, field: &str) -> Option<serde_json::Value> {
78        self.data
79            .read()
80            .unwrap()
81            .get(agent_key)
82            .and_then(|fields| fields.get(field).cloned())
83    }
84
85    fn load_all_sync(&self, agent_key: &str) -> HashMap<String, serde_json::Value> {
86        self.data
87            .read()
88            .unwrap()
89            .get(agent_key)
90            .cloned()
91            .unwrap_or_default()
92    }
93
94    fn save_all_sync(&self, agent_key: &str, fields: &HashMap<String, serde_json::Value>) {
95        let mut data = self.data.write().unwrap();
96        data.insert(agent_key.to_string(), fields.clone());
97    }
98
99    fn exists_sync(&self, agent_key: &str) -> bool {
100        self.data.read().unwrap().contains_key(agent_key)
101    }
102}
103
104/// A wrapper for @persistent fields that auto-checkpoints on modification.
105///
106/// This provides interior mutability and automatic persistence when the
107/// value is modified via `set()`.
108pub struct Persisted<T> {
109    value: RwLock<T>,
110    store: Arc<dyn CheckpointStore>,
111    agent_key: String,
112    field_name: String,
113}
114
115impl<T: Clone + Serialize + DeserializeOwned + Default + Send> Persisted<T> {
116    /// Create a new persisted field, loading from checkpoint if available.
117    pub fn new(
118        store: Arc<dyn CheckpointStore>,
119        agent_key: impl Into<String>,
120        field_name: impl Into<String>,
121    ) -> Self {
122        let agent_key = agent_key.into();
123        let field_name = field_name.into();
124
125        // Try to load from checkpoint
126        let value = store
127            .load_sync(&agent_key, &field_name)
128            .and_then(|v| serde_json::from_value(v).ok())
129            .unwrap_or_default();
130
131        Self {
132            value: RwLock::new(value),
133            store,
134            agent_key,
135            field_name,
136        }
137    }
138
139    /// Create with an explicit initial value (used when no checkpoint exists).
140    pub fn with_initial(
141        store: Arc<dyn CheckpointStore>,
142        agent_key: impl Into<String>,
143        field_name: impl Into<String>,
144        initial: T,
145    ) -> Self {
146        let agent_key = agent_key.into();
147        let field_name = field_name.into();
148
149        // Try to load from checkpoint, fall back to initial
150        let value = store
151            .load_sync(&agent_key, &field_name)
152            .and_then(|v| serde_json::from_value(v).ok())
153            .unwrap_or(initial);
154
155        Self {
156            value: RwLock::new(value),
157            store,
158            agent_key,
159            field_name,
160        }
161    }
162
163    /// Get the current value.
164    pub fn get(&self) -> T {
165        self.value.read().unwrap().clone()
166    }
167
168    /// Set the value and checkpoint it.
169    pub fn set(&self, new_value: T) {
170        *self.value.write().unwrap() = new_value.clone();
171        if let Ok(json) = serde_json::to_value(&new_value) {
172            self.store.save_sync(&self.agent_key, &self.field_name, json);
173        }
174    }
175
176    /// Checkpoint the current value without modifying it.
177    pub fn checkpoint(&self) {
178        let value = self.value.read().unwrap().clone();
179        if let Ok(json) = serde_json::to_value(&value) {
180            self.store.save_sync(&self.agent_key, &self.field_name, json);
181        }
182    }
183}
184
185/// Helper to generate a unique checkpoint key for an agent instance.
186pub fn agent_checkpoint_key(agent_name: &str, beliefs: &serde_json::Value) -> String {
187    use std::collections::hash_map::DefaultHasher;
188    use std::hash::{Hash, Hasher};
189
190    let mut hasher = DefaultHasher::new();
191    agent_name.hash(&mut hasher);
192    beliefs.to_string().hash(&mut hasher);
193    format!("{}_{:016x}", agent_name, hasher.finish())
194}
195
196/// Helper to save all @persistent fields atomically before yield.
197pub fn checkpoint_all<S: CheckpointStore + ?Sized>(
198    store: &S,
199    agent_key: &str,
200    fields: Vec<(&str, serde_json::Value)>,
201) {
202    let map: HashMap<String, serde_json::Value> = fields
203        .into_iter()
204        .map(|(k, v)| (k.to_string(), v))
205        .collect();
206    store.save_all_sync(agent_key, &map);
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    fn make_store() -> Arc<dyn CheckpointStore> {
214        Arc::new(MemoryCheckpointStore::new())
215    }
216
217    #[test]
218    fn memory_store_save_load() {
219        let store = MemoryCheckpointStore::new();
220        store.save_sync("agent1", "count", serde_json::json!(42));
221
222        let loaded = store.load_sync("agent1", "count");
223        assert_eq!(loaded, Some(serde_json::json!(42)));
224    }
225
226    #[test]
227    fn persisted_field_loads_from_checkpoint() {
228        let store = make_store();
229        store.save_sync("agent1", "count", serde_json::json!(100));
230
231        let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
232        assert_eq!(field.get(), 100);
233    }
234
235    #[test]
236    fn persisted_field_defaults_when_no_checkpoint() {
237        let store = make_store();
238        let field: Persisted<i64> = Persisted::new(store, "agent1", "count");
239        assert_eq!(field.get(), 0); // Default for i64
240    }
241
242    #[test]
243    fn persisted_field_auto_checkpoints_on_set() {
244        let store = make_store();
245        let field: Persisted<i64> = Persisted::new(Arc::clone(&store), "agent1", "count");
246
247        field.set(42);
248
249        // Verify it was persisted
250        let loaded = store.load_sync("agent1", "count");
251        assert_eq!(loaded, Some(serde_json::json!(42)));
252    }
253
254    #[test]
255    fn checkpoint_key_varies_with_beliefs() {
256        let key1 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 1}));
257        let key2 = agent_checkpoint_key("Agent", &serde_json::json!({"x": 2}));
258        assert_ne!(key1, key2);
259    }
260}