use std::collections::HashMap;
use crate::memory::Scope;
use super::{GraphError, GraphRow, GraphStore};
pub const MAX_ENRICHMENT_DEPTH: usize = 2;
pub const DEFAULT_ENRICHMENT_DEPTH: usize = 1;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GraphEntity {
pub name: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct GraphRelationship {
pub subject: String,
pub relation: String,
pub object: String,
pub confidence: f32,
}
#[derive(Debug, Clone, Default, PartialEq)]
pub struct GraphContext {
pub entities: Vec<GraphEntity>,
pub relationships: Vec<GraphRelationship>,
}
impl GraphContext {
pub fn is_empty(&self) -> bool {
self.entities.is_empty() && self.relationships.is_empty()
}
}
pub(super) async fn neighbors<G: GraphStore + ?Sized>(
store: &G,
seed_pids: &[&str],
scope: &Scope,
depth: usize,
) -> Result<GraphContext, GraphError> {
if seed_pids.is_empty() {
return Ok(GraphContext::default());
}
let depth = depth.clamp(1, MAX_ENRICHMENT_DEPTH);
let mut params = HashMap::from([
("agent_id".to_string(), scope.agent_id.clone().into()),
("org_id".to_string(), scope.org_id.clone().into()),
("user_id".to_string(), scope.user_id.clone().into()),
]);
for (i, pid) in seed_pids.iter().enumerate() {
params.insert(format!("pid{i}"), (*pid).into());
}
let pid_list = (0..seed_pids.len())
.map(|i| format!("$pid{i}"))
.collect::<Vec<_>>()
.join(", ");
let cypher = format!(
"MATCH (seed:Entity {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id}}) \
WHERE any(p IN seed.memory_pids WHERE p IN [{pid_list}]) \
MATCH (seed)-[r*1..{depth}]-(related:Entity) \
WITH seed, related, r \
UNWIND r AS edge \
WITH seed, related, edge WHERE edge.valid_to IS NULL \
RETURN startNode(edge).name AS subject, edge.relation AS relation, \
endNode(edge).name AS object, edge.confidence AS confidence, related.name AS related_name"
);
let rows = store.query(&cypher, ¶ms).await?;
Ok(build_context(&rows))
}
fn build_context(rows: &[GraphRow]) -> GraphContext {
let mut entities: Vec<GraphEntity> = Vec::new();
let mut relationships: Vec<GraphRelationship> = Vec::new();
for row in rows {
if let Some(name) = column(row, "related_name") {
let entity = GraphEntity { name: name.to_string() };
if !entities.contains(&entity) {
entities.push(entity);
}
}
let (Some(subject), Some(relation), Some(object)) =
(column(row, "subject"), column(row, "relation"), column(row, "object"))
else {
continue;
};
let confidence = column(row, "confidence").and_then(|c| c.parse().ok()).unwrap_or(1.0);
let relationship = GraphRelationship {
subject: subject.to_string(),
relation: relation.to_string(),
object: object.to_string(),
confidence,
};
if !relationships.contains(&relationship) {
relationships.push(relationship);
}
}
GraphContext { entities, relationships }
}
fn column<'a>(row: &'a GraphRow, name: &str) -> Option<&'a str> {
row.iter()
.find(|(column, _)| column == name)
.map(|(_, value)| value.as_str())
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
use crate::graph::{GraphParam, GraphRows};
fn scope() -> Scope {
Scope {
agent_id: "agent".to_string(),
org_id: "org".to_string(),
user_id: "user".to_string(),
}
}
fn row(pairs: &[(&str, &str)]) -> GraphRow {
pairs.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect()
}
#[derive(Default)]
struct StagedStore {
rows: Mutex<GraphRows>,
calls: Mutex<Vec<(String, HashMap<String, GraphParam>)>>,
}
impl StagedStore {
fn with_rows(rows: GraphRows) -> Self {
Self {
rows: Mutex::new(rows),
calls: Mutex::default(),
}
}
fn calls(&self) -> Vec<(String, HashMap<String, GraphParam>)> {
self.calls.lock().unwrap().clone()
}
}
impl GraphStore for StagedStore {
async fn ensure_graph(&self) -> Result<(), GraphError> {
Ok(())
}
async fn query(&self, cypher: &str, params: &HashMap<String, GraphParam>) -> Result<GraphRows, GraphError> {
self.calls.lock().unwrap().push((cypher.to_string(), params.clone()));
Ok(self.rows.lock().unwrap().clone())
}
}
#[tokio::test(flavor = "current_thread")]
async fn should_return_empty_for_no_seeds() {
let store = StagedStore::default();
let ctx = neighbors(&store, &[], &scope(), 1).await.unwrap();
assert!(ctx.is_empty());
assert!(store.calls().is_empty(), "no seeds -> no query");
}
#[tokio::test(flavor = "current_thread")]
async fn should_bind_seeds_and_scope_as_params() {
let store = StagedStore::default();
neighbors(&store, &["mem1", "mem2"], &scope(), 1).await.unwrap();
let (cypher, params) = &store.calls()[0];
assert!(!cypher.contains("mem1"), "pids must not be interpolated");
assert_eq!(params.get("pid0"), Some(&GraphParam::Str("mem1".to_string())));
assert_eq!(params.get("pid1"), Some(&GraphParam::Str("mem2".to_string())));
assert_eq!(params.get("agent_id"), Some(&GraphParam::Str("agent".to_string())));
}
#[tokio::test(flavor = "current_thread")]
async fn should_filter_current_edges_only() {
let store = StagedStore::default();
neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
assert!(store.calls()[0].0.contains("edge.valid_to IS NULL"));
}
#[tokio::test(flavor = "current_thread")]
async fn should_clamp_depth_into_range() {
let store = StagedStore::default();
neighbors(&store, &["mem1"], &scope(), 99).await.unwrap();
assert!(
store.calls()[0].0.contains(&format!("*1..{MAX_ENRICHMENT_DEPTH}")),
"depth clamps to the max",
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_build_deduped_context_from_rows() {
let store = StagedStore::with_rows(vec![
row(&[
("subject", "Alice"),
("relation", "works at"),
("object", "Acme"),
("confidence", "0.9"),
("related_name", "Acme"),
]),
row(&[
("subject", "Alice"),
("relation", "works at"),
("object", "Acme"),
("confidence", "0.9"),
("related_name", "Acme"),
]),
]);
let ctx = neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
assert_eq!(ctx.relationships.len(), 1);
assert_eq!(ctx.relationships[0].object, "Acme");
assert_eq!(ctx.entities.len(), 1);
assert_eq!(ctx.entities[0].name, "Acme");
}
#[tokio::test(flavor = "current_thread")]
async fn should_default_confidence_when_unparseable() {
let store = StagedStore::with_rows(vec![row(&[
("subject", "Alice"),
("relation", "knows"),
("object", "Bob"),
("confidence", "null"),
("related_name", "Bob"),
])]);
let ctx = neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
assert_eq!(ctx.relationships[0].confidence, 1.0);
}
}