Skip to main content

codetether_agent/cognition/
persistence.rs

1//! Persistence — atomic save/load of cognition state with evidence preservation.
2
3use std::collections::{HashMap, VecDeque};
4use std::path::PathBuf;
5use std::sync::Arc;
6
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use tokio::sync::RwLock;
10
11use super::beliefs::Belief;
12use super::{
13    AttentionItem, GlobalWorkspace, MemorySnapshot, PersonaRuntimeState, Proposal, ThoughtEvent,
14};
15
16/// Schema-versioned persisted cognition state.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PersistedCognitionState {
19    pub schema_version: u32,
20    pub persisted_at: DateTime<Utc>,
21    pub personas: HashMap<String, PersonaRuntimeState>,
22    pub proposals: HashMap<String, Proposal>,
23    pub beliefs: HashMap<String, Belief>,
24    pub attention_queue: Vec<AttentionItem>,
25    pub workspace: GlobalWorkspace,
26    /// Events referenced by beliefs/proposals + last N tail events.
27    pub evidence_events: Vec<ThoughtEvent>,
28    pub recent_snapshots: Vec<MemorySnapshot>,
29}
30
31const SCHEMA_VERSION: u32 = 1;
32const TAIL_EVENTS: usize = 200;
33
34/// Get the persistence file path.
35fn state_path() -> PathBuf {
36    let base = directories::ProjectDirs::from("com", "codetether", "codetether")
37        .map(|d| d.data_local_dir().to_path_buf())
38        .unwrap_or_else(|| PathBuf::from("/tmp/codetether"));
39    base.join("cognition").join("state.json")
40}
41
42/// Collect event IDs referenced by beliefs and proposals.
43fn referenced_event_ids(
44    beliefs: &HashMap<String, Belief>,
45    proposals: &HashMap<String, Proposal>,
46) -> std::collections::HashSet<String> {
47    let mut ids = std::collections::HashSet::new();
48    for belief in beliefs.values() {
49        for ref_id in &belief.evidence_refs {
50            ids.insert(ref_id.clone());
51        }
52    }
53    for proposal in proposals.values() {
54        for ref_id in &proposal.evidence_refs {
55            ids.insert(ref_id.clone());
56        }
57    }
58    ids
59}
60
61/// Trim event payloads for storage (remove very large fields).
62fn trim_event_for_storage(event: &ThoughtEvent) -> ThoughtEvent {
63    let mut trimmed = event.clone();
64    // Trim thinking field if present
65    if let Some(thinking) = trimmed.payload.get("thinking").and_then(|v| v.as_str()) {
66        if thinking.len() > 500 {
67            let short = &thinking[..500];
68            trimmed.payload["thinking"] = serde_json::Value::String(format!("{}...", short));
69        }
70    }
71    trimmed
72}
73
74/// Save cognition state atomically (tmp + rename).
75pub async fn save_state(
76    personas: &Arc<RwLock<HashMap<String, PersonaRuntimeState>>>,
77    proposals: &Arc<RwLock<HashMap<String, Proposal>>>,
78    beliefs: &Arc<RwLock<HashMap<String, Belief>>>,
79    attention_queue: &Arc<RwLock<Vec<AttentionItem>>>,
80    workspace: &Arc<RwLock<GlobalWorkspace>>,
81    events: &Arc<RwLock<VecDeque<ThoughtEvent>>>,
82    snapshots: &Arc<RwLock<VecDeque<MemorySnapshot>>>,
83) -> Result<(), String> {
84    let personas_snap = personas.read().await.clone();
85    let proposals_snap = proposals.read().await.clone();
86    let beliefs_snap = beliefs.read().await.clone();
87    let attention_snap = attention_queue.read().await.clone();
88    let workspace_snap = workspace.read().await.clone();
89    let events_snap = events.read().await.clone();
90    let snapshots_snap = snapshots.read().await.clone();
91
92    // Collect referenced events + tail
93    let ref_ids = referenced_event_ids(&beliefs_snap, &proposals_snap);
94    let mut evidence_events: Vec<ThoughtEvent> = Vec::new();
95    let mut seen = std::collections::HashSet::new();
96
97    // Referenced events first
98    for event in events_snap.iter() {
99        if ref_ids.contains(&event.id) && seen.insert(event.id.clone()) {
100            evidence_events.push(trim_event_for_storage(event));
101        }
102    }
103
104    // Then tail events
105    let tail_start = events_snap.len().saturating_sub(TAIL_EVENTS);
106    for event in events_snap.iter().skip(tail_start) {
107        if seen.insert(event.id.clone()) {
108            evidence_events.push(trim_event_for_storage(event));
109        }
110    }
111
112    let state = PersistedCognitionState {
113        schema_version: SCHEMA_VERSION,
114        persisted_at: Utc::now(),
115        personas: personas_snap,
116        proposals: proposals_snap,
117        beliefs: beliefs_snap,
118        attention_queue: attention_snap,
119        workspace: workspace_snap,
120        evidence_events,
121        recent_snapshots: snapshots_snap.into_iter().collect(),
122    };
123
124    let path = state_path();
125    let dir = path.parent().unwrap();
126
127    // Create directory if needed
128    if let Err(e) = tokio::fs::create_dir_all(dir).await {
129        return Err(format!("Failed to create persistence directory: {}", e));
130    }
131
132    // Serialize
133    let json = match serde_json::to_string_pretty(&state) {
134        Ok(j) => j,
135        Err(e) => return Err(format!("Failed to serialize state: {}", e)),
136    };
137
138    // Atomic write: tmp + rename
139    let tmp_path = path.with_extension("json.tmp");
140    if let Err(e) = tokio::fs::write(&tmp_path, &json).await {
141        return Err(format!("Failed to write temp file: {}", e));
142    }
143    if let Err(e) = tokio::fs::rename(&tmp_path, &path).await {
144        return Err(format!("Failed to rename temp file: {}", e));
145    }
146
147    tracing::debug!(path = %path.display(), "Cognition state persisted");
148    Ok(())
149}
150
151/// Load persisted cognition state, if available.
152pub fn load_state() -> Option<PersistedCognitionState> {
153    let path = state_path();
154    if !path.exists() {
155        return None;
156    }
157
158    let data = match std::fs::read_to_string(&path) {
159        Ok(d) => d,
160        Err(e) => {
161            tracing::warn!(error = %e, "Failed to read persisted cognition state");
162            return None;
163        }
164    };
165
166    match serde_json::from_str::<PersistedCognitionState>(&data) {
167        Ok(state) => {
168            if state.schema_version != SCHEMA_VERSION {
169                tracing::warn!(
170                    persisted_version = state.schema_version,
171                    current_version = SCHEMA_VERSION,
172                    "Schema version mismatch, starting fresh"
173                );
174                return None;
175            }
176            tracing::info!(
177                persisted_at = %state.persisted_at,
178                personas = state.personas.len(),
179                beliefs = state.beliefs.len(),
180                "Loaded persisted cognition state"
181            );
182            Some(state)
183        }
184        Err(e) => {
185            tracing::warn!(error = %e, "Corrupt persisted cognition state, starting fresh");
186            None
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn round_trip_serialize() {
197        let state = PersistedCognitionState {
198            schema_version: SCHEMA_VERSION,
199            persisted_at: Utc::now(),
200            personas: HashMap::new(),
201            proposals: HashMap::new(),
202            beliefs: HashMap::new(),
203            attention_queue: Vec::new(),
204            workspace: GlobalWorkspace::default(),
205            evidence_events: Vec::new(),
206            recent_snapshots: Vec::new(),
207        };
208
209        let json = serde_json::to_string(&state).expect("should serialize");
210        let loaded: PersistedCognitionState =
211            serde_json::from_str(&json).expect("should deserialize");
212        assert_eq!(loaded.schema_version, SCHEMA_VERSION);
213    }
214
215    #[test]
216    fn referenced_events_collected() {
217        use super::super::beliefs::{Belief, BeliefStatus};
218        use chrono::Duration;
219
220        let mut beliefs = HashMap::new();
221        beliefs.insert(
222            "b1".to_string(),
223            Belief {
224                id: "b1".to_string(),
225                belief_key: "test".to_string(),
226                claim: "test".to_string(),
227                confidence: 0.8,
228                evidence_refs: vec!["evt-1".to_string(), "evt-2".to_string()],
229                asserted_by: "p1".to_string(),
230                confirmed_by: Vec::new(),
231                contested_by: Vec::new(),
232                contradicts: Vec::new(),
233                created_at: Utc::now(),
234                updated_at: Utc::now(),
235                review_after: Utc::now() + Duration::hours(1),
236                status: BeliefStatus::Active,
237            },
238        );
239
240        let proposals = HashMap::new();
241        let ids = referenced_event_ids(&beliefs, &proposals);
242        assert!(ids.contains("evt-1"));
243        assert!(ids.contains("evt-2"));
244        assert_eq!(ids.len(), 2);
245    }
246}