use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::{Arc, Mutex};
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EdgeType {
RelatedTo,
SharesEntity,
Precedes,
LinkedBy,
}
impl EdgeType {
pub fn as_str(&self) -> &'static str {
match self {
EdgeType::RelatedTo => "related_to",
EdgeType::SharesEntity => "shares_entity",
EdgeType::Precedes => "precedes",
EdgeType::LinkedBy => "linked_by",
}
}
fn from_str(s: &str) -> Self {
match s {
"shares_entity" => EdgeType::SharesEntity,
"precedes" => EdgeType::Precedes,
"linked_by" => EdgeType::LinkedBy,
_ => EdgeType::RelatedTo,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphEdge {
pub from_id: String,
pub to_id: String,
pub edge_type: EdgeType,
pub weight: f32,
pub created_at: u64,
pub namespace: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphNode {
pub memory_id: String,
pub depth: u32,
pub incoming_edges: Vec<GraphEdge>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphExport {
pub namespace: String,
pub node_count: usize,
pub edge_count: usize,
pub edges: Vec<GraphEdge>,
}
const RELATED_TO_THRESHOLD: f32 = 0.85;
const MAX_EDGES_PER_MEMORY: usize = 50;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEvent {
pub id: i64,
pub event_type: String,
pub agent_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub importance: Option<f32>,
pub timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct AuditEventInsert {
pub event_type: String,
pub agent_id: String,
pub memory_id: Option<String>,
pub session_id: Option<String>,
pub importance: Option<f32>,
pub timestamp: u64,
}
#[derive(Clone)]
pub struct MemoryGraphEngine {
conn: Arc<Mutex<Connection>>,
}
impl MemoryGraphEngine {
pub fn open(path: &str) -> Result<Self, rusqlite::Error> {
let conn = Connection::open(path)?;
conn.execute_batch(
"PRAGMA journal_mode=WAL;
PRAGMA synchronous=NORMAL;
CREATE TABLE IF NOT EXISTS edges (
from_id TEXT NOT NULL,
to_id TEXT NOT NULL,
edge_type TEXT NOT NULL,
weight REAL NOT NULL DEFAULT 1.0,
created_at INTEGER NOT NULL,
namespace TEXT NOT NULL,
PRIMARY KEY (from_id, to_id, edge_type)
);
CREATE INDEX IF NOT EXISTS idx_edges_from ON edges(from_id);
CREATE INDEX IF NOT EXISTS idx_edges_to ON edges(to_id);
CREATE INDEX IF NOT EXISTS idx_edges_ns ON edges(namespace);
CREATE TABLE IF NOT EXISTS audit_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
event_type TEXT NOT NULL,
agent_id TEXT NOT NULL,
memory_id TEXT,
session_id TEXT,
importance REAL,
timestamp INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_agent ON audit_events(agent_id);
CREATE INDEX IF NOT EXISTS idx_audit_type ON audit_events(event_type);
CREATE INDEX IF NOT EXISTS idx_audit_ts ON audit_events(timestamp);",
)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
pub fn upsert_edge(&self, edge: &GraphEdge) -> Result<(), rusqlite::Error> {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
conn.execute(
"INSERT OR REPLACE INTO edges
(from_id, to_id, edge_type, weight, created_at, namespace)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
edge.from_id,
edge.to_id,
edge.edge_type.as_str(),
edge.weight,
edge.created_at as i64,
edge.namespace,
],
)?;
Ok(())
}
pub fn remove_memory(&self, memory_id: &str) -> Result<(), rusqlite::Error> {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
conn.execute(
"DELETE FROM edges WHERE from_id = ?1 OR to_id = ?1",
params![memory_id],
)?;
Ok(())
}
pub fn get_edges(&self, memory_id: &str) -> Vec<GraphEdge> {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
let mut stmt = match conn.prepare(
"SELECT from_id, to_id, edge_type, weight, created_at, namespace
FROM edges
WHERE from_id = ?1 OR to_id = ?1",
) {
Ok(s) => s,
Err(_) => return Vec::new(),
};
stmt.query_map(params![memory_id], row_to_edge)
.map(|rows| rows.filter_map(|r| r.ok()).collect())
.unwrap_or_default()
}
pub fn traverse(&self, root_id: &str, max_depth: u32, namespace: &str) -> Vec<GraphNode> {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
let mut visited: HashSet<String> = HashSet::new();
let mut queue: VecDeque<(String, u32)> = VecDeque::new();
let mut result: Vec<GraphNode> = Vec::new();
visited.insert(root_id.to_string());
queue.push_back((root_id.to_string(), 0));
result.push(GraphNode {
memory_id: root_id.to_string(),
depth: 0,
incoming_edges: Vec::new(),
});
while let Some((current, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
let mut stmt = match conn.prepare(
"SELECT from_id, to_id, edge_type, weight, created_at, namespace
FROM edges
WHERE (from_id = ?1 OR to_id = ?1) AND namespace = ?2",
) {
Ok(s) => s,
Err(_) => continue,
};
let edges: Vec<GraphEdge> = stmt
.query_map(params![current, namespace], row_to_edge)
.map(|rows| rows.filter_map(|r| r.ok()).collect())
.unwrap_or_default();
let mut neighbor_edges: HashMap<String, Vec<GraphEdge>> = HashMap::new();
for edge in &edges {
let neighbor = if edge.from_id == current {
edge.to_id.clone()
} else {
edge.from_id.clone()
};
if !visited.contains(&neighbor) {
neighbor_edges
.entry(neighbor)
.or_default()
.push(edge.clone());
}
}
for (neighbor, inc_edges) in neighbor_edges {
visited.insert(neighbor.clone());
queue.push_back((neighbor.clone(), depth + 1));
result.push(GraphNode {
memory_id: neighbor,
depth: depth + 1,
incoming_edges: inc_edges,
});
}
}
result
}
pub fn shortest_path(
&self,
from_id: &str,
to_id: &str,
namespace: &str,
) -> Option<Vec<String>> {
if from_id == to_id {
return Some(vec![from_id.to_string()]);
}
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
let mut visited: HashSet<String> = HashSet::new();
let mut queue: VecDeque<Vec<String>> = VecDeque::new();
visited.insert(from_id.to_string());
queue.push_back(vec![from_id.to_string()]);
while let Some(path) = queue.pop_front() {
let current = path.last().unwrap();
let mut stmt = conn
.prepare(
"SELECT from_id, to_id FROM edges
WHERE (from_id = ?1 OR to_id = ?1) AND namespace = ?2",
)
.ok()?;
let neighbors: Vec<String> = stmt
.query_map(params![current, namespace], |row| {
let from: String = row.get(0)?;
let to: String = row.get(1)?;
Ok((from, to))
})
.ok()?
.filter_map(|r| r.ok())
.map(|(from, to)| if from == *current { to } else { from })
.collect();
for neighbor in neighbors {
if visited.contains(&neighbor) {
continue;
}
let mut new_path = path.clone();
new_path.push(neighbor.clone());
if neighbor == to_id {
return Some(new_path);
}
visited.insert(neighbor);
queue.push_back(new_path);
}
}
None
}
pub fn export(&self, namespace: &str) -> GraphExport {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
let edges: Vec<GraphEdge> = {
let mut stmt = match conn.prepare(
"SELECT from_id, to_id, edge_type, weight, created_at, namespace
FROM edges WHERE namespace = ?1",
) {
Ok(s) => s,
Err(_) => {
return GraphExport {
namespace: namespace.to_string(),
node_count: 0,
edge_count: 0,
edges: Vec::new(),
}
}
};
stmt.query_map(params![namespace], row_to_edge)
.map(|rows| rows.filter_map(|r| r.ok()).collect())
.unwrap_or_default()
};
let mut nodes: HashSet<String> = HashSet::new();
for e in &edges {
nodes.insert(e.from_id.clone());
nodes.insert(e.to_id.clone());
}
GraphExport {
namespace: namespace.to_string(),
node_count: nodes.len(),
edge_count: edges.len(),
edges,
}
}
pub fn insert_audit_event(&self, event: &AuditEventInsert) -> Result<(), rusqlite::Error> {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
conn.execute(
"INSERT INTO audit_events
(event_type, agent_id, memory_id, session_id, importance, timestamp)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
event.event_type,
event.agent_id,
event.memory_id,
event.session_id,
event.importance,
event.timestamp as i64,
],
)?;
Ok(())
}
pub fn query_audit_events(
&self,
agent_id: Option<&str>,
event_type: Option<&str>,
from_ts: Option<u64>,
to_ts: Option<u64>,
limit: usize,
) -> Vec<AuditEvent> {
let conn = self.conn.lock().unwrap_or_else(|p| p.into_inner());
let limit = limit.min(10_000) as i64;
let mut stmt = match conn.prepare(
"SELECT id, event_type, agent_id, memory_id, session_id, importance, timestamp
FROM audit_events
WHERE (?1 IS NULL OR agent_id = ?1)
AND (?2 IS NULL OR event_type = ?2)
AND (?3 IS NULL OR timestamp >= ?3)
AND (?4 IS NULL OR timestamp <= ?4)
ORDER BY timestamp DESC
LIMIT ?5",
) {
Ok(s) => s,
Err(_) => return Vec::new(),
};
let from_ts_val = from_ts.map(|v| v as i64);
let to_ts_val = to_ts.map(|v| v as i64);
stmt.query_map(
params![agent_id, event_type, from_ts_val, to_ts_val, limit],
|row| {
Ok(AuditEvent {
id: row.get(0)?,
event_type: row.get(1)?,
agent_id: row.get(2)?,
memory_id: row.get(3)?,
session_id: row.get(4)?,
importance: row.get(5)?,
timestamp: row.get::<_, i64>(6)? as u64,
})
},
)
.map(|rows| rows.filter_map(|r| r.ok()).collect())
.unwrap_or_default()
}
pub fn build_edges_for_new_memory(
&self,
new_id: &str,
new_embedding: &[f32],
new_tags: &[String],
new_created_at: u64,
namespace: &str,
existing: &[(String, Vec<f32>, Vec<String>, u64)], ) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let new_entity_tags: HashSet<&str> = new_tags
.iter()
.filter(|t| t.starts_with("entity:"))
.map(|t| t.as_str())
.collect();
let mut edge_count = 0usize;
for (other_id, other_embedding, other_tags, other_created_at) in existing {
if other_id == new_id || edge_count >= MAX_EDGES_PER_MEMORY {
break;
}
let similarity = cosine_similarity(new_embedding, other_embedding);
if similarity >= RELATED_TO_THRESHOLD {
let _ = self.upsert_edge(&GraphEdge {
from_id: new_id.to_string(),
to_id: other_id.clone(),
edge_type: EdgeType::RelatedTo,
weight: similarity,
created_at: now,
namespace: namespace.to_string(),
});
edge_count += 1;
if *other_created_at < new_created_at {
let _ = self.upsert_edge(&GraphEdge {
from_id: other_id.clone(),
to_id: new_id.to_string(),
edge_type: EdgeType::Precedes,
weight: 1.0,
created_at: now,
namespace: namespace.to_string(),
});
} else {
let _ = self.upsert_edge(&GraphEdge {
from_id: new_id.to_string(),
to_id: other_id.clone(),
edge_type: EdgeType::Precedes,
weight: 1.0,
created_at: now,
namespace: namespace.to_string(),
});
}
}
let other_entity_tags: HashSet<&str> = other_tags
.iter()
.filter(|t| t.starts_with("entity:"))
.map(|t| t.as_str())
.collect();
if !new_entity_tags.is_empty()
&& new_entity_tags
.intersection(&other_entity_tags)
.next()
.is_some()
{
let _ = self.upsert_edge(&GraphEdge {
from_id: new_id.to_string(),
to_id: other_id.clone(),
edge_type: EdgeType::SharesEntity,
weight: 1.0,
created_at: now,
namespace: namespace.to_string(),
});
edge_count += 1;
}
}
}
}
fn row_to_edge(row: &rusqlite::Row<'_>) -> rusqlite::Result<GraphEdge> {
Ok(GraphEdge {
from_id: row.get(0)?,
to_id: row.get(1)?,
edge_type: EdgeType::from_str(&row.get::<_, String>(2)?),
weight: row.get(3)?,
created_at: row.get::<_, i64>(4)? as u64,
namespace: row.get(5)?,
})
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 {
0.0
} else {
(dot / denom).clamp(-1.0, 1.0)
}
}
pub fn open_from_env() -> Arc<MemoryGraphEngine> {
let path = std::env::var("DAKERA_DATA_DIR")
.map(|dir| format!("{}/graph.db", dir))
.unwrap_or_else(|_| ":memory:".to_string());
match MemoryGraphEngine::open(&path) {
Ok(engine) => {
tracing::info!(path = %path, "CE-5: memory knowledge graph opened");
Arc::new(engine)
}
Err(e) => {
tracing::warn!(error = %e, "CE-5: failed to open graph.db, falling back to :memory:");
Arc::new(MemoryGraphEngine::open(":memory:").expect("in-memory sqlite must work"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_engine() -> MemoryGraphEngine {
MemoryGraphEngine::open(":memory:").unwrap()
}
fn dummy_embedding(seed: f32, dim: usize) -> Vec<f32> {
let v = vec![seed / 10.0; dim];
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
v
} else {
v.iter().map(|x| x / norm).collect()
}
}
#[test]
fn test_upsert_and_get_edges() {
let g = test_engine();
g.upsert_edge(&GraphEdge {
from_id: "mem_a".into(),
to_id: "mem_b".into(),
edge_type: EdgeType::RelatedTo,
weight: 0.9,
created_at: 1000,
namespace: "ns1".into(),
})
.unwrap();
let edges = g.get_edges("mem_a");
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].to_id, "mem_b");
assert_eq!(edges[0].edge_type, EdgeType::RelatedTo);
}
#[test]
fn test_bfs_traversal() {
let g = test_engine();
let ns = "test_ns";
for (from, to) in [("mem_a", "mem_b"), ("mem_b", "mem_c")] {
g.upsert_edge(&GraphEdge {
from_id: from.into(),
to_id: to.into(),
edge_type: EdgeType::RelatedTo,
weight: 0.9,
created_at: 1000,
namespace: ns.into(),
})
.unwrap();
}
let nodes = g.traverse("mem_a", 3, ns);
let ids: Vec<&str> = nodes.iter().map(|n| n.memory_id.as_str()).collect();
assert!(ids.contains(&"mem_a"));
assert!(ids.contains(&"mem_b"));
assert!(ids.contains(&"mem_c"));
}
#[test]
fn test_shortest_path() {
let g = test_engine();
let ns = "test_ns2";
for (from, to) in [("ma", "mb"), ("mb", "mc"), ("ma", "mc")] {
g.upsert_edge(&GraphEdge {
from_id: from.into(),
to_id: to.into(),
edge_type: EdgeType::RelatedTo,
weight: 0.9,
created_at: 1000,
namespace: ns.into(),
})
.unwrap();
}
let path = g.shortest_path("ma", "mc", ns).unwrap();
assert_eq!(path.len(), 2);
assert_eq!(path[0], "ma");
assert_eq!(path[1], "mc");
}
#[test]
fn test_build_edges_for_new_memory() {
let g = test_engine();
let ns = "build_test";
let dim = 4;
let emb_a = dummy_embedding(1.0, dim);
let emb_b = dummy_embedding(2.0, dim);
let emb_new = dummy_embedding(1.5, dim);
g.build_edges_for_new_memory(
"mem_new",
&emb_new,
&[],
2000,
ns,
&[
("mem_a".into(), emb_a, vec![], 1000),
("mem_b".into(), emb_b, vec![], 1500),
],
);
let edges = g.get_edges("mem_new");
assert!(!edges.is_empty());
}
#[test]
fn test_remove_memory() {
let g = test_engine();
g.upsert_edge(&GraphEdge {
from_id: "del_me".into(),
to_id: "other".into(),
edge_type: EdgeType::RelatedTo,
weight: 0.9,
created_at: 0,
namespace: "ns".into(),
})
.unwrap();
g.remove_memory("del_me").unwrap();
assert!(g.get_edges("del_me").is_empty());
}
#[test]
fn test_audit_event_insert_and_query() {
let g = test_engine();
let insert = AuditEventInsert {
event_type: "memory.stored".to_string(),
agent_id: "agent-1".to_string(),
memory_id: Some("mem_abc".to_string()),
session_id: Some("sess_xyz".to_string()),
importance: Some(0.8),
timestamp: 1_700_000_000_000,
};
g.insert_audit_event(&insert).unwrap();
let events = g.query_audit_events(None, None, None, None, 10);
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "memory.stored");
assert_eq!(events[0].agent_id, "agent-1");
assert_eq!(events[0].memory_id.as_deref(), Some("mem_abc"));
assert_eq!(events[0].session_id.as_deref(), Some("sess_xyz"));
assert!((events[0].importance.unwrap() - 0.8).abs() < 1e-5);
assert_eq!(events[0].timestamp, 1_700_000_000_000);
}
#[test]
fn test_audit_query_filter_by_agent() {
let g = test_engine();
for i in 0..5u64 {
g.insert_audit_event(&AuditEventInsert {
event_type: "memory.recalled".to_string(),
agent_id: if i < 3 { "agent-a" } else { "agent-b" }.to_string(),
memory_id: None,
session_id: None,
importance: None,
timestamp: 1_000 + i,
})
.unwrap();
}
let for_a = g.query_audit_events(Some("agent-a"), None, None, None, 100);
assert_eq!(for_a.len(), 3);
let for_b = g.query_audit_events(Some("agent-b"), None, None, None, 100);
assert_eq!(for_b.len(), 2);
}
#[test]
fn test_audit_query_filter_by_event_type() {
let g = test_engine();
g.insert_audit_event(&AuditEventInsert {
event_type: "memory.stored".to_string(),
agent_id: "ag".to_string(),
memory_id: None,
session_id: None,
importance: None,
timestamp: 1,
})
.unwrap();
g.insert_audit_event(&AuditEventInsert {
event_type: "session.started".to_string(),
agent_id: "ag".to_string(),
memory_id: None,
session_id: None,
importance: None,
timestamp: 2,
})
.unwrap();
let stored = g.query_audit_events(None, Some("memory.stored"), None, None, 10);
assert_eq!(stored.len(), 1);
let sessions = g.query_audit_events(None, Some("session.started"), None, None, 10);
assert_eq!(sessions.len(), 1);
}
#[test]
fn test_audit_query_time_range() {
let g = test_engine();
for ts in [100u64, 200, 300, 400, 500] {
g.insert_audit_event(&AuditEventInsert {
event_type: "ev".to_string(),
agent_id: "ag".to_string(),
memory_id: None,
session_id: None,
importance: None,
timestamp: ts,
})
.unwrap();
}
let events = g.query_audit_events(None, None, Some(200), Some(400), 100);
assert_eq!(events.len(), 3);
}
#[test]
fn test_audit_query_limit() {
let g = test_engine();
for i in 0..20u64 {
g.insert_audit_event(&AuditEventInsert {
event_type: "ev".to_string(),
agent_id: "ag".to_string(),
memory_id: None,
session_id: None,
importance: None,
timestamp: i,
})
.unwrap();
}
let events = g.query_audit_events(None, None, None, None, 5);
assert_eq!(events.len(), 5);
}
}