use crate::embedding::EmbeddingProvider;
use crate::service::*;
use adk_core::Result;
use async_trait::async_trait;
use chrono::DateTime;
use neo4rs::Graph;
use std::collections::HashSet;
use std::sync::Arc;
use tracing::instrument;
pub struct Neo4jMemoryService {
graph: Graph,
embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
}
impl Neo4jMemoryService {
const REGISTRY_LABEL: &'static str = "_AdkMemoryMigration";
const NEO4J_MEMORY_MIGRATIONS: &'static [(i64, &'static str, &'static [&'static str])] = &[(
1,
"create initial constraints and indexes",
&[
"CREATE CONSTRAINT memory_entry_unique IF NOT EXISTS \
FOR (m:MemoryEntry) REQUIRE (m.id) IS UNIQUE",
"CREATE INDEX memory_app_user IF NOT EXISTS \
FOR (m:MemoryEntry) ON (m.app_name, m.user_id)",
"CREATE FULLTEXT INDEX memory_content IF NOT EXISTS \
FOR (m:MemoryEntry) ON EACH [m.content_text]",
],
)];
pub fn new(
graph: Graph,
embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
) -> adk_core::Result<Self> {
Ok(Self { graph, embedding_provider })
}
pub fn graph(&self) -> &Graph {
&self.graph
}
pub async fn migrate(&self) -> adk_core::Result<()> {
self.graph
.run(neo4rs::query(&format!(
"CREATE CONSTRAINT {}_version_unique IF NOT EXISTS \
FOR (m:{}) REQUIRE (m.version) IS UNIQUE",
Self::REGISTRY_LABEL.to_lowercase(),
Self::REGISTRY_LABEL,
)))
.await
.map_err(|e| {
adk_core::AdkError::memory(format!("migration registry creation failed: {e}"))
})?;
let mut max_applied = self.read_max_applied_version().await?;
if max_applied == 0 {
let existing = self.detect_existing_tables().await?;
if existing {
if let Some(&(version, description, _)) = Self::NEO4J_MEMORY_MIGRATIONS.first() {
self.record_migration(version, description).await?;
max_applied = version;
}
}
}
let max_compiled = Self::NEO4J_MEMORY_MIGRATIONS.last().map(|s| s.0).unwrap_or(0);
if max_applied > max_compiled {
return Err(adk_core::AdkError::memory(format!(
"schema version mismatch: database is at v{max_applied} \
but code only knows up to v{max_compiled}. \
Upgrade your ADK version."
)));
}
for &(version, description, cypher_statements) in Self::NEO4J_MEMORY_MIGRATIONS {
if version <= max_applied {
continue;
}
for cypher in cypher_statements {
self.graph.run(neo4rs::query(cypher)).await.map_err(|e| {
adk_core::AdkError::memory(format!(
"{}",
crate::migration::MigrationError {
version,
description: description.to_string(),
cause: e.to_string(),
}
))
})?;
}
self.record_migration(version, description).await?;
}
if let Some(provider) = &self.embedding_provider {
let dims = provider.dimensions();
let vector_index_query = format!(
"CREATE VECTOR INDEX memory_embedding IF NOT EXISTS \
FOR (m:MemoryEntry) ON (m.embedding) \
OPTIONS {{indexConfig: {{`vector.dimensions`: {dims}, \
`vector.similarity_function`: 'cosine'}}}}"
);
self.graph.run(neo4rs::query(&vector_index_query)).await.map_err(|e| {
adk_core::AdkError::memory(format!(
"migration failed: vector index creation failed: {e}"
))
})?;
}
Ok(())
}
pub async fn schema_version(&self) -> Result<i64> {
self.read_max_applied_version().await
}
async fn read_max_applied_version(&self) -> Result<i64> {
let query_str =
format!("OPTIONAL MATCH (m:{}) RETURN max(m.version) AS max_v", Self::REGISTRY_LABEL);
let mut row_stream = self.graph.execute(neo4rs::query(&query_str)).await.map_err(|e| {
adk_core::AdkError::memory(format!("migration registry read failed: {e}"))
})?;
if let Some(row) = row_stream.next().await.map_err(|e| {
adk_core::AdkError::memory(format!("migration registry read failed: {e}"))
})? {
Ok(row.get::<i64>("max_v").unwrap_or(0))
} else {
Ok(0)
}
}
async fn detect_existing_tables(&self) -> Result<bool> {
let mut row_stream = self
.graph
.execute(neo4rs::query(
"SHOW CONSTRAINTS YIELD name WHERE name = 'memory_entry_unique' RETURN name",
))
.await
.map_err(|e| adk_core::AdkError::memory(format!("baseline detection failed: {e}")))?;
let found = row_stream
.next()
.await
.map_err(|e| adk_core::AdkError::memory(format!("baseline detection failed: {e}")))?
.is_some();
Ok(found)
}
async fn record_migration(&self, version: i64, description: &str) -> Result<()> {
let query_str = format!(
"CREATE (m:{} {{version: $version, description: $description, applied_at: datetime()}})",
Self::REGISTRY_LABEL,
);
self.graph
.run(
neo4rs::query(&query_str)
.param("version", version)
.param("description", description.to_string()),
)
.await
.map_err(|e| {
adk_core::AdkError::memory(format!(
"{}",
crate::migration::MigrationError {
version,
description: description.to_string(),
cause: format!("registry record failed: {e}"),
}
))
})?;
Ok(())
}
}
#[async_trait]
impl MemoryService for Neo4jMemoryService {
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id, session_id = %session_id, entry_count = entries.len()))]
async fn add_session(
&self,
app_name: &str,
user_id: &str,
session_id: &str,
entries: Vec<MemoryEntry>,
) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
let texts: Vec<String> =
entries.iter().map(|e| crate::text::extract_text(&e.content)).collect();
let embeddings = if let Some(provider) = &self.embedding_provider {
let non_empty_texts: Vec<String> = texts
.iter()
.map(|t| if t.is_empty() { " ".to_string() } else { t.clone() })
.collect();
Some(provider.embed(&non_empty_texts).await.map_err(|e| {
adk_core::AdkError::memory(format!("embedding generation failed: {e}"))
})?)
} else {
None
};
let mut txn = self
.graph
.start_txn()
.await
.map_err(|e| adk_core::AdkError::memory(format!("transaction failed: {e}")))?;
txn.run(
neo4rs::query(
"MERGE (:MemorySession {session_id: $session_id, \
app_name: $app_name, user_id: $user_id})",
)
.param("session_id", session_id.to_string())
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string()),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("add_session failed: {e}")))?;
let mut entry_ids: Vec<String> = Vec::with_capacity(entries.len());
for (i, entry) in entries.iter().enumerate() {
let entry_id = format!("{session_id}_{i}");
entry_ids.push(entry_id.clone());
let content_json = serde_json::to_string(&entry.content)
.map_err(|e| adk_core::AdkError::memory(format!("serialization failed: {e}")))?;
let content_text = &texts[i];
let timestamp_str = entry.timestamp.to_rfc3339();
if let Some(ref embs) = embeddings {
let embedding_f64: Vec<f64> = embs[i].iter().map(|&v| v as f64).collect();
txn.run(
neo4rs::query(
"MATCH (s:MemorySession {session_id: $session_id, \
app_name: $app_name, user_id: $user_id}) \
CREATE (s)-[:FROM_SESSION]->(e:MemoryEntry { \
id: $id, app_name: $app_name, user_id: $user_id, \
session_id: $session_id, content: $content, \
content_text: $content_text, author: $author, \
timestamp: $timestamp, embedding: $embedding \
})",
)
.param("session_id", session_id.to_string())
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string())
.param("id", entry_id)
.param("content", content_json)
.param("content_text", content_text.clone())
.param("author", entry.author.clone())
.param("timestamp", timestamp_str)
.param("embedding", embedding_f64),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("add_session failed: {e}")))?;
} else {
txn.run(
neo4rs::query(
"MATCH (s:MemorySession {session_id: $session_id, \
app_name: $app_name, user_id: $user_id}) \
CREATE (s)-[:FROM_SESSION]->(e:MemoryEntry { \
id: $id, app_name: $app_name, user_id: $user_id, \
session_id: $session_id, content: $content, \
content_text: $content_text, author: $author, \
timestamp: $timestamp \
})",
)
.param("session_id", session_id.to_string())
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string())
.param("id", entry_id)
.param("content", content_json)
.param("content_text", content_text.clone())
.param("author", entry.author.clone())
.param("timestamp", timestamp_str),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("add_session failed: {e}")))?;
}
}
for i in 0..entry_ids.len().saturating_sub(1) {
txn.run(
neo4rs::query(
"MATCH (prev:MemoryEntry {id: $prev_id}) \
MATCH (curr:MemoryEntry {id: $curr_id}) \
CREATE (prev)-[:FOLLOWS]->(curr)",
)
.param("prev_id", entry_ids[i].clone())
.param("curr_id", entry_ids[i + 1].clone()),
)
.await
.map_err(|e| {
adk_core::AdkError::memory(format!("add_session failed: FOLLOWS creation: {e}"))
})?;
}
txn.commit()
.await
.map_err(|e| adk_core::AdkError::memory(format!("commit failed: {e}")))?;
Ok(())
}
#[instrument(skip_all, fields(app_name = %req.app_name, user_id = %req.user_id))]
async fn search(&self, req: SearchRequest) -> Result<SearchResponse> {
let limit = req.limit.unwrap_or(10) as i64;
let results = if let Some(ref provider) = self.embedding_provider {
let query_embedding = provider
.embed(std::slice::from_ref(&req.query))
.await
.map_err(|e| adk_core::AdkError::memory(format!("query embedding failed: {e}")))?;
let query_vec: Vec<f64> = query_embedding[0].iter().map(|&v| v as f64).collect();
let mut row_stream = self
.graph
.execute(
neo4rs::query(
"CALL db.index.vector.queryNodes('memory_embedding', $limit, \
$query_embedding) \
YIELD node, score \
WHERE node.app_name = $app_name AND node.user_id = $user_id \
OPTIONAL MATCH (node)-[:FOLLOWS]-(adjacent:MemoryEntry) \
RETURN node.id AS id, node.content AS content, \
node.author AS author, node.timestamp AS timestamp, \
score, \
collect(adjacent.id) AS adj_ids, \
collect(adjacent.content) AS adj_contents, \
collect(adjacent.author) AS adj_authors, \
collect(adjacent.timestamp) AS adj_timestamps \
ORDER BY score DESC",
)
.param("limit", limit)
.param("query_embedding", query_vec)
.param("app_name", req.app_name.clone())
.param("user_id", req.user_id.clone()),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("search failed: {e}")))?;
let mut entries = Vec::new();
let mut seen_ids: HashSet<String> = HashSet::new();
while let Some(row) = row_stream
.next()
.await
.map_err(|e| adk_core::AdkError::memory(format!("search failed: {e}")))?
{
if let Some(entry) = row_to_memory_entry(&row) {
let id = row.get::<String>("id").unwrap_or_default();
if seen_ids.insert(id) {
entries.push(entry);
}
}
collect_adjacent_entries(&row, &mut seen_ids, &mut entries);
}
entries
} else {
let mut row_stream = self
.graph
.execute(
neo4rs::query(
"CALL db.index.fulltext.queryNodes('memory_content', $query) \
YIELD node, score \
WHERE node.app_name = $app_name AND node.user_id = $user_id \
OPTIONAL MATCH (node)-[:FOLLOWS]-(adjacent:MemoryEntry) \
RETURN node.id AS id, node.content AS content, \
node.author AS author, node.timestamp AS timestamp, \
score, \
collect(adjacent.id) AS adj_ids, \
collect(adjacent.content) AS adj_contents, \
collect(adjacent.author) AS adj_authors, \
collect(adjacent.timestamp) AS adj_timestamps \
ORDER BY score DESC \
LIMIT $limit",
)
.param("query", req.query.clone())
.param("app_name", req.app_name.clone())
.param("user_id", req.user_id.clone())
.param("limit", limit),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("search failed: {e}")))?;
let mut entries = Vec::new();
let mut seen_ids: HashSet<String> = HashSet::new();
while let Some(row) = row_stream
.next()
.await
.map_err(|e| adk_core::AdkError::memory(format!("search failed: {e}")))?
{
if let Some(entry) = row_to_memory_entry(&row) {
let id = row.get::<String>("id").unwrap_or_default();
if seen_ids.insert(id) {
entries.push(entry);
}
}
collect_adjacent_entries(&row, &mut seen_ids, &mut entries);
}
entries
};
Ok(SearchResponse { memories: results })
}
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id))]
async fn delete_user(&self, app_name: &str, user_id: &str) -> Result<()> {
self.graph
.run(
neo4rs::query(
"MATCH (e:MemoryEntry {app_name: $app_name, user_id: $user_id}) \
DETACH DELETE e",
)
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string()),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("delete_user failed: {e}")))?;
self.graph
.run(
neo4rs::query(
"MATCH (s:MemorySession {app_name: $app_name, user_id: $user_id}) \
WHERE NOT (s)-[:FROM_SESSION]->() \
DELETE s",
)
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string()),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("delete_user cleanup failed: {e}")))?;
Ok(())
}
#[instrument(skip_all, fields(app_name = %app_name, user_id = %user_id, session_id = %session_id))]
async fn delete_session(&self, app_name: &str, user_id: &str, session_id: &str) -> Result<()> {
self.graph
.run(
neo4rs::query(
"MATCH (e:MemoryEntry {app_name: $app_name, user_id: $user_id, session_id: $session_id}) \
DETACH DELETE e",
)
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string())
.param("session_id", session_id.to_string()),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("delete_session failed: {e}")))?;
self.graph
.run(
neo4rs::query(
"MATCH (s:MemorySession {session_id: $session_id, app_name: $app_name, user_id: $user_id}) \
WHERE NOT (s)-[:FROM_SESSION]->() \
DELETE s",
)
.param("app_name", app_name.to_string())
.param("user_id", user_id.to_string())
.param("session_id", session_id.to_string()),
)
.await
.map_err(|e| adk_core::AdkError::memory(format!("delete_session cleanup failed: {e}")))?;
Ok(())
}
#[instrument(skip_all)]
async fn health_check(&self) -> Result<()> {
let _ = self
.graph
.execute(neo4rs::query("RETURN 1"))
.await
.map_err(|e| adk_core::AdkError::memory(format!("health check failed: {e}")))?;
Ok(())
}
}
fn row_to_memory_entry(row: &neo4rs::Row) -> Option<MemoryEntry> {
let content_str = row.get::<String>("content").ok()?;
let content: adk_core::Content = serde_json::from_str(&content_str)
.unwrap_or_else(|_| adk_core::Content { role: "user".to_string(), parts: vec![] });
let author = row.get::<String>("author").unwrap_or_else(|_| "unknown".to_string());
let timestamp_str = row.get::<String>("timestamp").unwrap_or_default();
let timestamp = DateTime::parse_from_rfc3339(×tamp_str)
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_default();
Some(MemoryEntry { content, author, timestamp })
}
fn collect_adjacent_entries(
row: &neo4rs::Row,
seen_ids: &mut HashSet<String>,
entries: &mut Vec<MemoryEntry>,
) {
let adj_ids: Vec<String> = row.get("adj_ids").unwrap_or_default();
let adj_contents: Vec<String> = row.get("adj_contents").unwrap_or_default();
let adj_authors: Vec<String> = row.get("adj_authors").unwrap_or_default();
let adj_timestamps: Vec<String> = row.get("adj_timestamps").unwrap_or_default();
for (i, adj_id) in adj_ids.iter().enumerate() {
if !seen_ids.insert(adj_id.clone()) {
continue;
}
let content_str = adj_contents.get(i).cloned().unwrap_or_default();
let content: adk_core::Content = serde_json::from_str(&content_str)
.unwrap_or_else(|_| adk_core::Content { role: "user".to_string(), parts: vec![] });
let author = adj_authors.get(i).cloned().unwrap_or_else(|| "unknown".to_string());
let timestamp_str = adj_timestamps.get(i).cloned().unwrap_or_default();
let timestamp = DateTime::parse_from_rfc3339(×tamp_str)
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_default();
entries.push(MemoryEntry { content, author, timestamp });
}
}