use super::{FlushResult, SnapshotQuery, StorageError};
use crate::models::{DecisionSnapshot, Snapshot};
use std::collections::HashMap;
pub trait SyncStorageBackend: Send + Sync {
fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError>;
fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError>;
fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError>;
fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError>;
fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError>;
fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError>;
fn flush(&self) -> Result<FlushResult, StorageError>;
fn health_check(&self) -> Result<bool, StorageError>;
}
#[cfg(feature = "sqlite-storage")]
pub struct SyncSqliteBackend {
inner: super::sqlite::SqliteBackend,
}
#[cfg(feature = "sqlite-storage")]
impl SyncSqliteBackend {
pub fn new(path: impl AsRef<std::path::Path>) -> Result<Self, StorageError> {
let inner = super::sqlite::SqliteBackend::new(path)?;
Ok(Self { inner })
}
pub fn in_memory() -> Result<Self, StorageError> {
let inner = super::sqlite::SqliteBackend::in_memory()?;
Ok(Self { inner })
}
}
#[cfg(feature = "sqlite-storage")]
impl SyncStorageBackend for SyncSqliteBackend {
fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
self.inner.save_internal(snapshot)
}
fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
self.inner.save_decision_internal(decision)
}
fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
self.inner.load_internal(snapshot_id)
}
fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
let snapshot = self.load(decision_id)?;
if let Some(decision) = snapshot.decisions.first() {
Ok(decision.clone())
} else {
Err(StorageError::NotFound(format!(
"Decision {} not found",
decision_id
)))
}
}
fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
self.inner.query_internal(query)
}
fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError> {
let conn_guard = self.inner.conn.lock().unwrap();
let rows_affected = conn_guard
.execute(
"DELETE FROM snapshots WHERE id = ?",
rusqlite::params![snapshot_id],
)
.map_err(|e| {
StorageError::ConnectionError(format!("Failed to delete snapshot: {}", e))
})?;
Ok(rows_affected > 0)
}
fn flush(&self) -> Result<FlushResult, StorageError> {
let conn_guard = self.inner.conn.lock().unwrap();
conn_guard
.execute("PRAGMA wal_checkpoint(TRUNCATE)", [])
.map_err(|e| {
StorageError::ConnectionError(format!("Failed to checkpoint WAL: {}", e))
})?;
let snapshot_count: i64 = conn_guard
.query_row("SELECT COUNT(*) FROM snapshots", [], |row| row.get(0))
.unwrap_or(0);
Ok(FlushResult {
snapshots_written: snapshot_count as usize,
bytes_written: 0, checkpoint_id: None,
})
}
fn health_check(&self) -> Result<bool, StorageError> {
let conn_guard = self.inner.conn.lock().unwrap();
let _: i64 = conn_guard
.query_row("SELECT 1", [], |row| row.get(0))
.map_err(|e| StorageError::ConnectionError(format!("Health check failed: {}", e)))?;
Ok(true)
}
}
pub struct MemoryStorageBackend {
snapshots: std::sync::Mutex<HashMap<String, Snapshot>>,
}
impl MemoryStorageBackend {
pub fn new() -> Self {
Self {
snapshots: std::sync::Mutex::new(HashMap::new()),
}
}
}
impl Default for MemoryStorageBackend {
fn default() -> Self {
Self::new()
}
}
impl SyncStorageBackend for MemoryStorageBackend {
fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
let snapshot_id = snapshot.metadata.snapshot_id.to_string();
let mut snapshots = self.snapshots.lock().unwrap();
snapshots.insert(snapshot_id.clone(), snapshot.clone());
Ok(snapshot_id)
}
fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
let snapshot = Snapshot {
metadata: decision.metadata.clone(),
decisions: vec![decision.clone()],
snapshot_type: crate::models::SnapshotType::Decision,
};
self.save(&snapshot)
}
fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
let snapshots = self.snapshots.lock().unwrap();
snapshots
.get(snapshot_id)
.cloned()
.ok_or_else(|| StorageError::NotFound(format!("Snapshot {} not found", snapshot_id)))
}
fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
let snapshot = self.load(decision_id)?;
if let Some(decision) = snapshot.decisions.first() {
Ok(decision.clone())
} else {
Err(StorageError::NotFound(format!(
"Decision {} not found",
decision_id
)))
}
}
fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
let snapshots = self.snapshots.lock().unwrap();
let mut results = Vec::new();
for (_, snapshot) in snapshots.iter() {
if matches_query(snapshot, &query) {
results.push(snapshot.clone());
}
}
results.sort_by(|a, b| b.metadata.timestamp.cmp(&a.metadata.timestamp));
let offset = query.offset.unwrap_or(0);
let limit = query.limit.unwrap_or(usize::MAX);
let end = std::cmp::min(offset + limit, results.len());
if offset < results.len() {
Ok(results[offset..end].to_vec())
} else {
Ok(Vec::new())
}
}
fn delete(&self, snapshot_id: &str) -> Result<bool, StorageError> {
let mut snapshots = self.snapshots.lock().unwrap();
Ok(snapshots.remove(snapshot_id).is_some())
}
fn flush(&self) -> Result<FlushResult, StorageError> {
let snapshots = self.snapshots.lock().unwrap();
Ok(FlushResult {
snapshots_written: snapshots.len(),
bytes_written: 0, checkpoint_id: None,
})
}
fn health_check(&self) -> Result<bool, StorageError> {
Ok(true)
}
}
fn matches_query(snapshot: &Snapshot, query: &SnapshotQuery) -> bool {
if let Some(start_time) = query.start_time {
if snapshot.metadata.timestamp < start_time {
return false;
}
}
if let Some(end_time) = query.end_time {
if snapshot.metadata.timestamp > end_time {
return false;
}
}
if query.function_name.is_some()
|| query.module_name.is_some()
|| query.model_name.is_some()
|| query.tags.is_some()
{
for decision in &snapshot.decisions {
if let Some(function_name) = &query.function_name {
if decision.function_name != *function_name {
continue;
}
}
if let Some(module_name) = &query.module_name {
if decision.module_name.as_ref() != Some(module_name) {
continue;
}
}
if let Some(model_name) = &query.model_name {
if let Some(model_params) = &decision.model_parameters {
if model_params.model_name != *model_name {
continue;
}
} else {
continue;
}
}
if let Some(query_tags) = &query.tags {
let mut all_tags_match = true;
for (key, value) in query_tags {
if decision.tags.get(key) != Some(value) {
all_tags_match = false;
break;
}
}
if !all_tags_match {
continue;
}
}
return true;
}
false
} else {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::*;
use serde_json::json;
fn create_test_snapshot() -> Snapshot {
let input = Input::new("test_input", json!("value"), "string");
let output = Output::new("test_output", json!("result"), "string");
let model_params = ModelParameters::new("gpt-4");
let decision = DecisionSnapshot::new("test_function")
.with_module("test_module")
.add_input(input)
.add_output(output)
.with_model_parameters(model_params)
.add_tag("env", "test");
let mut snapshot = Snapshot::new(SnapshotType::Session);
snapshot.add_decision(decision);
snapshot
}
#[test]
fn test_memory_backend_basic_operations() {
let backend = MemoryStorageBackend::new();
let snapshot = create_test_snapshot();
let snapshot_id = backend.save(&snapshot).unwrap();
let loaded_snapshot = backend.load(&snapshot_id).unwrap();
assert_eq!(snapshot.decisions.len(), loaded_snapshot.decisions.len());
assert_eq!(snapshot.snapshot_type, loaded_snapshot.snapshot_type);
assert!(backend.health_check().unwrap());
assert!(backend.delete(&snapshot_id).unwrap());
let result = backend.load(&snapshot_id);
assert!(matches!(result, Err(StorageError::NotFound(_))));
}
#[test]
fn test_memory_backend_query_by_function_name() {
let backend = MemoryStorageBackend::new();
let snapshot = create_test_snapshot();
backend.save(&snapshot).unwrap();
let query = SnapshotQuery::new().with_function_name("test_function");
let results = backend.query(query).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].decisions[0].function_name, "test_function");
}
#[cfg(feature = "sqlite-storage")]
#[test]
fn test_sync_sqlite_backend() {
let backend = SyncSqliteBackend::in_memory().unwrap();
let snapshot = create_test_snapshot();
let snapshot_id = backend.save(&snapshot).unwrap();
let loaded_snapshot = backend.load(&snapshot_id).unwrap();
assert_eq!(snapshot.decisions.len(), loaded_snapshot.decisions.len());
assert_eq!(snapshot.snapshot_type, loaded_snapshot.snapshot_type);
assert!(backend.health_check().unwrap());
assert!(backend.delete(&snapshot_id).unwrap());
let result = backend.load(&snapshot_id);
assert!(matches!(result, Err(StorageError::NotFound(_))));
}
}