use std::collections::HashMap;
use chrono::{DateTime, FixedOffset};
use crate::embedding::{EmbeddingError, EmbeddingModel};
use crate::memory::Scope;
use super::{
Edge, EdgeResolver, EntityResolver, ExistingEdge, GraphError, GraphParam, GraphStore, Resolution, ResolveError,
Triple, TripleSet,
};
const ENTITY_LABEL: &str = "Entity";
const FALLBACK_RELATION_LABEL: &str = "RELATED_TO";
#[derive(Debug, Clone)]
pub struct CommitContext {
pub scope: Scope,
pub memory_pid: String,
pub valid_from: DateTime<FixedOffset>,
}
#[derive(Debug, thiserror::Error)]
pub enum CommitError {
#[error("entity resolution failed: {0}")]
EntityResolution(#[from] ResolveError),
#[error("edge resolution failed: {0}")]
EdgeResolution(#[from] super::EdgeError),
#[error("node embedding failed: {0}")]
Embed(#[from] EmbeddingError),
#[error("graph write failed: {0}")]
Write(#[from] GraphError),
}
pub(super) async fn commit_triples<G, EM, ER, EdgeR>(
store: &G,
embedder: &EM,
entities: &ER,
edges: &EdgeR,
ctx: &CommitContext,
triples: &TripleSet,
) -> Result<usize, CommitError>
where
G: GraphStore + ?Sized,
EM: EmbeddingModel,
ER: EntityResolver,
EdgeR: EdgeResolver,
{
let mut committed = 0;
for triple in triples.iter() {
if commit_one(store, embedder, entities, edges, ctx, triple).await? {
committed += 1;
}
}
Ok(committed)
}
async fn commit_one<G, EM, ER, EdgeR>(
store: &G,
embedder: &EM,
entities: &ER,
edges: &EdgeR,
ctx: &CommitContext,
triple: &Triple,
) -> Result<bool, CommitError>
where
G: GraphStore + ?Sized,
EM: EmbeddingModel,
ER: EntityResolver,
EdgeR: EdgeResolver,
{
if triple.subject.trim().is_empty() || triple.object.trim().is_empty() {
return Ok(false);
}
let subject = entities.resolve(&ctx.scope, &triple.subject).await?;
let object = entities.resolve(&ctx.scope, &triple.object).await?;
let subject_key = resolution_key(&subject);
let object_key = resolution_key(&object);
if subject_key == object_key {
return Ok(false);
}
upsert_node(store, embedder, ctx, &subject).await?;
upsert_node(store, embedder, ctx, &object).await?;
let edge = Edge {
subject_key: subject_key.clone(),
relation: triple.relation.clone(),
object_key: object_key.clone(),
confidence: triple.confidence,
valid_from: ctx.valid_from,
};
let resolution = edges.resolve(&ctx.scope, edge).await?;
for closed in &resolution.close {
close_edge(store, ctx, closed).await?;
}
upsert_edge(store, ctx, &resolution.open).await?;
Ok(true)
}
fn resolution_key(resolution: &Resolution) -> String {
match resolution {
Resolution::Existing { name, .. } | Resolution::New { name } => name.clone(),
}
}
async fn upsert_node<G: GraphStore + ?Sized, EM: EmbeddingModel>(
store: &G,
embedder: &EM,
ctx: &CommitContext,
resolution: &Resolution,
) -> Result<(), CommitError> {
let name = resolution_key(resolution);
let embedding = embedder.embed(&name).await?;
let embedding_json = serde_json::to_string(&embedding).expect("serializing Vec<f32> to JSON cannot fail");
let cypher = format!(
"MERGE (e:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $name}}) \
ON CREATE SET e.first_seen_at = $now, e.embedding = $embedding, e.memory_pids = [$pid] \
ON MATCH SET e.memory_pids = \
CASE WHEN $pid IN e.memory_pids THEN e.memory_pids ELSE e.memory_pids + $pid END"
);
let mut params = scope_params(&ctx.scope);
params.insert("name".to_string(), name.into());
params.insert("pid".to_string(), ctx.memory_pid.clone().into());
params.insert("now".to_string(), ctx.valid_from.to_rfc3339().into());
params.insert("embedding".to_string(), embedding_json.into());
store.query(&cypher, ¶ms).await?;
Ok(())
}
async fn upsert_edge<G: GraphStore + ?Sized>(store: &G, ctx: &CommitContext, edge: &Edge) -> Result<(), CommitError> {
let label = sanitize_relation_label(&edge.relation);
let cypher = format!(
"MATCH (s:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $subject}}) \
MATCH (o:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $object}}) \
MERGE (s)-[r:{label} {{valid_from: $valid_from}}]->(o) \
ON CREATE SET r.confidence = $confidence, \
r.relation = $relation, r.memory_pids = [$pid] \
ON MATCH SET r.memory_pids = \
CASE WHEN $pid IN r.memory_pids THEN r.memory_pids ELSE r.memory_pids + $pid END"
);
let mut params = scope_params(&ctx.scope);
params.insert("subject".to_string(), edge.subject_key.clone().into());
params.insert("object".to_string(), edge.object_key.clone().into());
params.insert("relation".to_string(), edge.relation.clone().into());
params.insert("valid_from".to_string(), edge.valid_from.to_rfc3339().into());
params.insert("confidence".to_string(), GraphParam::Float(edge.confidence.into()));
params.insert("pid".to_string(), ctx.memory_pid.clone().into());
store.query(&cypher, ¶ms).await?;
Ok(())
}
async fn close_edge<G: GraphStore + ?Sized>(
store: &G,
ctx: &CommitContext,
target: &ExistingEdge,
) -> Result<(), CommitError> {
let cypher = format!(
"MATCH (s:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $subject}}) \
-[r]->(o:{ENTITY_LABEL} {{agent_id: $agent_id, org_id: $org_id, user_id: $user_id, name: $object}}) \
WHERE r.relation = $relation AND r.valid_from = $valid_from AND r.valid_to IS NULL \
SET r.valid_to = $valid_to"
);
let mut params = scope_params(&ctx.scope);
params.insert("subject".to_string(), target.subject_key.clone().into());
params.insert("object".to_string(), target.object_key.clone().into());
params.insert("relation".to_string(), target.relation.clone().into());
params.insert("valid_from".to_string(), target.valid_from.to_rfc3339().into());
params.insert("valid_to".to_string(), ctx.valid_from.to_rfc3339().into());
store.query(&cypher, ¶ms).await?;
Ok(())
}
fn scope_params(scope: &Scope) -> HashMap<String, GraphParam> {
HashMap::from([
("agent_id".to_string(), scope.agent_id.clone().into()),
("org_id".to_string(), scope.org_id.clone().into()),
("user_id".to_string(), scope.user_id.clone().into()),
])
}
fn sanitize_relation_label(relation: &str) -> String {
let mut label = String::with_capacity(relation.len());
let mut prev_underscore = false;
for ch in relation.chars() {
if ch.is_ascii_alphanumeric() {
label.extend(ch.to_uppercase());
prev_underscore = false;
} else if !prev_underscore && !label.is_empty() {
label.push('_');
prev_underscore = true;
}
}
let trimmed = label.trim_end_matches('_');
if trimmed.is_empty() {
FALLBACK_RELATION_LABEL.to_string()
} else {
trimmed.to_string()
}
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
use crate::graph::{
EntityVector, ExactStringResolver, GraphRows, InMemoryEntityCatalog, NaiveAppendResolver,
};
struct StubEmbedding;
impl EmbeddingModel for StubEmbedding {
async fn embed(&self, _text: &str) -> Result<Vec<f32>, EmbeddingError> {
Ok(vec![0.1, 0.2, 0.3])
}
fn dimensions(&self) -> usize {
3
}
}
fn scope() -> Scope {
Scope {
agent_id: "agent".to_string(),
org_id: "org".to_string(),
user_id: "user".to_string(),
}
}
fn now() -> DateTime<FixedOffset> {
DateTime::parse_from_rfc3339("2026-06-06T00:00:00Z").expect("valid date")
}
fn ctx() -> CommitContext {
CommitContext {
scope: scope(),
memory_pid: "mem1".to_string(),
valid_from: now(),
}
}
#[derive(Default)]
struct RecordingStore {
calls: Mutex<Vec<(String, HashMap<String, GraphParam>)>>,
}
impl RecordingStore {
fn calls(&self) -> Vec<(String, HashMap<String, GraphParam>)> {
self.calls.lock().expect("recording store poisoned").clone()
}
}
impl GraphStore for RecordingStore {
async fn ensure_graph(&self) -> Result<(), GraphError> {
Ok(())
}
async fn query(&self, cypher: &str, params: &HashMap<String, GraphParam>) -> Result<GraphRows, GraphError> {
self.calls
.lock()
.expect("recording store poisoned")
.push((cypher.to_string(), params.clone()));
Ok(GraphRows::new())
}
}
fn one_triple(subject: &str, relation: &str, object: &str) -> TripleSet {
serde_json::from_value(serde_json::json!({
"triples": [{ "subject": subject, "relation": relation, "object": object, "confidence": 0.9 }]
}))
.expect("valid triple json")
}
#[tokio::test(flavor = "current_thread")]
async fn should_commit_two_nodes_and_one_edge() {
let store = RecordingStore::default();
let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
let edges = NaiveAppendResolver::new();
let committed = commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "works at", "Acme"))
.await
.unwrap();
assert_eq!(committed, 1);
let calls = store.calls();
assert_eq!(calls.len(), 3);
assert!(calls[2].0.contains(":WORKS_AT"));
}
#[tokio::test(flavor = "current_thread")]
async fn should_bind_user_values_as_params_not_interpolate() {
let store = RecordingStore::default();
let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
let edges = NaiveAppendResolver::new();
let injection = r#"Acme"}) DETACH DELETE n //"#;
commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "works at", injection))
.await
.unwrap();
let calls = store.calls();
for (cypher, _) in &calls {
assert!(!cypher.contains("DETACH DELETE"), "user value leaked into query string");
}
let injected = GraphParam::Str(injection.to_string());
assert!(
calls.iter().any(|(_, params)| params.values().any(|v| *v == injected)),
"the injection value must ride as a bound param somewhere",
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_tag_every_write_with_scope_and_pid() {
let store = RecordingStore::default();
let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
let edges = NaiveAppendResolver::new();
commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "knows", "Bob"))
.await
.unwrap();
for (_, params) in store.calls() {
assert_eq!(params.get("agent_id"), Some(&GraphParam::Str("agent".to_string())));
assert_eq!(params.get("pid"), Some(&GraphParam::Str("mem1".to_string())));
}
}
#[tokio::test(flavor = "current_thread")]
async fn should_skip_self_loop_triple() {
let store = RecordingStore::default();
let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
let edges = NaiveAppendResolver::new();
let committed = commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "is", "Alice"))
.await
.unwrap();
assert_eq!(committed, 0);
assert!(store.calls().is_empty());
}
#[tokio::test(flavor = "current_thread")]
async fn should_skip_triple_with_blank_entity() {
let store = RecordingStore::default();
let entities = ExactStringResolver::new(InMemoryEntityCatalog::new());
let edges = NaiveAppendResolver::new();
let committed = commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "works at", " "))
.await
.unwrap();
assert_eq!(committed, 0);
assert!(store.calls().is_empty(), "blank entity must write nothing");
}
#[tokio::test(flavor = "current_thread")]
async fn should_merge_to_existing_node_when_entity_resolves() {
let catalog = InMemoryEntityCatalog::new();
catalog.insert(
&scope(),
EntityVector {
key: "Alice".to_string(),
name: "Alice".to_string(),
embedding: vec![1.0, 0.0],
},
);
let store = RecordingStore::default();
let entities = ExactStringResolver::new(catalog);
let edges = NaiveAppendResolver::new();
commit_triples(&store, &StubEmbedding, &entities, &edges, &ctx(), &one_triple("Alice", "likes", "Tea"))
.await
.unwrap();
let alice = GraphParam::Str("Alice".to_string());
let subject_merge = store
.calls()
.into_iter()
.find(|(c, p)| c.contains("MERGE (e:Entity") && p.get("name") == Some(&alice))
.expect("subject node merged");
assert_eq!(subject_merge.1.get("name"), Some(&alice));
}
#[test]
fn should_sanitize_relation_into_safe_label() {
assert_eq!(sanitize_relation_label("works at"), "WORKS_AT");
assert_eq!(sanitize_relation_label("lives-in"), "LIVES_IN");
assert_eq!(sanitize_relation_label(" prefers "), "PREFERS");
}
#[test]
fn should_collapse_punctuation_runs_in_label() {
assert_eq!(sanitize_relation_label("blocked//by"), "BLOCKED_BY");
assert_eq!(sanitize_relation_label("a & b"), "A_B");
}
#[test]
fn should_fall_back_when_relation_has_no_alphanumerics() {
assert_eq!(sanitize_relation_label("!!!"), FALLBACK_RELATION_LABEL);
assert_eq!(sanitize_relation_label(""), FALLBACK_RELATION_LABEL);
}
#[test]
fn should_not_let_injection_survive_label_sanitization() {
let label = sanitize_relation_label(r#"FOO]->() DETACH DELETE n //"#);
assert!(!label.contains(']'));
assert!(!label.contains(' '));
assert!(!label.contains('-'));
}
}