use async_trait::async_trait;
use std::sync::Mutex;
use super::{MemoryFilter, MemoryFragment, MemoryId, MemorySource, MemorySubstrate};
use crate::agent::result::AgentError;
pub struct TruenoMemory {
index: trueno_rag::sqlite::SqliteIndex,
next_id: Mutex<u64>,
}
impl TruenoMemory {
pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self, AgentError> {
let index = trueno_rag::sqlite::SqliteIndex::open(path)
.map_err(|e| AgentError::Memory(format!("open failed: {e}")))?;
let next_id = index
.get_metadata("memory_next_id")
.ok()
.flatten()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(1);
Ok(Self { index, next_id: Mutex::new(next_id) })
}
pub fn open_in_memory() -> Result<Self, AgentError> {
let index = trueno_rag::sqlite::SqliteIndex::open_in_memory()
.map_err(|e| AgentError::Memory(format!("in-memory open failed: {e}")))?;
Ok(Self { index, next_id: Mutex::new(1) })
}
fn gen_id(&self) -> Result<String, AgentError> {
let mut id = self.next_id.lock().map_err(|e| AgentError::Memory(format!("lock: {e}")))?;
let current = *id;
*id += 1;
let _ = self.index.set_metadata("memory_next_id", &id.to_string());
Ok(format!("trueno-{current}"))
}
fn doc_id(agent_id: &str, memory_id: &str) -> String {
format!("{agent_id}:{memory_id}")
}
fn kv_key(agent_id: &str, key: &str) -> String {
format!("kv:{agent_id}:{key}")
}
pub fn fragment_count(&self) -> Result<usize, AgentError> {
self.index.chunk_count().map_err(|e| AgentError::Memory(format!("chunk count: {e}")))
}
}
#[async_trait]
impl MemorySubstrate for TruenoMemory {
async fn remember(
&self,
agent_id: &str,
content: &str,
source: MemorySource,
_embedding: Option<&[f32]>,
) -> Result<MemoryId, AgentError> {
let memory_id = self.gen_id()?;
let doc_id = Self::doc_id(agent_id, &memory_id);
let source_str = match &source {
MemorySource::Conversation => "conversation",
MemorySource::ToolResult => "tool_result",
MemorySource::System => "system",
MemorySource::User => "user",
};
let chunk_id = format!("{doc_id}:0");
let chunks = vec![(chunk_id, content.to_string())];
self.index
.insert_document(&doc_id, Some(source_str), Some(agent_id), content, &chunks, None)
.map_err(|e| AgentError::Memory(format!("insert failed: {e}")))?;
Ok(memory_id)
}
async fn recall(
&self,
query: &str,
limit: usize,
filter: Option<MemoryFilter>,
_query_embedding: Option<&[f32]>,
) -> Result<Vec<MemoryFragment>, AgentError> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
let search_limit = if filter.is_some() { limit * 4 } else { limit };
let results = self
.index
.search_fts(query, search_limit)
.map_err(|e| AgentError::Memory(format!("search failed: {e}")))?;
let max_score = results.iter().map(|r| r.score).fold(0.0_f64, f64::max);
let mut fragments: Vec<MemoryFragment> = results
.into_iter()
.filter(|r| {
let Some(ref f) = filter else {
return true;
};
if let Some(ref aid) = f.agent_id {
if !r.doc_id.starts_with(&format!("{aid}:")) {
return false;
}
}
if let Some(ref src) = f.source {
let src_str = match src {
MemorySource::Conversation => "conversation",
MemorySource::ToolResult => "tool_result",
MemorySource::System => "system",
MemorySource::User => "user",
};
let _ = src_str;
}
true
})
.map(|r| {
#[allow(clippy::cast_possible_truncation)]
let relevance = if max_score > 0.0 { (r.score / max_score) as f32 } else { 0.0 };
let memory_id = match r.doc_id.split_once(':') {
Some((_, mid)) => mid.to_string(),
None => r.doc_id.clone(),
};
MemoryFragment {
id: memory_id,
content: r.content,
source: MemorySource::Conversation, relevance_score: relevance,
created_at: chrono::Utc::now(), }
})
.collect();
fragments.truncate(limit);
Ok(fragments)
}
async fn set(
&self,
agent_id: &str,
key: &str,
value: serde_json::Value,
) -> Result<(), AgentError> {
let kv_key = Self::kv_key(agent_id, key);
let serialized = serde_json::to_string(&value)
.map_err(|e| AgentError::Memory(format!("serialize: {e}")))?;
self.index
.set_metadata(&kv_key, &serialized)
.map_err(|e| AgentError::Memory(format!("set_metadata: {e}")))?;
Ok(())
}
async fn get(
&self,
agent_id: &str,
key: &str,
) -> Result<Option<serde_json::Value>, AgentError> {
let kv_key = Self::kv_key(agent_id, key);
let stored = self
.index
.get_metadata(&kv_key)
.map_err(|e| AgentError::Memory(format!("get_metadata: {e}")))?;
match stored {
Some(s) => {
let value = serde_json::from_str(&s)
.map_err(|e| AgentError::Memory(format!("deserialize: {e}")))?;
Ok(Some(value))
}
None => Ok(None),
}
}
async fn forget(&self, id: MemoryId) -> Result<(), AgentError> {
let doc_count = self
.index
.document_count()
.map_err(|e| AgentError::Memory(format!("doc_count: {e}")))?;
if doc_count > 0 {
let _ = self.index.remove_document(&id);
}
Ok(())
}
}
#[cfg(test)]
#[path = "trueno_tests.rs"]
mod tests;