Skip to main content

codetether_agent/cognition/
persistence.rs

1//! Persistence — atomic save/load of cognition state with evidence preservation.
2
3use std::collections::{HashMap, VecDeque};
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use tokio::sync::RwLock;
10
11use super::beliefs::Belief;
12use super::{
13    AttentionItem, GlobalWorkspace, MemorySnapshot, PersonaRuntimeState, Proposal, ThoughtEvent,
14};
15
16/// Schema-versioned persisted cognition state.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PersistedCognitionState {
19    pub schema_version: u32,
20    pub persisted_at: DateTime<Utc>,
21    pub personas: HashMap<String, PersonaRuntimeState>,
22    pub proposals: HashMap<String, Proposal>,
23    pub beliefs: HashMap<String, Belief>,
24    pub attention_queue: Vec<AttentionItem>,
25    pub workspace: GlobalWorkspace,
26    /// Events referenced by beliefs/proposals + last N tail events.
27    pub evidence_events: Vec<ThoughtEvent>,
28    pub recent_snapshots: Vec<MemorySnapshot>,
29}
30
31const SCHEMA_VERSION: u32 = 1;
32const TAIL_EVENTS: usize = 200;
33const MAX_RECENT_SNAPSHOTS: usize = 50;
34const MAX_STATE_FILE_BYTES: u64 = 100 * 1024 * 1024; // 100 MB
35
36/// Get the persistence file path.
37fn state_path() -> PathBuf {
38    let base =
39        crate::config::Config::data_dir().unwrap_or_else(|| PathBuf::from("/tmp/codetether-agent"));
40    base.join("cognition").join("state.json")
41}
42
43/// Collect event IDs referenced by beliefs and proposals.
44fn referenced_event_ids(
45    beliefs: &HashMap<String, Belief>,
46    proposals: &HashMap<String, Proposal>,
47) -> std::collections::HashSet<String> {
48    let mut ids = std::collections::HashSet::new();
49    for belief in beliefs.values() {
50        for ref_id in &belief.evidence_refs {
51            ids.insert(ref_id.clone());
52        }
53    }
54    for proposal in proposals.values() {
55        for ref_id in &proposal.evidence_refs {
56            ids.insert(ref_id.clone());
57        }
58    }
59    ids
60}
61
62/// Trim event payloads for storage (remove very large fields).
63fn trim_event_for_storage(event: &ThoughtEvent) -> ThoughtEvent {
64    let mut trimmed = event.clone();
65    // Trim thinking field if present
66    if let Some(thinking) = trimmed.payload.get("thinking").and_then(|v| v.as_str())
67        && thinking.len() > 500
68    {
69        let short = &thinking[..500];
70        trimmed.payload["thinking"] = serde_json::Value::String(format!("{}...", short));
71    }
72    trimmed
73}
74
75/// Save cognition state atomically (tmp + rename).
76pub async fn save_state(
77    personas: &Arc<RwLock<HashMap<String, PersonaRuntimeState>>>,
78    proposals: &Arc<RwLock<HashMap<String, Proposal>>>,
79    beliefs: &Arc<RwLock<HashMap<String, Belief>>>,
80    attention_queue: &Arc<RwLock<Vec<AttentionItem>>>,
81    workspace: &Arc<RwLock<GlobalWorkspace>>,
82    events: &Arc<RwLock<VecDeque<ThoughtEvent>>>,
83    snapshots: &Arc<RwLock<VecDeque<MemorySnapshot>>>,
84) -> Result<(), String> {
85    let personas_snap = personas.read().await.clone();
86    let proposals_snap = proposals.read().await.clone();
87    let beliefs_snap = beliefs.read().await.clone();
88    let attention_snap = attention_queue.read().await.clone();
89    let workspace_snap = workspace.read().await.clone();
90    let events_snap = events.read().await.clone();
91    let snapshots_snap = snapshots.read().await.clone();
92
93    // Collect referenced events + tail
94    let ref_ids = referenced_event_ids(&beliefs_snap, &proposals_snap);
95    let mut evidence_events: Vec<ThoughtEvent> = Vec::new();
96    let mut seen = std::collections::HashSet::new();
97
98    // Referenced events first
99    for event in events_snap.iter() {
100        if ref_ids.contains(&event.id) && seen.insert(event.id.clone()) {
101            evidence_events.push(trim_event_for_storage(event));
102        }
103    }
104
105    // Then tail events
106    let tail_start = events_snap.len().saturating_sub(TAIL_EVENTS);
107    for event in events_snap.iter().skip(tail_start) {
108        if seen.insert(event.id.clone()) {
109            evidence_events.push(trim_event_for_storage(event));
110        }
111    }
112
113    let state = PersistedCognitionState {
114        schema_version: SCHEMA_VERSION,
115        persisted_at: Utc::now(),
116        personas: personas_snap,
117        proposals: proposals_snap,
118        beliefs: beliefs_snap,
119        attention_queue: attention_snap,
120        workspace: workspace_snap,
121        evidence_events,
122        recent_snapshots: snapshots_snap
123            .into_iter()
124            .rev()
125            .take(MAX_RECENT_SNAPSHOTS)
126            .collect(),
127    };
128
129    let path = state_path();
130    let dir = path.parent().unwrap();
131
132    // Create directory if needed
133    if let Err(e) = tokio::fs::create_dir_all(dir).await {
134        return Err(format!("Failed to create persistence directory: {}", e));
135    }
136
137    // Serialize
138    let json = match serde_json::to_string_pretty(&state) {
139        Ok(j) => j,
140        Err(e) => return Err(format!("Failed to serialize state: {}", e)),
141    };
142
143    // Atomic write: tmp + rename
144    let tmp_path = path.with_extension("json.tmp");
145    if let Err(e) = tokio::fs::write(&tmp_path, &json).await {
146        return Err(format!("Failed to write temp file: {}", e));
147    }
148    if let Err(e) = tokio::fs::rename(&tmp_path, &path).await {
149        return Err(format!("Failed to rename temp file: {}", e));
150    }
151
152    tracing::debug!(path = %path.display(), "Cognition state persisted");
153    Ok(())
154}
155
156/// Load persisted cognition state, if available.
157pub fn load_state() -> Option<PersistedCognitionState> {
158    let path = state_path();
159    if !path.exists() {
160        return None;
161    }
162
163    // Guard against bloated state files that would OOM the process
164    match std::fs::metadata(&path) {
165        Ok(meta) if meta.len() > MAX_STATE_FILE_BYTES => {
166            tracing::warn!(
167                size_mb = meta.len() / (1024 * 1024),
168                max_mb = MAX_STATE_FILE_BYTES / (1024 * 1024),
169                "Persisted cognition state too large, starting fresh"
170            );
171            let _ = std::fs::rename(&path, path.with_extension("json.bloated"));
172            return None;
173        }
174        Err(e) => {
175            tracing::warn!(error = %e, "Failed to stat persisted cognition state");
176            return None;
177        }
178        _ => {}
179    }
180
181    let data = match std::fs::read_to_string(&path) {
182        Ok(d) => d,
183        Err(e) => {
184            tracing::warn!(error = %e, "Failed to read persisted cognition state");
185            return None;
186        }
187    };
188
189    match serde_json::from_str::<PersistedCognitionState>(&data) {
190        Ok(state) => {
191            if state.schema_version != SCHEMA_VERSION {
192                tracing::warn!(
193                    persisted_version = state.schema_version,
194                    current_version = SCHEMA_VERSION,
195                    "Schema version mismatch, starting fresh"
196                );
197                return None;
198            }
199            tracing::info!(
200                persisted_at = %state.persisted_at,
201                personas = state.personas.len(),
202                beliefs = state.beliefs.len(),
203                "Loaded persisted cognition state"
204            );
205            Some(state)
206        }
207        Err(e) => {
208            tracing::warn!(error = %e, "Corrupt persisted cognition state, starting fresh");
209            None
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn round_trip_serialize() {
220        let state = PersistedCognitionState {
221            schema_version: SCHEMA_VERSION,
222            persisted_at: Utc::now(),
223            personas: HashMap::new(),
224            proposals: HashMap::new(),
225            beliefs: HashMap::new(),
226            attention_queue: Vec::new(),
227            workspace: GlobalWorkspace::default(),
228            evidence_events: Vec::new(),
229            recent_snapshots: Vec::new(),
230        };
231
232        let json = serde_json::to_string(&state).expect("should serialize");
233        let loaded: PersistedCognitionState =
234            serde_json::from_str(&json).expect("should deserialize");
235        assert_eq!(loaded.schema_version, SCHEMA_VERSION);
236    }
237
238    #[test]
239    fn referenced_events_collected() {
240        use super::super::beliefs::{Belief, BeliefStatus};
241        use chrono::Duration;
242
243        let mut beliefs = HashMap::new();
244        beliefs.insert(
245            "b1".to_string(),
246            Belief {
247                id: "b1".to_string(),
248                belief_key: "test".to_string(),
249                claim: "test".to_string(),
250                confidence: 0.8,
251                evidence_refs: vec!["evt-1".to_string(), "evt-2".to_string()],
252                asserted_by: "p1".to_string(),
253                confirmed_by: Vec::new(),
254                contested_by: Vec::new(),
255                contradicts: Vec::new(),
256                created_at: Utc::now(),
257                updated_at: Utc::now(),
258                review_after: Utc::now() + Duration::hours(1),
259                status: BeliefStatus::Active,
260            },
261        );
262
263        let proposals = HashMap::new();
264        let ids = referenced_event_ids(&beliefs, &proposals);
265        assert!(ids.contains("evt-1"));
266        assert!(ids.contains("evt-2"));
267        assert_eq!(ids.len(), 2);
268    }
269}