use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
use chrono::Utc;
use petgraph::graph::{DiGraph, EdgeIndex, NodeIndex};
use petgraph::visit::EdgeRef;
use uuid::Uuid;
use crate::storage::StorageTrait;
use crate::types::Edge;
#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum EdgeType {
Temporal, Causal, #[default]
Entity, Semantic, Supersedes, }
pub fn edge_type_alignment(edge_type: &EdgeType, intent: &str) -> f32 {
match (edge_type, intent) {
(EdgeType::Temporal, "recall") | (EdgeType::Causal, "action") => 0.9,
(EdgeType::Entity, "question") => 0.8,
(EdgeType::Entity, "code") | (EdgeType::Semantic, "recall" | "question") => 0.7,
(EdgeType::Causal, "code") => 0.6,
_ => 0.3,
}
}
pub fn edge_confidence_at(base_confidence: f32, age_days: f32, half_life: f32) -> f32 {
base_confidence * (-age_days * 2.0_f32.ln() / half_life).exp()
}
pub struct MemoryGraph {
graph: DiGraph<Uuid, f32>,
node_map: HashMap<Uuid, NodeIndex>,
edge_meta: HashMap<EdgeIndex, Edge>,
}
impl MemoryGraph {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
node_map: HashMap::new(),
edge_meta: HashMap::new(),
}
}
pub fn add_node(&mut self, id: Uuid) {
if !self.node_map.contains_key(&id) {
let idx = self.graph.add_node(id);
self.node_map.insert(id, idx);
}
}
pub fn add_edge(&mut self, from: Uuid, to: Uuid, weight: f32) {
self.add_node(from);
self.add_node(to);
let from_idx = self.node_map[&from];
let to_idx = self.node_map[&to];
let edge_idx = self.graph.add_edge(from_idx, to_idx, weight);
let mut edge = Edge::new(from, to, "");
edge.weight = weight;
self.edge_meta.insert(edge_idx, edge);
}
pub fn add_edge_with_meta(&mut self, edge: Edge) {
self.add_node(edge.source);
self.add_node(edge.target);
let from_idx = self.node_map[&edge.source];
let to_idx = self.node_map[&edge.target];
let edge_idx = self.graph.add_edge(from_idx, to_idx, edge.weight);
self.edge_meta.insert(edge_idx, edge);
}
pub fn invalidate_edge(&mut self, from: Uuid, to: Uuid, superseded_by: Option<Uuid>) {
let (Some(&from_idx), Some(&to_idx)) = (self.node_map.get(&from), self.node_map.get(&to))
else {
return;
};
let edge_indices: Vec<EdgeIndex> = self
.graph
.edges_connecting(from_idx, to_idx)
.map(|e| e.id())
.collect();
for edge_idx in edge_indices {
if let Some(meta) = self.edge_meta.get_mut(&edge_idx)
&& meta.invalid_at.is_none()
{
meta.invalid_at = Some(Utc::now());
meta.superseded_by = superseded_by;
}
}
}
pub fn get_valid_edges(&self, entity_id: Uuid) -> Vec<&Edge> {
let Some(&node_idx) = self.node_map.get(&entity_id) else {
return Vec::new();
};
self.graph
.edges(node_idx)
.filter_map(|edge_ref| self.edge_meta.get(&edge_ref.id()))
.filter(|meta| meta.invalid_at.is_none())
.collect()
}
pub fn get_edge_history(&self, entity_id: Uuid) -> Vec<&Edge> {
let Some(&node_idx) = self.node_map.get(&entity_id) else {
return Vec::new();
};
let mut result: Vec<&Edge> = self
.graph
.edges(node_idx)
.filter_map(|edge_ref| self.edge_meta.get(&edge_ref.id()))
.collect();
result.sort_by_key(|e| e.valid_at);
result
}
pub fn traverse(&self, start: Uuid, max_depth: usize) -> Vec<(Uuid, f32)> {
let Some(&start_idx) = self.node_map.get(&start) else {
return Vec::new();
};
let mut visited: HashMap<NodeIndex, usize> = HashMap::new();
let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
visited.insert(start_idx, 0);
queue.push_back((start_idx, 0));
let mut results: Vec<(Uuid, f32)> = Vec::new();
while let Some((current, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
for edge_ref in self.graph.edges(current) {
if let Some(meta) = self.edge_meta.get(&edge_ref.id())
&& meta.invalid_at.is_some()
{
continue;
}
let neighbor = edge_ref.target();
if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(neighbor) {
let next_depth = depth + 1;
e.insert(next_depth);
queue.push_back((neighbor, next_depth));
let score = 1.0_f32 / (1.0 + next_depth as f32);
let id = self.graph[neighbor];
results.push((id, score));
}
}
}
results
}
pub fn build_from_storage(storage: &dyn StorageTrait, namespace_id: Uuid) -> Self {
let mut graph = MemoryGraph::new();
let Ok(entities) = storage.list_entities_by_namespace(namespace_id) else {
return graph;
};
for entity in &entities {
graph.add_node(entity.id);
if let Ok(memories) = storage.list_episodic_by_entity(entity.id, usize::MAX) {
for mem in memories {
graph.add_edge(entity.id, mem.id, 1.0);
}
}
if let Ok(edges) = storage.get_edges_for_entity(entity.id) {
for edge in edges {
graph.add_edge_with_meta(edge);
}
}
}
for entity in &entities {
if let Ok(sem_mems) = storage.list_semantic_by_entity(entity.id, usize::MAX) {
for mem in sem_mems {
graph.add_edge(entity.id, mem.id, mem.confidence);
}
}
}
graph
}
pub fn beam_search(
&self,
start: Uuid,
intent: &str,
beam_width: usize,
max_depth: usize,
) -> Vec<(Uuid, f32)> {
#[derive(PartialEq)]
struct Candidate(f32, NodeIndex);
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
let Some(&start_idx) = self.node_map.get(&start) else {
return Vec::new();
};
let mut visited = HashSet::new();
visited.insert(start_idx);
let mut scores: HashMap<NodeIndex, f32> = HashMap::new();
let mut beam = vec![(start_idx, 1.0_f32)];
for _depth in 0..max_depth {
let mut heap: BinaryHeap<Candidate> = BinaryHeap::new();
for &(current, parent_score) in &beam {
for edge_ref in self.graph.edges(current) {
let neighbor = edge_ref.target();
if visited.contains(&neighbor) {
continue;
}
if let Some(meta) = self.edge_meta.get(&edge_ref.id())
&& meta.invalid_at.is_some()
{
continue;
}
let meta = self.edge_meta.get(&edge_ref.id());
let type_alignment =
meta.map_or(0.3, |m| edge_type_alignment(&m.edge_type, intent));
let edge_weight = meta.map_or(*edge_ref.weight(), |m| m.weight);
let temporal_confidence = meta.map_or(1.0, |m| {
let age_days = (Utc::now() - m.valid_at).num_seconds() as f32 / 86400.0;
let half_life = m
.metadata
.get("half_life")
.and_then(serde_json::Value::as_f64)
.unwrap_or(90.0) as f32;
edge_confidence_at(1.0, age_days, half_life)
});
let transition_score =
(0.4 * type_alignment + 0.4 * edge_weight + 0.2 * temporal_confidence)
.exp();
let accumulated = parent_score * transition_score;
heap.push(Candidate(accumulated, neighbor));
}
}
let mut next_beam = Vec::new();
let mut count = 0;
while let Some(Candidate(score, node_idx)) = heap.pop() {
if count >= beam_width {
break;
}
if visited.insert(node_idx) {
scores.insert(node_idx, score);
next_beam.push((node_idx, score));
count += 1;
}
}
if next_beam.is_empty() {
break;
}
beam = next_beam;
}
let mut results: Vec<(Uuid, f32)> = scores
.iter()
.map(|(&node_idx, &score)| (self.graph[node_idx], score))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
pub fn edge_count(&self) -> usize {
self.graph.edge_count()
}
}
impl Default for MemoryGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_add_and_traverse() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
graph.add_node(a);
graph.add_node(b);
graph.add_node(c);
graph.add_edge(a, b, 1.0);
graph.add_edge(b, c, 1.0);
let results = graph.traverse(a, 3);
assert!(results.len() >= 2, "should find b and c");
let b_score = results.iter().find(|(id, _)| *id == b).unwrap().1;
let c_score = results.iter().find(|(id, _)| *id == c).unwrap().1;
assert!(
b_score > c_score,
"b (depth 1) should score higher than c (depth 2)"
);
}
#[test]
fn test_graph_empty_traverse() {
let graph = MemoryGraph::new();
let results = graph.traverse(Uuid::new_v4(), 3);
assert!(results.is_empty());
}
#[test]
fn test_graph_traverse_unknown_start() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
graph.add_node(a);
let results = graph.traverse(Uuid::new_v4(), 3);
assert!(results.is_empty());
}
#[test]
fn test_graph_node_edge_counts() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
graph.add_edge(a, b, 0.5);
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 1);
}
#[test]
fn test_graph_duplicate_node_ignored() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
graph.add_node(a);
graph.add_node(a); assert_eq!(graph.node_count(), 1);
}
#[test]
fn test_graph_max_depth_respected() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
graph.add_edge(a, b, 1.0);
graph.add_edge(b, c, 1.0);
let results = graph.traverse(a, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, b);
}
#[test]
fn test_graph_score_formula() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
graph.add_edge(a, b, 1.0);
let results = graph.traverse(a, 2);
assert_eq!(results.len(), 1);
assert!((results[0].1 - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_graph_build_from_storage() {
use crate::storage::sqlite::SqliteBackend;
use crate::types::{Entity, EntityKind, Episode, EpisodicMemory, Namespace};
let dir = tempfile::tempdir().unwrap();
let storage = SqliteBackend::open(dir.path()).unwrap();
let ns = Namespace::new("graph-test-ns");
storage.save_namespace(&ns).unwrap();
let mut entity = Entity::new("graph-agent", EntityKind::Agent);
entity.namespace_id = ns.id;
storage.save_entity(&entity).unwrap();
let episode = Episode::new(ns.id, vec![entity.id]);
storage.save_episode(&episode).unwrap();
let mem = EpisodicMemory::new(ns.id, episode.id, entity.id, entity.id, "graph content");
storage.save_episodic(&mem).unwrap();
let graph = MemoryGraph::build_from_storage(&storage, ns.id);
assert!(graph.node_count() >= 2);
assert!(graph.edge_count() >= 1);
}
#[test]
fn test_edge_temporal_validity() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
graph.add_edge(a, b, 1.0);
graph.add_edge(a, c, 1.0);
let results = graph.traverse(a, 2);
assert_eq!(results.len(), 2);
graph.invalidate_edge(a, b, None);
let results = graph.traverse(a, 2);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, c);
let valid = graph.get_valid_edges(a);
assert_eq!(valid.len(), 1);
assert_eq!(valid[0].target, c);
}
#[test]
fn test_edge_supersession() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
let edge_ab = Edge::new(a, b, "works_at");
graph.add_edge_with_meta(edge_ab);
let edge_ac = Edge::new(a, c, "works_at");
let superseding_id = edge_ac.id;
graph.add_edge_with_meta(edge_ac);
graph.invalidate_edge(a, b, Some(superseding_id));
let results = graph.traverse(a, 2);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, c);
let valid = graph.get_valid_edges(a);
assert_eq!(valid.len(), 1);
assert_eq!(valid[0].target, c);
}
#[test]
fn test_edge_history() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
let edge_ab = Edge::new(a, b, "works_at");
graph.add_edge_with_meta(edge_ab);
let edge_ac = Edge::new(a, c, "works_at");
let superseding_id = edge_ac.id;
graph.add_edge_with_meta(edge_ac);
graph.invalidate_edge(a, b, Some(superseding_id));
let history = graph.get_edge_history(a);
assert_eq!(
history.len(),
2,
"should have both current and superseded edges"
);
let targets: Vec<Uuid> = history.iter().map(|e| e.target).collect();
assert!(
targets.contains(&b),
"history should contain superseded edge to B"
);
assert!(
targets.contains(&c),
"history should contain current edge to C"
);
let invalidated = history.iter().find(|e| e.target == b).unwrap();
assert!(invalidated.invalid_at.is_some());
assert_eq!(invalidated.superseded_by, Some(superseding_id));
let current = history.iter().find(|e| e.target == c).unwrap();
assert!(current.invalid_at.is_none());
}
#[test]
fn test_edge_type_alignment() {
assert!(edge_type_alignment(&EdgeType::Temporal, "recall") > 0.3);
assert!(edge_type_alignment(&EdgeType::Causal, "action") > 0.3);
assert!(edge_type_alignment(&EdgeType::Entity, "question") > 0.3);
assert!((edge_type_alignment(&EdgeType::Temporal, "action") - 0.3).abs() < 0.01);
}
#[test]
fn test_edge_confidence_decays() {
let fresh = edge_confidence_at(1.0, 0.0, 90.0);
let old = edge_confidence_at(1.0, 180.0, 90.0);
assert!((fresh - 1.0).abs() < 0.01);
assert!((old - 0.25).abs() < 0.1);
assert!(fresh > old);
}
#[test]
fn test_beam_search_basic() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
graph.add_node(a);
graph.add_node(b);
graph.add_node(c);
let mut edge_ab = Edge::new(a, b, "caused");
edge_ab.edge_type = EdgeType::Causal;
edge_ab.weight = 0.8;
graph.add_edge_with_meta(edge_ab);
let mut edge_ac = Edge::new(a, c, "mentioned");
edge_ac.edge_type = EdgeType::Entity;
edge_ac.weight = 0.5;
graph.add_edge_with_meta(edge_ac);
let results = graph.beam_search(a, "action", 5, 2);
assert_eq!(results.len(), 2);
let b_score = results
.iter()
.find(|(id, _)| *id == b)
.map(|(_, s)| *s)
.unwrap_or(0.0);
let c_score = results
.iter()
.find(|(id, _)| *id == c)
.map(|(_, s)| *s)
.unwrap_or(0.0);
assert!(
b_score > c_score,
"Causal edge should score higher for action intent: b={b_score} c={c_score}"
);
}
#[test]
fn test_invalidated_edge_blocks_transitive_traversal() {
let mut graph = MemoryGraph::new();
let a = Uuid::new_v4();
let b = Uuid::new_v4();
let c = Uuid::new_v4();
graph.add_edge(a, b, 1.0);
graph.add_edge(b, c, 1.0);
assert_eq!(graph.traverse(a, 3).len(), 2);
graph.invalidate_edge(a, b, None);
assert!(
graph.traverse(a, 3).is_empty(),
"invalidated edge should block transitive traversal"
);
}
}