1use std::collections::HashMap;
19
20use chrono::{DateTime, FixedOffset};
21
22use crate::embedding::{EmbeddingError, EmbeddingModel};
23use crate::memory::Scope;
24
25use super::{
26 Edge, EdgeResolver, EntityResolver, ExistingEdge, GraphError, GraphParam, GraphStore, Resolution, ResolveError,
27 Triple, TripleSet,
28};
29
30const ENTITY_LABEL: &str = "Entity";
32
33const FALLBACK_RELATION_LABEL: &str = "RELATED_TO";
39
40#[derive(Debug, Clone)]
48pub struct CommitContext {
49 pub scope: Scope,
51 pub memory_pid: String,
53 pub valid_from: DateTime<FixedOffset>,
55}
56
57#[derive(Debug, thiserror::Error)]
59pub enum CommitError {
60 #[error("entity resolution failed: {0}")]
62 EntityResolution(#[from] ResolveError),
63
64 #[error("edge resolution failed: {0}")]
66 EdgeResolution(#[from] super::EdgeError),
67
68 #[error("node embedding failed: {0}")]
70 Embed(#[from] EmbeddingError),
71
72 #[error("graph write failed: {0}")]
74 Write(#[from] GraphError),
75}
76
77pub(super) async fn commit_triples<G, EM, ER, EdgeR>(
91 store: &G,
92 embedder: &EM,
93 entities: &ER,
94 edges: &EdgeR,
95 ctx: &CommitContext,
96 triples: &TripleSet,
97) -> Result<usize, CommitError>
98where
99 G: GraphStore + ?Sized,
100 EM: EmbeddingModel,
101 ER: EntityResolver,
102 EdgeR: EdgeResolver,
103{
104 let mut committed = 0;
105 for triple in triples.iter() {
106 if commit_one(store, embedder, entities, edges, ctx, triple).await? {
107 committed += 1;
108 }
109 }
110 Ok(committed)
111}
112
113async fn commit_one<G, EM, ER, EdgeR>(
115 store: &G,
116 embedder: &EM,
117 entities: &ER,
118 edges: &EdgeR,
119 ctx: &CommitContext,
120 triple: &Triple,
121) -> Result<bool, CommitError>
122where
123 G: GraphStore + ?Sized,
124 EM: EmbeddingModel,
125 ER: EntityResolver,
126 EdgeR: EdgeResolver,
127{
128 if triple.subject.trim().is_empty() || triple.object.trim().is_empty() {
132 return Ok(false);
133 }
134
135 let subject = entities.resolve(&ctx.scope, &triple.subject).await?;
136 let object = entities.resolve(&ctx.scope, &triple.object).await?;
137
138 let subject_key = resolution_key(&subject);
139 let object_key = resolution_key(&object);
140 if subject_key == object_key {
141 return Ok(false);
142 }
143
144 upsert_node(store, embedder, ctx, &subject).await?;
145 upsert_node(store, embedder, ctx, &object).await?;
146
147 let edge = Edge {
148 subject_key: subject_key.clone(),
149 relation: triple.relation.clone(),
150 object_key: object_key.clone(),
151 confidence: triple.confidence,
152 valid_from: ctx.valid_from,
153 };
154 let resolution = edges.resolve(&ctx.scope, edge).await?;
155
156 for closed in &resolution.close {
157 close_edge(store, ctx, closed).await?;
158 }
159 upsert_edge(store, ctx, &resolution.open).await?;
160
161 Ok(true)
162}
163
164fn resolution_key(resolution: &Resolution) -> String {
170 match resolution {
171 Resolution::Existing { name, .. } | Resolution::New { name } => name.clone(),
172 }
173}
174
175async fn upsert_node<G: GraphStore + ?Sized, EM: EmbeddingModel>(
182 store: &G,
183 embedder: &EM,
184 ctx: &CommitContext,
185 resolution: &Resolution,
186) -> Result<(), CommitError> {
187 let name = resolution_key(resolution);
188 let embedding = embedder.embed(&name).await?;
189 let embedding_json = serde_json::to_string(&embedding).expect("serializing Vec<f32> to JSON cannot fail");
190
191 let cypher = format!(
192 "MERGE (e:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $name}}) \
193 ON CREATE SET e.first_seen_at = $now, e.embedding = $embedding, e.memory_pids = [$pid] \
194 ON MATCH SET e.memory_pids = \
195 CASE WHEN $pid IN e.memory_pids THEN e.memory_pids ELSE e.memory_pids + $pid END"
196 );
197
198 let mut params = scope_params(&ctx.scope);
199 params.insert("name".to_string(), name.into());
200 params.insert("pid".to_string(), ctx.memory_pid.clone().into());
201 params.insert("now".to_string(), ctx.valid_from.to_rfc3339().into());
202 params.insert("embedding".to_string(), embedding_json.into());
203
204 store.query(&cypher, ¶ms).await?;
205 Ok(())
206}
207
208async fn upsert_edge<G: GraphStore + ?Sized>(store: &G, ctx: &CommitContext, edge: &Edge) -> Result<(), CommitError> {
217 let label = sanitize_relation_label(&edge.relation);
218 let cypher = format!(
219 "MATCH (s:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $subject}}) \
220 MATCH (o:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $object}}) \
221 MERGE (s)-[r:{label} {{valid_from: $valid_from}}]->(o) \
222 ON CREATE SET r.confidence = $confidence, \
223 r.relation = $relation, r.memory_pids = [$pid] \
224 ON MATCH SET r.memory_pids = \
225 CASE WHEN $pid IN r.memory_pids THEN r.memory_pids ELSE r.memory_pids + $pid END"
226 );
227
228 let mut params = scope_params(&ctx.scope);
229 params.insert("subject".to_string(), edge.subject_key.clone().into());
230 params.insert("object".to_string(), edge.object_key.clone().into());
231 params.insert("relation".to_string(), edge.relation.clone().into());
232 params.insert("valid_from".to_string(), edge.valid_from.to_rfc3339().into());
233 params.insert("confidence".to_string(), GraphParam::Float(edge.confidence.into()));
234 params.insert("pid".to_string(), ctx.memory_pid.clone().into());
235
236 store.query(&cypher, ¶ms).await?;
237 Ok(())
238}
239
240async fn close_edge<G: GraphStore + ?Sized>(
247 store: &G,
248 ctx: &CommitContext,
249 target: &ExistingEdge,
250) -> Result<(), CommitError> {
251 let cypher = format!(
252 "MATCH (s:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $subject}}) \
253 -[r]->(o:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $object}}) \
254 WHERE r.relation = $relation AND r.valid_from = $valid_from AND r.valid_to IS NULL \
255 SET r.valid_to = $valid_to"
256 );
257
258 let mut params = scope_params(&ctx.scope);
259 params.insert("subject".to_string(), target.subject_key.clone().into());
260 params.insert("object".to_string(), target.object_key.clone().into());
261 params.insert("relation".to_string(), target.relation.clone().into());
262 params.insert("valid_from".to_string(), target.valid_from.to_rfc3339().into());
263 params.insert("valid_to".to_string(), ctx.valid_from.to_rfc3339().into());
264
265 store.query(&cypher, ¶ms).await?;
266 Ok(())
267}
268
269fn scope_params(scope: &Scope) -> HashMap<String, GraphParam> {
271 HashMap::from([
272 ("agent_id".to_string(), scope.agent_id.clone().into()),
273 ("org_id".to_string(), scope.org_id.clone().into()),
274 ("user_id".to_string(), scope.user_id.clone().into()),
275 ])
276}
277
278fn sanitize_relation_label(relation: &str) -> String {
287 let mut label = String::with_capacity(relation.len());
288 let mut prev_underscore = false;
289 for ch in relation.chars() {
290 if ch.is_ascii_alphanumeric() {
291 label.extend(ch.to_uppercase());
292 prev_underscore = false;
293 } else if !prev_underscore && !label.is_empty() {
294 label.push('_');
295 prev_underscore = true;
296 }
297 }
298 let trimmed = label.trim_end_matches('_');
299 if trimmed.is_empty() {
300 FALLBACK_RELATION_LABEL.to_string()
301 } else {
302 trimmed.to_string()
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use std::sync::Mutex;
309
310 use super::*;
311 use crate::graph::{
312 EntityVector, ExactStringResolver, GraphRows, InMemoryEntityCatalog, NaiveAppendResolver,
313 };
314
315 struct StubEmbedding;
318
319 impl EmbeddingModel for StubEmbedding {
320 async fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
321 Ok(vec![0.1, 0.2, 0.3])
322 }
323
324 fn dimensions(&self) -> usize {
325 3
326 }
327 }
328
329 fn scope() -> Scope {
330 Scope {
331 agent_id: "agent".to_string(),
332 org_id: "org".to_string(),
333 user_id: "user".to_string(),
334 }
335 }
336
337 fn now() -> DateTime<FixedOffset> {
338 DateTime::parse_from_rfc3339("2026-06-06T00:00:00Z").expect("valid date")
339 }
340
341 fn ctx() -> CommitContext {
342 CommitContext {
343 scope: scope(),
344 memory_pid: "mem1".to_string(),
345 valid_from: now(),
346 }
347 }
348
349 #[derive(Default)]
351 struct RecordingStore {
352 calls: Mutex<Vec<(String, HashMap<String, GraphParam>)>>,
353 }
354
355 impl RecordingStore {
356 fn calls(&self) -> Vec<(String, HashMap<String, GraphParam>)> {
357 self.calls.lock().expect("recording store poisoned").clone()
358 }
359 }
360
361 impl GraphStore for RecordingStore {
362 async fn ensure_graph(&self) -> Result<(), GraphError> {
363 Ok(())
364 }
365
366 async fn query(&self, cypher: &str, params: &HashMap<String, GraphParam>) -> Result<GraphRows, GraphError> {
367 self.calls
368 .lock()
369 .expect("recording store poisoned")
370 .push((cypher.to_string(), params.clone()));
371 Ok(GraphRows::new())
372 }
373 }
374
375 fn one_triple(subject: &str, relation: &str, object: &str) -> TripleSet {
376 serde_json::from_value(serde_json::json!({
377 "triples": [{ "subject": subject, "relation": relation, "object": object, "confidence": 0.9 }]
378 }))
379 .expect("valid triple json")
380 }
381
382 #[tokio::test(flavor = "current_thread")]
383 async fn should_commit_two_nodes_and_one_edge() {
384 let store = RecordingStore::default();
385 let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
386 let edges = NaiveAppendResolver::new();
387
388 let committed = commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "works at", "Acme"))
389 .await
390 .unwrap();
391
392 assert_eq!(committed, 1);
393 let calls = store.calls();
394 assert_eq!(calls.len(), 3);
396 assert!(calls[2].0.contains(":WORKS_AT"));
397 }
398
399 #[tokio::test(flavor = "current_thread")]
400 async fn should_bind_user_values_as_params_not_interpolate() {
401 let store = RecordingStore::default();
402 let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
403 let edges = NaiveAppendResolver::new();
404
405 let injection = r#"Acme"}) DETACH DELETE n //"#;
406 commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "works at", injection))
407 .await
408 .unwrap();
409
410 let calls = store.calls();
411 for (cypher, _) in &calls {
412 assert!(!cypher.contains("DETACH DELETE"), "user value leaked into query string");
413 }
414 let injected = GraphParam::Str(injection.to_string());
415 assert!(
416 calls.iter().any(|(_, params)| params.values().any(|v| *v == injected)),
417 "the injection value must ride as a bound param somewhere",
418 );
419 }
420
421 #[tokio::test(flavor = "current_thread")]
422 async fn should_tag_every_write_with_scope_and_pid() {
423 let store = RecordingStore::default();
424 let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
425 let edges = NaiveAppendResolver::new();
426
427 commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "knows", "Bob"))
428 .await
429 .unwrap();
430
431 for (_, params) in store.calls() {
432 assert_eq!(params.get("agent_id"), Some(&GraphParam::Str("agent".to_string())));
433 assert_eq!(params.get("pid"), Some(&GraphParam::Str("mem1".to_string())));
434 }
435 }
436
437 #[tokio::test(flavor = "current_thread")]
438 async fn should_skip_self_loop_triple() {
439 let store = RecordingStore::default();
440 let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
441 let edges = NaiveAppendResolver::new();
442
443 let committed = commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "is", "Alice"))
444 .await
445 .unwrap();
446
447 assert_eq!(committed, 0);
448 assert!(store.calls().is_empty());
449 }
450
451 #[tokio::test(flavor = "current_thread")]
452 async fn should_skip_triple_with_blank_entity() {
453 let store = RecordingStore::default();
456 let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
457 let edges = NaiveAppendResolver::new();
458
459 let committed = commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "works at", " "))
460 .await
461 .unwrap();
462
463 assert_eq!(committed, 0);
464 assert!(store.calls().is_empty(), "blank entity must write nothing");
465 }
466
467 #[tokio::test(flavor = "current_thread")]
468 async fn should_merge_to_existing_node_when_entity_resolves() {
469 let catalog = InMemoryEntityCatalog::new();
472 catalog.insert(
473 &scope(),
474 EntityVector {
475 key: "Alice".to_string(),
476 name: "Alice".to_string(),
477 embedding: vec![1.0, 0.0],
478 },
479 );
480 let store = RecordingStore::default();
481 let entities = ExactStringResolver::new(catalog);
482 let edges = NaiveAppendResolver::new();
483
484 commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "likes", "Tea"))
485 .await
486 .unwrap();
487
488 let alice = GraphParam::Str("Alice".to_string());
489 let subject_merge = store
490 .calls()
491 .into_iter()
492 .find(|(c, p)| c.contains("MERGE (e:Entity") && p.get("name") == Some(&alice))
493 .expect("subject node merged");
494 assert_eq!(subject_merge.1.get("name"), Some(&alice));
495 }
496
497 #[test]
498 fn should_sanitize_relation_into_safe_label() {
499 assert_eq!(sanitize_relation_label("works at"), "WORKS_AT");
500 assert_eq!(sanitize_relation_label("lives-in"), "LIVES_IN");
501 assert_eq!(sanitize_relation_label(" prefers "), "PREFERS");
502 }
503
504 #[test]
505 fn should_collapse_punctuation_runs_in_label() {
506 assert_eq!(sanitize_relation_label("blocked//by"), "BLOCKED_BY");
507 assert_eq!(sanitize_relation_label("a & b"), "A_B");
508 }
509
510 #[test]
511 fn should_fall_back_when_relation_has_no_alphanumerics() {
512 assert_eq!(sanitize_relation_label("!!!"), FALLBACK_RELATION_LABEL);
513 assert_eq!(sanitize_relation_label(""), FALLBACK_RELATION_LABEL);
514 }
515
516 #[test]
517 fn should_not_let_injection_survive_label_sanitization() {
518 let label = sanitize_relation_label(r#"FOO]->() DETACH DELETE n //"#);
519 assert!(!label.contains(']'));
520 assert!(!label.contains(' '));
521 assert!(!label.contains('-'));
522 }
523}