use std::marker::PhantomData;
use std::path::Path;
use std::sync::Arc;
use crate::context::{ContextAssembly, ContextBudget, ContextItem, PRIORITY_LEARNING};
use crate::embeddings::EmbeddingBackend;
use crate::error::{MindCoreError, Result};
use crate::memory::MemoryStore;
use crate::memory::store::StoreResult;
use crate::scoring::CompositeScorer;
use crate::search::builder::SearchBuilder;
use crate::storage::Database;
use crate::storage::migrations;
use crate::traits::{MemoryRecord, ScoringStrategy};
pub struct MemoryEngine<T: MemoryRecord> {
db: Database,
global_db: Option<Database>,
store: MemoryStore<T>,
scoring: Arc<dyn ScoringStrategy>,
embedding: Option<Arc<dyn EmbeddingBackend>>,
}
impl<T: MemoryRecord> MemoryEngine<T> {
pub fn builder() -> MemoryEngineBuilder<T> {
MemoryEngineBuilder::new()
}
pub fn store(&self, record: &T) -> Result<StoreResult> {
self.store.store(&self.db, record)
}
pub fn get(&self, id: i64) -> Result<Option<T>> {
self.store.get(&self.db, id)
}
pub fn update(&self, id: i64, record: &T) -> Result<()> {
self.store.update(&self.db, id, record)
}
pub fn delete(&self, id: i64) -> Result<bool> {
self.store.delete(&self.db, id)
}
pub fn search(&self, query: &str) -> SearchBuilder<'_, T> {
let mut builder = SearchBuilder::new(&self.db, query)
.with_scoring(Arc::clone(&self.scoring));
if let Some(ref embedding) = self.embedding {
builder = builder.with_embedding(Arc::clone(embedding));
}
builder
}
pub fn embedding_backend(&self) -> Option<&dyn EmbeddingBackend> {
self.embedding.as_deref()
}
pub fn count(&self) -> Result<u64> {
self.store.count(&self.db)
}
pub fn assemble_context(
&self,
query: &str,
budget: &ContextBudget,
) -> Result<ContextAssembly> {
let results = self.search(query).limit(50).execute()?;
let candidates: Vec<ContextItem> = results
.iter()
.filter_map(|sr| {
self.db
.with_reader(|conn| {
let row = conn.query_row(
"SELECT searchable_text, memory_type, category FROM memories WHERE id = ?1",
[sr.memory_id],
|row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, Option<String>>(2)?,
))
},
);
match row {
Ok((text, type_str, category)) => {
let memory_type = crate::traits::MemoryType::from_str(&type_str)
.unwrap_or(crate::traits::MemoryType::Episodic);
Ok(Some(ContextItem {
memory_id: sr.memory_id,
content: text.clone(),
priority: PRIORITY_LEARNING,
estimated_tokens: budget.estimate_tokens(&text),
relevance_score: sr.score,
memory_type,
category,
}))
}
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(e.into()),
}
})
.ok()
.flatten()
})
.collect();
Ok(ContextAssembly::assemble(candidates, budget))
}
pub fn database(&self) -> &Database {
&self.db
}
pub fn global_database(&self) -> Option<&Database> {
self.global_db.as_ref()
}
}
impl<T: MemoryRecord> std::fmt::Debug for MemoryEngine<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryEngine")
.field("db", &self.db)
.finish()
}
}
pub struct MemoryEngineBuilder<T: MemoryRecord> {
database_path: Option<String>,
global_database_path: Option<String>,
scoring: Option<Arc<dyn ScoringStrategy>>,
embedding: Option<Arc<dyn EmbeddingBackend>>,
_phantom: PhantomData<T>,
}
impl<T: MemoryRecord> MemoryEngineBuilder<T> {
fn new() -> Self {
Self {
database_path: None,
global_database_path: None,
scoring: None,
embedding: None,
_phantom: PhantomData,
}
}
pub fn database(mut self, path: impl Into<String>) -> Self {
self.database_path = Some(path.into());
self
}
pub fn global_database(mut self, path: impl Into<String>) -> Self {
self.global_database_path = Some(path.into());
self
}
pub fn scoring(mut self, strategy: impl ScoringStrategy + 'static) -> Self {
self.scoring = Some(Arc::new(strategy));
self
}
pub fn embedding_backend(mut self, backend: impl EmbeddingBackend + 'static) -> Self {
self.embedding = Some(Arc::new(backend));
self
}
pub fn build(self) -> Result<MemoryEngine<T>> {
let db = match &self.database_path {
Some(path) => {
if let Some(parent) = Path::new(path).parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).map_err(|e| {
MindCoreError::Migration(format!(
"failed to create database directory {}: {e}",
parent.display()
))
})?;
}
}
Database::open(path)?
}
None => Database::open_in_memory()?,
};
db.with_writer(|conn| {
migrations::migrate(conn)?;
Ok(())
})?;
let global_db = match &self.global_database_path {
Some(path) => {
if let Some(parent) = Path::new(path).parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).map_err(|e| {
MindCoreError::Migration(format!(
"failed to create global database directory {}: {e}",
parent.display()
))
})?;
}
}
let gdb = Database::open(path)?;
gdb.with_writer(|conn| {
migrations::migrate(conn)?;
Ok(())
})?;
Some(gdb)
}
None => None,
};
let scoring = self
.scoring
.unwrap_or_else(|| Arc::new(CompositeScorer::empty()));
Ok(MemoryEngine {
db,
global_db,
store: MemoryStore::new(),
scoring,
embedding: self.embedding,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::MemoryType;
use chrono::Utc;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct TestMem {
id: Option<i64>,
text: String,
created_at: chrono::DateTime<Utc>,
}
impl MemoryRecord for TestMem {
fn id(&self) -> Option<i64> { self.id }
fn searchable_text(&self) -> String { self.text.clone() }
fn memory_type(&self) -> MemoryType { MemoryType::Semantic }
fn created_at(&self) -> chrono::DateTime<Utc> { self.created_at }
}
fn mem(text: &str) -> TestMem {
TestMem { id: None, text: text.into(), created_at: Utc::now() }
}
#[test]
fn builder_in_memory() {
let engine = MemoryEngine::<TestMem>::builder().build();
assert!(engine.is_ok());
}
#[test]
fn builder_with_file() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("test.db");
let engine = MemoryEngine::<TestMem>::builder()
.database(path.to_string_lossy().to_string())
.build();
assert!(engine.is_ok());
}
#[test]
fn builder_creates_parent_dirs() {
let dir = tempfile::tempdir().expect("tempdir");
let path = dir.path().join("deep/nested/dir/test.db");
let engine = MemoryEngine::<TestMem>::builder()
.database(path.to_string_lossy().to_string())
.build();
assert!(engine.is_ok());
}
#[test]
fn store_and_get_via_engine() {
let engine = MemoryEngine::<TestMem>::builder().build().expect("build");
let record = mem("hello from engine");
let result = engine.store(&record).expect("store");
let StoreResult::Added(id) = result else { panic!("expected Added") };
let retrieved = engine.get(id).expect("get");
assert!(retrieved.is_some());
assert_eq!(retrieved.as_ref().map(|r| r.text.as_str()), Some("hello from engine"));
}
#[test]
fn update_via_engine() {
let engine = MemoryEngine::<TestMem>::builder().build().expect("build");
let StoreResult::Added(id) = engine.store(&mem("original")).expect("store") else {
panic!("expected Added");
};
let updated = TestMem { id: Some(id), text: "updated".into(), created_at: Utc::now() };
engine.update(id, &updated).expect("update");
let r = engine.get(id).expect("get").expect("not found");
assert_eq!(r.text, "updated");
}
#[test]
fn delete_via_engine() {
let engine = MemoryEngine::<TestMem>::builder().build().expect("build");
let StoreResult::Added(id) = engine.store(&mem("to delete")).expect("store") else {
panic!("expected Added");
};
assert!(engine.delete(id).expect("delete"));
assert!(engine.get(id).expect("get").is_none());
}
#[test]
fn search_via_engine() {
let engine = MemoryEngine::<TestMem>::builder().build().expect("build");
engine.store(&mem("authentication error JWT")).expect("store");
engine.store(&mem("database connection timeout")).expect("store");
let results = engine.search("authentication").execute().expect("search");
assert_eq!(results.len(), 1);
}
#[test]
fn count_via_engine() {
let engine = MemoryEngine::<TestMem>::builder().build().expect("build");
assert_eq!(engine.count().expect("count"), 0);
engine.store(&mem("one")).expect("store");
engine.store(&mem("two")).expect("store");
assert_eq!(engine.count().expect("count"), 2);
}
#[test]
fn dedup_via_engine() {
let engine = MemoryEngine::<TestMem>::builder().build().expect("build");
let r1 = engine.store(&mem("same text")).expect("store 1");
let r2 = engine.store(&mem("same text")).expect("store 2");
assert!(matches!(r1, StoreResult::Added(_)));
assert!(matches!(r2, StoreResult::Duplicate(_)));
assert_eq!(engine.count().expect("count"), 1);
}
}