use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::types::{
Edge, Entity, Episode, EpisodicMemory, Memory, Namespace, ObservationMemory, ProceduralMemory,
SemanticMemory,
};
pub mod sqlite;
#[cfg(feature = "postgres")]
pub mod postgres;
#[cfg(feature = "postgres")]
pub use postgres::PostgresBackend;
#[derive(Debug, thiserror::Error)]
pub enum StorageError {
#[error("SQLite error: {0}")]
Sqlite(#[from] rusqlite::Error),
#[error("Serialization error: {0}")]
Serde(#[from] serde_json::Error),
#[error("Not found: {0}")]
NotFound(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Storage context: {0}")]
Context(String),
#[error("Mutex lock poisoned: {0}")]
LockPoisoned(String),
}
pub type StorageResult<T> = Result<T, StorageError>;
pub trait StorageTrait: Send + Sync {
fn save_namespace(&self, ns: &Namespace) -> StorageResult<()>;
fn get_namespace(&self, id: Uuid) -> StorageResult<Option<Namespace>>;
fn get_namespace_by_name(&self, name: &str) -> StorageResult<Option<Namespace>>;
fn save_entity(&self, entity: &Entity) -> StorageResult<()>;
fn get_entity(&self, id: Uuid) -> StorageResult<Option<Entity>>;
fn get_entity_by_name(&self, name: &str, namespace_id: Uuid) -> StorageResult<Option<Entity>>;
fn save_episode(&self, episode: &Episode) -> StorageResult<()>;
fn get_episode(&self, id: Uuid) -> StorageResult<Option<Episode>>;
fn update_episode(&self, episode: &Episode) -> StorageResult<()>;
fn save_episodic(&self, mem: &EpisodicMemory) -> StorageResult<()>;
fn get_episodic(&self, id: Uuid) -> StorageResult<Option<EpisodicMemory>>;
fn list_episodic_by_entity(
&self,
about_entity: Uuid,
limit: usize,
) -> StorageResult<Vec<EpisodicMemory>>;
fn list_episodic_by_episode(
&self,
namespace_id: Uuid,
episode_id: Uuid,
) -> StorageResult<Vec<EpisodicMemory>> {
let all = self.get_all_memories_by_namespace(namespace_id)?;
let mut out: Vec<EpisodicMemory> = all
.into_iter()
.filter_map(|m| match m {
Memory::Episodic(e) if e.episode_id == episode_id => Some(e),
_ => None,
})
.collect();
out.sort_by_key(|e| e.event_time.unwrap_or(e.timestamp));
Ok(out)
}
fn update_episodic_access(
&self,
id: Uuid,
stability: f32,
retrievability: f32,
) -> StorageResult<()>;
fn save_semantic(&self, mem: &SemanticMemory) -> StorageResult<()>;
fn get_semantic(&self, id: Uuid) -> StorageResult<Option<SemanticMemory>>;
fn list_semantic_by_entity(
&self,
subject: Uuid,
limit: usize,
) -> StorageResult<Vec<SemanticMemory>>;
fn invalidate_semantic(&self, id: Uuid) -> StorageResult<()>;
fn save_procedural(&self, mem: &ProceduralMemory) -> StorageResult<()>;
fn get_procedural(&self, id: Uuid) -> StorageResult<Option<ProceduralMemory>>;
fn update_procedural_reliability(
&self,
id: Uuid,
reliability: f32,
trial_count: u32,
success_count: u32,
) -> StorageResult<()>;
fn save_observation(&self, _mem: &ObservationMemory) -> StorageResult<()> {
Err(StorageError::Context(
"save_observation not implemented on this backend".into(),
))
}
fn get_observation(&self, _id: Uuid) -> StorageResult<Option<ObservationMemory>> {
Ok(None)
}
fn list_observations_by_episode_ids(
&self,
_episode_ids: &[Uuid],
_limit: usize,
) -> StorageResult<Vec<ObservationMemory>> {
Ok(Vec::new())
}
fn delete_observations_by_episode(&self, _episode_id: Uuid) -> StorageResult<usize> {
Ok(0)
}
fn delete_observations_by_entity(&self, _entity_id: Uuid) -> StorageResult<usize> {
Ok(0)
}
fn search_fts(
&self,
query: &str,
namespace_id: Uuid,
limit: usize,
) -> StorageResult<Vec<Memory>>;
fn search_fts_scoped(
&self,
query: &str,
namespace_id: Uuid,
entity_id: Uuid,
limit: usize,
) -> StorageResult<Vec<Memory>>;
fn get_all_memories_by_namespace(&self, namespace_id: Uuid) -> StorageResult<Vec<Memory>>;
fn delete_memories_by_entity(&self, entity_id: Uuid) -> StorageResult<usize>;
fn delete_memory_by_id(&self, id: Uuid) -> StorageResult<bool>;
fn purge_namespace(&self, namespace_id: Uuid) -> StorageResult<usize> {
let memories = self.get_all_memories_by_namespace(namespace_id)?;
let mut count = 0;
for mem in &memories {
if self.delete_memory_by_id(mem.id()).unwrap_or(false) {
count += 1;
}
}
Ok(count)
}
fn update_semantic_content(
&self,
id: Uuid,
predicate: &str,
object: &str,
confidence: Option<f32>,
) -> StorageResult<()>;
fn delete_entity(&self, id: Uuid) -> StorageResult<bool>;
fn list_entities_by_namespace(&self, namespace_id: Uuid) -> StorageResult<Vec<Entity>>;
fn save_edge(&self, edge: &Edge) -> StorageResult<()>;
fn get_edges_for_entity(&self, entity_id: Uuid) -> StorageResult<Vec<Edge>>;
fn count_memories_by_namespace(
&self,
namespace_id: Uuid,
) -> StorageResult<(usize, usize, usize)>;
fn count_entities_by_namespace(&self, namespace_id: Uuid) -> StorageResult<usize>;
fn log_activity(
&self,
namespace_id: Uuid,
event_type: &str,
detail: &serde_json::Value,
) -> StorageResult<()>;
fn get_activity_aggregates(
&self,
namespace_id: Uuid,
days: u32,
) -> StorageResult<Vec<ActivityAggregate>>;
fn get_recent_activity(
&self,
namespace_id: Uuid,
limit: usize,
) -> StorageResult<Vec<ActivityEvent>>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivityEvent {
pub id: Uuid,
pub event_type: String,
pub namespace_id: Uuid,
pub detail_json: serde_json::Value,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActivityAggregate {
pub date: String,
pub recalls: usize,
pub remembers: usize,
pub observes: usize,
pub forgets: usize,
}