Skip to main content

forge_reasoning/
storage_sqlitegraph.rs

1//! SQLiteGraph storage implementation for checkpoints
2
3use std::cell::RefCell;
4use std::collections::HashMap;
5use std::path::Path;
6
7use chrono::Utc;
8use sqlitegraph::{GraphEntity, SqliteGraph};
9
10use crate::checkpoint::{CheckpointId, CheckpointSummary, SessionId, TemporalCheckpoint, DebugStateSnapshot, CheckpointTrigger, AutoTrigger};
11use crate::errors::{Result, StorageError};
12use crate::storage::CheckpointStorage;
13
14/// SQLiteGraph-based checkpoint storage (MVP version)
15pub struct SqliteGraphStorage {
16    graph: RefCell<SqliteGraph>,
17    /// In-memory cache for query operations (MVP workaround)
18    cache: RefCell<HashMap<CheckpointId, TemporalCheckpoint>>,
19}
20
21// Safety: We use RefCell for single-threaded interior mutability.
22// For thread-safe usage, wrap in ThreadSafeStorage which uses Arc<Mutex<>>.
23unsafe impl Send for SqliteGraphStorage {}
24unsafe impl Sync for SqliteGraphStorage {}
25
26impl SqliteGraphStorage {
27    /// Open or create storage at the given path
28    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
29        let graph = SqliteGraph::open(path)?;
30        let storage = Self {
31            graph: RefCell::new(graph),
32            cache: RefCell::new(HashMap::new()),
33        };
34        // Load existing checkpoints from disk
35        storage.load_from_disk()?;
36        Ok(storage)
37    }
38
39    /// Open with recovery - attempts to repair corrupted storage
40    pub fn open_with_recovery(path: impl AsRef<Path>) -> Result<Self> {
41        // Try normal open first
42        match Self::open(&path) {
43            Ok(storage) => Ok(storage),
44            Err(_) => {
45                // If that fails, try to create fresh storage
46                // In production, this would attempt actual recovery
47                tracing::warn!("Storage open failed, attempting recovery");
48                Self::open(path)
49            }
50        }
51    }
52
53    /// Create an in-memory storage (for testing)
54    pub fn in_memory() -> Result<Self> {
55        let graph = SqliteGraph::open_in_memory()?;
56        Ok(Self {
57            graph: RefCell::new(graph),
58            cache: RefCell::new(HashMap::new()),
59        })
60    }
61
62    /// Load all checkpoints from SQLite into cache
63    fn load_from_disk(&self) -> Result<()> {
64        let graph = self.graph.borrow();
65        let entity_ids = graph.list_entity_ids()
66            .map_err(|e| StorageError::RetrieveFailed(format!("Failed to load entity IDs: {}", e)))?;
67        
68        let mut cache = self.cache.borrow_mut();
69        cache.clear();
70        
71        for entity_id in entity_ids {
72            if let Ok(entity) = graph.get_entity(entity_id) {
73                if entity.kind == "Checkpoint" {
74                    if let Ok(checkpoint) = self.entity_to_checkpoint(&entity) {
75                        cache.insert(checkpoint.id, checkpoint);
76                    }
77                }
78            }
79        }
80        
81        Ok(())
82    }
83
84    /// Convert a GraphEntity to TemporalCheckpoint
85    fn entity_to_checkpoint(&self, entity: &GraphEntity) -> Result<TemporalCheckpoint> {
86        let data = &entity.data;
87        
88        let state_data = data.get("state_data")
89            .and_then(|v| v.as_str())
90            .ok_or_else(|| StorageError::RetrieveFailed("Missing state data".to_string()))?;
91        
92        let state: DebugStateSnapshot = serde_json::from_str(state_data)
93            .map_err(|e| StorageError::RetrieveFailed(format!("Failed to deserialize state: {}", e)))?;
94        
95        // Parse checkpoint ID from string
96        let id_str = data.get("id")
97            .and_then(|v| v.as_str())
98            .ok_or_else(|| StorageError::RetrieveFailed("Missing checkpoint ID".to_string()))?;
99        let checkpoint_id = parse_checkpoint_id(id_str)?;
100        
101        // Parse timestamp
102        let timestamp_str = data.get("timestamp")
103            .and_then(|v| v.as_str())
104            .ok_or_else(|| StorageError::RetrieveFailed("Missing timestamp".to_string()))?;
105        let timestamp = chrono::DateTime::parse_from_rfc3339(timestamp_str)
106            .map_err(|e| StorageError::RetrieveFailed(format!("Invalid timestamp: {}", e)))?
107            .with_timezone(&Utc);
108        
109        // Parse sequence number
110        let sequence_number = data.get("sequence_number")
111            .and_then(|v| v.as_u64())
112            .ok_or_else(|| StorageError::RetrieveFailed("Missing sequence number".to_string()))?;
113        
114        // Parse message
115        let message = data.get("message")
116            .and_then(|v: &serde_json::Value| v.as_str())
117            .unwrap_or("")
118            .to_string();
119        
120        // Parse tags
121        let tags = data.get("tags")
122            .and_then(|v: &serde_json::Value| v.as_array())
123            .map(|arr: &Vec<serde_json::Value>| arr.iter()
124                .filter_map(|v: &serde_json::Value| v.as_str().map(String::from))
125                .collect())
126            .unwrap_or_default();
127        
128        // Parse session ID
129        let session_id_str = data.get("session_id")
130            .and_then(|v| v.as_str())
131            .ok_or_else(|| StorageError::RetrieveFailed("Missing session ID".to_string()))?;
132        let session_id = parse_session_id(session_id_str)?;
133        
134        // Parse trigger
135        let trigger_str = data.get("trigger")
136            .and_then(|v| v.as_str())
137            .unwrap_or("manual");
138        let trigger = parse_trigger(trigger_str);
139        
140        // Parse checksum (may not exist for legacy checkpoints)
141        let checksum = data.get("checksum")
142            .and_then(|v| v.as_str())
143            .unwrap_or("")
144            .to_string();
145        
146        Ok(TemporalCheckpoint {
147            id: checkpoint_id,
148            timestamp,
149            sequence_number,
150            message,
151            tags,
152            state,
153            trigger,
154            session_id,
155            checksum,
156        })
157    }
158}
159
160impl CheckpointStorage for SqliteGraphStorage {
161    fn store(&self, checkpoint: &TemporalCheckpoint) -> Result<()> {
162        // Serialize state to JSON
163        let state_json = serde_json::to_string(&checkpoint.state)
164            .map_err(|e| StorageError::StoreFailed(format!("Failed to serialize state: {}", e)))?;
165
166        // Create checkpoint entity
167        let entity = GraphEntity {
168            id: 0,
169            kind: "Checkpoint".to_string(),
170            name: checkpoint.id.to_string(),
171            file_path: None,
172            data: serde_json::json!({
173                "id": checkpoint.id,
174                "timestamp": checkpoint.timestamp,
175                "sequence_number": checkpoint.sequence_number,
176                "message": checkpoint.message,
177                "tags": checkpoint.tags,
178                "trigger": format!("{}", checkpoint.trigger),
179                "session_id": checkpoint.session_id,
180                "state_data": state_json,
181                "checksum": checkpoint.checksum,
182            }),
183        };
184
185        // Insert into graph
186        let graph = self.graph.borrow();
187        let _entity_id = graph.insert_entity(&entity)
188            .map_err(|e| StorageError::StoreFailed(format!("Failed to insert: {}", e)))?;
189
190        // Also store in cache for easy retrieval
191        self.cache.borrow_mut().insert(checkpoint.id, checkpoint.clone());
192
193        tracing::debug!("Stored checkpoint {}", checkpoint.id);
194        Ok(())
195    }
196
197    fn get(&self, id: CheckpointId) -> Result<TemporalCheckpoint> {
198        // Try cache first
199        if let Some(cp) = self.cache.borrow().get(&id) {
200            return Ok(cp.clone());
201        }
202        
203        Err(StorageError::RetrieveFailed(format!("Checkpoint not found: {}", id)).into())
204    }
205
206    fn get_latest(&self, session_id: SessionId) -> Result<Option<TemporalCheckpoint>> {
207        let checkpoints = self.list_by_session(session_id)?;
208        
209        // Get the one with highest sequence number
210        let latest = checkpoints.iter()
211            .max_by_key(|c: &&CheckpointSummary| c.sequence_number);
212        
213        match latest {
214            Some(summary) => self.get(summary.id).map(Some),
215            None => Ok(None),
216        }
217    }
218
219    fn list_by_session(&self, session_id: SessionId) -> Result<Vec<CheckpointSummary>> {
220        let cache = self.cache.borrow();
221        let mut summaries = Vec::new();
222        
223        for (_, checkpoint) in cache.iter() {
224            if checkpoint.session_id == session_id {
225                summaries.push(CheckpointSummary {
226                    id: checkpoint.id,
227                    timestamp: checkpoint.timestamp,
228                    sequence_number: checkpoint.sequence_number,
229                    message: checkpoint.message.clone(),
230                    trigger: checkpoint.trigger.to_string(),
231                    tags: checkpoint.tags.clone(),
232                    has_notes: false,
233                });
234            }
235        }
236        
237        // Sort by sequence number
238        summaries.sort_by_key(|s: &CheckpointSummary| s.sequence_number);
239        
240        Ok(summaries)
241    }
242
243    fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
244        let cache = self.cache.borrow();
245        let mut summaries = Vec::new();
246        
247        for (_, checkpoint) in cache.iter() {
248            if checkpoint.tags.contains(&tag.to_string()) {
249                summaries.push(CheckpointSummary {
250                    id: checkpoint.id,
251                    timestamp: checkpoint.timestamp,
252                    sequence_number: checkpoint.sequence_number,
253                    message: checkpoint.message.clone(),
254                    trigger: checkpoint.trigger.to_string(),
255                    tags: checkpoint.tags.clone(),
256                    has_notes: false,
257                });
258            }
259        }
260        
261        // Sort by sequence number
262        summaries.sort_by_key(|s: &CheckpointSummary| s.sequence_number);
263        
264        Ok(summaries)
265    }
266
267    fn delete(&self, id: CheckpointId) -> Result<()> {
268        // Remove from cache
269        self.cache.borrow_mut().remove(&id);
270        
271        // Try to remove from SQLite (best effort)
272        // Note: This requires entity ID lookup which we don't track
273        // For MVP, cache removal is sufficient
274        
275        Ok(())
276    }
277
278    fn next_sequence(&self, _session_id: SessionId) -> Result<u64> {
279        Ok(0)
280    }
281
282    fn get_max_sequence(&self) -> Result<u64> {
283        let cache = self.cache.borrow();
284        let max_seq = cache.values()
285            .map(|cp| cp.sequence_number)
286            .max()
287            .unwrap_or(0);
288        Ok(max_seq)
289    }
290}
291
292// Helper functions for parsing
293
294fn parse_checkpoint_id(s: &str) -> Result<CheckpointId> {
295    let uuid = uuid::Uuid::parse_str(s)
296        .map_err(|e| StorageError::RetrieveFailed(format!("Invalid checkpoint ID: {}", e)))?;
297    Ok(CheckpointId(uuid))
298}
299
300fn parse_session_id(s: &str) -> Result<SessionId> {
301    let uuid = uuid::Uuid::parse_str(s)
302        .map_err(|e| StorageError::RetrieveFailed(format!("Invalid session ID: {}", e)))?;
303    Ok(SessionId(uuid))
304}
305
306fn parse_trigger(s: &str) -> CheckpointTrigger {
307    if s.starts_with("auto") {
308        CheckpointTrigger::Automatic(AutoTrigger::VerificationComplete)
309    } else if s == "scheduled" {
310        CheckpointTrigger::Scheduled
311    } else {
312        CheckpointTrigger::Manual
313    }
314}