1use std::collections::HashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Mutex, RwLock};
8
9use chrono::Utc;
10
11use crate::checkpoint::{AutoTrigger, CheckpointId, CheckpointSummary, SessionId, TemporalCheckpoint};
12use crate::errors::{Result, ReasoningError};
13use crate::thread_safe::{ThreadSafeCheckpointManager, ThreadSafeStorage};
14
15#[derive(Clone, Debug)]
17pub struct AutoCheckpointConfig {
18 pub interval_seconds: u64,
19 pub on_error: bool,
20 pub on_tool_call: bool,
21}
22
23impl Default for AutoCheckpointConfig {
24 fn default() -> Self {
25 Self {
26 interval_seconds: 300, on_error: true,
28 on_tool_call: false,
29 }
30 }
31}
32
33#[derive(Clone, Debug)]
35pub enum CheckpointEvent {
36 Created {
37 checkpoint_id: CheckpointId,
38 session_id: SessionId,
39 timestamp: chrono::DateTime<Utc>,
40 },
41 Restored {
42 checkpoint_id: CheckpointId,
43 session_id: SessionId,
44 },
45 Deleted {
46 checkpoint_id: CheckpointId,
47 session_id: SessionId,
48 },
49 Compacted {
50 session_id: SessionId,
51 remaining: usize,
52 },
53}
54
55#[derive(Clone, Debug)]
57pub enum CheckpointCommand {
58 Create {
59 session_id: SessionId,
60 message: String,
61 tags: Vec<String>,
62 },
63 List {
64 session_id: SessionId,
65 },
66 Restore {
67 session_id: SessionId,
68 checkpoint_id: CheckpointId,
69 },
70 Delete {
71 checkpoint_id: CheckpointId,
72 },
73 Compact {
74 session_id: SessionId,
75 keep_recent: usize,
76 },
77}
78
79#[derive(Clone, Debug)]
81pub enum CommandResult {
82 Created(CheckpointId),
83 List(Vec<CheckpointSummary>),
84 Restored(TemporalCheckpoint),
85 Deleted,
86 Compacted(usize),
87 Error(String),
88}
89
90#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
92pub struct ServiceMetrics {
93 pub total_checkpoints: usize,
94 pub active_sessions: usize,
95 pub total_sessions_created: usize,
96}
97
98#[derive(Clone, Debug)]
100pub struct HealthStatus {
101 pub healthy: bool,
102 pub message: String,
103}
104
105#[derive(Clone, Debug)]
107pub struct CheckpointAnnotation {
108 pub note: String,
109 pub severity: AnnotationSeverity,
110 pub timestamp: chrono::DateTime<Utc>,
111}
112
113#[derive(Clone, Copy, Debug, PartialEq, Eq)]
115pub enum AnnotationSeverity {
116 Info,
117 Warning,
118 Critical,
119}
120
121#[derive(Clone, Debug)]
123pub struct AnnotatedCheckpoint {
124 pub checkpoint: TemporalCheckpoint,
125 pub annotations: Vec<CheckpointAnnotation>,
126}
127
128pub struct CheckpointService {
130 storage: ThreadSafeStorage,
131 sessions: RwLock<HashMap<SessionId, SessionInfo>>,
132 subscribers: Mutex<HashMap<SessionId, Vec<tokio::sync::mpsc::Sender<CheckpointEvent>>>>,
133 running: RwLock<bool>,
134 annotations: RwLock<HashMap<CheckpointId, Vec<CheckpointAnnotation>>>,
135 global_sequence: AtomicU64,
137}
138
139struct SessionInfo {
140 auto_config: Option<AutoCheckpointConfig>,
141}
142
143impl CheckpointService {
144 pub fn new(storage: ThreadSafeStorage) -> Self {
149 let initial_sequence = Self::find_max_sequence(&storage);
151
152 Self {
153 storage,
154 sessions: RwLock::new(HashMap::new()),
155 subscribers: Mutex::new(HashMap::new()),
156 running: RwLock::new(true),
157 annotations: RwLock::new(HashMap::new()),
158 global_sequence: AtomicU64::new(initial_sequence),
159 }
160 }
161
162 fn find_max_sequence(storage: &ThreadSafeStorage) -> u64 {
164 match storage.get_max_sequence() {
166 Ok(max_seq) => max_seq,
167 Err(_) => 0,
168 }
169 }
170
171 pub fn global_sequence(&self) -> u64 {
176 self.global_sequence.load(Ordering::SeqCst)
177 }
178
179 fn next_sequence(&self) -> u64 {
183 self.global_sequence.fetch_add(1, Ordering::SeqCst) + 1
184 }
185
186 pub fn is_running(&self) -> bool {
188 *self.running.read().unwrap()
189 }
190
191 pub fn stop(&self) {
193 *self.running.write().unwrap() = false;
194 }
195
196 pub fn create_session(&self, _name: &str) -> Result<SessionId> {
198 let session_id = SessionId::new();
199 let info = SessionInfo {
200 auto_config: None,
201 };
202
203 self.sessions.write().unwrap().insert(session_id, info);
204 Ok(session_id)
205 }
206
207 fn get_manager(&self, session_id: SessionId) -> ThreadSafeCheckpointManager {
209 ThreadSafeCheckpointManager::new(self.storage.clone(), session_id)
210 }
211
212 pub fn checkpoint(&self, session_id: &SessionId, message: impl Into<String>) -> Result<CheckpointId> {
214 if !self.is_running() {
215 return Err(ReasoningError::InvalidState("Service not running".to_string()));
216 }
217
218 let manager = self.get_manager(*session_id);
219 let seq = self.next_sequence();
220 let id = manager.checkpoint_with_sequence(message, seq)?;
221
222 self.emit_event(CheckpointEvent::Created {
224 checkpoint_id: id,
225 session_id: *session_id,
226 timestamp: Utc::now(),
227 });
228
229 Ok(id)
230 }
231
232 pub fn list_checkpoints(&self, session_id: &SessionId) -> Result<Vec<CheckpointSummary>> {
234 let manager = self.get_manager(*session_id);
235 manager.list()
236 }
237
238 pub fn restore(&self, session_id: &SessionId, checkpoint_id: &CheckpointId) -> Result<crate::checkpoint::DebugStateSnapshot> {
240 let manager = self.get_manager(*session_id);
241 let checkpoint = manager.get(checkpoint_id)?.ok_or_else(|| {
242 ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id))
243 })?;
244
245 let state = manager.restore(&checkpoint)?;
246
247 self.emit_event(CheckpointEvent::Restored {
248 checkpoint_id: *checkpoint_id,
249 session_id: *session_id,
250 });
251
252 Ok(state)
253 }
254
255 pub fn enable_auto_checkpoint(&self, session_id: &SessionId, config: AutoCheckpointConfig) -> Result<()> {
257 let mut sessions = self.sessions.write().unwrap();
258 if let Some(info) = sessions.get_mut(session_id) {
259 info.auto_config = Some(config);
260 Ok(())
261 } else {
262 Err(ReasoningError::NotFound(format!("Session {:?} not found", session_id)))
263 }
264 }
265
266 pub fn trigger_auto_checkpoint(&self, session_id: &SessionId, trigger: AutoTrigger) -> Result<Option<CheckpointId>> {
268 let manager = self.get_manager(*session_id);
269 let seq = self.next_sequence();
270 let result = manager.auto_checkpoint_with_sequence(trigger, seq)?;
271
272 if let Some(id) = result {
273 self.emit_event(CheckpointEvent::Created {
274 checkpoint_id: id,
275 session_id: *session_id,
276 timestamp: Utc::now(),
277 });
278 }
279
280 Ok(result)
281 }
282
283 pub fn subscribe(&self, session_id: &SessionId) -> Result<tokio::sync::mpsc::Receiver<CheckpointEvent>> {
285 let (tx, rx) = tokio::sync::mpsc::channel(100); let mut subscribers = self.subscribers.lock().unwrap();
288 subscribers.entry(*session_id).or_insert_with(Vec::new).push(tx);
289
290 Ok(rx)
291 }
292
293 fn emit_event(&self, event: CheckpointEvent) {
295 let session_id = match &event {
296 CheckpointEvent::Created { session_id, .. } => *session_id,
297 CheckpointEvent::Restored { session_id, .. } => *session_id,
298 CheckpointEvent::Deleted { session_id, .. } => *session_id,
299 CheckpointEvent::Compacted { session_id, .. } => *session_id,
300 };
301
302 let subscribers = self.subscribers.lock().unwrap();
303 if let Some(subs) = subscribers.get(&session_id) {
304 for tx in subs {
305 let _ = tx.try_send(event.clone());
307 }
308 }
309 }
310
311 pub fn execute(&self, command: CheckpointCommand) -> Result<CommandResult> {
313 match command {
314 CheckpointCommand::Create { session_id, message, tags } => {
315 let manager = self.get_manager(session_id);
316 let seq = self.next_sequence();
317 let id = if tags.is_empty() {
318 manager.checkpoint_with_sequence(message, seq)?
319 } else {
320 manager.checkpoint_with_tags_and_sequence(message, tags, seq)?
321 };
322
323 self.emit_event(CheckpointEvent::Created {
324 checkpoint_id: id,
325 session_id,
326 timestamp: Utc::now(),
327 });
328
329 Ok(CommandResult::Created(id))
330 }
331 CheckpointCommand::List { session_id } => {
332 let manager = self.get_manager(session_id);
333 let checkpoints = manager.list()?;
334 Ok(CommandResult::List(checkpoints))
335 }
336 CheckpointCommand::Restore { session_id, checkpoint_id } => {
337 let _checkpoint = self.restore(&session_id, &checkpoint_id)?;
338 let manager = self.get_manager(session_id);
340 let cp = manager.get(&checkpoint_id)?.ok_or_else(|| {
341 ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id))
342 })?;
343 Ok(CommandResult::Restored(cp))
344 }
345 CheckpointCommand::Delete { checkpoint_id } => {
346 let sessions = self.sessions.read().unwrap();
348 for session_id in sessions.keys() {
349 let manager = self.get_manager(*session_id);
350 let _ = manager.delete(&checkpoint_id);
351 }
352
353 self.emit_event(CheckpointEvent::Deleted {
354 checkpoint_id,
355 session_id: SessionId::new(), });
357
358 Ok(CommandResult::Deleted)
359 }
360 CheckpointCommand::Compact { session_id, keep_recent } => {
361 let manager = self.get_manager(session_id);
362 let deleted = manager.compact(keep_recent)?;
363
364 self.emit_event(CheckpointEvent::Compacted {
365 session_id,
366 remaining: keep_recent,
367 });
368
369 Ok(CommandResult::Compacted(deleted))
370 }
371 }
372 }
373
374 pub fn sync_to_disk(&self) -> Result<()> {
376 Ok(())
379 }
380
381 pub fn annotate(&self, checkpoint_id: &CheckpointId, annotation: CheckpointAnnotation) -> Result<()> {
383 let sessions = self.sessions.read().unwrap();
385 let mut found = false;
386 for session_id in sessions.keys() {
387 let manager = self.get_manager(*session_id);
388 if manager.get(checkpoint_id)?.is_some() {
389 found = true;
390 break;
391 }
392 }
393
394 if !found {
395 return Err(ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id)));
396 }
397
398 let mut annotations = self.annotations.write().unwrap();
400 annotations.entry(*checkpoint_id).or_insert_with(Vec::new).push(annotation);
401
402 Ok(())
403 }
404
405 pub fn get_with_annotations(&self, checkpoint_id: &CheckpointId) -> Result<AnnotatedCheckpoint> {
407 let sessions = self.sessions.read().unwrap();
408 let annotations = self.annotations.read().unwrap();
409
410 for session_id in sessions.keys() {
411 let manager = self.get_manager(*session_id);
412 if let Some(checkpoint) = manager.get(checkpoint_id)? {
413 let checkpoint_annotations = annotations.get(checkpoint_id)
414 .cloned()
415 .unwrap_or_default();
416
417 return Ok(AnnotatedCheckpoint {
418 checkpoint,
419 annotations: checkpoint_annotations,
420 });
421 }
422 }
423 Err(ReasoningError::NotFound(format!("Checkpoint {} not found", checkpoint_id)))
424 }
425
426 pub fn metrics(&self) -> Result<ServiceMetrics> {
428 let sessions = self.sessions.read().unwrap();
429 let total_checkpoints: usize = sessions.keys()
430 .map(|session_id| {
431 let manager = self.get_manager(*session_id);
432 manager.list().map(|cps| cps.len()).unwrap_or(0)
433 })
434 .sum();
435
436 Ok(ServiceMetrics {
437 total_checkpoints,
438 active_sessions: sessions.len(),
439 total_sessions_created: sessions.len(),
440 })
441 }
442
443 pub fn health_check(&self) -> Result<HealthStatus> {
445 if !self.is_running() {
446 return Ok(HealthStatus {
447 healthy: false,
448 message: "Service is stopped".to_string(),
449 });
450 }
451
452 match self.storage.list_by_session(SessionId::new()) {
454 Ok(_) => Ok(HealthStatus {
455 healthy: true,
456 message: "Service is healthy".to_string(),
457 }),
458 Err(e) => Ok(HealthStatus {
459 healthy: false,
460 message: format!("Storage error: {}", e),
461 }),
462 }
463 }
464
465 pub fn list_by_sequence_range(&self, start_seq: u64, end_seq: u64) -> Result<Vec<CheckpointSummary>> {
470 let sessions = self.sessions.read().unwrap();
471 let mut all_checkpoints = Vec::new();
472
473 for session_id in sessions.keys() {
474 let manager = self.get_manager(*session_id);
475 let cps = manager.list()?;
476 for cp in cps {
477 if cp.sequence_number >= start_seq && cp.sequence_number <= end_seq {
478 all_checkpoints.push(cp);
479 }
480 }
481 }
482
483 all_checkpoints.sort_by_key(|cp| cp.sequence_number);
485 Ok(all_checkpoints)
486 }
487
488 pub fn export_all_checkpoints(&self) -> Result<String> {
493 let sessions = self.sessions.read().unwrap();
494 let mut all_checkpoints: Vec<TemporalCheckpoint> = Vec::new();
495
496 for session_id in sessions.keys() {
497 let manager = self.get_manager(*session_id);
498 let cps = manager.list()?;
499 for cp_summary in cps {
500 if let Ok(Some(cp)) = manager.get(&cp_summary.id) {
501 all_checkpoints.push(cp);
502 }
503 }
504 }
505
506 all_checkpoints.sort_by_key(|cp| cp.sequence_number);
508
509 let export = ExportData {
510 checkpoints: all_checkpoints,
511 global_sequence: self.global_sequence(),
512 exported_at: Utc::now(),
513 };
514
515 serde_json::to_string_pretty(&export)
516 .map_err(ReasoningError::Serialization)
517 }
518
519 pub fn import_checkpoints(&self, export_data: &str) -> Result<ImportResult> {
524 let export: ExportData = serde_json::from_str(export_data)
525 .map_err(ReasoningError::Serialization)?;
526
527 let mut imported = 0;
528 let mut skipped = 0;
529 let mut max_sequence = 0u64;
530
531 for checkpoint in export.checkpoints {
532 max_sequence = max_sequence.max(checkpoint.sequence_number);
534
535 let manager = self.get_manager(checkpoint.session_id);
537 match manager.get(&checkpoint.id) {
538 Ok(Some(_)) => {
539 skipped += 1;
541 }
542 _ => {
543 if let Err(e) = self.storage.store(&checkpoint) {
545 tracing::warn!("Failed to import checkpoint {}: {}", checkpoint.id, e);
546 } else {
547 imported += 1;
548 }
549 }
550 }
551 }
552
553 let current = self.global_sequence();
555 if max_sequence > current {
556 self.global_sequence.store(max_sequence, Ordering::SeqCst);
557 }
558
559 Ok(ImportResult { imported, skipped })
560 }
561
562 pub fn validate_checkpoint(&self, checkpoint_id: &CheckpointId) -> Result<bool> {
566 let cp = self.get_with_annotations(checkpoint_id)?;
567
568 if cp.checkpoint.checksum.is_empty() {
570 return Ok(true);
571 }
572
573 match cp.checkpoint.validate() {
574 Ok(()) => Ok(true),
575 Err(_) => Ok(false),
576 }
577 }
578
579 pub fn health_check_with_validation(&self) -> Result<HealthStatus> {
584 let basic = self.health_check()?;
586 if !basic.healthy {
587 return Ok(basic);
588 }
589
590 let sessions = self.sessions.read().unwrap();
592 let mut checked = 0;
593 let mut invalid = 0;
594
595 for session_id in sessions.keys() {
596 let manager = self.get_manager(*session_id);
597 if let Ok(cps) = manager.list() {
598 for cp in cps.iter().rev().take(5) {
600 checked += 1;
601 if let Ok(Some(checkpoint)) = manager.get(&cp.id) {
602 if !checkpoint.checksum.is_empty() {
603 if let Err(e) = checkpoint.validate() {
604 tracing::warn!("Checkpoint {} failed validation: {}", cp.id, e);
605 invalid += 1;
606 }
607 }
608 }
609 }
610 }
611 }
612
613 if invalid > 0 {
614 return Ok(HealthStatus {
615 healthy: false,
616 message: format!("{} of {} recent checkpoints failed validation", invalid, checked),
617 });
618 }
619
620 Ok(HealthStatus {
621 healthy: true,
622 message: format!("Service healthy, {} recent checkpoints validated", checked),
623 })
624 }
625
626 pub fn validate_all_checkpoints(&self) -> Result<ValidationReport> {
631 let sessions = self.sessions.read().unwrap();
632 let mut valid = 0;
633 let mut invalid = 0;
634 let mut skipped = 0;
635
636 for session_id in sessions.keys() {
637 let manager = self.get_manager(*session_id);
638 if let Ok(cps) = manager.list() {
639 for cp_summary in cps {
640 if let Ok(Some(cp)) = manager.get(&cp_summary.id) {
641 if cp.checksum.is_empty() {
642 skipped += 1;
644 } else {
645 match cp.validate() {
646 Ok(()) => valid += 1,
647 Err(e) => {
648 tracing::warn!("Checkpoint {} validation failed: {}", cp.id, e);
649 invalid += 1;
650 }
651 }
652 }
653 }
654 }
655 }
656 }
657
658 Ok(ValidationReport {
659 valid,
660 invalid,
661 skipped,
662 checked_at: Some(Utc::now()),
663 })
664 }
665
666 pub async fn get_hypothesis_state(
668 &self,
669 checkpoint_id: CheckpointId,
670 ) -> Result<Option<crate::hypothesis::types::HypothesisState>> {
671 let sessions = self.sessions.read().unwrap();
673 for session_id in sessions.keys() {
674 let manager = self.get_manager(*session_id);
675 if let Some(checkpoint) = manager.get(&checkpoint_id)? {
676 return Ok(checkpoint.state.hypothesis_state);
677 }
678 }
679 Ok(None)
680 }
681}
682
683#[derive(Debug, Clone)]
685pub struct ValidationReport {
686 pub valid: usize,
687 pub invalid: usize,
688 pub skipped: usize,
689 pub checked_at: Option<chrono::DateTime<Utc>>,
690}
691
692impl ValidationReport {
693 pub fn total(&self) -> usize {
695 self.valid + self.invalid + self.skipped
696 }
697
698 pub fn all_valid(&self) -> bool {
700 self.invalid == 0
701 }
702}
703
704#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
706struct ExportData {
707 checkpoints: Vec<TemporalCheckpoint>,
708 global_sequence: u64,
709 exported_at: chrono::DateTime<Utc>,
710}
711
712#[derive(Debug, Clone)]
714pub struct ImportResult {
715 pub imported: usize,
716 pub skipped: usize,
717}
718
719#[cfg(test)]
720mod tests {
721 use super::*;
722
723 #[test]
724 fn test_service_basic() {
725 let service = CheckpointService::new(ThreadSafeStorage::in_memory().unwrap());
726 assert!(service.is_running());
727
728 let session = service.create_session("test").unwrap();
729 let id = service.checkpoint(&session, "Test").unwrap();
730 assert!(!id.to_string().is_empty());
731 }
732}