Skip to main content

memoir_core/graph/
enrich.rs

1//! Read-path enrichment: graph neighborhoods around search hits.
2//!
3//! After a vector search, a consumer can opt into a graph traversal that
4//! surfaces entities and relationships *related* to the hits — facts the vector
5//! search alone would miss. [`GraphStore::neighbors`](super::GraphStore::neighbors)
6//! seeds from the hit memories' entities (those whose `memory_pids` contains a
7//! hit pid) and walks current edges out to a bounded depth, returning a flat
8//! [`GraphContext`]. Traversal is scope-confined and reads only current edges
9//! (`valid_to = null`); superseded edges are history, not "related now".
10
11use std::collections::HashMap;
12
13use crate::memory::Scope;
14
15use super::{GraphError, GraphRow, GraphStore};
16
17/// Maximum traversal depth — the manifesto's "1-2 hop" upper bound.
18///
19/// Caps how far enrichment walks from a hit's entities. Beyond two hops the
20/// related set grows fast and its relevance to the original hit thins; the cap
21/// keeps an opt-in enrichment from turning into an unbounded graph scan.
22pub const MAX_ENRICHMENT_DEPTH: usize = 2;
23
24/// Default traversal depth when a consumer opts in without specifying one.
25///
26/// One hop — a hit's directly-related entities — is the high-value case and the
27/// cheapest. Deeper traversal is opt-in via the depth knob.
28pub const DEFAULT_ENRICHMENT_DEPTH: usize = 1;
29
30/// An entity surfaced by read-path graph enrichment.
31///
32/// Untyped in v1 (`:Entity`, ticket 0005), so it carries only the canonical
33/// `name`; a type field can be added later without breaking the struct.
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct GraphEntity {
36    /// The entity node's canonical name.
37    pub name: String,
38}
39
40/// A relationship surfaced by read-path graph enrichment.
41///
42/// A current (non-superseded) edge between two entities, with the extractor's
43/// confidence carried through for the consumer to weigh.
44#[derive(Debug, Clone, PartialEq)]
45pub struct GraphRelationship {
46    /// The subject entity's name.
47    pub subject: String,
48    /// The relation label.
49    pub relation: String,
50    /// The object entity's name.
51    pub object: String,
52    /// The extractor's confidence in this relationship, 0.0-1.0.
53    pub confidence: f32,
54}
55
56/// The graph neighborhood around a search's hits.
57///
58/// A property of *one* enriched search (attached to the result wrapper, not to
59/// any [`Memory`](crate::memory::Memory)). Flat, deduplicated lists; empty when
60/// enrichment was not requested or no graph is configured. Fields are public so
61/// later additions (entity type, edge validity) are additive via struct-update.
62#[derive(Debug, Clone, Default, PartialEq)]
63pub struct GraphContext {
64    /// Distinct related entities, including the seed entities.
65    pub entities: Vec<GraphEntity>,
66    /// Distinct current relationships among the neighborhood.
67    pub relationships: Vec<GraphRelationship>,
68}
69
70impl GraphContext {
71    /// Returns whether the context holds no entities or relationships.
72    pub fn is_empty(&self) -> bool {
73        self.entities.is_empty() && self.relationships.is_empty()
74    }
75}
76
77/// Backs [`GraphStore::neighbors`]; see that method for semantics.
78pub(super) async fn neighbors<G: GraphStore + ?Sized>(
79    store: &G,
80    seed_pids: &[&str],
81    scope: &Scope,
82    depth: usize,
83) -> Result<GraphContext, GraphError> {
84    if seed_pids.is_empty() {
85        return Ok(GraphContext::default());
86    }
87    let depth = depth.clamp(1, MAX_ENRICHMENT_DEPTH);
88
89    let mut params = HashMap::from([
90        ("agent_id".to_string(), scope.agent_id.clone().into()),
91        ("org_id".to_string(), scope.org_id.clone().into()),
92        ("user_id".to_string(), scope.user_id.clone().into()),
93    ]);
94    for (i, pid) in seed_pids.iter().enumerate() {
95        params.insert(format!("pid{i}"), (*pid).into());
96    }
97    let pid_list = (0..seed_pids.len())
98        .map(|i| format!("$pid{i}"))
99        .collect::<Vec<_>>()
100        .join(", ");
101
102    // Seed = entities in scope whose memory_pids intersects the hit pids. Walk
103    // current edges (valid_to null) out to `depth` hops, returning each edge's
104    // endpoints + properties. The depth is interpolated (it is a clamped
105    // integer, never user text), the rest binds as parameters.
106    let cypher = format!(
107        "MATCH (seed:Entity {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id}}) \
108         WHERE any(p IN seed.memory_pids WHERE p IN [{pid_list}]) \
109         MATCH (seed)-[r*1..{depth}]-(related:Entity) \
110         WITH seed, related, r \
111         UNWIND r AS edge \
112         WITH seed, related, edge WHERE edge.valid_to IS NULL \
113         RETURN startNode(edge).name AS subject, edge.relation AS relation, \
114                endNode(edge).name AS object, edge.confidence AS confidence, related.name AS related_name"
115    );
116
117    let rows = store.query(&cypher, &params).await?;
118    Ok(build_context(&rows))
119}
120
121/// Assembles a deduplicated [`GraphContext`] from traversal result rows.
122fn build_context(rows: &[GraphRow]) -> GraphContext {
123    let mut entities: Vec<GraphEntity> = Vec::new();
124    let mut relationships: Vec<GraphRelationship> = Vec::new();
125
126    for row in rows {
127        if let Some(name) = column(row, "related_name") {
128            let entity = GraphEntity { name: name.to_string() };
129            if !entities.contains(&entity) {
130                entities.push(entity);
131            }
132        }
133
134        let (Some(subject), Some(relation), Some(object)) =
135            (column(row, "subject"), column(row, "relation"), column(row, "object"))
136        else {
137            continue;
138        };
139        let confidence = column(row, "confidence").and_then(|c| c.parse().ok()).unwrap_or(1.0);
140        let relationship = GraphRelationship {
141            subject: subject.to_string(),
142            relation: relation.to_string(),
143            object: object.to_string(),
144            confidence,
145        };
146        if !relationships.contains(&relationship) {
147            relationships.push(relationship);
148        }
149    }
150
151    GraphContext { entities, relationships }
152}
153
154/// Returns the value of the column named `name` in a result row.
155fn column<'a>(row: &'a GraphRow, name: &str) -> Option<&'a str> {
156    row.iter()
157        .find(|(column, _)| column == name)
158        .map(|(_, value)| value.as_str())
159}
160
161#[cfg(test)]
162mod tests {
163    use std::sync::Mutex;
164
165    use super::*;
166    use crate::graph::{GraphParam, GraphRows};
167
168    fn scope() -> Scope {
169        Scope {
170            agent_id: "agent".to_string(),
171            org_id: "org".to_string(),
172            user_id: "user".to_string(),
173        }
174    }
175
176    fn row(pairs: &[(&str, &str)]) -> GraphRow {
177        pairs.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect()
178    }
179
180    /// Returns staged rows for any query, recording the (cypher, params) call.
181    #[derive(Default)]
182    struct StagedStore {
183        rows: Mutex<GraphRows>,
184        calls: Mutex<Vec<(String, HashMap<String, GraphParam>)>>,
185    }
186
187    impl StagedStore {
188        fn with_rows(rows: GraphRows) -> Self {
189            Self {
190                rows: Mutex::new(rows),
191                calls: Mutex::default(),
192            }
193        }
194
195        fn calls(&self) -> Vec<(String, HashMap<String, GraphParam>)> {
196            self.calls.lock().unwrap().clone()
197        }
198    }
199
200    impl GraphStore for StagedStore {
201        async fn ensure_graph(&self) -> Result<(), GraphError> {
202            Ok(())
203        }
204
205        async fn query(&self, cypher: &str, params: &HashMap<String, GraphParam>) -> Result<GraphRows, GraphError> {
206            self.calls.lock().unwrap().push((cypher.to_string(), params.clone()));
207            Ok(self.rows.lock().unwrap().clone())
208        }
209    }
210
211    #[tokio::test(flavor = "current_thread")]
212    async fn should_return_empty_for_no_seeds() {
213        let store = StagedStore::default();
214        let ctx = neighbors(&store, &[], &scope(), 1).await.unwrap();
215        assert!(ctx.is_empty());
216        assert!(store.calls().is_empty(), "no seeds -> no query");
217    }
218
219    #[tokio::test(flavor = "current_thread")]
220    async fn should_bind_seeds_and_scope_as_params() {
221        let store = StagedStore::default();
222        neighbors(&store, &["mem1", "mem2"], &scope(), 1).await.unwrap();
223
224        let (cypher, params) = &store.calls()[0];
225        assert!(!cypher.contains("mem1"), "pids must not be interpolated");
226        assert_eq!(params.get("pid0"), Some(&GraphParam::Str("mem1".to_string())));
227        assert_eq!(params.get("pid1"), Some(&GraphParam::Str("mem2".to_string())));
228        assert_eq!(params.get("agent_id"), Some(&GraphParam::Str("agent".to_string())));
229    }
230
231    #[tokio::test(flavor = "current_thread")]
232    async fn should_filter_current_edges_only() {
233        let store = StagedStore::default();
234        neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
235        assert!(store.calls()[0].0.contains("edge.valid_to IS NULL"));
236    }
237
238    #[tokio::test(flavor = "current_thread")]
239    async fn should_clamp_depth_into_range() {
240        let store = StagedStore::default();
241        neighbors(&store, &["mem1"], &scope(), 99).await.unwrap();
242        assert!(
243            store.calls()[0].0.contains(&format!("*1..{MAX_ENRICHMENT_DEPTH}")),
244            "depth clamps to the max",
245        );
246    }
247
248    #[tokio::test(flavor = "current_thread")]
249    async fn should_build_deduped_context_from_rows() {
250        let store = StagedStore::with_rows(vec![
251            row(&[
252                ("subject", "Alice"),
253                ("relation", "works at"),
254                ("object", "Acme"),
255                ("confidence", "0.9"),
256                ("related_name", "Acme"),
257            ]),
258            // duplicate relationship + entity — must dedup
259            row(&[
260                ("subject", "Alice"),
261                ("relation", "works at"),
262                ("object", "Acme"),
263                ("confidence", "0.9"),
264                ("related_name", "Acme"),
265            ]),
266        ]);
267
268        let ctx = neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
269
270        assert_eq!(ctx.relationships.len(), 1);
271        assert_eq!(ctx.relationships[0].object, "Acme");
272        assert_eq!(ctx.entities.len(), 1);
273        assert_eq!(ctx.entities[0].name, "Acme");
274    }
275
276    #[tokio::test(flavor = "current_thread")]
277    async fn should_default_confidence_when_unparseable() {
278        let store = StagedStore::with_rows(vec![row(&[
279            ("subject", "Alice"),
280            ("relation", "knows"),
281            ("object", "Bob"),
282            ("confidence", "null"),
283            ("related_name", "Bob"),
284        ])]);
285
286        let ctx = neighbors(&store, &["mem1"], &scope(), 1).await.unwrap();
287        assert_eq!(ctx.relationships[0].confidence, 1.0);
288    }
289}