1use std::path::Path;
7use std::sync::{Arc, Mutex};
8
9use chrono::Utc;
10
11use crate::checkpoint::{
12 CheckpointId, CheckpointSummary, CompactionPolicy, DebugStateSnapshot,
13 SessionId, TemporalCheckpoint
14};
15use crate::errors::Result;
16use crate::storage::CheckpointStorage;
17use crate::SqliteGraphStorage;
18
19pub struct ThreadSafeStorage {
21 inner: Arc<Mutex<Box<dyn CheckpointStorage>>>,
22}
23
24impl ThreadSafeStorage {
25 pub fn new<S: CheckpointStorage + 'static>(storage: S) -> Self {
27 Self {
28 inner: Arc::new(Mutex::new(Box::new(storage))),
29 }
30 }
31
32 pub fn in_memory() -> Result<Self> {
34 let storage = SqliteGraphStorage::in_memory()?;
35 Ok(Self::new(storage))
36 }
37
38 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
40 let storage = SqliteGraphStorage::open(path)?;
41 Ok(Self::new(storage))
42 }
43
44 pub fn store(&self, checkpoint: &TemporalCheckpoint) -> Result<()> {
46 let storage = self.inner.lock().expect("Storage lock poisoned");
47 storage.store(checkpoint)
48 }
49
50 pub fn get(&self, id: CheckpointId) -> Result<TemporalCheckpoint> {
52 let storage = self.inner.lock().expect("Storage lock poisoned");
53 storage.get(id)
54 }
55
56 pub fn get_latest(&self, session_id: SessionId) -> Result<Option<TemporalCheckpoint>> {
58 let storage = self.inner.lock().expect("Storage lock poisoned");
59 storage.get_latest(session_id)
60 }
61
62 pub fn list_by_session(&self, session_id: SessionId) -> Result<Vec<CheckpointSummary>> {
64 let storage = self.inner.lock().expect("Storage lock poisoned");
65 storage.list_by_session(session_id)
66 }
67
68 pub fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
70 let storage = self.inner.lock().expect("Storage lock poisoned");
71 storage.list_by_tag(tag)
72 }
73
74 pub fn delete(&self, id: CheckpointId) -> Result<()> {
76 let storage = self.inner.lock().expect("Storage lock poisoned");
77 storage.delete(id)
78 }
79
80 pub fn get_max_sequence(&self) -> Result<u64> {
82 let storage = self.inner.lock().expect("Storage lock poisoned");
83 storage.get_max_sequence()
84 }
85}
86
87impl Clone for ThreadSafeStorage {
88 fn clone(&self) -> Self {
89 Self {
90 inner: Arc::clone(&self.inner),
91 }
92 }
93}
94
95unsafe impl Send for ThreadSafeStorage {}
97unsafe impl Sync for ThreadSafeStorage {}
98
99pub struct ThreadSafeCheckpointManager {
103 storage: ThreadSafeStorage,
104 session_id: SessionId,
105 sequence_counter: Mutex<u64>,
106 last_checkpoint_time: Mutex<chrono::DateTime<Utc>>,
107}
108
109impl ThreadSafeCheckpointManager {
110 pub fn new(storage: ThreadSafeStorage, session_id: SessionId) -> Self {
112 Self {
113 storage,
114 session_id,
115 sequence_counter: Mutex::new(0),
116 last_checkpoint_time: Mutex::new(Utc::now()),
117 }
118 }
119
120 pub fn checkpoint(&self, message: impl Into<String>) -> Result<CheckpointId> {
122 let seq = {
123 let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
124 *counter += 1;
125 *counter
126 };
127 self.checkpoint_with_sequence(message, seq)
128 }
129
130 pub fn checkpoint_with_sequence(
132 &self,
133 message: impl Into<String>,
134 sequence: u64,
135 ) -> Result<CheckpointId> {
136 let state = self.capture_state()?;
137
138 let checkpoint = TemporalCheckpoint::new(
139 sequence,
140 message,
141 state,
142 crate::checkpoint::CheckpointTrigger::Manual,
143 self.session_id,
144 );
145
146 self.storage.store(&checkpoint)?;
147 self.update_last_checkpoint_time();
148
149 let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
151 *counter = (*counter).max(sequence);
152
153 Ok(checkpoint.id)
154 }
155
156 pub fn checkpoint_with_tags(
158 &self,
159 message: impl Into<String>,
160 tags: Vec<String>,
161 ) -> Result<CheckpointId> {
162 let seq = {
163 let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
164 *counter += 1;
165 *counter
166 };
167 self.checkpoint_with_tags_and_sequence(message, tags, seq)
168 }
169
170 pub fn checkpoint_with_tags_and_sequence(
172 &self,
173 message: impl Into<String>,
174 tags: Vec<String>,
175 sequence: u64,
176 ) -> Result<CheckpointId> {
177 let state = self.capture_state()?;
178
179 let mut checkpoint = TemporalCheckpoint::new(
180 sequence,
181 message,
182 state,
183 crate::checkpoint::CheckpointTrigger::Manual,
184 self.session_id,
185 );
186 checkpoint.tags = tags;
187
188 self.storage.store(&checkpoint)?;
189 self.update_last_checkpoint_time();
190
191 let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
193 *counter = (*counter).max(sequence);
194
195 Ok(checkpoint.id)
196 }
197
198 pub fn auto_checkpoint(&self, trigger: crate::checkpoint::AutoTrigger) -> Result<Option<CheckpointId>> {
200 let should_checkpoint = match trigger {
201 crate::checkpoint::AutoTrigger::SignificantTimePassed => {
202 let last = *self.last_checkpoint_time.lock().expect("Time lock poisoned");
203 Utc::now().signed_duration_since(last).num_minutes() > 5
204 }
205 _ => true,
206 };
207
208 if !should_checkpoint {
209 return Ok(None);
210 }
211
212 let seq = {
213 let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
214 *counter += 1;
215 *counter
216 };
217
218 self.auto_checkpoint_with_sequence(trigger, seq)
219 }
220
221 pub fn auto_checkpoint_with_sequence(
223 &self,
224 trigger: crate::checkpoint::AutoTrigger,
225 sequence: u64,
226 ) -> Result<Option<CheckpointId>> {
227 let state = self.capture_state()?;
228
229 let checkpoint = TemporalCheckpoint::new(
230 sequence,
231 format!("Auto: {:?}", trigger),
232 state,
233 crate::checkpoint::CheckpointTrigger::Automatic(trigger),
234 self.session_id,
235 );
236
237 self.storage.store(&checkpoint)?;
238 self.update_last_checkpoint_time();
239
240 let mut counter = self.sequence_counter.lock().expect("Counter poisoned");
242 *counter = (*counter).max(sequence);
243
244 Ok(Some(checkpoint.id))
245 }
246
247 pub fn list(&self) -> Result<Vec<CheckpointSummary>> {
249 self.storage.list_by_session(self.session_id)
250 }
251
252 pub fn get(&self, id: &CheckpointId) -> Result<Option<TemporalCheckpoint>> {
254 match self.storage.get(*id) {
255 Ok(cp) => Ok(Some(cp)),
256 Err(_) => Ok(None),
257 }
258 }
259
260 pub fn list_by_session(&self, session_id: &SessionId) -> Result<Vec<CheckpointSummary>> {
262 self.storage.list_by_session(*session_id)
263 }
264
265 pub fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
267 self.storage.list_by_tag(tag)
268 }
269
270 pub fn delete(&self, id: &CheckpointId) -> Result<()> {
272 self.storage.delete(*id)
273 }
274
275 pub fn compact(&self, keep_recent: usize) -> Result<usize> {
277 self.compact_with_policy(CompactionPolicy::KeepRecent(keep_recent))
278 }
279
280 pub fn compact_with_policy(&self, policy: CompactionPolicy) -> Result<usize> {
282 let all_checkpoints = self.storage.list_by_session(self.session_id)?;
283
284 let ids_to_keep: std::collections::HashSet<CheckpointId> = match &policy {
286 CompactionPolicy::KeepRecent(n) => {
287 let mut sorted = all_checkpoints.clone();
288 sorted.sort_by_key(|cp| cp.sequence_number);
289 sorted.iter().rev().take(*n).map(|cp| cp.id).collect()
290 }
291 CompactionPolicy::PreserveTagged(tags) => {
292 all_checkpoints.iter()
293 .filter(|cp| cp.tags.iter().any(|t| tags.contains(t)))
294 .map(|cp| cp.id)
295 .collect()
296 }
297 CompactionPolicy::Hybrid { keep_recent, preserve_tags } => {
298 let mut to_keep = std::collections::HashSet::new();
299
300 let mut sorted = all_checkpoints.clone();
301 sorted.sort_by_key(|cp| cp.sequence_number);
302 for cp in sorted.iter().rev().take(*keep_recent) {
303 to_keep.insert(cp.id);
304 }
305
306 for cp in &all_checkpoints {
307 if cp.tags.iter().any(|t| preserve_tags.contains(t)) {
308 to_keep.insert(cp.id);
309 }
310 }
311
312 to_keep
313 }
314 };
315
316 let mut deleted = 0;
318 for cp in &all_checkpoints {
319 if !ids_to_keep.contains(&cp.id) {
320 self.storage.delete(cp.id)?;
321 deleted += 1;
322 }
323 }
324
325 Ok(deleted)
326 }
327
328 pub fn restore(&self, checkpoint: &TemporalCheckpoint) -> Result<DebugStateSnapshot> {
330 if checkpoint.state.working_dir.is_none() {
331 return Err(crate::errors::ReasoningError::InvalidState(
332 "Checkpoint has no working directory".to_string()
333 ));
334 }
335 Ok(checkpoint.state.clone())
336 }
337
338 pub fn get_summary(&self, id: &CheckpointId) -> Result<Option<CheckpointSummary>> {
340 match self.storage.get(*id) {
341 Ok(cp) => Ok(Some(CheckpointSummary {
342 id: cp.id,
343 timestamp: cp.timestamp,
344 sequence_number: cp.sequence_number,
345 message: cp.message,
346 trigger: cp.trigger.to_string(),
347 tags: cp.tags,
348 has_notes: false,
349 })),
350 Err(_) => Ok(None),
351 }
352 }
353
354 fn capture_state(&self) -> Result<DebugStateSnapshot> {
355 Ok(DebugStateSnapshot {
356 session_id: self.session_id,
357 started_at: Utc::now(),
358 checkpoint_timestamp: Utc::now(),
359 working_dir: std::env::current_dir().ok(),
360 env_vars: std::env::vars().collect(),
361 metrics: crate::checkpoint::SessionMetrics::default(),
362 hypothesis_state: None, })
364 }
365
366 fn update_last_checkpoint_time(&self) {
367 *self.last_checkpoint_time.lock().expect("Time lock poisoned") = Utc::now();
368 }
369}
370
371unsafe impl Send for ThreadSafeCheckpointManager {}
373unsafe impl Sync for ThreadSafeCheckpointManager {}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_thread_safe_storage_creation() {
381 let storage = ThreadSafeStorage::in_memory().unwrap();
382 let _ = storage.list_by_session(SessionId::new());
384 }
385
386 #[test]
387 fn test_thread_safe_manager_creation() {
388 let storage = ThreadSafeStorage::in_memory().unwrap();
389 let session_id = SessionId::new();
390 let manager = ThreadSafeCheckpointManager::new(storage, session_id);
391
392 let id = manager.checkpoint("Test").unwrap();
394 assert!(!id.to_string().is_empty());
395 }
396}