Skip to main content

forge_reasoning/
checkpoint.rs

1//! Temporal Checkpointing - Core types and manager
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::cell::{Cell, RefCell};
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::rc::Rc;
9use uuid::Uuid;
10
11/// Compute SHA-256 checksum of checkpoint data
12pub fn compute_checksum(data: &[u8]) -> String {
13    use sha2::{Sha256, Digest};
14    let mut hasher = Sha256::new();
15    hasher.update(data);
16    format!("{:x}", hasher.finalize())
17}
18
19use crate::errors::Result;
20use crate::storage::CheckpointStorage;
21
22/// Unique identifier for a checkpoint
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
24pub struct CheckpointId(pub Uuid);
25
26impl CheckpointId {
27    pub fn new() -> Self {
28        Self(Uuid::new_v4())
29    }
30}
31
32impl Default for CheckpointId {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl std::fmt::Display for CheckpointId {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}", self.0)
41    }
42}
43
44/// Unique identifier for a debugging session
45#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
46pub struct SessionId(pub Uuid);
47
48impl SessionId {
49    pub fn new() -> Self {
50        Self(Uuid::new_v4())
51    }
52}
53
54impl Default for SessionId {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl std::fmt::Display for SessionId {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "{}", self.0)
63    }
64}
65
66/// A snapshot of complete debugging state at a point in time
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct TemporalCheckpoint {
69    pub id: CheckpointId,
70    pub timestamp: DateTime<Utc>,
71    pub sequence_number: u64,
72    pub message: String,
73    pub tags: Vec<String>,
74    pub state: DebugStateSnapshot,
75    pub trigger: CheckpointTrigger,
76    pub session_id: SessionId,
77    /// SHA-256 checksum for data integrity verification
78    pub checksum: String,
79}
80
81impl TemporalCheckpoint {
82    pub fn new(
83        sequence: u64,
84        message: impl Into<String>,
85        state: DebugStateSnapshot,
86        trigger: CheckpointTrigger,
87        session_id: SessionId,
88    ) -> Self {
89        let id = CheckpointId::new();
90        let timestamp = Utc::now();
91        let message = message.into();
92        
93        // Create checkpoint without checksum first
94        let mut checkpoint = Self {
95            id,
96            timestamp,
97            sequence_number: sequence,
98            message: message.clone(),
99            tags: Vec::new(),
100            state: state.clone(),
101            trigger: trigger.clone(),
102            session_id,
103            checksum: String::new(), // Temporary, will compute
104        };
105        
106        // Compute checksum from serialized data (excluding checksum itself)
107        checkpoint.checksum = checkpoint.compute_checksum();
108        
109        checkpoint
110    }
111    
112    /// Compute checksum of this checkpoint's data
113    fn compute_checksum(&self) -> String {
114        // Create a copy without checksum for serialization
115        let data_for_hash = CheckpointDataForHash {
116            id: self.id,
117            timestamp: self.timestamp,
118            sequence_number: self.sequence_number,
119            message: &self.message,
120            tags: &self.tags,
121            state: &self.state,
122            trigger: &self.trigger,
123            session_id: self.session_id,
124        };
125        
126        let json = serde_json::to_vec(&data_for_hash)
127            .unwrap_or_default();
128        compute_checksum(&json)
129    }
130    
131    /// Validate the checkpoint's checksum
132    pub fn validate(&self) -> crate::errors::Result<()> {
133        let expected = self.compute_checksum();
134        if self.checksum != expected {
135            return Err(crate::errors::ReasoningError::ValidationFailed(
136                format!("Checksum mismatch: expected {}, got {}", expected, self.checksum)
137            ));
138        }
139        Ok(())
140    }
141}
142
143/// Helper struct for computing checksum (excludes checksum field)
144#[derive(Serialize)]
145struct CheckpointDataForHash<'a> {
146    id: CheckpointId,
147    timestamp: DateTime<Utc>,
148    sequence_number: u64,
149    message: &'a str,
150    tags: &'a [String],
151    state: &'a DebugStateSnapshot,
152    trigger: &'a CheckpointTrigger,
153    session_id: SessionId,
154}
155
156/// Complete snapshot of debugging state
157#[derive(Clone, Debug, Default, Serialize, Deserialize)]
158pub struct DebugStateSnapshot {
159    pub session_id: SessionId,
160    pub started_at: DateTime<Utc>,
161    pub checkpoint_timestamp: DateTime<Utc>,
162    pub working_dir: Option<PathBuf>,
163    pub env_vars: HashMap<String, String>,
164    pub metrics: SessionMetrics,
165    /// Hypothesis state snapshot (optional for backward compatibility)
166    pub hypothesis_state: Option<crate::hypothesis::types::HypothesisState>,
167}
168
169/// What triggered this checkpoint
170#[derive(Clone, Debug, Serialize, Deserialize)]
171pub enum CheckpointTrigger {
172    Manual,
173    Automatic(AutoTrigger),
174    Scheduled,
175}
176
177impl std::fmt::Display for CheckpointTrigger {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        match self {
180            Self::Manual => write!(f, "manual"),
181            Self::Automatic(_) => write!(f, "auto"),
182            Self::Scheduled => write!(f, "scheduled"),
183        }
184    }
185}
186
187/// Types of automatic triggers
188#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
189pub enum AutoTrigger {
190    HypothesisStatusChange,
191    NewContradictionDetected,
192    VerificationComplete,
193    BranchSwitch,
194    GapFilled,
195    CodeModified,
196    SignificantTimePassed,
197    ContextCompactionWarning,
198}
199
200/// Policy for checkpoint compaction
201#[derive(Clone, Debug)]
202pub enum CompactionPolicy {
203    /// Keep N most recent checkpoints
204    KeepRecent(usize),
205    /// Keep all checkpoints with specific tags
206    PreserveTagged(Vec<String>),
207    /// Keep recent + preserve tagged
208    Hybrid { keep_recent: usize, preserve_tags: Vec<String> },
209}
210
211impl Default for CompactionPolicy {
212    fn default() -> Self {
213        CompactionPolicy::KeepRecent(100)
214    }
215}
216
217/// A user note/observation during debugging
218#[derive(Clone, Debug, Serialize, Deserialize)]
219pub struct DebugNote {
220    pub timestamp: DateTime<Utc>,
221    pub content: String,
222    pub tags: Vec<String>,
223}
224
225/// Verification result snapshot
226#[derive(Clone, Debug, Serialize, Deserialize)]
227pub struct VerificationResult {
228    pub name: String,
229    pub timestamp: DateTime<Utc>,
230    pub passed: bool,
231    pub output: Option<String>,
232    pub duration_ms: u64,
233}
234
235/// Session performance metrics
236#[derive(Clone, Debug, Default, Serialize, Deserialize)]
237pub struct SessionMetrics {
238    pub checkpoints_created: u64,
239    pub hypotheses_tested: u64,
240    pub verifications_run: u64,
241    pub gaps_filled: u64,
242}
243
244/// Summary of a checkpoint (for listing)
245#[derive(Clone, Debug, Serialize, Deserialize)]
246pub struct CheckpointSummary {
247    pub id: CheckpointId,
248    pub timestamp: DateTime<Utc>,
249    pub sequence_number: u64,
250    pub message: String,
251    pub trigger: String,
252    pub tags: Vec<String>,
253    pub has_notes: bool,
254}
255
256/// The main checkpoint manager
257pub struct TemporalCheckpointManager {
258    storage: Rc<dyn CheckpointStorage>,
259    session_id: SessionId,
260    sequence_counter: Cell<u64>,
261    last_checkpoint_time: RefCell<DateTime<Utc>>,
262}
263
264impl TemporalCheckpointManager {
265    /// Create a new checkpoint manager
266    pub fn new(storage: Rc<dyn CheckpointStorage>, session_id: SessionId) -> Self {
267        Self {
268            storage,
269            session_id,
270            sequence_counter: Cell::new(0),
271            last_checkpoint_time: RefCell::new(Utc::now()),
272        }
273    }
274
275    /// Create a manual checkpoint
276    pub fn checkpoint(&self, message: impl Into<String>) -> Result<CheckpointId> {
277        let seq = self.sequence_counter.get();
278        self.sequence_counter.set(seq + 1);
279        let state = self.capture_state()?;
280
281        let checkpoint = TemporalCheckpoint::new(
282            seq,
283            message,
284            state,
285            CheckpointTrigger::Manual,
286            self.session_id,
287        );
288
289        self.storage.store(&checkpoint)?;
290        self.update_last_checkpoint_time();
291
292        Ok(checkpoint.id)
293    }
294
295    /// Create an automatic checkpoint (if appropriate)
296    pub fn auto_checkpoint(&self, trigger: AutoTrigger) -> Result<Option<CheckpointId>> {
297        let should_checkpoint = match trigger {
298            AutoTrigger::SignificantTimePassed => {
299                let last = *self.last_checkpoint_time.borrow();
300                Utc::now().signed_duration_since(last).num_minutes() > 5
301            }
302            _ => true,
303        };
304
305        if !should_checkpoint {
306            return Ok(None);
307        }
308
309        let seq = self.sequence_counter.get();
310        self.sequence_counter.set(seq + 1);
311        let state = self.capture_state()?;
312
313        let checkpoint = TemporalCheckpoint::new(
314            seq,
315            format!("Auto: {:?}", trigger),
316            state,
317            CheckpointTrigger::Automatic(trigger),
318            self.session_id,
319        );
320
321        self.storage.store(&checkpoint)?;
322        self.update_last_checkpoint_time();
323
324        Ok(Some(checkpoint.id))
325    }
326
327    /// List all checkpoints for this session
328    pub fn list(&self) -> Result<Vec<CheckpointSummary>> {
329        self.storage.list_by_session(self.session_id)
330    }
331
332    /// Get a checkpoint by ID
333    pub fn get(&self, id: &CheckpointId) -> Result<Option<TemporalCheckpoint>> {
334        match self.storage.get(*id) {
335            Ok(cp) => Ok(Some(cp)),
336            Err(_) => Ok(None),
337        }
338    }
339
340    /// List checkpoints for a specific session
341    pub fn list_by_session(&self, session_id: &SessionId) -> Result<Vec<CheckpointSummary>> {
342        self.storage.list_by_session(*session_id)
343    }
344
345    /// List checkpoints with a specific tag
346    pub fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
347        self.storage.list_by_tag(tag)
348    }
349
350    /// Create a checkpoint with tags
351    pub fn checkpoint_with_tags(
352        &self,
353        message: impl Into<String>,
354        tags: Vec<String>,
355    ) -> Result<CheckpointId> {
356        let seq = self.sequence_counter.get();
357        self.sequence_counter.set(seq + 1);
358        let state = self.capture_state()?;
359
360        let mut checkpoint = TemporalCheckpoint::new(
361            seq,
362            message,
363            state,
364            CheckpointTrigger::Manual,
365            self.session_id,
366        );
367        checkpoint.tags = tags;
368
369        self.storage.store(&checkpoint)?;
370        self.update_last_checkpoint_time();
371
372        Ok(checkpoint.id)
373    }
374
375    /// Restore state from a checkpoint
376    pub fn restore(&self, checkpoint: &TemporalCheckpoint) -> Result<DebugStateSnapshot> {
377        // Validate checkpoint has valid state
378        if checkpoint.state.working_dir.is_none() {
379            return Err(crate::errors::ReasoningError::InvalidState(
380                "Checkpoint has no working directory".to_string()
381            ));
382        }
383        Ok(checkpoint.state.clone())
384    }
385
386    /// Get a summary of a checkpoint by ID
387    pub fn get_summary(&self, id: &CheckpointId) -> Result<Option<CheckpointSummary>> {
388        match self.storage.get(*id) {
389            Ok(cp) => Ok(Some(CheckpointSummary {
390                id: cp.id,
391                timestamp: cp.timestamp,
392                sequence_number: cp.sequence_number,
393                message: cp.message,
394                trigger: cp.trigger.to_string(),
395                tags: cp.tags,
396                has_notes: false,
397            })),
398            Err(_) => Ok(None),
399        }
400    }
401
402    /// Delete a checkpoint by ID
403    pub fn delete(&self, id: &CheckpointId) -> Result<()> {
404        self.storage.delete(*id)
405    }
406
407    /// Compact checkpoints, keeping only the most recent N
408    pub fn compact(&self, keep_recent: usize) -> Result<usize> {
409        self.compact_with_policy(CompactionPolicy::KeepRecent(keep_recent))
410    }
411
412    /// Compact checkpoints using a specific policy
413    pub fn compact_with_policy(&self, policy: CompactionPolicy) -> Result<usize> {
414        let all_checkpoints = self.storage.list_by_session(self.session_id)?;
415        
416        // Determine which checkpoints to keep
417        let ids_to_keep: std::collections::HashSet<CheckpointId> = match &policy {
418            CompactionPolicy::KeepRecent(n) => {
419                // Sort by sequence number, keep last N
420                let mut sorted = all_checkpoints.clone();
421                sorted.sort_by_key(|cp| cp.sequence_number);
422                sorted.iter().rev().take(*n).map(|cp| cp.id).collect()
423            }
424            CompactionPolicy::PreserveTagged(tags) => {
425                // Keep all checkpoints with any of the specified tags
426                all_checkpoints.iter()
427                    .filter(|cp| cp.tags.iter().any(|t| tags.contains(t)))
428                    .map(|cp| cp.id)
429                    .collect()
430            }
431            CompactionPolicy::Hybrid { keep_recent, preserve_tags } => {
432                // Keep recent + preserve tagged
433                let mut to_keep = std::collections::HashSet::new();
434                
435                // Add recent
436                let mut sorted = all_checkpoints.clone();
437                sorted.sort_by_key(|cp| cp.sequence_number);
438                for cp in sorted.iter().rev().take(*keep_recent) {
439                    to_keep.insert(cp.id);
440                }
441                
442                // Add tagged
443                for cp in &all_checkpoints {
444                    if cp.tags.iter().any(|t| preserve_tags.contains(t)) {
445                        to_keep.insert(cp.id);
446                    }
447                }
448                
449                to_keep
450            }
451        };
452        
453        // Delete checkpoints not in keep list
454        let mut deleted = 0;
455        for cp in &all_checkpoints {
456            if !ids_to_keep.contains(&cp.id) {
457                self.storage.delete(cp.id)?;
458                deleted += 1;
459            }
460        }
461        
462        Ok(deleted)
463    }
464
465    fn capture_state(&self) -> Result<DebugStateSnapshot> {
466        Ok(DebugStateSnapshot {
467            session_id: self.session_id,
468            started_at: Utc::now(),
469            checkpoint_timestamp: Utc::now(),
470            working_dir: std::env::current_dir().ok(),
471            env_vars: std::env::vars().collect(),
472            metrics: SessionMetrics::default(),
473            hypothesis_state: None, // Will be populated when hypothesis state is captured
474        })
475    }
476
477    fn update_last_checkpoint_time(&self) {
478        *self.last_checkpoint_time.borrow_mut() = Utc::now();
479    }
480}