Skip to main content

forge_reasoning/
service.rs

1//! Integration service for checkpointing
2//!
3//! Provides high-level API for Forge agent integration
4
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Mutex, RwLock};
8
9use chrono::Utc;
10
11use crate::checkpoint::{AutoTrigger, CheckpointId, CheckpointSummary, SessionId, TemporalCheckpoint};
12use crate::errors::{Result, ReasoningError};
13use crate::thread_safe::{ThreadSafeCheckpointManager, ThreadSafeStorage};
14
15/// Configuration for auto-checkpointing
16#[derive(Clone, Debug)]
17pub struct AutoCheckpointConfig {
18    pub interval_seconds: u64,
19    pub on_error: bool,
20    pub on_tool_call: bool,
21}
22
23impl Default for AutoCheckpointConfig {
24    fn default() -> Self {
25        Self {
26            interval_seconds: 300, // 5 minutes
27            on_error: true,
28            on_tool_call: false,
29        }
30    }
31}
32
33/// Events emitted by the checkpoint service
34#[derive(Clone, Debug)]
35pub enum CheckpointEvent {
36    Created {
37        checkpoint_id: CheckpointId,
38        session_id: SessionId,
39        timestamp: chrono::DateTime<Utc>,
40    },
41    Restored {
42        checkpoint_id: CheckpointId,
43        session_id: SessionId,
44    },
45    Deleted {
46        checkpoint_id: CheckpointId,
47        session_id: SessionId,
48    },
49    Compacted {
50        session_id: SessionId,
51        remaining: usize,
52    },
53}
54
55/// Commands that can be executed on the service
56#[derive(Clone, Debug)]
57pub enum CheckpointCommand {
58    Create {
59        session_id: SessionId,
60        message: String,
61        tags: Vec<String>,
62    },
63    List {
64        session_id: SessionId,
65    },
66    Restore {
67        session_id: SessionId,
68        checkpoint_id: CheckpointId,
69    },
70    Delete {
71        checkpoint_id: CheckpointId,
72    },
73    Compact {
74        session_id: SessionId,
75        keep_recent: usize,
76    },
77}
78
79/// Results from command execution
80#[derive(Clone, Debug)]
81pub enum CommandResult {
82    Created(CheckpointId),
83    List(Vec<CheckpointSummary>),
84    Restored(TemporalCheckpoint),
85    Deleted,
86    Compacted(usize),
87    Error(String),
88}
89
90/// Service metrics
91#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
92pub struct ServiceMetrics {
93    pub total_checkpoints: usize,
94    pub active_sessions: usize,
95    pub total_sessions_created: usize,
96}
97
98/// Health check status
99#[derive(Clone, Debug)]
100pub struct HealthStatus {
101    pub healthy: bool,
102    pub message: String,
103}
104
105/// Annotation for checkpoints
106#[derive(Clone, Debug)]
107pub struct CheckpointAnnotation {
108    pub note: String,
109    pub severity: AnnotationSeverity,
110    pub timestamp: chrono::DateTime<Utc>,
111}
112
113/// Severity level for annotations
114#[derive(Clone, Copy, Debug, PartialEq, Eq)]
115pub enum AnnotationSeverity {
116    Info,
117    Warning,
118    Critical,
119}
120
121/// Checkpoint with annotations
122#[derive(Clone, Debug)]
123pub struct AnnotatedCheckpoint {
124    pub checkpoint: TemporalCheckpoint,
125    pub annotations: Vec<CheckpointAnnotation>,
126}
127
128/// Main checkpoint service for integration
129pub struct CheckpointService {
130    storage: ThreadSafeStorage,
131    sessions: RwLock<HashMap<SessionId, SessionInfo>>,
132    subscribers: Mutex<HashMap<SessionId, Vec<tokio::sync::mpsc::Sender<CheckpointEvent>>>>,
133    running: RwLock<bool>,
134    annotations: RwLock<HashMap<CheckpointId, Vec<CheckpointAnnotation>>>,
135    /// Global sequence counter for monotonic checkpoint ordering across all sessions
136    global_sequence: AtomicU64,
137}
138
139struct SessionInfo {
140    auto_config: Option<AutoCheckpointConfig>,
141}
142
143impl CheckpointService {
144    /// Create a new checkpoint service
145    /// 
146    /// Initializes the global sequence counter from storage to ensure
147    /// monotonic sequences across service restarts.
148    pub fn new(storage: ThreadSafeStorage) -> Self {
149        // Initialize global sequence from storage (find max existing sequence)
150        let initial_sequence = Self::find_max_sequence(&storage);
151        
152        Self {
153            storage,
154            sessions: RwLock::new(HashMap::new()),
155            subscribers: Mutex::new(HashMap::new()),
156            running: RwLock::new(true),
157            annotations: RwLock::new(HashMap::new()),
158            global_sequence: AtomicU64::new(initial_sequence),
159        }
160    }
161
162    /// Find the maximum sequence number in storage
163    fn find_max_sequence(storage: &ThreadSafeStorage) -> u64 {
164        // Query storage for max sequence across all checkpoints
165        match storage.get_max_sequence() {
166            Ok(max_seq) => max_seq,
167            Err(_) => 0,
168        }
169    }
170
171    /// Get the current global sequence number
172    /// 
173    /// Returns the sequence number of the most recently created checkpoint.
174    /// Returns 0 if no checkpoints have been created yet.
175    pub fn global_sequence(&self) -> u64 {
176        self.global_sequence.load(Ordering::SeqCst)
177    }
178
179    /// Get the next sequence number atomically
180    /// 
181    /// Returns 1-based sequence numbers (first checkpoint is 1, not 0)
182    fn next_sequence(&self) -> u64 {
183        self.global_sequence.fetch_add(1, Ordering::SeqCst) + 1
184    }
185
186    /// Check if service is running
187    pub fn is_running(&self) -> bool {
188        *self.running.read().unwrap()
189    }
190
191    /// Stop the service
192    pub fn stop(&self) {
193        *self.running.write().unwrap() = false;
194    }
195
196    /// Create a new session
197    pub fn create_session(&self, _name: &str) -> Result<SessionId> {
198        let session_id = SessionId::new();
199        let info = SessionInfo {
200            auto_config: None,
201        };
202        
203        self.sessions.write().unwrap().insert(session_id, info);
204        Ok(session_id)
205    }
206
207    /// Get or create session manager
208    fn get_manager(&self, session_id: SessionId) -> ThreadSafeCheckpointManager {
209        ThreadSafeCheckpointManager::new(self.storage.clone(), session_id)
210    }
211
212    /// Create a checkpoint with global sequence number
213    pub fn checkpoint(&self, session_id: &SessionId, message: impl Into<String>) -> Result<CheckpointId> {
214        if !self.is_running() {
215            return Err(ReasoningError::InvalidState("Service not running".to_string()));
216        }
217        
218        let manager = self.get_manager(*session_id);
219        let seq = self.next_sequence();
220        let id = manager.checkpoint_with_sequence(message, seq)?;
221        
222        // Emit event
223        self.emit_event(CheckpointEvent::Created {
224            checkpoint_id: id,
225            session_id: *session_id,
226            timestamp: Utc::now(),
227        });
228        
229        Ok(id)
230    }
231
232    /// List checkpoints for a session
233    pub fn list_checkpoints(&self, session_id: &SessionId) -> Result<Vec<CheckpointSummary>> {
234        let manager = self.get_manager(*session_id);
235        manager.list()
236    }
237
238    /// Restore a checkpoint
239    pub fn restore(&self, session_id: &SessionId, checkpoint_id: &CheckpointId) -> Result<crate::checkpoint::DebugStateSnapshot> {
240        let manager = self.get_manager(*session_id);
241        let checkpoint = manager.get(checkpoint_id)?.ok_or_else(|| {
242            ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id))
243        })?;
244        
245        let state = manager.restore(&checkpoint)?;
246        
247        self.emit_event(CheckpointEvent::Restored {
248            checkpoint_id: *checkpoint_id,
249            session_id: *session_id,
250        });
251        
252        Ok(state)
253    }
254
255    /// Enable auto-checkpointing for a session
256    pub fn enable_auto_checkpoint(&self, session_id: &SessionId, config: AutoCheckpointConfig) -> Result<()> {
257        let mut sessions = self.sessions.write().unwrap();
258        if let Some(info) = sessions.get_mut(session_id) {
259            info.auto_config = Some(config);
260            Ok(())
261        } else {
262            Err(ReasoningError::NotFound(format!("Session {:?} not found", session_id)))
263        }
264    }
265
266    /// Trigger an auto-checkpoint with global sequence
267    pub fn trigger_auto_checkpoint(&self, session_id: &SessionId, trigger: AutoTrigger) -> Result<Option<CheckpointId>> {
268        let manager = self.get_manager(*session_id);
269        let seq = self.next_sequence();
270        let result = manager.auto_checkpoint_with_sequence(trigger, seq)?;
271        
272        if let Some(id) = result {
273            self.emit_event(CheckpointEvent::Created {
274                checkpoint_id: id,
275                session_id: *session_id,
276                timestamp: Utc::now(),
277            });
278        }
279        
280        Ok(result)
281    }
282
283    /// Subscribe to checkpoint events for a session
284    pub fn subscribe(&self, session_id: &SessionId) -> Result<tokio::sync::mpsc::Receiver<CheckpointEvent>> {
285        let (tx, rx) = tokio::sync::mpsc::channel(100); // Buffer up to 100 events
286        
287        let mut subscribers = self.subscribers.lock().unwrap();
288        subscribers.entry(*session_id).or_insert_with(Vec::new).push(tx);
289        
290        Ok(rx)
291    }
292
293    /// Emit event to subscribers
294    fn emit_event(&self, event: CheckpointEvent) {
295        let session_id = match &event {
296            CheckpointEvent::Created { session_id, .. } => *session_id,
297            CheckpointEvent::Restored { session_id, .. } => *session_id,
298            CheckpointEvent::Deleted { session_id, .. } => *session_id,
299            CheckpointEvent::Compacted { session_id, .. } => *session_id,
300        };
301        
302        let subscribers = self.subscribers.lock().unwrap();
303        if let Some(subs) = subscribers.get(&session_id) {
304            for tx in subs {
305                // Best-effort delivery (try_send is non-blocking)
306                let _ = tx.try_send(event.clone());
307            }
308        }
309    }
310
311    /// Execute a command
312    pub fn execute(&self, command: CheckpointCommand) -> Result<CommandResult> {
313        match command {
314            CheckpointCommand::Create { session_id, message, tags } => {
315                let manager = self.get_manager(session_id);
316                let seq = self.next_sequence();
317                let id = if tags.is_empty() {
318                    manager.checkpoint_with_sequence(message, seq)?
319                } else {
320                    manager.checkpoint_with_tags_and_sequence(message, tags, seq)?
321                };
322                
323                self.emit_event(CheckpointEvent::Created {
324                    checkpoint_id: id,
325                    session_id,
326                    timestamp: Utc::now(),
327                });
328                
329                Ok(CommandResult::Created(id))
330            }
331            CheckpointCommand::List { session_id } => {
332                let manager = self.get_manager(session_id);
333                let checkpoints = manager.list()?;
334                Ok(CommandResult::List(checkpoints))
335            }
336            CheckpointCommand::Restore { session_id, checkpoint_id } => {
337                let _checkpoint = self.restore(&session_id, &checkpoint_id)?;
338                // Create a minimal TemporalCheckpoint for the result
339                let manager = self.get_manager(session_id);
340                let cp = manager.get(&checkpoint_id)?.ok_or_else(|| {
341                    ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id))
342                })?;
343                Ok(CommandResult::Restored(cp))
344            }
345            CheckpointCommand::Delete { checkpoint_id } => {
346                // Delete from all sessions (simplified)
347                let sessions = self.sessions.read().unwrap();
348                for session_id in sessions.keys() {
349                    let manager = self.get_manager(*session_id);
350                    let _ = manager.delete(&checkpoint_id);
351                }
352                
353                self.emit_event(CheckpointEvent::Deleted {
354                    checkpoint_id,
355                    session_id: SessionId::new(), // Simplified
356                });
357                
358                Ok(CommandResult::Deleted)
359            }
360            CheckpointCommand::Compact { session_id, keep_recent } => {
361                let manager = self.get_manager(session_id);
362                let deleted = manager.compact(keep_recent)?;
363                
364                self.emit_event(CheckpointEvent::Compacted {
365                    session_id,
366                    remaining: keep_recent,
367                });
368                
369                Ok(CommandResult::Compacted(deleted))
370            }
371        }
372    }
373
374    /// Sync checkpoints to disk (background persistence)
375    pub fn sync_to_disk(&self) -> Result<()> {
376        // In a real implementation, this would trigger background flush
377        // For now, we just verify the storage is working
378        Ok(())
379    }
380
381    /// Annotate a checkpoint
382    pub fn annotate(&self, checkpoint_id: &CheckpointId, annotation: CheckpointAnnotation) -> Result<()> {
383        // Verify checkpoint exists
384        let sessions = self.sessions.read().unwrap();
385        let mut found = false;
386        for session_id in sessions.keys() {
387            let manager = self.get_manager(*session_id);
388            if manager.get(checkpoint_id)?.is_some() {
389                found = true;
390                break;
391            }
392        }
393        
394        if !found {
395            return Err(ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id)));
396        }
397        
398        // Store annotation
399        let mut annotations = self.annotations.write().unwrap();
400        annotations.entry(*checkpoint_id).or_insert_with(Vec::new).push(annotation);
401        
402        Ok(())
403    }
404
405    /// Get checkpoint with annotations
406    pub fn get_with_annotations(&self, checkpoint_id: &CheckpointId) -> Result<AnnotatedCheckpoint> {
407        let sessions = self.sessions.read().unwrap();
408        let annotations = self.annotations.read().unwrap();
409        
410        for session_id in sessions.keys() {
411            let manager = self.get_manager(*session_id);
412            if let Some(checkpoint) = manager.get(checkpoint_id)? {
413                let checkpoint_annotations = annotations.get(checkpoint_id)
414                    .cloned()
415                    .unwrap_or_default();
416                
417                return Ok(AnnotatedCheckpoint {
418                    checkpoint,
419                    annotations: checkpoint_annotations,
420                });
421            }
422        }
423        Err(ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id)))
424    }
425
426    /// Get service metrics
427    pub fn metrics(&self) -> Result<ServiceMetrics> {
428        let sessions = self.sessions.read().unwrap();
429        let total_checkpoints: usize = sessions.keys()
430            .map(|session_id| {
431                let manager = self.get_manager(*session_id);
432                manager.list().map(|cps| cps.len()).unwrap_or(0)
433            })
434            .sum();
435        
436        Ok(ServiceMetrics {
437            total_checkpoints,
438            active_sessions: sessions.len(),
439            total_sessions_created: sessions.len(),
440        })
441    }
442
443    /// Health check
444    pub fn health_check(&self) -> Result<HealthStatus> {
445        if !self.is_running() {
446            return Ok(HealthStatus {
447                healthy: false,
448                message: "Service is stopped".to_string(),
449            });
450        }
451        
452        // Try a simple operation
453        match self.storage.list_by_session(SessionId::new()) {
454            Ok(_) => Ok(HealthStatus {
455                healthy: true,
456                message: "Service is healthy".to_string(),
457            }),
458            Err(e) => Ok(HealthStatus {
459                healthy: false,
460                message: format!("Storage error: {}", e),
461            }),
462        }
463    }
464
465    /// List checkpoints by global sequence number range
466    /// 
467    /// Returns all checkpoints (across all sessions) with sequence numbers
468    /// in the inclusive range [start_seq, end_seq].
469    pub fn list_by_sequence_range(&self, start_seq: u64, end_seq: u64) -> Result<Vec<CheckpointSummary>> {
470        let sessions = self.sessions.read().unwrap();
471        let mut all_checkpoints = Vec::new();
472        
473        for session_id in sessions.keys() {
474            let manager = self.get_manager(*session_id);
475            let cps = manager.list()?;
476            for cp in cps {
477                if cp.sequence_number >= start_seq && cp.sequence_number <= end_seq {
478                    all_checkpoints.push(cp);
479                }
480            }
481        }
482        
483        // Sort by sequence number
484        all_checkpoints.sort_by_key(|cp| cp.sequence_number);
485        Ok(all_checkpoints)
486    }
487
488    /// Export all checkpoints from all sessions
489    /// 
490    /// Returns a JSON-serializable export containing all checkpoints
491    /// and the current global sequence number.
492    pub fn export_all_checkpoints(&self) -> Result<String> {
493        let sessions = self.sessions.read().unwrap();
494        let mut all_checkpoints: Vec<TemporalCheckpoint> = Vec::new();
495        
496        for session_id in sessions.keys() {
497            let manager = self.get_manager(*session_id);
498            let cps = manager.list()?;
499            for cp_summary in cps {
500                if let Ok(Some(cp)) = manager.get(&cp_summary.id) {
501                    all_checkpoints.push(cp);
502                }
503            }
504        }
505        
506        // Sort by sequence number
507        all_checkpoints.sort_by_key(|cp| cp.sequence_number);
508        
509        let export = ExportData {
510            checkpoints: all_checkpoints,
511            global_sequence: self.global_sequence(),
512            exported_at: Utc::now(),
513        };
514        
515        serde_json::to_string_pretty(&export)
516            .map_err(ReasoningError::Serialization)
517    }
518
519    /// Import checkpoints from export data
520    /// 
521    /// Imports all checkpoints and restores the global sequence counter.
522    /// Skips checkpoints that already exist (by ID).
523    pub fn import_checkpoints(&self, export_data: &str) -> Result<ImportResult> {
524        let export: ExportData = serde_json::from_str(export_data)
525            .map_err(ReasoningError::Serialization)?;
526        
527        let mut imported = 0;
528        let mut skipped = 0;
529        let mut max_sequence = 0u64;
530        
531        for checkpoint in export.checkpoints {
532            // Track the maximum sequence number
533            max_sequence = max_sequence.max(checkpoint.sequence_number);
534            
535            // Check if checkpoint already exists
536            let manager = self.get_manager(checkpoint.session_id);
537            match manager.get(&checkpoint.id) {
538                Ok(Some(_)) => {
539                    // Already exists, skip
540                    skipped += 1;
541                }
542                _ => {
543                    // Store the checkpoint
544                    if let Err(e) = self.storage.store(&checkpoint) {
545                        tracing::warn!("Failed to import checkpoint {}: {}", checkpoint.id, e);
546                    } else {
547                        imported += 1;
548                    }
549                }
550            }
551        }
552        
553        // Update global sequence if imported checkpoints have higher sequences
554        let current = self.global_sequence();
555        if max_sequence > current {
556            self.global_sequence.store(max_sequence, Ordering::SeqCst);
557        }
558        
559        Ok(ImportResult { imported, skipped })
560    }
561
562    /// Validate a single checkpoint by ID
563    /// 
564    /// Returns true if the checkpoint's checksum is valid, false otherwise.
565    pub fn validate_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<bool> {
566        let cp = self.get_with_annotations(checkpoint_id)?;
567        
568        // Empty checksum means legacy checkpoint (skip validation)
569        if cp.checkpoint.checksum.is_empty() {
570            return Ok(true);
571        }
572        
573        match cp.checkpoint.validate() {
574            Ok(()) => Ok(true),
575            Err(_) => Ok(false),
576        }
577    }
578
579    /// Health check with validation of recent checkpoints
580    /// 
581    /// Performs a health check and additionally validates the most
582    /// recent checkpoints to detect data corruption.
583    pub fn health_check_with_validation(&self) -> Result<HealthStatus> {
584        // First do basic health check
585        let basic = self.health_check()?;
586        if !basic.healthy {
587            return Ok(basic);
588        }
589        
590        // Validate recent checkpoints from all sessions
591        let sessions = self.sessions.read().unwrap();
592        let mut checked = 0;
593        let mut invalid = 0;
594        
595        for session_id in sessions.keys() {
596            let manager = self.get_manager(*session_id);
597            if let Ok(cps) = manager.list() {
598                // Check up to 5 most recent checkpoints per session
599                for cp in cps.iter().rev().take(5) {
600                    checked += 1;
601                    if let Ok(Some(checkpoint)) = manager.get(&cp.id) {
602                        if !checkpoint.checksum.is_empty() {
603                            if let Err(e) = checkpoint.validate() {
604                                tracing::warn!("Checkpoint {} failed validation: {}", cp.id, e);
605                                invalid += 1;
606                            }
607                        }
608                    }
609                }
610            }
611        }
612        
613        if invalid > 0 {
614            return Ok(HealthStatus {
615                healthy: false,
616                message: format!("{} of {} recent checkpoints failed validation", invalid, checked),
617            });
618        }
619        
620        Ok(HealthStatus {
621            healthy: true,
622            message: format!("Service healthy, {} recent checkpoints validated", checked),
623        })
624    }
625
626    /// Validate all checkpoints
627    /// 
628    /// Performs a full validation of all checkpoints in the system.
629    /// Returns a report with validation statistics.
630    pub fn validate_all_checkpoints(&self) -> Result<ValidationReport> {
631        let sessions = self.sessions.read().unwrap();
632        let mut valid = 0;
633        let mut invalid = 0;
634        let mut skipped = 0;
635        
636        for session_id in sessions.keys() {
637            let manager = self.get_manager(*session_id);
638            if let Ok(cps) = manager.list() {
639                for cp_summary in cps {
640                    if let Ok(Some(cp)) = manager.get(&cp_summary.id) {
641                        if cp.checksum.is_empty() {
642                            // Legacy checkpoint without checksum
643                            skipped += 1;
644                        } else {
645                            match cp.validate() {
646                                Ok(()) => valid += 1,
647                                Err(e) => {
648                                    tracing::warn!("Checkpoint {} validation failed: {}", cp.id, e);
649                                    invalid += 1;
650                                }
651                            }
652                        }
653                    }
654                }
655            }
656        }
657        
658        Ok(ValidationReport {
659            valid,
660            invalid,
661            skipped,
662            checked_at: Some(Utc::now()),
663        })
664    }
665
666    /// Get hypothesis state at a specific checkpoint
667    pub async fn get_hypothesis_state(
668        &self,
669        checkpoint_id: CheckpointId,
670    ) -> Result<Option<crate::hypothesis::types::HypothesisState>> {
671        // Search for the checkpoint across all sessions
672        let sessions = self.sessions.read().unwrap();
673        for session_id in sessions.keys() {
674            let manager = self.get_manager(*session_id);
675            if let Some(checkpoint) = manager.get(&checkpoint_id)? {
676                return Ok(checkpoint.state.hypothesis_state);
677            }
678        }
679        Ok(None)
680    }
681}
682
683/// Report from validation operation
684#[derive(Debug, Clone)]
685pub struct ValidationReport {
686    pub valid: usize,
687    pub invalid: usize,
688    pub skipped: usize,
689    pub checked_at: Option<chrono::DateTime<Utc>>,
690}
691
692impl ValidationReport {
693    /// Total number of checkpoints checked
694    pub fn total(&self) -> usize {
695        self.valid + self.invalid + self.skipped
696    }
697    
698    /// Whether all checked checkpoints were valid
699    pub fn all_valid(&self) -> bool {
700        self.invalid == 0
701    }
702}
703
704/// Data structure for export/import
705#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
706struct ExportData {
707    checkpoints: Vec<TemporalCheckpoint>,
708    global_sequence: u64,
709    exported_at: chrono::DateTime<Utc>,
710}
711
712/// Result of import operation
713#[derive(Debug, Clone)]
714pub struct ImportResult {
715    pub imported: usize,
716    pub skipped: usize,
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722
723    #[test]
724    fn test_service_basic() {
725        let service = CheckpointService::new(ThreadSafeStorage::in_memory().unwrap());
726        assert!(service.is_running());
727        
728        let session = service.create_session("test").unwrap();
729        let id = service.checkpoint(&session, "Test").unwrap();
730        assert!(!id.to_string().is_empty());
731    }
732}