1use std::collections::HashMap;
12
13use crate::memory::Scope;
14
15use super::{GraphError, GraphRow, GraphStore};
16
17pub const MAX_ENRICHMENT_DEPTH: usize = 2;
23
24pub const DEFAULT_ENRICHMENT_DEPTH: usize = 1;
29
30#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct GraphEntity {
36 pub name: String,
38}
39
40#[derive(Debug, Clone, PartialEq)]
45pub struct GraphRelationship {
46 pub subject: String,
48 pub relation: String,
50 pub object: String,
52 pub confidence: f32,
54}
55
56#[derive(Debug, Clone, Default, PartialEq)]
63pub struct GraphContext {
64 pub entities: Vec<GraphEntity>,
66 pub relationships: Vec<GraphRelationship>,
68}
69
70impl GraphContext {
71 pub fn is_empty(&self) -> bool {
73 self.entities.is_empty() && self.relationships.is_empty()
74 }
75}
76
77pub(super) async fn neighbors<G: GraphStore + ?Sized>(
79 store: &G,
80 seed_pids: &[&str],
81 scope: &Scope,
82 depth: usize,
83) -> Result<GraphContext, GraphError> {
84 if seed_pids.is_empty() {
85 return Ok(GraphContext::default());
86 }
87 let depth = depth.clamp(1, MAX_ENRICHMENT_DEPTH);
88
89 let mut params = HashMap::from([
90 ("agent_id".to_string(), scope.agent_id.clone().into()),
91 ("org_id".to_string(), scope.org_id.clone().into()),
92 ("user_id".to_string(), scope.user_id.clone().into()),
93 ]);
94 for (i, pid) in seed_pids.iter().enumerate() {
95 params.insert(format!("pid{i}"), (*pid).into());
96 }
97 let pid_list = (0..seed_pids.len())
98 .map(|i| format!("$pid{i}"))
99 .collect::<Vec<_>>()
100 .join(", ");
101
102 let cypher = format!(
107 "MATCH (seed:Entity {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id}}) \
108 WHERE any(p IN seed.memory_pids WHERE p IN [{pid_list}]) \
109 MATCH (seed)-[r*1..{depth}]-(related:Entity) \
110 WITH seed, related, r \
111 UNWIND r AS edge \
112 WITH seed, related, edge WHERE edge.valid_to IS NULL \
113 RETURN startNode(edge).name AS subject, edge.relation AS relation, \
114 endNode(edge).name AS object, edge.confidence AS confidence, related.name AS related_name"
115 );
116
117 let rows = store.query(&cypher, ¶ms).await?;
118 Ok(build_context(&rows))
119}
120
121fn build_context(rows: &[GraphRow]) -> GraphContext {
123 let mut entities: Vec<GraphEntity> = Vec::new();
124 let mut relationships: Vec<GraphRelationship> = Vec::new();
125
126 for row in rows {
127 if let Some(name) = column(row, "related_name") {
128 let entity = GraphEntity { name: name.to_string() };
129 if !entities.contains(&entity) {
130 entities.push(entity);
131 }
132 }
133
134 let (Some(subject), Some(relation), Some(object)) =
135 (column(row, "subject"), column(row, "relation"), column(row, "object"))
136 else {
137 continue;
138 };
139 let confidence = column(row, "confidence").and_then(|c| c.parse().ok()).unwrap_or(1.0);
140 let relationship = GraphRelationship {
141 subject: subject.to_string(),
142 relation: relation.to_string(),
143 object: object.to_string(),
144 confidence,
145 };
146 if !relationships.contains(&relationship) {
147 relationships.push(relationship);
148 }
149 }
150
151 GraphContext { entities, relationships }
152}
153
154fn column<'a>(row: &'a GraphRow, name: &str) -> Option<&'a str> {
156 row.iter()
157 .find(|(column, _)| column == name)
158 .map(|(_, value)| value.as_str())
159}
160
161#[cfg(test)]
162mod tests {
163 use std::sync::Mutex;
164
165 use super::*;
166 use crate::graph::{GraphParam, GraphRows};
167
168 fn scope() -> Scope {
169 Scope {
170 agent_id: "agent".to_string(),
171 org_id: "org".to_string(),
172 user_id: "user".to_string(),
173 }
174 }
175
176 fn row(pairs: &[(&str, &str)]) -> GraphRow {
177 pairs.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect()
178 }
179
180 #[derive(Default)]
182 struct StagedStore {
183 rows: Mutex<GraphRows>,
184 calls: Mutex<Vec<(String, HashMap<String, GraphParam>)>>,
185 }
186
187 impl StagedStore {
188 fn with_rows(rows: GraphRows) -> Self {
189 Self {
190 rows: Mutex::new(rows),
191 calls: Mutex::default(),
192 }
193 }
194
195 fn calls(&self) -> Vec<(String, HashMap<String, GraphParam>)> {
196 self.calls.lock().unwrap().clone()
197 }
198 }
199
200 impl GraphStore for StagedStore {
201 async fn ensure_graph(&self) -> Result<(), GraphError> {
202 Ok(())
203 }
204
205 async fn query(&self, cypher: &str, params: &HashMap<String, GraphParam>) -> Result<GraphRows, GraphError> {
206 self.calls.lock().unwrap().push((cypher.to_string(), params.clone()));
207 Ok(self.rows.lock().unwrap().clone())
208 }
209 }
210
211 #[tokio::test(flavor = "current_thread")]
212 async fn should_return_empty_for_no_seeds() {
213 let store = StagedStore::default();
214 let ctx = neighbors(&store, &[], &scope(), 1).await.unwrap();
215 assert!(ctx.is_empty());
216 assert!(store.calls().is_empty(), "no seeds -> no query");
217 }
218
219 #[tokio::test(flavor = "current_thread")]
220 async fn should_bind_seeds_and_scope_as_params() {
221 let store = StagedStore::default();
222 neighbors(&store, &["mem1", "mem2"], &scope(), 1).await.unwrap();
223
224 let (cypher, params) = &store.calls()[0];
225 assert!(!cypher.contains("mem1"), "pids must not be interpolated");
226 assert_eq!(params.get("pid0"), Some(&GraphParam::Str("mem1".to_string())));
227 assert_eq!(params.get("pid1"), Some(&GraphParam::Str("mem2".to_string())));
228 assert_eq!(params.get("agent_id"), Some(&GraphParam::Str("agent".to_string())));
229 }
230
231 #[tokio::test(flavor = "current_thread")]
232 async fn should_filter_current_edges_only() {
233 let store = StagedStore::default();
234 neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
235 assert!(store.calls()[0].0.contains("edge.valid_to IS NULL"));
236 }
237
238 #[tokio::test(flavor = "current_thread")]
239 async fn should_clamp_depth_into_range() {
240 let store = StagedStore::default();
241 neighbors(&store, &["mem1"], &scope(), 99).await.unwrap();
242 assert!(
243 store.calls()[0].0.contains(&format!("*1..{MAX_ENRICHMENT_DEPTH}")),
244 "depth clamps to the max",
245 );
246 }
247
248 #[tokio::test(flavor = "current_thread")]
249 async fn should_build_deduped_context_from_rows() {
250 let store = StagedStore::with_rows(vec![
251 row(&[
252 ("subject", "Alice"),
253 ("relation", "works at"),
254 ("object", "Acme"),
255 ("confidence", "0.9"),
256 ("related_name", "Acme"),
257 ]),
258 row(&[
260 ("subject", "Alice"),
261 ("relation", "works at"),
262 ("object", "Acme"),
263 ("confidence", "0.9"),
264 ("related_name", "Acme"),
265 ]),
266 ]);
267
268 let ctx = neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
269
270 assert_eq!(ctx.relationships.len(), 1);
271 assert_eq!(ctx.relationships[0].object, "Acme");
272 assert_eq!(ctx.entities.len(), 1);
273 assert_eq!(ctx.entities[0].name, "Acme");
274 }
275
276 #[tokio::test(flavor = "current_thread")]
277 async fn should_default_confidence_when_unparseable() {
278 let store = StagedStore::with_rows(vec![row(&[
279 ("subject", "Alice"),
280 ("relation", "knows"),
281 ("object", "Bob"),
282 ("confidence", "null"),
283 ("related_name", "Bob"),
284 ])]);
285
286 let ctx = neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
287 assert_eq!(ctx.relationships[0].confidence, 1.0);
288 }
289}