Skip to main content

forge_reasoning/
thread_safe.rs

1//! Thread-safe implementations for concurrent checkpointing
2//!
3//! Provides `Arc<Mutex<>>` wrappers for storage and manager
4//! to enable multi-threaded checkpoint operations.
5
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8
9use chrono::Utc;
10
11use crate::checkpoint::{
12    CheckpointId, CheckpointSummary, CompactionPolicy, DebugStateSnapshot, 
13    SessionId, TemporalCheckpoint
14};
15use crate::errors::Result;
16use crate::storage::CheckpointStorage;
17use crate::SqliteGraphStorage;
18
19/// Thread-safe wrapper around any CheckpointStorage
20pub struct ThreadSafeStorage {
21    inner: Arc<Mutex<Box<dyn CheckpointStorage>>>,
22}
23
24impl ThreadSafeStorage {
25    /// Create from existing storage
26    pub fn new<S: CheckpointStorage + 'static>(storage: S) -> Self {
27        Self {
28            inner: Arc::new(Mutex::new(Box::new(storage))),
29        }
30    }
31
32    /// Create in-memory thread-safe storage
33    pub fn in_memory() -> Result<Self> {
34        let storage = SqliteGraphStorage::in_memory()?;
35        Ok(Self::new(storage))
36    }
37
38    /// Create file-based thread-safe storage
39    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
40        let storage = SqliteGraphStorage::open(path)?;
41        Ok(Self::new(storage))
42    }
43
44    /// Store a checkpoint
45    pub fn store(&self, checkpoint: &TemporalCheckpoint) -> Result<()> {
46        let storage = self.inner.lock().expect("Storage lock poisoned");
47        storage.store(checkpoint)
48    }
49
50    /// Get checkpoint by ID
51    pub fn get(&self, id: CheckpointId) -> Result<TemporalCheckpoint> {
52        let storage = self.inner.lock().expect("Storage lock poisoned");
53        storage.get(id)
54    }
55
56    /// Get latest checkpoint for session
57    pub fn get_latest(&self, session_id: SessionId) -> Result<Option<TemporalCheckpoint>> {
58        let storage = self.inner.lock().expect("Storage lock poisoned");
59        storage.get_latest(session_id)
60    }
61
62    /// List checkpoints by session
63    pub fn list_by_session(&self, session_id: SessionId) -> Result<Vec<CheckpointSummary>> {
64        let storage = self.inner.lock().expect("Storage lock poisoned");
65        storage.list_by_session(session_id)
66    }
67
68    /// List checkpoints by tag
69    pub fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
70        let storage = self.inner.lock().expect("Storage lock poisoned");
71        storage.list_by_tag(tag)
72    }
73
74    /// Delete checkpoint
75    pub fn delete(&self, id: CheckpointId) -> Result<()> {
76        let storage = self.inner.lock().expect("Storage lock poisoned");
77        storage.delete(id)
78    }
79
80    /// Get maximum sequence number across all checkpoints
81    pub fn get_max_sequence(&self) -> Result<u64> {
82        let storage = self.inner.lock().expect("Storage lock poisoned");
83        storage.get_max_sequence()
84    }
85}
86
87impl Clone for ThreadSafeStorage {
88    fn clone(&self) -> Self {
89        Self {
90            inner: Arc::clone(&self.inner),
91        }
92    }
93}
94
95// Safety: ThreadSafeStorage uses Arc<Mutex<>> internally
96unsafe impl Send for ThreadSafeStorage {}
97unsafe impl Sync for ThreadSafeStorage {}
98
99/// Thread-safe checkpoint manager
100/// 
101/// Wraps operations in Mutex for concurrent access
102pub struct ThreadSafeCheckpointManager {
103    storage: ThreadSafeStorage,
104    session_id: SessionId,
105    sequence_counter: Mutex<u64>,
106    last_checkpoint_time: Mutex<chrono::DateTime<Utc>>,
107}
108
109impl ThreadSafeCheckpointManager {
110    /// Create a new thread-safe manager
111    pub fn new(storage: ThreadSafeStorage, session_id: SessionId) -> Self {
112        Self {
113            storage,
114            session_id,
115            sequence_counter: Mutex::new(0),
116            last_checkpoint_time: Mutex::new(Utc::now()),
117        }
118    }
119
120    /// Create a manual checkpoint with auto-generated sequence
121    pub fn checkpoint(&self, message: impl Into<String>) -> Result<CheckpointId> {
122        let seq = {
123            let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
124            *counter += 1;
125            *counter
126        };
127        self.checkpoint_with_sequence(message, seq)
128    }
129
130    /// Create a checkpoint with a specific sequence number (for global sequencing)
131    pub fn checkpoint_with_sequence(
132        &self,
133        message: impl Into<String>,
134        sequence: u64,
135    ) -> Result<CheckpointId> {
136        let state = self.capture_state()?;
137
138        let checkpoint = TemporalCheckpoint::new(
139            sequence,
140            message,
141            state,
142            crate::checkpoint::CheckpointTrigger::Manual,
143            self.session_id,
144        );
145
146        self.storage.store(&checkpoint)?;
147        self.update_last_checkpoint_time();
148
149        // Update local counter to track sequences
150        let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
151        *counter = (*counter).max(sequence);
152
153        Ok(checkpoint.id)
154    }
155
156    /// Create a checkpoint with tags and auto-generated sequence
157    pub fn checkpoint_with_tags(
158        &self,
159        message: impl Into<String>,
160        tags: Vec<String>,
161    ) -> Result<CheckpointId> {
162        let seq = {
163            let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
164            *counter += 1;
165            *counter
166        };
167        self.checkpoint_with_tags_and_sequence(message, tags, seq)
168    }
169
170    /// Create a checkpoint with tags and specific sequence number
171    pub fn checkpoint_with_tags_and_sequence(
172        &self,
173        message: impl Into<String>,
174        tags: Vec<String>,
175        sequence: u64,
176    ) -> Result<CheckpointId> {
177        let state = self.capture_state()?;
178
179        let mut checkpoint = TemporalCheckpoint::new(
180            sequence,
181            message,
182            state,
183            crate::checkpoint::CheckpointTrigger::Manual,
184            self.session_id,
185        );
186        checkpoint.tags = tags;
187
188        self.storage.store(&checkpoint)?;
189        self.update_last_checkpoint_time();
190
191        // Update local counter to track sequences
192        let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
193        *counter = (*counter).max(sequence);
194
195        Ok(checkpoint.id)
196    }
197
198    /// Create an automatic checkpoint with auto-generated sequence
199    pub fn auto_checkpoint(&self, trigger: crate::checkpoint::AutoTrigger) -> Result<Option<CheckpointId>> {
200        let should_checkpoint = match trigger {
201            crate::checkpoint::AutoTrigger::SignificantTimePassed => {
202                let last = *self.last_checkpoint_time.lock().expect("Time lock poisoned");
203                Utc::now().signed_duration_since(last).num_minutes() > 5
204            }
205            _ => true,
206        };
207
208        if !should_checkpoint {
209            return Ok(None);
210        }
211
212        let seq = {
213            let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
214            *counter += 1;
215            *counter
216        };
217        
218        self.auto_checkpoint_with_sequence(trigger, seq)
219    }
220
221    /// Create an automatic checkpoint with specific sequence number
222    pub fn auto_checkpoint_with_sequence(
223        &self,
224        trigger: crate::checkpoint::AutoTrigger,
225        sequence: u64,
226    ) -> Result<Option<CheckpointId>> {
227        let state = self.capture_state()?;
228
229        let checkpoint = TemporalCheckpoint::new(
230            sequence,
231            format!("Auto: {:?}", trigger),
232            state,
233            crate::checkpoint::CheckpointTrigger::Automatic(trigger),
234            self.session_id,
235        );
236
237        self.storage.store(&checkpoint)?;
238        self.update_last_checkpoint_time();
239
240        // Update local counter to track sequences
241        let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
242        *counter = (*counter).max(sequence);
243
244        Ok(Some(checkpoint.id))
245    }
246
247    /// List all checkpoints for this session
248    pub fn list(&self) -> Result<Vec<CheckpointSummary>> {
249        self.storage.list_by_session(self.session_id)
250    }
251
252    /// Get a checkpoint by ID
253    pub fn get(&self, id: &CheckpointId) -> Result<Option<TemporalCheckpoint>> {
254        match self.storage.get(*id) {
255            Ok(cp) => Ok(Some(cp)),
256            Err(_) => Ok(None),
257        }
258    }
259
260    /// List checkpoints for a specific session
261    pub fn list_by_session(&self, session_id: &SessionId) -> Result<Vec<CheckpointSummary>> {
262        self.storage.list_by_session(*session_id)
263    }
264
265    /// List checkpoints with a specific tag
266    pub fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
267        self.storage.list_by_tag(tag)
268    }
269
270    /// Delete a checkpoint
271    pub fn delete(&self, id: &CheckpointId) -> Result<()> {
272        self.storage.delete(*id)
273    }
274
275    /// Compact checkpoints
276    pub fn compact(&self, keep_recent: usize) -> Result<usize> {
277        self.compact_with_policy(CompactionPolicy::KeepRecent(keep_recent))
278    }
279
280    /// Compact with policy
281    pub fn compact_with_policy(&self, policy: CompactionPolicy) -> Result<usize> {
282        let all_checkpoints = self.storage.list_by_session(self.session_id)?;
283        
284        // Determine which checkpoints to keep
285        let ids_to_keep: std::collections::HashSet<CheckpointId> = match &policy {
286            CompactionPolicy::KeepRecent(n) => {
287                let mut sorted = all_checkpoints.clone();
288                sorted.sort_by_key(|cp| cp.sequence_number);
289                sorted.iter().rev().take(*n).map(|cp| cp.id).collect()
290            }
291            CompactionPolicy::PreserveTagged(tags) => {
292                all_checkpoints.iter()
293                    .filter(|cp| cp.tags.iter().any(|t| tags.contains(t)))
294                    .map(|cp| cp.id)
295                    .collect()
296            }
297            CompactionPolicy::Hybrid { keep_recent, preserve_tags } => {
298                let mut to_keep = std::collections::HashSet::new();
299                
300                let mut sorted = all_checkpoints.clone();
301                sorted.sort_by_key(|cp| cp.sequence_number);
302                for cp in sorted.iter().rev().take(*keep_recent) {
303                    to_keep.insert(cp.id);
304                }
305                
306                for cp in &all_checkpoints {
307                    if cp.tags.iter().any(|t| preserve_tags.contains(t)) {
308                        to_keep.insert(cp.id);
309                    }
310                }
311                
312                to_keep
313            }
314        };
315        
316        // Delete checkpoints not in keep list
317        let mut deleted = 0;
318        for cp in &all_checkpoints {
319            if !ids_to_keep.contains(&cp.id) {
320                self.storage.delete(cp.id)?;
321                deleted += 1;
322            }
323        }
324        
325        Ok(deleted)
326    }
327
328    /// Restore state from checkpoint
329    pub fn restore(&self, checkpoint: &TemporalCheckpoint) -> Result<DebugStateSnapshot> {
330        if checkpoint.state.working_dir.is_none() {
331            return Err(crate::errors::ReasoningError::InvalidState(
332                "Checkpoint has no working directory".to_string()
333            ));
334        }
335        Ok(checkpoint.state.clone())
336    }
337
338    /// Get summary by ID
339    pub fn get_summary(&self, id: &CheckpointId) -> Result<Option<CheckpointSummary>> {
340        match self.storage.get(*id) {
341            Ok(cp) => Ok(Some(CheckpointSummary {
342                id: cp.id,
343                timestamp: cp.timestamp,
344                sequence_number: cp.sequence_number,
345                message: cp.message,
346                trigger: cp.trigger.to_string(),
347                tags: cp.tags,
348                has_notes: false,
349            })),
350            Err(_) => Ok(None),
351        }
352    }
353
354    fn capture_state(&self) -> Result<DebugStateSnapshot> {
355        Ok(DebugStateSnapshot {
356            session_id: self.session_id,
357            started_at: Utc::now(),
358            checkpoint_timestamp: Utc::now(),
359            working_dir: std::env::current_dir().ok(),
360            env_vars: std::env::vars().collect(),
361            metrics: crate::checkpoint::SessionMetrics::default(),
362            hypothesis_state: None, // Will be populated when hypothesis state is captured
363        })
364    }
365
366    fn update_last_checkpoint_time(&self) {
367        *self.last_checkpoint_time.lock().expect("Time lock poisoned") = Utc::now();
368    }
369}
370
371// Safety: ThreadSafeCheckpointManager uses Mutex internally
372unsafe impl Send for ThreadSafeCheckpointManager {}
373unsafe impl Sync for ThreadSafeCheckpointManager {}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_thread_safe_storage_creation() {
381        let storage = ThreadSafeStorage::in_memory().unwrap();
382        // Just verify it doesn't panic
383        let _ = storage.list_by_session(SessionId::new());
384    }
385
386    #[test]
387    fn test_thread_safe_manager_creation() {
388        let storage = ThreadSafeStorage::in_memory().unwrap();
389        let session_id = SessionId::new();
390        let manager = ThreadSafeCheckpointManager::new(storage, session_id);
391        
392        // Should be able to create checkpoint
393        let id = manager.checkpoint("Test").unwrap();
394        assert!(!id.to_string().is_empty());
395    }
396}