Skip to main content

codetether_agent/cognition/
persistence.rs

1//! Persistence — atomic save/load of cognition state with evidence preservation.
2
3use std::collections::{HashMap, VecDeque};
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use tokio::sync::RwLock;
10
11use super::beliefs::Belief;
12use super::{
13    AttentionItem, GlobalWorkspace, MemorySnapshot, PersonaRuntimeState, Proposal, ThoughtEvent,
14};
15
16/// Schema-versioned persisted cognition state.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PersistedCognitionState {
19    pub schema_version: u32,
20    pub persisted_at: DateTime<Utc>,
21    pub personas: HashMap<String, PersonaRuntimeState>,
22    pub proposals: HashMap<String, Proposal>,
23    pub beliefs: HashMap<String, Belief>,
24    pub attention_queue: Vec<AttentionItem>,
25    pub workspace: GlobalWorkspace,
26    /// Events referenced by beliefs/proposals + last N tail events.
27    pub evidence_events: Vec<ThoughtEvent>,
28    pub recent_snapshots: Vec<MemorySnapshot>,
29}
30
31const SCHEMA_VERSION: u32 = 1;
32const TAIL_EVENTS: usize = 200;
33const MAX_RECENT_SNAPSHOTS: usize = 50;
34const MAX_STATE_FILE_BYTES: u64 = 100 * 1024 * 1024; // 100 MB
35
36/// Get the persistence file path.
37fn state_path() -> PathBuf {
38    let base = directories::ProjectDirs::from("com", "codetether", "codetether")
39        .map(|d| d.data_local_dir().to_path_buf())
40        .unwrap_or_else(|| PathBuf::from("/tmp/codetether"));
41    base.join("cognition").join("state.json")
42}
43
44/// Collect event IDs referenced by beliefs and proposals.
45fn referenced_event_ids(
46    beliefs: &HashMap<String, Belief>,
47    proposals: &HashMap<String, Proposal>,
48) -> std::collections::HashSet<String> {
49    let mut ids = std::collections::HashSet::new();
50    for belief in beliefs.values() {
51        for ref_id in &belief.evidence_refs {
52            ids.insert(ref_id.clone());
53        }
54    }
55    for proposal in proposals.values() {
56        for ref_id in &proposal.evidence_refs {
57            ids.insert(ref_id.clone());
58        }
59    }
60    ids
61}
62
63/// Trim event payloads for storage (remove very large fields).
64fn trim_event_for_storage(event: &ThoughtEvent) -> ThoughtEvent {
65    let mut trimmed = event.clone();
66    // Trim thinking field if present
67    if let Some(thinking) = trimmed.payload.get("thinking").and_then(|v| v.as_str()) {
68        if thinking.len() > 500 {
69            let short = &thinking[..500];
70            trimmed.payload["thinking"] = serde_json::Value::String(format!("{}...", short));
71        }
72    }
73    trimmed
74}
75
76/// Save cognition state atomically (tmp + rename).
77pub async fn save_state(
78    personas: &Arc<RwLock<HashMap<String, PersonaRuntimeState>>>,
79    proposals: &Arc<RwLock<HashMap<String, Proposal>>>,
80    beliefs: &Arc<RwLock<HashMap<String, Belief>>>,
81    attention_queue: &Arc<RwLock<Vec<AttentionItem>>>,
82    workspace: &Arc<RwLock<GlobalWorkspace>>,
83    events: &Arc<RwLock<VecDeque<ThoughtEvent>>>,
84    snapshots: &Arc<RwLock<VecDeque<MemorySnapshot>>>,
85) -> Result<(), String> {
86    let personas_snap = personas.read().await.clone();
87    let proposals_snap = proposals.read().await.clone();
88    let beliefs_snap = beliefs.read().await.clone();
89    let attention_snap = attention_queue.read().await.clone();
90    let workspace_snap = workspace.read().await.clone();
91    let events_snap = events.read().await.clone();
92    let snapshots_snap = snapshots.read().await.clone();
93
94    // Collect referenced events + tail
95    let ref_ids = referenced_event_ids(&beliefs_snap, &proposals_snap);
96    let mut evidence_events: Vec<ThoughtEvent> = Vec::new();
97    let mut seen = std::collections::HashSet::new();
98
99    // Referenced events first
100    for event in events_snap.iter() {
101        if ref_ids.contains(&event.id) && seen.insert(event.id.clone()) {
102            evidence_events.push(trim_event_for_storage(event));
103        }
104    }
105
106    // Then tail events
107    let tail_start = events_snap.len().saturating_sub(TAIL_EVENTS);
108    for event in events_snap.iter().skip(tail_start) {
109        if seen.insert(event.id.clone()) {
110            evidence_events.push(trim_event_for_storage(event));
111        }
112    }
113
114    let state = PersistedCognitionState {
115        schema_version: SCHEMA_VERSION,
116        persisted_at: Utc::now(),
117        personas: personas_snap,
118        proposals: proposals_snap,
119        beliefs: beliefs_snap,
120        attention_queue: attention_snap,
121        workspace: workspace_snap,
122        evidence_events,
123        recent_snapshots: snapshots_snap
124            .into_iter()
125            .rev()
126            .take(MAX_RECENT_SNAPSHOTS)
127            .collect(),
128    };
129
130    let path = state_path();
131    let dir = path.parent().unwrap();
132
133    // Create directory if needed
134    if let Err(e) = tokio::fs::create_dir_all(dir).await {
135        return Err(format!("Failed to create persistence directory: {}", e));
136    }
137
138    // Serialize
139    let json = match serde_json::to_string_pretty(&state) {
140        Ok(j) => j,
141        Err(e) => return Err(format!("Failed to serialize state: {}", e)),
142    };
143
144    // Atomic write: tmp + rename
145    let tmp_path = path.with_extension("json.tmp");
146    if let Err(e) = tokio::fs::write(&tmp_path, &json).await {
147        return Err(format!("Failed to write temp file: {}", e));
148    }
149    if let Err(e) = tokio::fs::rename(&tmp_path, &path).await {
150        return Err(format!("Failed to rename temp file: {}", e));
151    }
152
153    tracing::debug!(path = %path.display(), "Cognition state persisted");
154    Ok(())
155}
156
157/// Load persisted cognition state, if available.
158pub fn load_state() -> Option<PersistedCognitionState> {
159    let path = state_path();
160    if !path.exists() {
161        return None;
162    }
163
164    // Guard against bloated state files that would OOM the process
165    match std::fs::metadata(&path) {
166        Ok(meta) if meta.len() > MAX_STATE_FILE_BYTES => {
167            tracing::warn!(
168                size_mb = meta.len() / (1024 * 1024),
169                max_mb = MAX_STATE_FILE_BYTES / (1024 * 1024),
170                "Persisted cognition state too large, starting fresh"
171            );
172            let _ = std::fs::rename(&path, path.with_extension("json.bloated"));
173            return None;
174        }
175        Err(e) => {
176            tracing::warn!(error = %e, "Failed to stat persisted cognition state");
177            return None;
178        }
179        _ => {}
180    }
181
182    let data = match std::fs::read_to_string(&path) {
183        Ok(d) => d,
184        Err(e) => {
185            tracing::warn!(error = %e, "Failed to read persisted cognition state");
186            return None;
187        }
188    };
189
190    match serde_json::from_str::<PersistedCognitionState>(&data) {
191        Ok(state) => {
192            if state.schema_version != SCHEMA_VERSION {
193                tracing::warn!(
194                    persisted_version = state.schema_version,
195                    current_version = SCHEMA_VERSION,
196                    "Schema version mismatch, starting fresh"
197                );
198                return None;
199            }
200            tracing::info!(
201                persisted_at = %state.persisted_at,
202                personas = state.personas.len(),
203                beliefs = state.beliefs.len(),
204                "Loaded persisted cognition state"
205            );
206            Some(state)
207        }
208        Err(e) => {
209            tracing::warn!(error = %e, "Corrupt persisted cognition state, starting fresh");
210            None
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn round_trip_serialize() {
221        let state = PersistedCognitionState {
222            schema_version: SCHEMA_VERSION,
223            persisted_at: Utc::now(),
224            personas: HashMap::new(),
225            proposals: HashMap::new(),
226            beliefs: HashMap::new(),
227            attention_queue: Vec::new(),
228            workspace: GlobalWorkspace::default(),
229            evidence_events: Vec::new(),
230            recent_snapshots: Vec::new(),
231        };
232
233        let json = serde_json::to_string(&state).expect("should serialize");
234        let loaded: PersistedCognitionState =
235            serde_json::from_str(&json).expect("should deserialize");
236        assert_eq!(loaded.schema_version, SCHEMA_VERSION);
237    }
238
239    #[test]
240    fn referenced_events_collected() {
241        use super::super::beliefs::{Belief, BeliefStatus};
242        use chrono::Duration;
243
244        let mut beliefs = HashMap::new();
245        beliefs.insert(
246            "b1".to_string(),
247            Belief {
248                id: "b1".to_string(),
249                belief_key: "test".to_string(),
250                claim: "test".to_string(),
251                confidence: 0.8,
252                evidence_refs: vec!["evt-1".to_string(), "evt-2".to_string()],
253                asserted_by: "p1".to_string(),
254                confirmed_by: Vec::new(),
255                contested_by: Vec::new(),
256                contradicts: Vec::new(),
257                created_at: Utc::now(),
258                updated_at: Utc::now(),
259                review_after: Utc::now() + Duration::hours(1),
260                status: BeliefStatus::Active,
261            },
262        );
263
264        let proposals = HashMap::new();
265        let ids = referenced_event_ids(&beliefs, &proposals);
266        assert!(ids.contains("evt-1"));
267        assert!(ids.contains("evt-2"));
268        assert_eq!(ids.len(), 2);
269    }
270}