use crate::embedding::EmbeddingProvider;
use crate::service::*;
use adk_core::Result;
use async_trait::async_trait;
use chrono::Utc;
use mongodb::bson::{DateTime as BsonDateTime, Document, doc};
use mongodb::options::IndexOptions;
use mongodb::{Client, Database, IndexModel};
use std::sync::Arc;
use tracing::instrument;
pub struct MongoMemoryService {
db: Database,
embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
}
impl MongoMemoryService {
pub fn new(
client: Client,
database_name: &str,
embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
) -> Result<Self> {
let db = client.database(database_name);
Ok(Self { db, embedding_provider })
}
const REGISTRY_COLLECTION: &'static str = "_adk_memory_migrations";
const MONGO_MEMORY_MIGRATIONS: &'static [(i64, &'static str)] =
&[(1, "create initial indexes")];
pub async fn migrate(&self) -> Result<()> {
self.db
.collection::<Document>(Self::REGISTRY_COLLECTION)
.create_index(
IndexModel::builder()
.keys(doc! { "version": 1 })
.options(
IndexOptions::builder()
.unique(true)
.name("idx_migration_version_unique".to_string())
.build(),
)
.build(),
)
.await
.map_err(|e| {
adk_core::AdkError::memory(format!("migration registry creation failed: {e}"))
})?;
let mut max_applied = self.read_max_applied_version().await?;
if max_applied == 0 {
let existing = self.detect_existing_tables().await?;
if existing {
if let Some(&(version, description)) = Self::MONGO_MEMORY_MIGRATIONS.first() {
self.record_migration(version, description).await?;
max_applied = version;
}
}
}
let max_compiled = Self::MONGO_MEMORY_MIGRATIONS.last().map(|s| s.0).unwrap_or(0);
if max_applied > max_compiled {
return Err(adk_core::AdkError::memory(format!(
"schema version mismatch: database is at v{max_applied} \
but code only knows up to v{max_compiled}. \
Upgrade your ADK version."
)));
}
for &(version, description) in Self::MONGO_MEMORY_MIGRATIONS {
if version <= max_applied {
continue;
}
run_mongo_memory_step(&self.db, version).await.map_err(|e| {
adk_core::AdkError::memory(format!(
"{}",
crate::migration::MigrationError {
version,
description: description.to_string(),
cause: e.to_string(),
}
))
})?;
self.record_migration(version, description).await?;
}
Ok(())
}
pub async fn schema_version(&self) -> Result<i64> {
let collections =
self.db.list_collection_names().await.map_err(|e| {
adk_core::AdkError::memory(format!("schema version query failed: {e}"))
})?;
if !collections.contains(&Self::REGISTRY_COLLECTION.to_string()) {
return Ok(0);
}
self.read_max_applied_version().await
}
async fn read_max_applied_version(&self) -> Result<i64> {
use mongodb::options::FindOneOptions;
let registry = self.db.collection::<Document>(Self::REGISTRY_COLLECTION);
let opts = FindOneOptions::builder().sort(doc! { "version": -1 }).build();
let result = registry.find_one(doc! {}).with_options(opts).await.map_err(|e| {
adk_core::AdkError::memory(format!("migration registry read failed: {e}"))
})?;
match result {
Some(doc) => {
let version = doc.get_i64("version").unwrap_or(0);
Ok(version)
}
None => Ok(0),
}
}
async fn detect_existing_tables(&self) -> Result<bool> {
let collections =
self.db.list_collection_names().await.map_err(|e| {
adk_core::AdkError::memory(format!("baseline detection failed: {e}"))
})?;
Ok(collections.contains(&"memory_entries".to_string()))
}
async fn record_migration(&self, version: i64, description: &str) -> Result<()> {
let registry = self.db.collection::<Document>(Self::REGISTRY_COLLECTION);
let now = BsonDateTime::from_millis(Utc::now().timestamp_millis());
registry
.insert_one(doc! {
"version": version,
"description": description,
"applied_at": now,
})
.await
.map_err(|e| {
adk_core::AdkError::memory(format!(
"{}",
crate::migration::MigrationError {
version,
description: description.to_string(),
cause: format!("registry record failed: {e}"),
}
))
})?;
Ok(())
}
}
async fn run_mongo_memory_step(db: &Database, version: i64) -> Result<()> {
match version {
1 => mongo_memory_v1(db).await,
_ => Err(adk_core::AdkError::memory(format!("unknown migration version: {version}"))),
}
}
async fn mongo_memory_v1(db: &Database) -> Result<()> {
let collection = db.collection::<Document>("memory_entries");
collection
.create_index(
IndexModel::builder()
.keys(doc! { "app_name": 1, "user_id": 1 })
.options(
IndexOptions::builder().name("idx_memory_entries_app_user".to_string()).build(),
)
.build(),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("index creation failed: {e}")))?;
collection
.create_index(
IndexModel::builder()
.keys(doc! { "content_text": "text" })
.options(
IndexOptions::builder().name("idx_memory_entries_text".to_string()).build(),
)
.build(),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("text index creation failed: {e}")))?;
Ok(())
}
#[async_trait]
impl MemoryService for MongoMemoryService {
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id, session_id = %session_id, entry_count = entries.len()))]
async fn add_session(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
entries: Vec<MemoryEntry>,
) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
let collection = self.db.collection::<Document>("memory_entries");
let texts: Vec<String> =
entries.iter().map(|e| crate::text::extract_text(&e.content)).collect();
let embeddings = if let Some(provider) = &self.embedding_provider {
let non_empty_texts: Vec<String> = texts
.iter()
.map(|t| if t.is_empty() { " ".to_string() } else { t.clone() })
.collect();
Some(provider.embed(&non_empty_texts).await.map_err(|e| {
adk_core::AdkError::memory(format!("embedding generation failed: {e}"))
})?)
} else {
None
};
let mut documents = Vec::with_capacity(entries.len());
for (i, entry) in entries.iter().enumerate() {
let content_json = serde_json::to_value(&entry.content)
.map_err(|e| adk_core::AdkError::memory(format!("serialization failed: {e}")))?;
let content_bson = mongodb::bson::to_bson(&content_json)
.map_err(|e| adk_core::AdkError::memory(format!("bson conversion failed: {e}")))?;
let timestamp = BsonDateTime::from_millis(entry.timestamp.timestamp_millis());
let mut document = doc! {
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"content": content_bson,
"content_text": &texts[i],
"author": &entry.author,
"timestamp": timestamp,
};
if let Some(ref embs) = embeddings {
let embedding_vec: Vec<mongodb::bson::Bson> =
embs[i].iter().map(|&v| mongodb::bson::Bson::Double(v as f64)).collect();
document.insert("embedding", embedding_vec);
}
documents.push(document);
}
collection
.insert_many(documents)
.await
.map_err(|e| adk_core::AdkError::memory(format!("add_session failed: {e}")))?;
Ok(())
}
#[instrument(skip_all, fields(app_name = %req.app_name, user_id = %req.user_id))]
async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
let collection = self.db.collection::<Document>("memory_entries");
let limit = req.limit.unwrap_or(10) as i64;
let docs = if let Some(ref provider) = self.embedding_provider {
let query_embedding = provider
.embed(std::slice::from_ref(&req.query))
.await
.map_err(|e| adk_core::AdkError::memory(format!("query embedding failed: {e}")))?;
let query_vec: Vec<mongodb::bson::Bson> =
query_embedding[0].iter().map(|&v| mongodb::bson::Bson::Double(v as f64)).collect();
let pipeline = vec![
doc! {
"$vectorSearch": {
"index": "memory_embedding_index",
"path": "embedding",
"queryVector": &query_vec,
"numCandidates": 100,
"limit": limit,
}
},
doc! {
"$match": {
"app_name": &req.app_name,
"user_id": &req.user_id,
}
},
];
let mut cursor = collection.aggregate(pipeline).await.map_err(|e| {
let msg = e.to_string();
if msg.contains("PlanExecutor") || msg.contains("$vectorSearch") {
adk_core::AdkError::memory(
"vector search index not available: Atlas Vector Search index \
'memory_embedding_index' must be created via Atlas UI/API"
.to_string(),
)
} else {
adk_core::AdkError::memory(format!("search failed: {e}"))
}
})?;
let mut results = Vec::new();
while cursor
.advance()
.await
.map_err(|e| adk_core::AdkError::memory(format!("search cursor failed: {e}")))?
{
let doc = cursor.deserialize_current().map_err(|e| {
adk_core::AdkError::memory(format!("search deserialization failed: {e}"))
})?;
results.push(doc);
}
results
} else {
let filter = doc! {
"app_name": &req.app_name,
"user_id": &req.user_id,
"$text": { "$search": &req.query },
};
let mut cursor = collection
.find(filter)
.sort(doc! { "score": { "$meta": "textScore" } })
.limit(limit)
.await
.map_err(|e| adk_core::AdkError::memory(format!("search failed: {e}")))?;
let mut results = Vec::new();
while cursor
.advance()
.await
.map_err(|e| adk_core::AdkError::memory(format!("search cursor failed: {e}")))?
{
let doc = cursor.deserialize_current().map_err(|e| {
adk_core::AdkError::memory(format!("search deserialization failed: {e}"))
})?;
results.push(doc);
}
results
};
let memories =
docs.iter()
.filter_map(|doc| {
let content_bson = doc.get("content")?;
let content_json: serde_json::Value =
mongodb::bson::from_bson(content_bson.clone()).ok()?;
let content: adk_core::Content =
serde_json::from_value(content_json).unwrap_or_else(|_| {
adk_core::Content { role: "user".to_string(), parts: vec![] }
});
let author = doc.get_str("author").unwrap_or("unknown").to_string();
let timestamp = doc
.get_datetime("timestamp")
.ok()
.map(|dt| {
chrono::DateTime::from_timestamp_millis(dt.timestamp_millis())
.unwrap_or_default()
})
.unwrap_or_default();
Some(MemoryEntry { content, author, timestamp })
})
.collect();
Ok(SearchResponse { memories })
}
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id))]
async fn delete_user(&self, app_name: &str, user_id: &str) -> Result<()> {
let collection = self.db.collection::<Document>("memory_entries");
collection
.delete_many(doc! { "app_name": app_name, "user_id": user_id })
.await
.map_err(|e| adk_core::AdkError::memory(format!("delete_user failed: {e}")))?;
Ok(())
}
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id, session_id = %session_id))]
async fn delete_session(&self, app_name: &str, user_id: &str, session_id: &str) -> Result<()> {
let collection = self.db.collection::<Document>("memory_entries");
collection
.delete_many(doc! {
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
})
.await
.map_err(|e| adk_core::AdkError::memory(format!("delete_session failed: {e}")))?;
Ok(())
}
#[instrument(skip_all)]
async fn health_check(&self) -> Result<()> {
self.db
.run_command(doc! { "ping": 1 })
.await
.map_err(|e| adk_core::AdkError::memory(format!("health check failed: {e}")))?;
Ok(())
}
}