1use super::{FlushResult, SnapshotQuery, StorageError};
7use crate::models::{DecisionSnapshot, Snapshot};
8use std::collections::HashMap;
9
10pub trait SyncStorageBackend: Send + Sync {
12 fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError>;
14
15 fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError>;
17
18 fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError>;
20
21 fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError>;
23
24 fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError>;
26
27 fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError>;
29
30 fn flush(&self) -> Result<FlushResult, StorageError>;
32
33 fn health_check(&self) -> Result<bool, StorageError>;
35}
36
37#[cfg(feature = "sqlite-storage")]
39pub struct SyncSqliteBackend {
40 inner: super::sqlite::SqliteBackend,
41}
42
43#[cfg(feature = "sqlite-storage")]
44impl SyncSqliteBackend {
45 pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self, StorageError> {
47 let inner = super::sqlite::SqliteBackend::new(path)?;
48 Ok(Self { inner })
49 }
50
51 pub fn in_memory() -> Result<Self, StorageError> {
53 let inner = super::sqlite::SqliteBackend::in_memory()?;
54 Ok(Self { inner })
55 }
56}
57
58#[cfg(feature = "sqlite-storage")]
59impl SyncStorageBackend for SyncSqliteBackend {
60 fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
61 self.inner.save_internal(snapshot)
62 }
63
64 fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
65 self.inner.save_decision_internal(decision)
66 }
67
68 fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
69 self.inner.load_internal(snapshot_id)
70 }
71
72 fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
73 let snapshot = self.load(decision_id)?;
74 if let Some(decision) = snapshot.decisions.first() {
75 Ok(decision.clone())
76 } else {
77 Err(StorageError::NotFound(format!(
78 "Decision {} not found",
79 decision_id
80 )))
81 }
82 }
83
84 fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
85 self.inner.query_internal(query)
86 }
87
88 fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError> {
89 let conn_guard = self.inner.conn.lock().unwrap();
90
91 let rows_affected = conn_guard
92 .execute(
93 "DELETE FROM snapshots WHERE id = ?",
94 rusqlite::params![snapshot_id],
95 )
96 .map_err(|e| {
97 StorageError::ConnectionError(format!("Failed to delete snapshot: {}", e))
98 })?;
99
100 Ok(rows_affected > 0)
101 }
102
103 fn flush(&self) -> Result<FlushResult, StorageError> {
104 let conn_guard = self.inner.conn.lock().unwrap();
105
106 conn_guard
108 .execute("PRAGMA wal_checkpoint(TRUNCATE)", [])
109 .map_err(|e| {
110 StorageError::ConnectionError(format!("Failed to checkpoint WAL: {}", e))
111 })?;
112
113 let snapshot_count: i64 = conn_guard
115 .query_row("SELECT COUNT(*) FROM snapshots", [], |row| row.get(0))
116 .unwrap_or(0);
117
118 Ok(FlushResult {
119 snapshots_written: snapshot_count as usize,
120 bytes_written: 0, checkpoint_id: None,
122 })
123 }
124
125 fn health_check(&self) -> Result<bool, StorageError> {
126 let conn_guard = self.inner.conn.lock().unwrap();
127
128 let _: i64 = conn_guard
130 .query_row("SELECT 1", [], |row| row.get(0))
131 .map_err(|e| StorageError::ConnectionError(format!("Health check failed: {}", e)))?;
132
133 Ok(true)
134 }
135}
136
137pub struct MemoryStorageBackend {
139 snapshots: std::sync::Mutex<HashMap<String, Snapshot>>,
140}
141
142impl MemoryStorageBackend {
143 pub fn new() -> Self {
145 Self {
146 snapshots: std::sync::Mutex::new(HashMap::new()),
147 }
148 }
149}
150
151impl Default for MemoryStorageBackend {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157impl SyncStorageBackend for MemoryStorageBackend {
158 fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
159 let snapshot_id = snapshot.metadata.snapshot_id.to_string();
160 let mut snapshots = self.snapshots.lock().unwrap();
161 snapshots.insert(snapshot_id.clone(), snapshot.clone());
162 Ok(snapshot_id)
163 }
164
165 fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
166 let snapshot = Snapshot {
168 metadata: decision.metadata.clone(),
169 decisions: vec![decision.clone()],
170 snapshot_type: crate::models::SnapshotType::Decision,
171 };
172 self.save(&snapshot)
173 }
174
175 fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
176 let snapshots = self.snapshots.lock().unwrap();
177 snapshots
178 .get(snapshot_id)
179 .cloned()
180 .ok_or_else(|| StorageError::NotFound(format!("Snapshot {} not found", snapshot_id)))
181 }
182
183 fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
184 let snapshot = self.load(decision_id)?;
185 if let Some(decision) = snapshot.decisions.first() {
186 Ok(decision.clone())
187 } else {
188 Err(StorageError::NotFound(format!(
189 "Decision {} not found",
190 decision_id
191 )))
192 }
193 }
194
195 fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
196 let snapshots = self.snapshots.lock().unwrap();
197 let mut results = Vec::new();
198
199 for (_, snapshot) in snapshots.iter() {
200 if matches_query(snapshot, &query) {
201 results.push(snapshot.clone());
202 }
203 }
204
205 results.sort_by(|a, b| b.metadata.timestamp.cmp(&a.metadata.timestamp));
207
208 let offset = query.offset.unwrap_or(0);
210 let limit = query.limit.unwrap_or(usize::MAX);
211
212 let end = std::cmp::min(offset + limit, results.len());
213 if offset < results.len() {
214 Ok(results[offset..end].to_vec())
215 } else {
216 Ok(Vec::new())
217 }
218 }
219
220 fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError> {
221 let mut snapshots = self.snapshots.lock().unwrap();
222 Ok(snapshots.remove(snapshot_id).is_some())
223 }
224
225 fn flush(&self) -> Result<FlushResult, StorageError> {
226 let snapshots = self.snapshots.lock().unwrap();
227 Ok(FlushResult {
228 snapshots_written: snapshots.len(),
229 bytes_written: 0, checkpoint_id: None,
231 })
232 }
233
234 fn health_check(&self) -> Result<bool, StorageError> {
235 Ok(true)
237 }
238}
239
240fn matches_query(snapshot: &Snapshot, query: &SnapshotQuery) -> bool {
242 if let Some(start_time) = query.start_time {
244 if snapshot.metadata.timestamp < start_time {
245 return false;
246 }
247 }
248
249 if let Some(end_time) = query.end_time {
250 if snapshot.metadata.timestamp > end_time {
251 return false;
252 }
253 }
254
255 if query.function_name.is_some()
257 || query.module_name.is_some()
258 || query.model_name.is_some()
259 || query.tags.is_some()
260 {
261 for decision in &snapshot.decisions {
262 if let Some(function_name) = &query.function_name {
263 if decision.function_name != *function_name {
264 continue;
265 }
266 }
267
268 if let Some(module_name) = &query.module_name {
269 if decision.module_name.as_ref() != Some(module_name) {
270 continue;
271 }
272 }
273
274 if let Some(model_name) = &query.model_name {
275 if let Some(model_params) = &decision.model_parameters {
276 if model_params.model_name != *model_name {
277 continue;
278 }
279 } else {
280 continue;
281 }
282 }
283
284 if let Some(query_tags) = &query.tags {
285 let mut all_tags_match = true;
286 for (key, value) in query_tags {
287 if decision.tags.get(key) != Some(value) {
288 all_tags_match = false;
289 break;
290 }
291 }
292 if !all_tags_match {
293 continue;
294 }
295 }
296
297 return true;
299 }
300
301 false
303 } else {
304 true
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::models::*;
313 use serde_json::json;
314
315 fn create_test_snapshot() -> Snapshot {
316 let input = Input::new("test_input", json!("value"), "string");
317 let output = Output::new("test_output", json!("result"), "string");
318 let model_params = ModelParameters::new("gpt-4");
319
320 let decision = DecisionSnapshot::new("test_function")
321 .with_module("test_module")
322 .add_input(input)
323 .add_output(output)
324 .with_model_parameters(model_params)
325 .add_tag("env", "test");
326
327 let mut snapshot = Snapshot::new(SnapshotType::Session);
328 snapshot.add_decision(decision);
329 snapshot
330 }
331
332 #[test]
333 fn test_memory_backend_basic_operations() {
334 let backend = MemoryStorageBackend::new();
335 let snapshot = create_test_snapshot();
336
337 let snapshot_id = backend.save(&snapshot).unwrap();
339 let loaded_snapshot = backend.load(&snapshot_id).unwrap();
340
341 assert_eq!(snapshot.decisions.len(), loaded_snapshot.decisions.len());
342 assert_eq!(snapshot.snapshot_type, loaded_snapshot.snapshot_type);
343
344 assert!(backend.health_check().unwrap());
346
347 assert!(backend.delete(&snapshot_id).unwrap());
349
350 let result = backend.load(&snapshot_id);
352 assert!(matches!(result, Err(StorageError::NotFound(_))));
353 }
354
355 #[test]
356 fn test_memory_backend_query_by_function_name() {
357 let backend = MemoryStorageBackend::new();
358 let snapshot = create_test_snapshot();
359 backend.save(&snapshot).unwrap();
360
361 let query = SnapshotQuery::new().with_function_name("test_function");
362 let results = backend.query(query).unwrap();
363
364 assert_eq!(results.len(), 1);
365 assert_eq!(results[0].decisions[0].function_name, "test_function");
366 }
367
368 #[cfg(feature = "sqlite-storage")]
369 #[test]
370 fn test_sync_sqlite_backend() {
371 let backend = SyncSqliteBackend::in_memory().unwrap();
372 let snapshot = create_test_snapshot();
373
374 let snapshot_id = backend.save(&snapshot).unwrap();
376 let loaded_snapshot = backend.load(&snapshot_id).unwrap();
377
378 assert_eq!(snapshot.decisions.len(), loaded_snapshot.decisions.len());
379 assert_eq!(snapshot.snapshot_type, loaded_snapshot.snapshot_type);
380
381 assert!(backend.health_check().unwrap());
383
384 assert!(backend.delete(&snapshot_id).unwrap());
386
387 let result = backend.load(&snapshot_id);
389 assert!(matches!(result, Err(StorageError::NotFound(_))));
390 }
391}