forge_reasoning/
storage_sqlitegraph.rs1use std::cell::RefCell;
4use std::collections::HashMap;
5use std::path::Path;
6
7use chrono::Utc;
8use sqlitegraph::{GraphEntity, SqliteGraph};
9
10use crate::checkpoint::{CheckpointId, CheckpointSummary, SessionId, TemporalCheckpoint, DebugStateSnapshot, CheckpointTrigger, AutoTrigger};
11use crate::errors::{Result, StorageError};
12use crate::storage::CheckpointStorage;
13
14pub struct SqliteGraphStorage {
16 graph: RefCell<SqliteGraph>,
17 cache: RefCell<HashMap<CheckpointId, TemporalCheckpoint>>,
19}
20
21unsafe impl Send for SqliteGraphStorage {}
24unsafe impl Sync for SqliteGraphStorage {}
25
26impl SqliteGraphStorage {
27 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
29 let graph = SqliteGraph::open(path)?;
30 let storage = Self {
31 graph: RefCell::new(graph),
32 cache: RefCell::new(HashMap::new()),
33 };
34 storage.load_from_disk()?;
36 Ok(storage)
37 }
38
39 pub fn open_with_recovery(path: impl AsRef<Path>) -> Result<Self> {
41 match Self::open(&path) {
43 Ok(storage) => Ok(storage),
44 Err(_) => {
45 tracing::warn!("Storage open failed, attempting recovery");
48 Self::open(path)
49 }
50 }
51 }
52
53 pub fn in_memory() -> Result<Self> {
55 let graph = SqliteGraph::open_in_memory()?;
56 Ok(Self {
57 graph: RefCell::new(graph),
58 cache: RefCell::new(HashMap::new()),
59 })
60 }
61
62 fn load_from_disk(&self) -> Result<()> {
64 let graph = self.graph.borrow();
65 let entity_ids = graph.list_entity_ids()
66 .map_err(|e| StorageError::RetrieveFailed(format!("Failed to load entity IDs: {}", e)))?;
67
68 let mut cache = self.cache.borrow_mut();
69 cache.clear();
70
71 for entity_id in entity_ids {
72 if let Ok(entity) = graph.get_entity(entity_id) {
73 if entity.kind == "Checkpoint" {
74 if let Ok(checkpoint) = self.entity_to_checkpoint(&entity) {
75 cache.insert(checkpoint.id, checkpoint);
76 }
77 }
78 }
79 }
80
81 Ok(())
82 }
83
84 fn entity_to_checkpoint(&self, entity: &GraphEntity) -> Result<TemporalCheckpoint> {
86 let data = &entity.data;
87
88 let state_data = data.get("state_data")
89 .and_then(|v| v.as_str())
90 .ok_or_else(|| StorageError::RetrieveFailed("Missing state data".to_string()))?;
91
92 let state: DebugStateSnapshot = serde_json::from_str(state_data)
93 .map_err(|e| StorageError::RetrieveFailed(format!("Failed to deserialize state: {}", e)))?;
94
95 let id_str = data.get("id")
97 .and_then(|v| v.as_str())
98 .ok_or_else(|| StorageError::RetrieveFailed("Missing checkpoint ID".to_string()))?;
99 let checkpoint_id = parse_checkpoint_id(id_str)?;
100
101 let timestamp_str = data.get("timestamp")
103 .and_then(|v| v.as_str())
104 .ok_or_else(|| StorageError::RetrieveFailed("Missing timestamp".to_string()))?;
105 let timestamp = chrono::DateTime::parse_from_rfc3339(timestamp_str)
106 .map_err(|e| StorageError::RetrieveFailed(format!("Invalid timestamp: {}", e)))?
107 .with_timezone(&Utc);
108
109 let sequence_number = data.get("sequence_number")
111 .and_then(|v| v.as_u64())
112 .ok_or_else(|| StorageError::RetrieveFailed("Missing sequence number".to_string()))?;
113
114 let message = data.get("message")
116 .and_then(|v: &serde_json::Value| v.as_str())
117 .unwrap_or("")
118 .to_string();
119
120 let tags = data.get("tags")
122 .and_then(|v: &serde_json::Value| v.as_array())
123 .map(|arr: &Vec<serde_json::Value>| arr.iter()
124 .filter_map(|v: &serde_json::Value| v.as_str().map(String::from))
125 .collect())
126 .unwrap_or_default();
127
128 let session_id_str = data.get("session_id")
130 .and_then(|v| v.as_str())
131 .ok_or_else(|| StorageError::RetrieveFailed("Missing session ID".to_string()))?;
132 let session_id = parse_session_id(session_id_str)?;
133
134 let trigger_str = data.get("trigger")
136 .and_then(|v| v.as_str())
137 .unwrap_or("manual");
138 let trigger = parse_trigger(trigger_str);
139
140 let checksum = data.get("checksum")
142 .and_then(|v| v.as_str())
143 .unwrap_or("")
144 .to_string();
145
146 Ok(TemporalCheckpoint {
147 id: checkpoint_id,
148 timestamp,
149 sequence_number,
150 message,
151 tags,
152 state,
153 trigger,
154 session_id,
155 checksum,
156 })
157 }
158}
159
160impl CheckpointStorage for SqliteGraphStorage {
161 fn store(&self, checkpoint: &TemporalCheckpoint) -> Result<()> {
162 let state_json = serde_json::to_string(&checkpoint.state)
164 .map_err(|e| StorageError::StoreFailed(format!("Failed to serialize state: {}", e)))?;
165
166 let entity = GraphEntity {
168 id: 0,
169 kind: "Checkpoint".to_string(),
170 name: checkpoint.id.to_string(),
171 file_path: None,
172 data: serde_json::json!({
173 "id": checkpoint.id,
174 "timestamp": checkpoint.timestamp,
175 "sequence_number": checkpoint.sequence_number,
176 "message": checkpoint.message,
177 "tags": checkpoint.tags,
178 "trigger": format!("{}", checkpoint.trigger),
179 "session_id": checkpoint.session_id,
180 "state_data": state_json,
181 "checksum": checkpoint.checksum,
182 }),
183 };
184
185 let graph = self.graph.borrow();
187 let _entity_id = graph.insert_entity(&entity)
188 .map_err(|e| StorageError::StoreFailed(format!("Failed to insert: {}", e)))?;
189
190 self.cache.borrow_mut().insert(checkpoint.id, checkpoint.clone());
192
193 tracing::debug!("Stored checkpoint {}", checkpoint.id);
194 Ok(())
195 }
196
197 fn get(&self, id: CheckpointId) -> Result<TemporalCheckpoint> {
198 if let Some(cp) = self.cache.borrow().get(&id) {
200 return Ok(cp.clone());
201 }
202
203 Err(StorageError::RetrieveFailed(format!("Checkpoint not found: {}", id)).into())
204 }
205
206 fn get_latest(&self, session_id: SessionId) -> Result<Option<TemporalCheckpoint>> {
207 let checkpoints = self.list_by_session(session_id)?;
208
209 let latest = checkpoints.iter()
211 .max_by_key(|c: &&CheckpointSummary| c.sequence_number);
212
213 match latest {
214 Some(summary) => self.get(summary.id).map(Some),
215 None => Ok(None),
216 }
217 }
218
219 fn list_by_session(&self, session_id: SessionId) -> Result<Vec<CheckpointSummary>> {
220 let cache = self.cache.borrow();
221 let mut summaries = Vec::new();
222
223 for (_, checkpoint) in cache.iter() {
224 if checkpoint.session_id == session_id {
225 summaries.push(CheckpointSummary {
226 id: checkpoint.id,
227 timestamp: checkpoint.timestamp,
228 sequence_number: checkpoint.sequence_number,
229 message: checkpoint.message.clone(),
230 trigger: checkpoint.trigger.to_string(),
231 tags: checkpoint.tags.clone(),
232 has_notes: false,
233 });
234 }
235 }
236
237 summaries.sort_by_key(|s: &CheckpointSummary| s.sequence_number);
239
240 Ok(summaries)
241 }
242
243 fn list_by_tag(&self, tag: &str) -> Result<Vec<CheckpointSummary>> {
244 let cache = self.cache.borrow();
245 let mut summaries = Vec::new();
246
247 for (_, checkpoint) in cache.iter() {
248 if checkpoint.tags.contains(&tag.to_string()) {
249 summaries.push(CheckpointSummary {
250 id: checkpoint.id,
251 timestamp: checkpoint.timestamp,
252 sequence_number: checkpoint.sequence_number,
253 message: checkpoint.message.clone(),
254 trigger: checkpoint.trigger.to_string(),
255 tags: checkpoint.tags.clone(),
256 has_notes: false,
257 });
258 }
259 }
260
261 summaries.sort_by_key(|s: &CheckpointSummary| s.sequence_number);
263
264 Ok(summaries)
265 }
266
267 fn delete(&self, id: CheckpointId) -> Result<()> {
268 self.cache.borrow_mut().remove(&id);
270
271 Ok(())
276 }
277
278 fn next_sequence(&self, _session_id: SessionId) -> Result<u64> {
279 Ok(0)
280 }
281
282 fn get_max_sequence(&self) -> Result<u64> {
283 let cache = self.cache.borrow();
284 let max_seq = cache.values()
285 .map(|cp| cp.sequence_number)
286 .max()
287 .unwrap_or(0);
288 Ok(max_seq)
289 }
290}
291
292fn parse_checkpoint_id(s: &str) -> Result<CheckpointId> {
295 let uuid = uuid::Uuid::parse_str(s)
296 .map_err(|e| StorageError::RetrieveFailed(format!("Invalid checkpoint ID: {}", e)))?;
297 Ok(CheckpointId(uuid))
298}
299
300fn parse_session_id(s: &str) -> Result<SessionId> {
301 let uuid = uuid::Uuid::parse_str(s)
302 .map_err(|e| StorageError::RetrieveFailed(format!("Invalid session ID: {}", e)))?;
303 Ok(SessionId(uuid))
304}
305
306fn parse_trigger(s: &str) -> CheckpointTrigger {
307 if s.starts_with("auto") {
308 CheckpointTrigger::Automatic(AutoTrigger::VerificationComplete)
309 } else if s == "scheduled" {
310 CheckpointTrigger::Scheduled
311 } else {
312 CheckpointTrigger::Manual
313 }
314}