Skip to main content

llm_agent_runtime/
graph.rs

1//! # Module: Graph
2//!
3//! ## Responsibility
4//! Provides an in-memory knowledge graph with typed entities and relationships.
5//! Mirrors the public API of `mem-graph`.
6//!
7//! ## Guarantees
8//! - Thread-safe: `GraphStore` wraps state in `Arc<Mutex<_>>`
9//! - BFS/DFS traversal and shortest-path are correct for directed graphs
10//! - Non-panicking: all operations return `Result`
11//!
12//! ## NOT Responsible For
13//! - Persistence to disk or external store
14//! - Graph sharding / distributed graphs
15
16use crate::error::AgentRuntimeError;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20use std::sync::{Arc, Mutex};
21
22// ── Lock recovery helper ───────────────────────────────────────────────────────
23
24/// Recover from a poisoned mutex by logging a warning and returning the inner
25/// value. This is safe because we never leave shared data in a partially-written
26/// state across an await or panic boundary in this module.
27fn recover_lock<'a, T>(
28    result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
29    ctx: &str,
30) -> std::sync::MutexGuard<'a, T>
31where
32    T: ?Sized,
33{
34    match result {
35        Ok(guard) => guard,
36        Err(poisoned) => {
37            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
38            poisoned.into_inner()
39        }
40    }
41}
42
43// ── OrdF32 newtype ─────────────────────────────────────────────────────────────
44
45/// Newtype wrapper for `f32` that implements `Ord`.
46/// NaN is treated as `Greater` for safe use in a `BinaryHeap`.
47#[derive(Debug, Clone, Copy, PartialEq)]
48struct OrdF32(f32);
49
50impl Eq for OrdF32 {}
51
52impl PartialOrd for OrdF32 {
53    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl Ord for OrdF32 {
59    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
60        self.0
61            .partial_cmp(&other.0)
62            .unwrap_or(std::cmp::Ordering::Greater)
63    }
64}
65
66// ── EntityId ──────────────────────────────────────────────────────────────────
67
68/// Stable identifier for a graph entity.
69#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
70pub struct EntityId(pub String);
71
72impl EntityId {
73    /// Create a new `EntityId` from any string-like value.
74    pub fn new(id: impl Into<String>) -> Self {
75        Self(id.into())
76    }
77}
78
79impl std::fmt::Display for EntityId {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{}", self.0)
82    }
83}
84
85// ── Entity ────────────────────────────────────────────────────────────────────
86
87/// A node in the knowledge graph.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Entity {
90    /// Unique identifier.
91    pub id: EntityId,
92    /// Human-readable label (e.g. "Person", "Concept").
93    pub label: String,
94    /// Arbitrary key-value properties.
95    pub properties: HashMap<String, Value>,
96}
97
98impl Entity {
99    /// Construct a new entity with no properties.
100    pub fn new(id: impl Into<String>, label: impl Into<String>) -> Self {
101        Self {
102            id: EntityId::new(id),
103            label: label.into(),
104            properties: HashMap::new(),
105        }
106    }
107
108    /// Construct a new entity with the given properties.
109    pub fn with_properties(
110        id: impl Into<String>,
111        label: impl Into<String>,
112        properties: HashMap<String, Value>,
113    ) -> Self {
114        Self {
115            id: EntityId::new(id),
116            label: label.into(),
117            properties,
118        }
119    }
120}
121
122// ── Relationship ──────────────────────────────────────────────────────────────
123
124/// A directed, typed edge between two entities.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Relationship {
127    /// Source entity.
128    pub from: EntityId,
129    /// Target entity.
130    pub to: EntityId,
131    /// Relationship type label (e.g. "KNOWS", "PART_OF").
132    pub kind: String,
133    /// Optional weight for weighted-graph use cases.
134    pub weight: f32,
135}
136
137impl Relationship {
138    /// Construct a new relationship with the given kind and weight.
139    pub fn new(
140        from: impl Into<String>,
141        to: impl Into<String>,
142        kind: impl Into<String>,
143        weight: f32,
144    ) -> Self {
145        Self {
146            from: EntityId::new(from),
147            to: EntityId::new(to),
148            kind: kind.into(),
149            weight,
150        }
151    }
152}
153
154// ── MemGraphError (mirrors upstream) ─────────────────────────────────────────
155
156/// Graph-specific errors, mirrors `mem-graph::MemGraphError`.
157#[derive(Debug, thiserror::Error)]
158pub enum MemGraphError {
159    /// The requested entity was not found.
160    #[error("Entity '{0}' not found")]
161    EntityNotFound(String),
162
163    /// A relationship between the two entities already exists with the same kind.
164    #[error("Relationship '{kind}' from '{from}' to '{to}' already exists")]
165    DuplicateRelationship {
166        /// Source entity ID.
167        from: String,
168        /// Target entity ID.
169        to: String,
170        /// Relationship kind label.
171        kind: String,
172    },
173
174    /// Generic internal error.
175    #[error("Graph internal error: {0}")]
176    Internal(String),
177}
178
179impl From<MemGraphError> for AgentRuntimeError {
180    fn from(e: MemGraphError) -> Self {
181        AgentRuntimeError::Graph(e.to_string())
182    }
183}
184
185// ── GraphStore ────────────────────────────────────────────────────────────────
186
187/// In-memory knowledge graph supporting entities, relationships, BFS/DFS,
188/// shortest-path, weighted shortest-path, and graph analytics.
189///
190/// ## Guarantees
191/// - Thread-safe via `Arc<Mutex<_>>`
192/// - BFS/DFS are non-recursive (stack-safe)
193/// - Shortest-path is hop-count based (BFS)
194/// - Weighted shortest-path uses Dijkstra's algorithm
195#[derive(Debug, Clone)]
196pub struct GraphStore {
197    inner: Arc<Mutex<GraphInner>>,
198}
199
200#[derive(Debug)]
201struct GraphInner {
202    entities: HashMap<EntityId, Entity>,
203    relationships: Vec<Relationship>,
204}
205
206impl GraphStore {
207    /// Create a new, empty graph store.
208    pub fn new() -> Self {
209        Self {
210            inner: Arc::new(Mutex::new(GraphInner {
211                entities: HashMap::new(),
212                relationships: Vec::new(),
213            })),
214        }
215    }
216
217    /// Add an entity to the graph.
218    ///
219    /// If an entity with the same ID already exists, it is replaced.
220    pub fn add_entity(&self, entity: Entity) -> Result<(), AgentRuntimeError> {
221        let mut inner = recover_lock(self.inner.lock(), "add_entity");
222        inner.entities.insert(entity.id.clone(), entity);
223        Ok(())
224    }
225
226    /// Retrieve an entity by ID.
227    pub fn get_entity(&self, id: &EntityId) -> Result<Entity, AgentRuntimeError> {
228        let inner = recover_lock(self.inner.lock(), "get_entity");
229        inner
230            .entities
231            .get(id)
232            .cloned()
233            .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))
234    }
235
236    /// Add a directed relationship between two existing entities.
237    ///
238    /// Both source and target entities must already exist in the graph.
239    pub fn add_relationship(&self, rel: Relationship) -> Result<(), AgentRuntimeError> {
240        let mut inner = recover_lock(self.inner.lock(), "add_relationship");
241
242        if !inner.entities.contains_key(&rel.from) {
243            return Err(AgentRuntimeError::Graph(format!(
244                "source entity '{}' not found",
245                rel.from.0
246            )));
247        }
248        if !inner.entities.contains_key(&rel.to) {
249            return Err(AgentRuntimeError::Graph(format!(
250                "target entity '{}' not found",
251                rel.to.0
252            )));
253        }
254
255        // Reject duplicate (from, to, kind) triples — the DuplicateRelationship
256        // error variant existed but was never raised, silently allowing duplicate
257        // edges that corrupt relationship_count() and BFS/DFS result counts.
258        let duplicate = inner
259            .relationships
260            .iter()
261            .any(|r| r.from == rel.from && r.to == rel.to && r.kind == rel.kind);
262        if duplicate {
263            return Err(AgentRuntimeError::Graph(
264                MemGraphError::DuplicateRelationship {
265                    from: rel.from.0.clone(),
266                    to: rel.to.0.clone(),
267                    kind: rel.kind.clone(),
268                }
269                .to_string(),
270            ));
271        }
272
273        inner.relationships.push(rel);
274        Ok(())
275    }
276
277    /// Remove an entity and all relationships involving it.
278    pub fn remove_entity(&self, id: &EntityId) -> Result<(), AgentRuntimeError> {
279        let mut inner = recover_lock(self.inner.lock(), "remove_entity");
280
281        if inner.entities.remove(id).is_none() {
282            return Err(AgentRuntimeError::Graph(format!(
283                "entity '{}' not found",
284                id.0
285            )));
286        }
287        inner.relationships.retain(|r| &r.from != id && &r.to != id);
288        Ok(())
289    }
290
291    /// Return all direct neighbours of the given entity (BFS layer 1).
292    fn neighbours(relationships: &[Relationship], id: &EntityId) -> Vec<EntityId> {
293        relationships
294            .iter()
295            .filter(|r| &r.from == id)
296            .map(|r| r.to.clone())
297            .collect()
298    }
299
300    /// Breadth-first search starting from `start`.
301    ///
302    /// Returns entity IDs in BFS discovery order (not including the start node).
303    #[tracing::instrument(skip(self))]
304    pub fn bfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
305        let inner = recover_lock(self.inner.lock(), "bfs");
306
307        if !inner.entities.contains_key(start) {
308            return Err(AgentRuntimeError::Graph(format!(
309                "start entity '{}' not found",
310                start.0
311            )));
312        }
313
314        let mut visited: HashSet<EntityId> = HashSet::new();
315        let mut queue: VecDeque<EntityId> = VecDeque::new();
316        let mut result: Vec<EntityId> = Vec::new();
317
318        visited.insert(start.clone());
319        queue.push_back(start.clone());
320
321        while let Some(current) = queue.pop_front() {
322            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
323            for neighbour in neighbours {
324                if visited.insert(neighbour.clone()) {
325                    result.push(neighbour.clone());
326                    queue.push_back(neighbour);
327                }
328            }
329        }
330
331        tracing::debug!("BFS visited {} nodes", result.len());
332        Ok(result)
333    }
334
335    /// Depth-first search starting from `start`.
336    ///
337    /// Returns entity IDs in DFS discovery order (not including the start node).
338    #[tracing::instrument(skip(self))]
339    pub fn dfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
340        let inner = recover_lock(self.inner.lock(), "dfs");
341
342        if !inner.entities.contains_key(start) {
343            return Err(AgentRuntimeError::Graph(format!(
344                "start entity '{}' not found",
345                start.0
346            )));
347        }
348
349        let mut visited: HashSet<EntityId> = HashSet::new();
350        let mut stack: Vec<EntityId> = Vec::new();
351        let mut result: Vec<EntityId> = Vec::new();
352
353        visited.insert(start.clone());
354        stack.push(start.clone());
355
356        while let Some(current) = stack.pop() {
357            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
358            for neighbour in neighbours {
359                if visited.insert(neighbour.clone()) {
360                    result.push(neighbour.clone());
361                    stack.push(neighbour);
362                }
363            }
364        }
365
366        tracing::debug!("DFS visited {} nodes", result.len());
367        Ok(result)
368    }
369
370    /// Find the shortest path (by hop count) between `from` and `to`.
371    ///
372    /// # Returns
373    /// - `Some(path)` — ordered list of `EntityId`s from `from` to `to` (inclusive)
374    /// - `None` — no path exists
375    #[tracing::instrument(skip(self))]
376    pub fn shortest_path(
377        &self,
378        from: &EntityId,
379        to: &EntityId,
380    ) -> Result<Option<Vec<EntityId>>, AgentRuntimeError> {
381        let inner = recover_lock(self.inner.lock(), "shortest_path");
382
383        if !inner.entities.contains_key(from) {
384            return Err(AgentRuntimeError::Graph(format!(
385                "source entity '{}' not found",
386                from.0
387            )));
388        }
389        if !inner.entities.contains_key(to) {
390            return Err(AgentRuntimeError::Graph(format!(
391                "target entity '{}' not found",
392                to.0
393            )));
394        }
395
396        if from == to {
397            return Ok(Some(vec![from.clone()]));
398        }
399
400        let mut visited: HashSet<EntityId> = HashSet::new();
401        let mut queue: VecDeque<Vec<EntityId>> = VecDeque::new();
402
403        visited.insert(from.clone());
404        queue.push_back(vec![from.clone()]);
405
406        while let Some(path) = queue.pop_front() {
407            let current = match path.last() {
408                Some(c) => c.clone(),
409                None => continue,
410            };
411
412            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
413
414            for neighbour in neighbours {
415                if &neighbour == to {
416                    let mut full_path = path.clone();
417                    full_path.push(neighbour);
418                    return Ok(Some(full_path));
419                }
420                if visited.insert(neighbour.clone()) {
421                    let mut new_path = path.clone();
422                    new_path.push(neighbour.clone());
423                    queue.push_back(new_path);
424                }
425            }
426        }
427
428        Ok(None)
429    }
430
431    /// Find the shortest weighted path between `from` and `to` using Dijkstra's algorithm.
432    ///
433    /// Uses `Relationship::weight` as edge cost. Negative weights are not supported
434    /// and will cause this method to return an error.
435    ///
436    /// # Returns
437    /// - `Ok(Some((path, total_weight)))` — the shortest path and its total weight
438    /// - `Ok(None)` — no path exists between `from` and `to`
439    /// - `Err(...)` — either entity not found, or a negative edge weight was encountered
440    pub fn shortest_path_weighted(
441        &self,
442        from: &EntityId,
443        to: &EntityId,
444    ) -> Result<Option<(Vec<EntityId>, f32)>, AgentRuntimeError> {
445        let inner = recover_lock(self.inner.lock(), "shortest_path_weighted");
446
447        if !inner.entities.contains_key(from) {
448            return Err(AgentRuntimeError::Graph(format!(
449                "source entity '{}' not found",
450                from.0
451            )));
452        }
453        if !inner.entities.contains_key(to) {
454            return Err(AgentRuntimeError::Graph(format!(
455                "target entity '{}' not found",
456                to.0
457            )));
458        }
459
460        // Validate: no negative weights
461        for rel in &inner.relationships {
462            if rel.weight < 0.0 {
463                return Err(AgentRuntimeError::Graph(format!(
464                    "negative weight {:.4} on edge '{}' -> '{}'",
465                    rel.weight, rel.from.0, rel.to.0
466                )));
467            }
468        }
469
470        if from == to {
471            return Ok(Some((vec![from.clone()], 0.0)));
472        }
473
474        // Dijkstra using a max-heap with negated weights to simulate a min-heap.
475        // Heap entries: (negated_cost, node_id)
476        let mut dist: HashMap<EntityId, f32> = HashMap::new();
477        let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
478        // BinaryHeap is a max-heap; negate weights to get min-heap behaviour.
479        let mut heap: BinaryHeap<(OrdF32, EntityId)> = BinaryHeap::new();
480
481        dist.insert(from.clone(), 0.0);
482        heap.push((OrdF32(-0.0), from.clone()));
483
484        while let Some((OrdF32(neg_cost), current)) = heap.pop() {
485            let cost = -neg_cost;
486
487            // Skip stale entries.
488            if let Some(&best) = dist.get(&current) {
489                if cost > best {
490                    continue;
491                }
492            }
493
494            if &current == to {
495                // Reconstruct path in reverse.
496                let mut path = vec![to.clone()];
497                let mut node = to.clone();
498                while let Some(p) = prev.get(&node) {
499                    path.push(p.clone());
500                    node = p.clone();
501                }
502                path.reverse();
503                return Ok(Some((path, cost)));
504            }
505
506            for rel in inner.relationships.iter().filter(|r| &r.from == &current) {
507                let next_cost = cost + rel.weight;
508                let entry = dist.entry(rel.to.clone()).or_insert(f32::INFINITY);
509                if next_cost < *entry {
510                    *entry = next_cost;
511                    prev.insert(rel.to.clone(), current.clone());
512                    heap.push((OrdF32(-next_cost), rel.to.clone()));
513                }
514            }
515        }
516
517        Ok(None)
518    }
519
520    /// Compute the transitive closure: all entities reachable from `start`.
521    pub fn transitive_closure(
522        &self,
523        start: &EntityId,
524    ) -> Result<HashSet<EntityId>, AgentRuntimeError> {
525        let reachable = self.bfs(start)?;
526        let mut set: HashSet<EntityId> = reachable.into_iter().collect();
527        set.insert(start.clone());
528        Ok(set)
529    }
530
531    /// Return the number of entities in the graph.
532    pub fn entity_count(&self) -> Result<usize, AgentRuntimeError> {
533        let inner = recover_lock(self.inner.lock(), "entity_count");
534        Ok(inner.entities.len())
535    }
536
537    /// Return the number of relationships in the graph.
538    pub fn relationship_count(&self) -> Result<usize, AgentRuntimeError> {
539        let inner = recover_lock(self.inner.lock(), "relationship_count");
540        Ok(inner.relationships.len())
541    }
542
543    /// Compute normalized degree centrality for each entity.
544    /// Degree = (in_degree + out_degree) / (n - 1), or 0.0 if n <= 1.
545    pub fn degree_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
546        let inner = recover_lock(self.inner.lock(), "degree_centrality");
547        let n = inner.entities.len();
548
549        let mut out_degree: HashMap<EntityId, usize> = HashMap::new();
550        let mut in_degree: HashMap<EntityId, usize> = HashMap::new();
551
552        for id in inner.entities.keys() {
553            out_degree.insert(id.clone(), 0);
554            in_degree.insert(id.clone(), 0);
555        }
556
557        for rel in &inner.relationships {
558            *out_degree.entry(rel.from.clone()).or_insert(0) += 1;
559            *in_degree.entry(rel.to.clone()).or_insert(0) += 1;
560        }
561
562        let denom = if n <= 1 { 1.0 } else { (n - 1) as f32 };
563        let mut result = HashMap::new();
564        for id in inner.entities.keys() {
565            let od = *out_degree.get(id).unwrap_or(&0);
566            let id_ = *in_degree.get(id).unwrap_or(&0);
567            let centrality = if n <= 1 {
568                0.0
569            } else {
570                (od + id_) as f32 / denom
571            };
572            result.insert(id.clone(), centrality);
573        }
574
575        Ok(result)
576    }
577
578    /// Compute normalized betweenness centrality for each entity.
579    /// Uses Brandes' algorithm with hop-count BFS.
580    ///
581    /// # Complexity
582    /// O(V * E) time. Not suitable for very large graphs (>1000 nodes).
583    pub fn betweenness_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
584        let inner = recover_lock(self.inner.lock(), "betweenness_centrality");
585        let n = inner.entities.len();
586        let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
587
588        let mut centrality: HashMap<EntityId, f32> =
589            nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
590
591        for source in &nodes {
592            // BFS to find shortest path counts and predecessors.
593            let mut stack: Vec<EntityId> = Vec::new();
594            let mut predecessors: HashMap<EntityId, Vec<EntityId>> =
595                nodes.iter().map(|id| (id.clone(), vec![])).collect();
596            let mut sigma: HashMap<EntityId, f32> =
597                nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
598            let mut dist: HashMap<EntityId, i64> =
599                nodes.iter().map(|id| (id.clone(), -1i64)).collect();
600
601            *sigma.entry(source.clone()).or_insert(0.0) = 1.0;
602            *dist.entry(source.clone()).or_insert(-1) = 0;
603
604            let mut queue: VecDeque<EntityId> = VecDeque::new();
605            queue.push_back(source.clone());
606
607            while let Some(v) = queue.pop_front() {
608                stack.push(v.clone());
609                let d_v = *dist.get(&v).unwrap_or(&0);
610                let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
611                for rel in inner.relationships.iter().filter(|r| &r.from == &v) {
612                    let w = &rel.to;
613                    let d_w = dist.get(w).copied().unwrap_or(-1);
614                    if d_w < 0 {
615                        queue.push_back(w.clone());
616                        *dist.entry(w.clone()).or_insert(-1) = d_v + 1;
617                    }
618                    if dist.get(w).copied().unwrap_or(-1) == d_v + 1 {
619                        *sigma.entry(w.clone()).or_insert(0.0) += sigma_v;
620                        predecessors.entry(w.clone()).or_default().push(v.clone());
621                    }
622                }
623            }
624
625            let mut delta: HashMap<EntityId, f32> =
626                nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
627
628            while let Some(w) = stack.pop() {
629                let delta_w = *delta.get(&w).unwrap_or(&0.0);
630                let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
631                for v in predecessors.get(&w).cloned().unwrap_or_default() {
632                    let sigma_v = *sigma.get(&v).unwrap_or(&1.0);
633                    let contribution = (sigma_v / sigma_w) * (1.0 + delta_w);
634                    *delta.entry(v.clone()).or_insert(0.0) += contribution;
635                }
636                if &w != source {
637                    *centrality.entry(w.clone()).or_insert(0.0) += delta_w;
638                }
639            }
640        }
641
642        // Normalize by 2 / ((n-1) * (n-2)) for directed graphs.
643        if n > 2 {
644            let norm = 2.0 / (((n - 1) * (n - 2)) as f32);
645            for v in centrality.values_mut() {
646                *v *= norm;
647            }
648        } else {
649            for v in centrality.values_mut() {
650                *v = 0.0;
651            }
652        }
653
654        Ok(centrality)
655    }
656
657    /// Detect communities using label propagation.
658    /// Each entity starts as its own community. In each iteration, each entity
659    /// adopts the most frequent label among its neighbours.
660    /// Returns a map of entity ID → community ID (usize).
661    pub fn label_propagation_communities(
662        &self,
663        max_iterations: usize,
664    ) -> Result<HashMap<EntityId, usize>, AgentRuntimeError> {
665        let inner = recover_lock(self.inner.lock(), "label_propagation_communities");
666        let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
667
668        // Assign each node a unique initial label (index in nodes vec).
669        let mut labels: HashMap<EntityId, usize> = nodes
670            .iter()
671            .enumerate()
672            .map(|(i, id)| (id.clone(), i))
673            .collect();
674
675        for _ in 0..max_iterations {
676            let mut changed = false;
677            // Iterate in a stable order.
678            for node in &nodes {
679                // Collect neighbour labels.
680                let neighbour_labels: Vec<usize> = inner
681                    .relationships
682                    .iter()
683                    .filter(|r| &r.from == node || &r.to == node)
684                    .map(|r| {
685                        if &r.from == node {
686                            labels.get(&r.to).copied().unwrap_or(0)
687                        } else {
688                            labels.get(&r.from).copied().unwrap_or(0)
689                        }
690                    })
691                    .collect();
692
693                if neighbour_labels.is_empty() {
694                    continue;
695                }
696
697                // Find the most frequent label.
698                let mut freq: HashMap<usize, usize> = HashMap::new();
699                for &lbl in &neighbour_labels {
700                    *freq.entry(lbl).or_insert(0) += 1;
701                }
702                let best = freq
703                    .into_iter()
704                    .max_by_key(|&(_, count)| count)
705                    .map(|(lbl, _)| lbl);
706
707                if let Some(new_label) = best {
708                    let current = labels.entry(node.clone()).or_insert(0);
709                    if *current != new_label {
710                        *current = new_label;
711                        changed = true;
712                    }
713                }
714            }
715
716            if !changed {
717                break;
718            }
719        }
720
721        Ok(labels)
722    }
723
724    /// Extract a subgraph containing only the specified entities and the
725    /// relationships between them.
726    pub fn subgraph(&self, node_ids: &[EntityId]) -> Result<GraphStore, AgentRuntimeError> {
727        let inner = recover_lock(self.inner.lock(), "subgraph");
728        let id_set: HashSet<&EntityId> = node_ids.iter().collect();
729
730        let new_store = GraphStore::new();
731
732        for id in node_ids {
733            let entity = inner
734                .entities
735                .get(id)
736                .ok_or_else(|| {
737                    AgentRuntimeError::Graph(format!("entity '{}' not found", id.0))
738                })?
739                .clone();
740            // We hold inner lock; call directly on the new store's inner.
741            let mut new_inner = recover_lock(new_store.inner.lock(), "subgraph:add_entity");
742            new_inner.entities.insert(entity.id.clone(), entity);
743        }
744
745        for rel in inner.relationships.iter() {
746            if id_set.contains(&rel.from) && id_set.contains(&rel.to) {
747                let mut new_inner =
748                    recover_lock(new_store.inner.lock(), "subgraph:add_relationship");
749                new_inner.relationships.push(rel.clone());
750            }
751        }
752
753        Ok(new_store)
754    }
755}
756
757impl Default for GraphStore {
758    fn default() -> Self {
759        Self::new()
760    }
761}
762
763// ── Tests ─────────────────────────────────────────────────────────────────────
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768
769    fn make_graph() -> GraphStore {
770        GraphStore::new()
771    }
772
773    fn add(g: &GraphStore, id: &str) {
774        g.add_entity(Entity::new(id, "Node")).unwrap();
775    }
776
777    fn link(g: &GraphStore, from: &str, to: &str) {
778        g.add_relationship(Relationship::new(from, to, "CONNECTS", 1.0))
779            .unwrap();
780    }
781
782    fn link_w(g: &GraphStore, from: &str, to: &str, weight: f32) {
783        g.add_relationship(Relationship::new(from, to, "CONNECTS", weight))
784            .unwrap();
785    }
786
787    // ── EntityId ──────────────────────────────────────────────────────────────
788
789    #[test]
790    fn test_entity_id_equality() {
791        assert_eq!(EntityId::new("a"), EntityId::new("a"));
792        assert_ne!(EntityId::new("a"), EntityId::new("b"));
793    }
794
795    #[test]
796    fn test_entity_id_display() {
797        let id = EntityId::new("hello");
798        assert_eq!(id.to_string(), "hello");
799    }
800
801    // ── Entity ────────────────────────────────────────────────────────────────
802
803    #[test]
804    fn test_entity_new_has_empty_properties() {
805        let e = Entity::new("e1", "Person");
806        assert!(e.properties.is_empty());
807    }
808
809    #[test]
810    fn test_entity_with_properties_stores_props() {
811        let mut props = HashMap::new();
812        props.insert("age".into(), Value::Number(42.into()));
813        let e = Entity::with_properties("e1", "Person", props);
814        assert!(e.properties.contains_key("age"));
815    }
816
817    // ── GraphStore basic ops ──────────────────────────────────────────────────
818
819    #[test]
820    fn test_graph_add_entity_increments_count() {
821        let g = make_graph();
822        add(&g, "a");
823        assert_eq!(g.entity_count().unwrap(), 1);
824    }
825
826    #[test]
827    fn test_graph_get_entity_returns_entity() {
828        let g = make_graph();
829        g.add_entity(Entity::new("e1", "Person")).unwrap();
830        let e = g.get_entity(&EntityId::new("e1")).unwrap();
831        assert_eq!(e.label, "Person");
832    }
833
834    #[test]
835    fn test_graph_get_entity_missing_returns_error() {
836        let g = make_graph();
837        assert!(g.get_entity(&EntityId::new("ghost")).is_err());
838    }
839
840    #[test]
841    fn test_graph_add_relationship_increments_count() {
842        let g = make_graph();
843        add(&g, "a");
844        add(&g, "b");
845        link(&g, "a", "b");
846        assert_eq!(g.relationship_count().unwrap(), 1);
847    }
848
849    #[test]
850    fn test_graph_add_relationship_missing_source_fails() {
851        let g = make_graph();
852        add(&g, "b");
853        let result = g.add_relationship(Relationship::new("ghost", "b", "X", 1.0));
854        assert!(result.is_err());
855    }
856
857    #[test]
858    fn test_graph_add_relationship_missing_target_fails() {
859        let g = make_graph();
860        add(&g, "a");
861        let result = g.add_relationship(Relationship::new("a", "ghost", "X", 1.0));
862        assert!(result.is_err());
863    }
864
865    #[test]
866    fn test_graph_remove_entity_removes_relationships() {
867        let g = make_graph();
868        add(&g, "a");
869        add(&g, "b");
870        link(&g, "a", "b");
871        g.remove_entity(&EntityId::new("a")).unwrap();
872        assert_eq!(g.entity_count().unwrap(), 1);
873        assert_eq!(g.relationship_count().unwrap(), 0);
874    }
875
876    #[test]
877    fn test_graph_remove_entity_missing_returns_error() {
878        let g = make_graph();
879        assert!(g.remove_entity(&EntityId::new("ghost")).is_err());
880    }
881
882    // ── BFS ───────────────────────────────────────────────────────────────────
883
884    #[test]
885    fn test_bfs_finds_direct_neighbours() {
886        let g = make_graph();
887        add(&g, "a");
888        add(&g, "b");
889        add(&g, "c");
890        link(&g, "a", "b");
891        link(&g, "a", "c");
892        let visited = g.bfs(&EntityId::new("a")).unwrap();
893        assert_eq!(visited.len(), 2);
894    }
895
896    #[test]
897    fn test_bfs_traverses_chain() {
898        let g = make_graph();
899        add(&g, "a");
900        add(&g, "b");
901        add(&g, "c");
902        add(&g, "d");
903        link(&g, "a", "b");
904        link(&g, "b", "c");
905        link(&g, "c", "d");
906        let visited = g.bfs(&EntityId::new("a")).unwrap();
907        assert_eq!(visited.len(), 3);
908        assert_eq!(visited[0], EntityId::new("b"));
909    }
910
911    #[test]
912    fn test_bfs_handles_isolated_node() {
913        let g = make_graph();
914        add(&g, "a");
915        let visited = g.bfs(&EntityId::new("a")).unwrap();
916        assert!(visited.is_empty());
917    }
918
919    #[test]
920    fn test_bfs_missing_start_returns_error() {
921        let g = make_graph();
922        assert!(g.bfs(&EntityId::new("ghost")).is_err());
923    }
924
925    // ── DFS ───────────────────────────────────────────────────────────────────
926
927    #[test]
928    fn test_dfs_visits_all_reachable_nodes() {
929        let g = make_graph();
930        add(&g, "a");
931        add(&g, "b");
932        add(&g, "c");
933        add(&g, "d");
934        link(&g, "a", "b");
935        link(&g, "a", "c");
936        link(&g, "b", "d");
937        let visited = g.dfs(&EntityId::new("a")).unwrap();
938        assert_eq!(visited.len(), 3);
939    }
940
941    #[test]
942    fn test_dfs_handles_isolated_node() {
943        let g = make_graph();
944        add(&g, "a");
945        let visited = g.dfs(&EntityId::new("a")).unwrap();
946        assert!(visited.is_empty());
947    }
948
949    #[test]
950    fn test_dfs_missing_start_returns_error() {
951        let g = make_graph();
952        assert!(g.dfs(&EntityId::new("ghost")).is_err());
953    }
954
955    // ── Shortest path ─────────────────────────────────────────────────────────
956
957    #[test]
958    fn test_shortest_path_direct_connection() {
959        let g = make_graph();
960        add(&g, "a");
961        add(&g, "b");
962        link(&g, "a", "b");
963        let path = g
964            .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
965            .unwrap();
966        assert_eq!(path, Some(vec![EntityId::new("a"), EntityId::new("b")]));
967    }
968
969    #[test]
970    fn test_shortest_path_multi_hop() {
971        let g = make_graph();
972        add(&g, "a");
973        add(&g, "b");
974        add(&g, "c");
975        link(&g, "a", "b");
976        link(&g, "b", "c");
977        let path = g
978            .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
979            .unwrap();
980        assert_eq!(path.as_ref().map(|p| p.len()), Some(3));
981    }
982
983    #[test]
984    fn test_shortest_path_returns_none_for_disconnected() {
985        let g = make_graph();
986        add(&g, "a");
987        add(&g, "b");
988        let path = g
989            .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
990            .unwrap();
991        assert_eq!(path, None);
992    }
993
994    #[test]
995    fn test_shortest_path_same_node_returns_single_element() {
996        let g = make_graph();
997        add(&g, "a");
998        let path = g
999            .shortest_path(&EntityId::new("a"), &EntityId::new("a"))
1000            .unwrap();
1001        assert_eq!(path, Some(vec![EntityId::new("a")]));
1002    }
1003
1004    #[test]
1005    fn test_shortest_path_missing_source_returns_error() {
1006        let g = make_graph();
1007        add(&g, "b");
1008        assert!(g
1009            .shortest_path(&EntityId::new("ghost"), &EntityId::new("b"))
1010            .is_err());
1011    }
1012
1013    #[test]
1014    fn test_shortest_path_missing_target_returns_error() {
1015        let g = make_graph();
1016        add(&g, "a");
1017        assert!(g
1018            .shortest_path(&EntityId::new("a"), &EntityId::new("ghost"))
1019            .is_err());
1020    }
1021
1022    // ── Transitive closure ────────────────────────────────────────────────────
1023
1024    #[test]
1025    fn test_transitive_closure_includes_start() {
1026        let g = make_graph();
1027        add(&g, "a");
1028        add(&g, "b");
1029        link(&g, "a", "b");
1030        let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1031        assert!(closure.contains(&EntityId::new("a")));
1032        assert!(closure.contains(&EntityId::new("b")));
1033    }
1034
1035    #[test]
1036    fn test_transitive_closure_isolated_node_contains_only_self() {
1037        let g = make_graph();
1038        add(&g, "a");
1039        let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1040        assert_eq!(closure.len(), 1);
1041    }
1042
1043    // ── MemGraphError conversion ──────────────────────────────────────────────
1044
1045    #[test]
1046    fn test_mem_graph_error_converts_to_runtime_error() {
1047        let e = MemGraphError::EntityNotFound("x".into());
1048        let re: AgentRuntimeError = e.into();
1049        assert!(matches!(re, AgentRuntimeError::Graph(_)));
1050    }
1051
1052    // ── Weighted shortest path ────────────────────────────────────────────────
1053
1054    #[test]
1055    fn test_shortest_path_weighted_simple() {
1056        // a --(1.0)--> b --(2.0)--> c
1057        // a --(10.0)--> c  (direct but heavier)
1058        let g = make_graph();
1059        add(&g, "a");
1060        add(&g, "b");
1061        add(&g, "c");
1062        link_w(&g, "a", "b", 1.0);
1063        link_w(&g, "b", "c", 2.0);
1064        g.add_relationship(Relationship::new("a", "c", "DIRECT", 10.0))
1065            .unwrap();
1066
1067        let result = g
1068            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("c"))
1069            .unwrap();
1070        assert!(result.is_some());
1071        let (path, weight) = result.unwrap();
1072        // The cheapest path is a -> b -> c with total weight 3.0
1073        assert_eq!(
1074            path,
1075            vec![EntityId::new("a"), EntityId::new("b"), EntityId::new("c")]
1076        );
1077        assert!((weight - 3.0).abs() < 1e-5);
1078    }
1079
1080    #[test]
1081    fn test_shortest_path_weighted_returns_none_for_disconnected() {
1082        let g = make_graph();
1083        add(&g, "a");
1084        add(&g, "b");
1085        let result = g
1086            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"))
1087            .unwrap();
1088        assert!(result.is_none());
1089    }
1090
1091    #[test]
1092    fn test_shortest_path_weighted_same_node() {
1093        let g = make_graph();
1094        add(&g, "a");
1095        let result = g
1096            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("a"))
1097            .unwrap();
1098        assert_eq!(result, Some((vec![EntityId::new("a")], 0.0)));
1099    }
1100
1101    #[test]
1102    fn test_shortest_path_weighted_negative_weight_errors() {
1103        let g = make_graph();
1104        add(&g, "a");
1105        add(&g, "b");
1106        g.add_relationship(Relationship::new("a", "b", "NEG", -1.0))
1107            .unwrap();
1108        let result = g.shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"));
1109        assert!(result.is_err());
1110    }
1111
1112    // ── Degree centrality ─────────────────────────────────────────────────────
1113
1114    #[test]
1115    fn test_degree_centrality_basic() {
1116        // Star graph: a -> b, a -> c, a -> d
1117        // a: out=3, in=0 => (3+0)/(4-1) = 1.0
1118        // b: out=0, in=1 => (0+1)/3 = 0.333...
1119        let g = make_graph();
1120        add(&g, "a");
1121        add(&g, "b");
1122        add(&g, "c");
1123        add(&g, "d");
1124        link(&g, "a", "b");
1125        link(&g, "a", "c");
1126        link(&g, "a", "d");
1127
1128        let centrality = g.degree_centrality().unwrap();
1129        let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1130        let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1131
1132        assert!((a_cent - 1.0).abs() < 1e-5, "a centrality was {a_cent}");
1133        assert!(
1134            (b_cent - 1.0 / 3.0).abs() < 1e-5,
1135            "b centrality was {b_cent}"
1136        );
1137    }
1138
1139    // ── Betweenness centrality ────────────────────────────────────────────────
1140
1141    #[test]
1142    fn test_betweenness_centrality_chain() {
1143        // Chain: a -> b -> c -> d
1144        // b and c are on all paths from a to c, a to d, b to d
1145        // Node b and c should have higher centrality than a and d.
1146        let g = make_graph();
1147        add(&g, "a");
1148        add(&g, "b");
1149        add(&g, "c");
1150        add(&g, "d");
1151        link(&g, "a", "b");
1152        link(&g, "b", "c");
1153        link(&g, "c", "d");
1154
1155        let centrality = g.betweenness_centrality().unwrap();
1156        let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1157        let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1158        let c_cent = *centrality.get(&EntityId::new("c")).unwrap();
1159        let d_cent = *centrality.get(&EntityId::new("d")).unwrap();
1160
1161        // Endpoints should have 0 centrality; intermediate nodes should be > 0.
1162        assert!(
1163            (a_cent).abs() < 1e-5,
1164            "expected a_cent ~ 0, got {a_cent}"
1165        );
1166        assert!(b_cent > 0.0, "expected b_cent > 0, got {b_cent}");
1167        assert!(c_cent > 0.0, "expected c_cent > 0, got {c_cent}");
1168        assert!(
1169            (d_cent).abs() < 1e-5,
1170            "expected d_cent ~ 0, got {d_cent}"
1171        );
1172    }
1173
1174    // ── Label propagation communities ─────────────────────────────────────────
1175
1176    #[test]
1177    fn test_label_propagation_communities_two_clusters() {
1178        // Cluster 1: a <-> b <-> c (fully connected via bidirectional edges)
1179        // Cluster 2: x <-> y <-> z
1180        // No edges between clusters.
1181        let g = make_graph();
1182        for id in &["a", "b", "c", "x", "y", "z"] {
1183            add(&g, id);
1184        }
1185        // Cluster 1 (bidirectional via two directed edges each)
1186        link(&g, "a", "b");
1187        link(&g, "b", "a");
1188        link(&g, "b", "c");
1189        link(&g, "c", "b");
1190        link(&g, "a", "c");
1191        link(&g, "c", "a");
1192        // Cluster 2
1193        link(&g, "x", "y");
1194        link(&g, "y", "x");
1195        link(&g, "y", "z");
1196        link(&g, "z", "y");
1197        link(&g, "x", "z");
1198        link(&g, "z", "x");
1199
1200        let communities = g.label_propagation_communities(100).unwrap();
1201
1202        let label_a = communities[&EntityId::new("a")];
1203        let label_b = communities[&EntityId::new("b")];
1204        let label_c = communities[&EntityId::new("c")];
1205        let label_x = communities[&EntityId::new("x")];
1206        let label_y = communities[&EntityId::new("y")];
1207        let label_z = communities[&EntityId::new("z")];
1208
1209        // All nodes in cluster 1 share a label, all in cluster 2 share a label,
1210        // and the two clusters have different labels.
1211        assert_eq!(label_a, label_b, "a and b should be in same community");
1212        assert_eq!(label_b, label_c, "b and c should be in same community");
1213        assert_eq!(label_x, label_y, "x and y should be in same community");
1214        assert_eq!(label_y, label_z, "y and z should be in same community");
1215        assert_ne!(
1216            label_a, label_x,
1217            "cluster 1 and cluster 2 should be different communities"
1218        );
1219    }
1220
1221    // ── Subgraph extraction ───────────────────────────────────────────────────
1222
1223    #[test]
1224    fn test_subgraph_extracts_correct_nodes_and_edges() {
1225        // Full graph: a -> b -> c -> d
1226        // Subgraph of {a, b, c} should contain edges a->b and b->c but not c->d.
1227        let g = make_graph();
1228        add(&g, "a");
1229        add(&g, "b");
1230        add(&g, "c");
1231        add(&g, "d");
1232        link(&g, "a", "b");
1233        link(&g, "b", "c");
1234        link(&g, "c", "d");
1235
1236        let sub = g
1237            .subgraph(&[
1238                EntityId::new("a"),
1239                EntityId::new("b"),
1240                EntityId::new("c"),
1241            ])
1242            .unwrap();
1243
1244        assert_eq!(sub.entity_count().unwrap(), 3);
1245        assert_eq!(sub.relationship_count().unwrap(), 2);
1246
1247        // d should not be present in the subgraph.
1248        assert!(sub.get_entity(&EntityId::new("d")).is_err());
1249
1250        // a -> b and b -> c should be present; c -> d should not.
1251        let path = sub
1252            .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1253            .unwrap();
1254        assert!(path.is_some());
1255        assert_eq!(path.unwrap().len(), 3);
1256    }
1257}