Skip to main content

llm_agent_runtime/
graph.rs

1//! # Module: Graph
2//!
3//! ## Responsibility
4//! Provides an in-memory knowledge graph with typed entities and relationships.
5//! Mirrors the public API of `mem-graph`.
6//!
7//! ## Guarantees
8//! - Thread-safe: `GraphStore` wraps state in `Arc<Mutex<_>>`
9//! - BFS/DFS traversal and shortest-path are correct for directed graphs
10//! - Non-panicking: all operations return `Result`
11//!
12//! ## NOT Responsible For
13//! - Persistence to disk or external store
14//! - Graph sharding / distributed graphs
15
16use crate::error::AgentRuntimeError;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20use std::sync::{Arc, Mutex};
21
22// ── Lock recovery helper ───────────────────────────────────────────────────────
23
24/// Recover from a poisoned mutex by logging a warning and returning the inner
25/// value. This is safe because we never leave shared data in a partially-written
26/// state across an await or panic boundary in this module.
27fn recover_lock<'a, T>(
28    result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
29    ctx: &str,
30) -> std::sync::MutexGuard<'a, T>
31where
32    T: ?Sized,
33{
34    match result {
35        Ok(guard) => guard,
36        Err(poisoned) => {
37            tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
38            poisoned.into_inner()
39        }
40    }
41}
42
43// ── OrdF32 newtype ─────────────────────────────────────────────────────────────
44
45/// Newtype wrapper for `f32` that implements `Ord`.
46/// NaN is treated as `Greater` for safe use in a `BinaryHeap`.
47#[derive(Debug, Clone, Copy, PartialEq)]
48struct OrdF32(f32);
49
50impl Eq for OrdF32 {}
51
52impl PartialOrd for OrdF32 {
53    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54        Some(self.cmp(other))
55    }
56}
57
58impl Ord for OrdF32 {
59    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
60        self.0
61            .partial_cmp(&other.0)
62            .unwrap_or(std::cmp::Ordering::Greater)
63    }
64}
65
66// ── EntityId ──────────────────────────────────────────────────────────────────
67
68/// Stable identifier for a graph entity.
69#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
70pub struct EntityId(pub String);
71
72impl EntityId {
73    /// Create a new `EntityId` from any string-like value.
74    pub fn new(id: impl Into<String>) -> Self {
75        Self(id.into())
76    }
77}
78
79impl std::fmt::Display for EntityId {
80    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81        write!(f, "{}", self.0)
82    }
83}
84
85// ── Entity ────────────────────────────────────────────────────────────────────
86
87/// A node in the knowledge graph.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Entity {
90    /// Unique identifier.
91    pub id: EntityId,
92    /// Human-readable label (e.g. "Person", "Concept").
93    pub label: String,
94    /// Arbitrary key-value properties.
95    pub properties: HashMap<String, Value>,
96}
97
98impl Entity {
99    /// Construct a new entity with no properties.
100    pub fn new(id: impl Into<String>, label: impl Into<String>) -> Self {
101        Self {
102            id: EntityId::new(id),
103            label: label.into(),
104            properties: HashMap::new(),
105        }
106    }
107
108    /// Construct a new entity with the given properties.
109    pub fn with_properties(
110        id: impl Into<String>,
111        label: impl Into<String>,
112        properties: HashMap<String, Value>,
113    ) -> Self {
114        Self {
115            id: EntityId::new(id),
116            label: label.into(),
117            properties,
118        }
119    }
120}
121
122// ── Relationship ──────────────────────────────────────────────────────────────
123
124/// A directed, typed edge between two entities.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Relationship {
127    /// Source entity.
128    pub from: EntityId,
129    /// Target entity.
130    pub to: EntityId,
131    /// Relationship type label (e.g. "KNOWS", "PART_OF").
132    pub kind: String,
133    /// Optional weight for weighted-graph use cases.
134    pub weight: f32,
135}
136
137impl Relationship {
138    /// Construct a new relationship with the given kind and weight.
139    pub fn new(
140        from: impl Into<String>,
141        to: impl Into<String>,
142        kind: impl Into<String>,
143        weight: f32,
144    ) -> Self {
145        Self {
146            from: EntityId::new(from),
147            to: EntityId::new(to),
148            kind: kind.into(),
149            weight,
150        }
151    }
152}
153
154// ── MemGraphError (mirrors upstream) ─────────────────────────────────────────
155
156/// Graph-specific errors, mirrors `mem-graph::MemGraphError`.
157#[derive(Debug, thiserror::Error)]
158pub enum MemGraphError {
159    /// The requested entity was not found.
160    #[error("Entity '{0}' not found")]
161    EntityNotFound(String),
162
163    /// A relationship between the two entities already exists with the same kind.
164    #[error("Relationship '{kind}' from '{from}' to '{to}' already exists")]
165    DuplicateRelationship {
166        /// Source entity ID.
167        from: String,
168        /// Target entity ID.
169        to: String,
170        /// Relationship kind label.
171        kind: String,
172    },
173
174    /// Generic internal error.
175    #[error("Graph internal error: {0}")]
176    Internal(String),
177}
178
179impl From<MemGraphError> for AgentRuntimeError {
180    fn from(e: MemGraphError) -> Self {
181        AgentRuntimeError::Graph(e.to_string())
182    }
183}
184
185// ── GraphStore ────────────────────────────────────────────────────────────────
186
187/// In-memory knowledge graph supporting entities, relationships, BFS/DFS,
188/// shortest-path, weighted shortest-path, and graph analytics.
189///
190/// ## Guarantees
191/// - Thread-safe via `Arc<Mutex<_>>`
192/// - BFS/DFS are non-recursive (stack-safe)
193/// - Shortest-path is hop-count based (BFS)
194/// - Weighted shortest-path uses Dijkstra's algorithm
195#[derive(Debug, Clone)]
196pub struct GraphStore {
197    inner: Arc<Mutex<GraphInner>>,
198}
199
200#[derive(Debug)]
201struct GraphInner {
202    entities: HashMap<EntityId, Entity>,
203    relationships: Vec<Relationship>,
204}
205
206impl GraphStore {
207    /// Create a new, empty graph store.
208    pub fn new() -> Self {
209        Self {
210            inner: Arc::new(Mutex::new(GraphInner {
211                entities: HashMap::new(),
212                relationships: Vec::new(),
213            })),
214        }
215    }
216
217    /// Add an entity to the graph.
218    ///
219    /// If an entity with the same ID already exists, it is replaced.
220    pub fn add_entity(&self, entity: Entity) -> Result<(), AgentRuntimeError> {
221        let mut inner = recover_lock(self.inner.lock(), "add_entity");
222        inner.entities.insert(entity.id.clone(), entity);
223        Ok(())
224    }
225
226    /// Retrieve an entity by ID.
227    pub fn get_entity(&self, id: &EntityId) -> Result<Entity, AgentRuntimeError> {
228        let inner = recover_lock(self.inner.lock(), "get_entity");
229        inner
230            .entities
231            .get(id)
232            .cloned()
233            .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))
234    }
235
236    /// Add a directed relationship between two existing entities.
237    ///
238    /// Both source and target entities must already exist in the graph.
239    pub fn add_relationship(&self, rel: Relationship) -> Result<(), AgentRuntimeError> {
240        let mut inner = recover_lock(self.inner.lock(), "add_relationship");
241
242        if !inner.entities.contains_key(&rel.from) {
243            return Err(AgentRuntimeError::Graph(format!(
244                "source entity '{}' not found",
245                rel.from.0
246            )));
247        }
248        if !inner.entities.contains_key(&rel.to) {
249            return Err(AgentRuntimeError::Graph(format!(
250                "target entity '{}' not found",
251                rel.to.0
252            )));
253        }
254
255        // Reject duplicate (from, to, kind) triples — the DuplicateRelationship
256        // error variant existed but was never raised, silently allowing duplicate
257        // edges that corrupt relationship_count() and BFS/DFS result counts.
258        let duplicate = inner
259            .relationships
260            .iter()
261            .any(|r| r.from == rel.from && r.to == rel.to && r.kind == rel.kind);
262        if duplicate {
263            return Err(AgentRuntimeError::Graph(
264                MemGraphError::DuplicateRelationship {
265                    from: rel.from.0.clone(),
266                    to: rel.to.0.clone(),
267                    kind: rel.kind.clone(),
268                }
269                .to_string(),
270            ));
271        }
272
273        inner.relationships.push(rel);
274        Ok(())
275    }
276
277    /// Remove an entity and all relationships involving it.
278    pub fn remove_entity(&self, id: &EntityId) -> Result<(), AgentRuntimeError> {
279        let mut inner = recover_lock(self.inner.lock(), "remove_entity");
280
281        if inner.entities.remove(id).is_none() {
282            return Err(AgentRuntimeError::Graph(format!(
283                "entity '{}' not found",
284                id.0
285            )));
286        }
287        inner.relationships.retain(|r| &r.from != id && &r.to != id);
288        Ok(())
289    }
290
291    /// Return all direct neighbours of the given entity (BFS layer 1).
292    fn neighbours(relationships: &[Relationship], id: &EntityId) -> Vec<EntityId> {
293        relationships
294            .iter()
295            .filter(|r| &r.from == id)
296            .map(|r| r.to.clone())
297            .collect()
298    }
299
300    /// Breadth-first search starting from `start`.
301    ///
302    /// Returns entity IDs in BFS discovery order (not including the start node).
303    #[tracing::instrument(skip(self))]
304    pub fn bfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
305        let inner = recover_lock(self.inner.lock(), "bfs");
306
307        if !inner.entities.contains_key(start) {
308            return Err(AgentRuntimeError::Graph(format!(
309                "start entity '{}' not found",
310                start.0
311            )));
312        }
313
314        let mut visited: HashSet<EntityId> = HashSet::new();
315        let mut queue: VecDeque<EntityId> = VecDeque::new();
316        let mut result: Vec<EntityId> = Vec::new();
317
318        visited.insert(start.clone());
319        queue.push_back(start.clone());
320
321        while let Some(current) = queue.pop_front() {
322            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
323            for neighbour in neighbours {
324                if visited.insert(neighbour.clone()) {
325                    result.push(neighbour.clone());
326                    queue.push_back(neighbour);
327                }
328            }
329        }
330
331        tracing::debug!("BFS visited {} nodes", result.len());
332        Ok(result)
333    }
334
335    /// Depth-first search starting from `start`.
336    ///
337    /// Returns entity IDs in DFS discovery order (not including the start node).
338    #[tracing::instrument(skip(self))]
339    pub fn dfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
340        let inner = recover_lock(self.inner.lock(), "dfs");
341
342        if !inner.entities.contains_key(start) {
343            return Err(AgentRuntimeError::Graph(format!(
344                "start entity '{}' not found",
345                start.0
346            )));
347        }
348
349        let mut visited: HashSet<EntityId> = HashSet::new();
350        let mut stack: Vec<EntityId> = Vec::new();
351        let mut result: Vec<EntityId> = Vec::new();
352
353        visited.insert(start.clone());
354        stack.push(start.clone());
355
356        while let Some(current) = stack.pop() {
357            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
358            for neighbour in neighbours {
359                if visited.insert(neighbour.clone()) {
360                    result.push(neighbour.clone());
361                    stack.push(neighbour);
362                }
363            }
364        }
365
366        tracing::debug!("DFS visited {} nodes", result.len());
367        Ok(result)
368    }
369
370    /// Find the shortest path (by hop count) between `from` and `to`.
371    ///
372    /// # Returns
373    /// - `Some(path)` — ordered list of `EntityId`s from `from` to `to` (inclusive)
374    /// - `None` — no path exists
375    #[tracing::instrument(skip(self))]
376    pub fn shortest_path(
377        &self,
378        from: &EntityId,
379        to: &EntityId,
380    ) -> Result<Option<Vec<EntityId>>, AgentRuntimeError> {
381        let inner = recover_lock(self.inner.lock(), "shortest_path");
382
383        if !inner.entities.contains_key(from) {
384            return Err(AgentRuntimeError::Graph(format!(
385                "source entity '{}' not found",
386                from.0
387            )));
388        }
389        if !inner.entities.contains_key(to) {
390            return Err(AgentRuntimeError::Graph(format!(
391                "target entity '{}' not found",
392                to.0
393            )));
394        }
395
396        if from == to {
397            return Ok(Some(vec![from.clone()]));
398        }
399
400        let mut visited: HashSet<EntityId> = HashSet::new();
401        let mut queue: VecDeque<Vec<EntityId>> = VecDeque::new();
402
403        visited.insert(from.clone());
404        queue.push_back(vec![from.clone()]);
405
406        while let Some(path) = queue.pop_front() {
407            let current = match path.last() {
408                Some(c) => c.clone(),
409                None => continue,
410            };
411
412            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
413
414            for neighbour in neighbours {
415                if &neighbour == to {
416                    let mut full_path = path.clone();
417                    full_path.push(neighbour);
418                    return Ok(Some(full_path));
419                }
420                if visited.insert(neighbour.clone()) {
421                    let mut new_path = path.clone();
422                    new_path.push(neighbour.clone());
423                    queue.push_back(new_path);
424                }
425            }
426        }
427
428        Ok(None)
429    }
430
431    /// Find the shortest weighted path between `from` and `to` using Dijkstra's algorithm.
432    ///
433    /// Uses `Relationship::weight` as edge cost. Negative weights are not supported
434    /// and will cause this method to return an error.
435    ///
436    /// # Returns
437    /// - `Ok(Some((path, total_weight)))` — the shortest path and its total weight
438    /// - `Ok(None)` — no path exists between `from` and `to`
439    /// - `Err(...)` — either entity not found, or a negative edge weight was encountered
440    pub fn shortest_path_weighted(
441        &self,
442        from: &EntityId,
443        to: &EntityId,
444    ) -> Result<Option<(Vec<EntityId>, f32)>, AgentRuntimeError> {
445        let inner = recover_lock(self.inner.lock(), "shortest_path_weighted");
446
447        if !inner.entities.contains_key(from) {
448            return Err(AgentRuntimeError::Graph(format!(
449                "source entity '{}' not found",
450                from.0
451            )));
452        }
453        if !inner.entities.contains_key(to) {
454            return Err(AgentRuntimeError::Graph(format!(
455                "target entity '{}' not found",
456                to.0
457            )));
458        }
459
460        // Validate: no negative weights
461        for rel in &inner.relationships {
462            if rel.weight < 0.0 {
463                return Err(AgentRuntimeError::Graph(format!(
464                    "negative weight {:.4} on edge '{}' -> '{}'",
465                    rel.weight, rel.from.0, rel.to.0
466                )));
467            }
468        }
469
470        if from == to {
471            return Ok(Some((vec![from.clone()], 0.0)));
472        }
473
474        // Dijkstra using a max-heap with negated weights to simulate a min-heap.
475        // Heap entries: (negated_cost, node_id)
476        let mut dist: HashMap<EntityId, f32> = HashMap::new();
477        let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
478        // BinaryHeap is a max-heap; negate weights to get min-heap behaviour.
479        let mut heap: BinaryHeap<(OrdF32, EntityId)> = BinaryHeap::new();
480
481        dist.insert(from.clone(), 0.0);
482        heap.push((OrdF32(-0.0), from.clone()));
483
484        while let Some((OrdF32(neg_cost), current)) = heap.pop() {
485            let cost = -neg_cost;
486
487            // Skip stale entries.
488            if let Some(&best) = dist.get(&current) {
489                if cost > best {
490                    continue;
491                }
492            }
493
494            if &current == to {
495                // Reconstruct path in reverse.
496                let mut path = vec![to.clone()];
497                let mut node = to.clone();
498                while let Some(p) = prev.get(&node) {
499                    path.push(p.clone());
500                    node = p.clone();
501                }
502                path.reverse();
503                return Ok(Some((path, cost)));
504            }
505
506            for rel in inner.relationships.iter().filter(|r| &r.from == &current) {
507                let next_cost = cost + rel.weight;
508                let entry = dist.entry(rel.to.clone()).or_insert(f32::INFINITY);
509                if next_cost < *entry {
510                    *entry = next_cost;
511                    prev.insert(rel.to.clone(), current.clone());
512                    heap.push((OrdF32(-next_cost), rel.to.clone()));
513                }
514            }
515        }
516
517        Ok(None)
518    }
519
520    /// Compute the transitive closure: all entities reachable from `start`.
521    pub fn transitive_closure(
522        &self,
523        start: &EntityId,
524    ) -> Result<HashSet<EntityId>, AgentRuntimeError> {
525        let reachable = self.bfs(start)?;
526        let mut set: HashSet<EntityId> = reachable.into_iter().collect();
527        set.insert(start.clone());
528        Ok(set)
529    }
530
531    /// Return the number of entities in the graph.
532    pub fn entity_count(&self) -> Result<usize, AgentRuntimeError> {
533        let inner = recover_lock(self.inner.lock(), "entity_count");
534        Ok(inner.entities.len())
535    }
536
537    /// Return the number of relationships in the graph.
538    pub fn relationship_count(&self) -> Result<usize, AgentRuntimeError> {
539        let inner = recover_lock(self.inner.lock(), "relationship_count");
540        Ok(inner.relationships.len())
541    }
542
543    /// Compute normalized degree centrality for each entity.
544    /// Degree = (in_degree + out_degree) / (n - 1), or 0.0 if n <= 1.
545    pub fn degree_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
546        let inner = recover_lock(self.inner.lock(), "degree_centrality");
547        let n = inner.entities.len();
548
549        let mut out_degree: HashMap<EntityId, usize> = HashMap::new();
550        let mut in_degree: HashMap<EntityId, usize> = HashMap::new();
551
552        for id in inner.entities.keys() {
553            out_degree.insert(id.clone(), 0);
554            in_degree.insert(id.clone(), 0);
555        }
556
557        for rel in &inner.relationships {
558            *out_degree.entry(rel.from.clone()).or_insert(0) += 1;
559            *in_degree.entry(rel.to.clone()).or_insert(0) += 1;
560        }
561
562        let denom = if n <= 1 { 1.0 } else { (n - 1) as f32 };
563        let mut result = HashMap::new();
564        for id in inner.entities.keys() {
565            let od = *out_degree.get(id).unwrap_or(&0);
566            let id_ = *in_degree.get(id).unwrap_or(&0);
567            let centrality = if n <= 1 {
568                0.0
569            } else {
570                (od + id_) as f32 / denom
571            };
572            result.insert(id.clone(), centrality);
573        }
574
575        Ok(result)
576    }
577
578    /// Compute normalized betweenness centrality for each entity.
579    /// Uses Brandes' algorithm with hop-count BFS.
580    ///
581    /// # Complexity
582    /// O(V * E) time. Not suitable for very large graphs (>1000 nodes).
583    pub fn betweenness_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
584        let inner = recover_lock(self.inner.lock(), "betweenness_centrality");
585        let n = inner.entities.len();
586        let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
587
588        let mut centrality: HashMap<EntityId, f32> =
589            nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
590
591        for source in &nodes {
592            // BFS to find shortest path counts and predecessors.
593            let mut stack: Vec<EntityId> = Vec::new();
594            let mut predecessors: HashMap<EntityId, Vec<EntityId>> =
595                nodes.iter().map(|id| (id.clone(), vec![])).collect();
596            let mut sigma: HashMap<EntityId, f32> =
597                nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
598            let mut dist: HashMap<EntityId, i64> =
599                nodes.iter().map(|id| (id.clone(), -1i64)).collect();
600
601            *sigma.entry(source.clone()).or_insert(0.0) = 1.0;
602            *dist.entry(source.clone()).or_insert(-1) = 0;
603
604            let mut queue: VecDeque<EntityId> = VecDeque::new();
605            queue.push_back(source.clone());
606
607            while let Some(v) = queue.pop_front() {
608                stack.push(v.clone());
609                let d_v = *dist.get(&v).unwrap_or(&0);
610                let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
611                for rel in inner.relationships.iter().filter(|r| &r.from == &v) {
612                    let w = &rel.to;
613                    let d_w = dist.get(w).copied().unwrap_or(-1);
614                    if d_w < 0 {
615                        queue.push_back(w.clone());
616                        *dist.entry(w.clone()).or_insert(-1) = d_v + 1;
617                    }
618                    if dist.get(w).copied().unwrap_or(-1) == d_v + 1 {
619                        *sigma.entry(w.clone()).or_insert(0.0) += sigma_v;
620                        predecessors.entry(w.clone()).or_default().push(v.clone());
621                    }
622                }
623            }
624
625            let mut delta: HashMap<EntityId, f32> =
626                nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
627
628            while let Some(w) = stack.pop() {
629                let delta_w = *delta.get(&w).unwrap_or(&0.0);
630                let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
631                for v in predecessors.get(&w).cloned().unwrap_or_default() {
632                    let sigma_v = *sigma.get(&v).unwrap_or(&1.0);
633                    let contribution = (sigma_v / sigma_w) * (1.0 + delta_w);
634                    *delta.entry(v.clone()).or_insert(0.0) += contribution;
635                }
636                if &w != source {
637                    *centrality.entry(w.clone()).or_insert(0.0) += delta_w;
638                }
639            }
640        }
641
642        // Normalize by 2 / ((n-1) * (n-2)) for directed graphs.
643        if n > 2 {
644            let norm = 2.0 / (((n - 1) * (n - 2)) as f32);
645            for v in centrality.values_mut() {
646                *v *= norm;
647            }
648        } else {
649            for v in centrality.values_mut() {
650                *v = 0.0;
651            }
652        }
653
654        Ok(centrality)
655    }
656
657    /// Detect communities using label propagation.
658    /// Each entity starts as its own community. In each iteration, each entity
659    /// adopts the most frequent label among its neighbours.
660    /// Returns a map of entity ID → community ID (usize).
661    pub fn label_propagation_communities(
662        &self,
663        max_iterations: usize,
664    ) -> Result<HashMap<EntityId, usize>, AgentRuntimeError> {
665        let inner = recover_lock(self.inner.lock(), "label_propagation_communities");
666        let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
667
668        // Assign each node a unique initial label (index in nodes vec).
669        let mut labels: HashMap<EntityId, usize> = nodes
670            .iter()
671            .enumerate()
672            .map(|(i, id)| (id.clone(), i))
673            .collect();
674
675        for _ in 0..max_iterations {
676            let mut changed = false;
677            // Iterate in a stable order.
678            for node in &nodes {
679                // Collect neighbour labels.
680                let neighbour_labels: Vec<usize> = inner
681                    .relationships
682                    .iter()
683                    .filter(|r| &r.from == node || &r.to == node)
684                    .map(|r| {
685                        if &r.from == node {
686                            labels.get(&r.to).copied().unwrap_or(0)
687                        } else {
688                            labels.get(&r.from).copied().unwrap_or(0)
689                        }
690                    })
691                    .collect();
692
693                if neighbour_labels.is_empty() {
694                    continue;
695                }
696
697                // Find the most frequent label.
698                let mut freq: HashMap<usize, usize> = HashMap::new();
699                for &lbl in &neighbour_labels {
700                    *freq.entry(lbl).or_insert(0) += 1;
701                }
702                let best = freq
703                    .into_iter()
704                    .max_by_key(|&(_, count)| count)
705                    .map(|(lbl, _)| lbl);
706
707                if let Some(new_label) = best {
708                    let current = labels.entry(node.clone()).or_insert(0);
709                    if *current != new_label {
710                        *current = new_label;
711                        changed = true;
712                    }
713                }
714            }
715
716            if !changed {
717                break;
718            }
719        }
720
721        Ok(labels)
722    }
723
724    /// Extract a subgraph containing only the specified entities and the
725    /// relationships between them.
726    pub fn subgraph(&self, node_ids: &[EntityId]) -> Result<GraphStore, AgentRuntimeError> {
727        let inner = recover_lock(self.inner.lock(), "subgraph");
728        let id_set: HashSet<&EntityId> = node_ids.iter().collect();
729
730        let new_store = GraphStore::new();
731
732        for id in node_ids {
733            let entity = inner
734                .entities
735                .get(id)
736                .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))?
737                .clone();
738            // We hold inner lock; call directly on the new store's inner.
739            let mut new_inner = recover_lock(new_store.inner.lock(), "subgraph:add_entity");
740            new_inner.entities.insert(entity.id.clone(), entity);
741        }
742
743        for rel in inner.relationships.iter() {
744            if id_set.contains(&rel.from) && id_set.contains(&rel.to) {
745                let mut new_inner =
746                    recover_lock(new_store.inner.lock(), "subgraph:add_relationship");
747                new_inner.relationships.push(rel.clone());
748            }
749        }
750
751        Ok(new_store)
752    }
753}
754
755impl Default for GraphStore {
756    fn default() -> Self {
757        Self::new()
758    }
759}
760
761// ── Tests ─────────────────────────────────────────────────────────────────────
762
763#[cfg(test)]
764mod tests {
765    use super::*;
766
767    fn make_graph() -> GraphStore {
768        GraphStore::new()
769    }
770
771    fn add(g: &GraphStore, id: &str) {
772        g.add_entity(Entity::new(id, "Node")).unwrap();
773    }
774
775    fn link(g: &GraphStore, from: &str, to: &str) {
776        g.add_relationship(Relationship::new(from, to, "CONNECTS", 1.0))
777            .unwrap();
778    }
779
780    fn link_w(g: &GraphStore, from: &str, to: &str, weight: f32) {
781        g.add_relationship(Relationship::new(from, to, "CONNECTS", weight))
782            .unwrap();
783    }
784
785    // ── EntityId ──────────────────────────────────────────────────────────────
786
787    #[test]
788    fn test_entity_id_equality() {
789        assert_eq!(EntityId::new("a"), EntityId::new("a"));
790        assert_ne!(EntityId::new("a"), EntityId::new("b"));
791    }
792
793    #[test]
794    fn test_entity_id_display() {
795        let id = EntityId::new("hello");
796        assert_eq!(id.to_string(), "hello");
797    }
798
799    // ── Entity ────────────────────────────────────────────────────────────────
800
801    #[test]
802    fn test_entity_new_has_empty_properties() {
803        let e = Entity::new("e1", "Person");
804        assert!(e.properties.is_empty());
805    }
806
807    #[test]
808    fn test_entity_with_properties_stores_props() {
809        let mut props = HashMap::new();
810        props.insert("age".into(), Value::Number(42.into()));
811        let e = Entity::with_properties("e1", "Person", props);
812        assert!(e.properties.contains_key("age"));
813    }
814
815    // ── GraphStore basic ops ──────────────────────────────────────────────────
816
817    #[test]
818    fn test_graph_add_entity_increments_count() {
819        let g = make_graph();
820        add(&g, "a");
821        assert_eq!(g.entity_count().unwrap(), 1);
822    }
823
824    #[test]
825    fn test_graph_get_entity_returns_entity() {
826        let g = make_graph();
827        g.add_entity(Entity::new("e1", "Person")).unwrap();
828        let e = g.get_entity(&EntityId::new("e1")).unwrap();
829        assert_eq!(e.label, "Person");
830    }
831
832    #[test]
833    fn test_graph_get_entity_missing_returns_error() {
834        let g = make_graph();
835        assert!(g.get_entity(&EntityId::new("ghost")).is_err());
836    }
837
838    #[test]
839    fn test_graph_add_relationship_increments_count() {
840        let g = make_graph();
841        add(&g, "a");
842        add(&g, "b");
843        link(&g, "a", "b");
844        assert_eq!(g.relationship_count().unwrap(), 1);
845    }
846
847    #[test]
848    fn test_graph_add_relationship_missing_source_fails() {
849        let g = make_graph();
850        add(&g, "b");
851        let result = g.add_relationship(Relationship::new("ghost", "b", "X", 1.0));
852        assert!(result.is_err());
853    }
854
855    #[test]
856    fn test_graph_add_relationship_missing_target_fails() {
857        let g = make_graph();
858        add(&g, "a");
859        let result = g.add_relationship(Relationship::new("a", "ghost", "X", 1.0));
860        assert!(result.is_err());
861    }
862
863    #[test]
864    fn test_graph_remove_entity_removes_relationships() {
865        let g = make_graph();
866        add(&g, "a");
867        add(&g, "b");
868        link(&g, "a", "b");
869        g.remove_entity(&EntityId::new("a")).unwrap();
870        assert_eq!(g.entity_count().unwrap(), 1);
871        assert_eq!(g.relationship_count().unwrap(), 0);
872    }
873
874    #[test]
875    fn test_graph_remove_entity_missing_returns_error() {
876        let g = make_graph();
877        assert!(g.remove_entity(&EntityId::new("ghost")).is_err());
878    }
879
880    // ── BFS ───────────────────────────────────────────────────────────────────
881
882    #[test]
883    fn test_bfs_finds_direct_neighbours() {
884        let g = make_graph();
885        add(&g, "a");
886        add(&g, "b");
887        add(&g, "c");
888        link(&g, "a", "b");
889        link(&g, "a", "c");
890        let visited = g.bfs(&EntityId::new("a")).unwrap();
891        assert_eq!(visited.len(), 2);
892    }
893
894    #[test]
895    fn test_bfs_traverses_chain() {
896        let g = make_graph();
897        add(&g, "a");
898        add(&g, "b");
899        add(&g, "c");
900        add(&g, "d");
901        link(&g, "a", "b");
902        link(&g, "b", "c");
903        link(&g, "c", "d");
904        let visited = g.bfs(&EntityId::new("a")).unwrap();
905        assert_eq!(visited.len(), 3);
906        assert_eq!(visited[0], EntityId::new("b"));
907    }
908
909    #[test]
910    fn test_bfs_handles_isolated_node() {
911        let g = make_graph();
912        add(&g, "a");
913        let visited = g.bfs(&EntityId::new("a")).unwrap();
914        assert!(visited.is_empty());
915    }
916
917    #[test]
918    fn test_bfs_missing_start_returns_error() {
919        let g = make_graph();
920        assert!(g.bfs(&EntityId::new("ghost")).is_err());
921    }
922
923    // ── DFS ───────────────────────────────────────────────────────────────────
924
925    #[test]
926    fn test_dfs_visits_all_reachable_nodes() {
927        let g = make_graph();
928        add(&g, "a");
929        add(&g, "b");
930        add(&g, "c");
931        add(&g, "d");
932        link(&g, "a", "b");
933        link(&g, "a", "c");
934        link(&g, "b", "d");
935        let visited = g.dfs(&EntityId::new("a")).unwrap();
936        assert_eq!(visited.len(), 3);
937    }
938
939    #[test]
940    fn test_dfs_handles_isolated_node() {
941        let g = make_graph();
942        add(&g, "a");
943        let visited = g.dfs(&EntityId::new("a")).unwrap();
944        assert!(visited.is_empty());
945    }
946
947    #[test]
948    fn test_dfs_missing_start_returns_error() {
949        let g = make_graph();
950        assert!(g.dfs(&EntityId::new("ghost")).is_err());
951    }
952
953    // ── Shortest path ─────────────────────────────────────────────────────────
954
955    #[test]
956    fn test_shortest_path_direct_connection() {
957        let g = make_graph();
958        add(&g, "a");
959        add(&g, "b");
960        link(&g, "a", "b");
961        let path = g
962            .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
963            .unwrap();
964        assert_eq!(path, Some(vec![EntityId::new("a"), EntityId::new("b")]));
965    }
966
967    #[test]
968    fn test_shortest_path_multi_hop() {
969        let g = make_graph();
970        add(&g, "a");
971        add(&g, "b");
972        add(&g, "c");
973        link(&g, "a", "b");
974        link(&g, "b", "c");
975        let path = g
976            .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
977            .unwrap();
978        assert_eq!(path.as_ref().map(|p| p.len()), Some(3));
979    }
980
981    #[test]
982    fn test_shortest_path_returns_none_for_disconnected() {
983        let g = make_graph();
984        add(&g, "a");
985        add(&g, "b");
986        let path = g
987            .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
988            .unwrap();
989        assert_eq!(path, None);
990    }
991
992    #[test]
993    fn test_shortest_path_same_node_returns_single_element() {
994        let g = make_graph();
995        add(&g, "a");
996        let path = g
997            .shortest_path(&EntityId::new("a"), &EntityId::new("a"))
998            .unwrap();
999        assert_eq!(path, Some(vec![EntityId::new("a")]));
1000    }
1001
1002    #[test]
1003    fn test_shortest_path_missing_source_returns_error() {
1004        let g = make_graph();
1005        add(&g, "b");
1006        assert!(g
1007            .shortest_path(&EntityId::new("ghost"), &EntityId::new("b"))
1008            .is_err());
1009    }
1010
1011    #[test]
1012    fn test_shortest_path_missing_target_returns_error() {
1013        let g = make_graph();
1014        add(&g, "a");
1015        assert!(g
1016            .shortest_path(&EntityId::new("a"), &EntityId::new("ghost"))
1017            .is_err());
1018    }
1019
1020    // ── Transitive closure ────────────────────────────────────────────────────
1021
1022    #[test]
1023    fn test_transitive_closure_includes_start() {
1024        let g = make_graph();
1025        add(&g, "a");
1026        add(&g, "b");
1027        link(&g, "a", "b");
1028        let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1029        assert!(closure.contains(&EntityId::new("a")));
1030        assert!(closure.contains(&EntityId::new("b")));
1031    }
1032
1033    #[test]
1034    fn test_transitive_closure_isolated_node_contains_only_self() {
1035        let g = make_graph();
1036        add(&g, "a");
1037        let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1038        assert_eq!(closure.len(), 1);
1039    }
1040
1041    // ── MemGraphError conversion ──────────────────────────────────────────────
1042
1043    #[test]
1044    fn test_mem_graph_error_converts_to_runtime_error() {
1045        let e = MemGraphError::EntityNotFound("x".into());
1046        let re: AgentRuntimeError = e.into();
1047        assert!(matches!(re, AgentRuntimeError::Graph(_)));
1048    }
1049
1050    // ── Weighted shortest path ────────────────────────────────────────────────
1051
1052    #[test]
1053    fn test_shortest_path_weighted_simple() {
1054        // a --(1.0)--> b --(2.0)--> c
1055        // a --(10.0)--> c  (direct but heavier)
1056        let g = make_graph();
1057        add(&g, "a");
1058        add(&g, "b");
1059        add(&g, "c");
1060        link_w(&g, "a", "b", 1.0);
1061        link_w(&g, "b", "c", 2.0);
1062        g.add_relationship(Relationship::new("a", "c", "DIRECT", 10.0))
1063            .unwrap();
1064
1065        let result = g
1066            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("c"))
1067            .unwrap();
1068        assert!(result.is_some());
1069        let (path, weight) = result.unwrap();
1070        // The cheapest path is a -> b -> c with total weight 3.0
1071        assert_eq!(
1072            path,
1073            vec![EntityId::new("a"), EntityId::new("b"), EntityId::new("c")]
1074        );
1075        assert!((weight - 3.0).abs() < 1e-5);
1076    }
1077
1078    #[test]
1079    fn test_shortest_path_weighted_returns_none_for_disconnected() {
1080        let g = make_graph();
1081        add(&g, "a");
1082        add(&g, "b");
1083        let result = g
1084            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"))
1085            .unwrap();
1086        assert!(result.is_none());
1087    }
1088
1089    #[test]
1090    fn test_shortest_path_weighted_same_node() {
1091        let g = make_graph();
1092        add(&g, "a");
1093        let result = g
1094            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("a"))
1095            .unwrap();
1096        assert_eq!(result, Some((vec![EntityId::new("a")], 0.0)));
1097    }
1098
1099    #[test]
1100    fn test_shortest_path_weighted_negative_weight_errors() {
1101        let g = make_graph();
1102        add(&g, "a");
1103        add(&g, "b");
1104        g.add_relationship(Relationship::new("a", "b", "NEG", -1.0))
1105            .unwrap();
1106        let result = g.shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"));
1107        assert!(result.is_err());
1108    }
1109
1110    // ── Degree centrality ─────────────────────────────────────────────────────
1111
1112    #[test]
1113    fn test_degree_centrality_basic() {
1114        // Star graph: a -> b, a -> c, a -> d
1115        // a: out=3, in=0 => (3+0)/(4-1) = 1.0
1116        // b: out=0, in=1 => (0+1)/3 = 0.333...
1117        let g = make_graph();
1118        add(&g, "a");
1119        add(&g, "b");
1120        add(&g, "c");
1121        add(&g, "d");
1122        link(&g, "a", "b");
1123        link(&g, "a", "c");
1124        link(&g, "a", "d");
1125
1126        let centrality = g.degree_centrality().unwrap();
1127        let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1128        let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1129
1130        assert!((a_cent - 1.0).abs() < 1e-5, "a centrality was {a_cent}");
1131        assert!(
1132            (b_cent - 1.0 / 3.0).abs() < 1e-5,
1133            "b centrality was {b_cent}"
1134        );
1135    }
1136
1137    // ── Betweenness centrality ────────────────────────────────────────────────
1138
1139    #[test]
1140    fn test_betweenness_centrality_chain() {
1141        // Chain: a -> b -> c -> d
1142        // b and c are on all paths from a to c, a to d, b to d
1143        // Node b and c should have higher centrality than a and d.
1144        let g = make_graph();
1145        add(&g, "a");
1146        add(&g, "b");
1147        add(&g, "c");
1148        add(&g, "d");
1149        link(&g, "a", "b");
1150        link(&g, "b", "c");
1151        link(&g, "c", "d");
1152
1153        let centrality = g.betweenness_centrality().unwrap();
1154        let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1155        let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1156        let c_cent = *centrality.get(&EntityId::new("c")).unwrap();
1157        let d_cent = *centrality.get(&EntityId::new("d")).unwrap();
1158
1159        // Endpoints should have 0 centrality; intermediate nodes should be > 0.
1160        assert!((a_cent).abs() < 1e-5, "expected a_cent ~ 0, got {a_cent}");
1161        assert!(b_cent > 0.0, "expected b_cent > 0, got {b_cent}");
1162        assert!(c_cent > 0.0, "expected c_cent > 0, got {c_cent}");
1163        assert!((d_cent).abs() < 1e-5, "expected d_cent ~ 0, got {d_cent}");
1164    }
1165
1166    // ── Label propagation communities ─────────────────────────────────────────
1167
1168    #[test]
1169    fn test_label_propagation_communities_two_clusters() {
1170        // Cluster 1: a <-> b <-> c (fully connected via bidirectional edges)
1171        // Cluster 2: x <-> y <-> z
1172        // No edges between clusters.
1173        let g = make_graph();
1174        for id in &["a", "b", "c", "x", "y", "z"] {
1175            add(&g, id);
1176        }
1177        // Cluster 1 (bidirectional via two directed edges each)
1178        link(&g, "a", "b");
1179        link(&g, "b", "a");
1180        link(&g, "b", "c");
1181        link(&g, "c", "b");
1182        link(&g, "a", "c");
1183        link(&g, "c", "a");
1184        // Cluster 2
1185        link(&g, "x", "y");
1186        link(&g, "y", "x");
1187        link(&g, "y", "z");
1188        link(&g, "z", "y");
1189        link(&g, "x", "z");
1190        link(&g, "z", "x");
1191
1192        let communities = g.label_propagation_communities(100).unwrap();
1193
1194        let label_a = communities[&EntityId::new("a")];
1195        let label_b = communities[&EntityId::new("b")];
1196        let label_c = communities[&EntityId::new("c")];
1197        let label_x = communities[&EntityId::new("x")];
1198        let label_y = communities[&EntityId::new("y")];
1199        let label_z = communities[&EntityId::new("z")];
1200
1201        // All nodes in cluster 1 share a label, all in cluster 2 share a label,
1202        // and the two clusters have different labels.
1203        assert_eq!(label_a, label_b, "a and b should be in same community");
1204        assert_eq!(label_b, label_c, "b and c should be in same community");
1205        assert_eq!(label_x, label_y, "x and y should be in same community");
1206        assert_eq!(label_y, label_z, "y and z should be in same community");
1207        assert_ne!(
1208            label_a, label_x,
1209            "cluster 1 and cluster 2 should be different communities"
1210        );
1211    }
1212
1213    // ── Subgraph extraction ───────────────────────────────────────────────────
1214
1215    #[test]
1216    fn test_subgraph_extracts_correct_nodes_and_edges() {
1217        // Full graph: a -> b -> c -> d
1218        // Subgraph of {a, b, c} should contain edges a->b and b->c but not c->d.
1219        let g = make_graph();
1220        add(&g, "a");
1221        add(&g, "b");
1222        add(&g, "c");
1223        add(&g, "d");
1224        link(&g, "a", "b");
1225        link(&g, "b", "c");
1226        link(&g, "c", "d");
1227
1228        let sub = g
1229            .subgraph(&[EntityId::new("a"), EntityId::new("b"), EntityId::new("c")])
1230            .unwrap();
1231
1232        assert_eq!(sub.entity_count().unwrap(), 3);
1233        assert_eq!(sub.relationship_count().unwrap(), 2);
1234
1235        // d should not be present in the subgraph.
1236        assert!(sub.get_entity(&EntityId::new("d")).is_err());
1237
1238        // a -> b and b -> c should be present; c -> d should not.
1239        let path = sub
1240            .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1241            .unwrap();
1242        assert!(path.is_some());
1243        assert_eq!(path.unwrap().len(), 3);
1244    }
1245}