use async_trait::async_trait;
use mongodb::bson::{doc, Document};
use mongodb::{Client, Collection};
use crate::error::{AgentError, Result};
use crate::memory::{MemoryRecord, MemoryStore};
#[cfg(feature = "fastembed")]
use fastembed::{InitOptions, TextEmbedding};
#[cfg(feature = "fastembed")]
use tokio::sync::OnceCell;
pub struct MongoStore {
collection: Collection<Document>,
#[cfg(feature = "fastembed")]
embedder: OnceCell<TextEmbedding>,
}
impl MongoStore {
pub async fn new(connection_string: &str, database: &str, collection: &str) -> Result<Self> {
let client = Client::with_uri_str(connection_string)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to connect to MongoDB: {}", e)))?;
let db = client.database(database);
let collection = db.collection::<Document>(collection);
collection
.create_index(
mongodb::IndexModel::builder()
.keys(doc! { "session_id": 1, "timestamp": -1 })
.build(),
)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to create index: {}", e)))?;
Ok(Self {
collection,
#[cfg(feature = "fastembed")]
embedder: OnceCell::new(),
})
}
pub async fn create_vector_index(&self, index_name: &str) -> Result<()> {
tracing::info!(
"Vector search index {} would be created through Atlas",
index_name
);
Ok(())
}
}
#[async_trait]
impl MemoryStore for MongoStore {
async fn store(&self, record: MemoryRecord) -> Result<()> {
let mut doc = doc! {
"_id": record.id.to_string(),
"session_id": &record.session_id,
"role": &record.role,
"content": &record.content,
"importance": record.importance,
"timestamp": mongodb::bson::DateTime::from_system_time(record.timestamp.into()),
};
if let Some(metadata) = &record.metadata {
let metadata_doc = serde_json::to_value(metadata)
.map_err(|e| AgentError::SerializationError(e))
.and_then(|v| {
mongodb::bson::to_bson(&v).map_err(|e| {
AgentError::MemoryError(format!("Failed to convert metadata: {}", e))
})
})?;
doc.insert("metadata", metadata_doc);
}
if let Some(embedding) = &record.embedding {
doc.insert("embedding", embedding);
}
self.collection
.replace_one(doc! { "_id": record.id.to_string() }, doc.clone())
.upsert(true)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to store memory: {}", e)))?;
Ok(())
}
async fn retrieve(&self, session_id: &str, limit: usize) -> Result<Vec<MemoryRecord>> {
let filter = doc! { "session_id": session_id };
let options = mongodb::options::FindOptions::builder()
.sort(doc! { "timestamp": -1 })
.limit(limit as i64)
.build();
let mut cursor = self
.collection
.find(filter)
.with_options(options)
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to retrieve memories: {}", e)))?;
let mut records = Vec::new();
while cursor
.advance()
.await
.map_err(|e| AgentError::MemoryError(format!("Failed to advance cursor: {}", e)))?
{
let doc = cursor.deserialize_current().map_err(|e| AgentError::MemoryError(format!("Failed to deserialize: {}", e)))?;
records.push(document_to_memory_record(&doc)?);
}
Ok(records)
}
async fn search(
&self,
session_id: &str,
query_embedding: Vec<f32>,
limit: usize,
) -> Result<Vec<MemoryRecord>> {
let all_records = self.retrieve(session_id, 1000).await?;
let mut scored: Vec<(f32, MemoryRecord)> = all_records
.into_iter()
.filter(|r| r.embedding.is_some())
.map(|r| {
let embedding = r.embedding.as_ref().unwrap();
let similarity = super::cosine_similarity(&query_embedding, embedding);
(similarity, r)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
Ok(scored.into_iter().take(limit).map(|(_, r)| r).collect())
}
async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
#[cfg(feature = "fastembed")]
{
let embedder = self
.embedder
.get_or_try_init(|| async {
TextEmbedding::try_new(InitOptions::default())
.map_err(|e| AgentError::MemoryError(e.to_string()))
})
.await?;
let embeddings = embedder
.embed(vec![_text], None)
.map_err(|e| AgentError::MemoryError(e.to_string()))?;
Ok(embeddings[0].clone())
}
#[cfg(not(feature = "fastembed"))]
Ok(Vec::new())
}
async fn flush(&self) -> Result<()> {
Ok(())
}
}
fn document_to_memory_record(doc: &Document) -> Result<MemoryRecord> {
let id = doc
.get_str("_id")
.ok()
.and_then(|s| uuid::Uuid::parse_str(s).ok())
.ok_or_else(|| AgentError::MemoryError("Missing or invalid _id".to_string()))?;
let session_id = doc
.get_str("session_id")
.map_err(|_| AgentError::MemoryError("Missing session_id".to_string()))?
.to_string();
let role = doc
.get_str("role")
.map_err(|_| AgentError::MemoryError("Missing role".to_string()))?
.to_string();
let content = doc
.get_str("content")
.map_err(|_| AgentError::MemoryError("Missing content".to_string()))?
.to_string();
let importance = doc.get_f64("importance").unwrap_or(0.5) as f32;
let timestamp = doc
.get_datetime("timestamp")
.ok()
.map(|dt| chrono::DateTime::from(dt.to_system_time()))
.unwrap_or_else(chrono::Utc::now);
let metadata = doc
.get_document("metadata")
.ok()
.and_then(|d| mongodb::bson::from_bson(mongodb::bson::Bson::Document(d.clone())).ok());
let embedding = doc.get_array("embedding").ok().and_then(|arr| {
arr.iter()
.map(|v| v.as_f64().map(|f| f as f32))
.collect::<Option<Vec<f32>>>()
});
Ok(MemoryRecord {
id,
session_id,
role,
content,
importance,
timestamp,
metadata,
embedding,
})
}