use std::collections::HashMap;
use crate::memory::Scope;
use super::{GraphError, GraphParam, GraphStore};
pub(super) async fn forget_pids<G: GraphStore + ?Sized>(store: &G, pids: &[&str]) -> Result<(), GraphError> {
for pid in pids {
forget_one_pid(store, pid).await?;
}
Ok(())
}
async fn forget_one_pid<G: GraphStore + ?Sized>(store: &G, pid: &str) -> Result<(), GraphError> {
let params = HashMap::from([("pid".to_string(), GraphParam::from(pid))]);
let edge_cypher = "MATCH ()-[r]->() WHERE $pid IN r.memory_pids \
SET r.memory_pids = [p IN r.memory_pids WHERE p <> $pid] \
WITH r WHERE size(r.memory_pids) = 0 \
DELETE r";
store.query(edge_cypher, ¶ms).await?;
let node_cypher = "MATCH (n:Entity) WHERE $pid IN n.memory_pids \
SET n.memory_pids = [p IN n.memory_pids WHERE p <> $pid] \
WITH n WHERE size(n.memory_pids) = 0 AND NOT (n)--() \
DELETE n";
store.query(node_cypher, ¶ms).await?;
Ok(())
}
pub(super) async fn forget_scope<G: GraphStore + ?Sized>(store: &G, scope: &Scope) -> Result<(), GraphError> {
let cypher = "MATCH (n:Entity {agent_id: $agent_id, org_id: $org_id, user_id: $user_id}) DETACH DELETE n";
store.query(cypher, &scope_params(scope)).await?;
Ok(())
}
fn scope_params(scope: &Scope) -> HashMap<String, GraphParam> {
HashMap::from([
("agent_id".to_string(), scope.agent_id.clone().into()),
("org_id".to_string(), scope.org_id.clone().into()),
("user_id".to_string(), scope.user_id.clone().into()),
])
}
#[cfg(test)]
mod tests {
use std::sync::Mutex;
use super::*;
use crate::graph::{GraphRows, GraphStore};
fn scope() -> Scope {
Scope {
agent_id: "agent".to_string(),
org_id: "org".to_string(),
user_id: "user".to_string(),
}
}
#[derive(Default)]
struct RecordingStore {
calls: Mutex<Vec<(String, HashMap<String, GraphParam>)>>,
}
impl RecordingStore {
fn calls(&self) -> Vec<(String, HashMap<String, GraphParam>)> {
self.calls.lock().expect("recording store poisoned").clone()
}
}
impl GraphStore for RecordingStore {
async fn ensure_graph(&self) -> Result<(), GraphError> {
Ok(())
}
async fn query(&self, cypher: &str, params: &HashMap<String, GraphParam>) -> Result<GraphRows, GraphError> {
self.calls
.lock()
.expect("recording store poisoned")
.push((cypher.to_string(), params.clone()));
Ok(GraphRows::new())
}
}
#[tokio::test(flavor = "current_thread")]
async fn should_issue_edge_then_node_statements_per_pid() {
let store = RecordingStore::default();
forget_pids(&store, &["mem1"]).await.unwrap();
let calls = store.calls();
assert_eq!(calls.len(), 2);
assert!(calls[0].0.contains("-[r]->"), "edges decremented first");
assert!(calls[1].0.contains("(n:Entity"), "nodes decremented second");
}
#[tokio::test(flavor = "current_thread")]
async fn should_bind_pid_as_param_without_scope_guard() {
let store = RecordingStore::default();
forget_pids(&store, &["mem1"]).await.unwrap();
for (cypher, params) in store.calls() {
assert!(!cypher.contains("mem1"), "pid must not be interpolated");
assert_eq!(params.get("pid"), Some(&GraphParam::Str("mem1".to_string())));
assert!(!cypher.contains("agent_id"), "pid path is not scope-confined");
}
}
#[tokio::test(flavor = "current_thread")]
async fn should_guard_node_delete_on_empty_pids_and_no_edges() {
let store = RecordingStore::default();
forget_pids(&store, &["mem1"]).await.unwrap();
let node_call = &store.calls()[1].0;
assert!(node_call.contains("size(n.memory_pids) = 0"));
assert!(
node_call.contains("NOT (n)--()"),
"node kept if a surviving edge joins it"
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_issue_two_statements_per_pid_for_multiple_pids() {
let store = RecordingStore::default();
forget_pids(&store, &["mem1", "mem2"]).await.unwrap();
assert_eq!(store.calls().len(), 4, "edge+node per pid");
}
#[tokio::test(flavor = "current_thread")]
async fn should_no_op_for_empty_pid_list() {
let store = RecordingStore::default();
forget_pids(&store, &[]).await.unwrap();
assert!(store.calls().is_empty());
}
#[tokio::test(flavor = "current_thread")]
async fn should_detach_delete_whole_scope_on_scope_forget() {
let store = RecordingStore::default();
forget_scope(&store, &scope()).await.unwrap();
let calls = store.calls();
assert_eq!(calls.len(), 1);
assert!(calls[0].0.contains("DETACH DELETE"));
assert_eq!(calls[0].1.get("agent_id"), Some(&GraphParam::Str("agent".to_string())));
}
}