Skip to main content

llm_agent_runtime/
graph.rs

1//! # Module: Graph
2//!
3//! ## Responsibility
4//! Provides an in-memory knowledge graph with typed entities and relationships.
5//! Mirrors the public API of `mem-graph`.
6//!
7//! ## Guarantees
8//! - Thread-safe: `GraphStore` wraps state in `Arc<Mutex<_>>`
9//! - BFS/DFS traversal and shortest-path are correct for directed graphs
10//! - Non-panicking: all operations return `Result`
11//!
12//! ## NOT Responsible For
13//! - Persistence to disk or external store
14//! - Graph sharding / distributed graphs
15
16use crate::error::AgentRuntimeError;
17use crate::util::recover_lock;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
21use std::sync::{Arc, Mutex};
22
23// ── OrdF32 newtype ─────────────────────────────────────────────────────────────
24
25/// Newtype wrapper for `f32` that implements `Ord`.
26/// NaN is treated as `Greater` for safe use in a `BinaryHeap`.
27#[derive(Debug, Clone, Copy, PartialEq)]
28struct OrdF32(f32);
29
30impl Eq for OrdF32 {}
31
32impl PartialOrd for OrdF32 {
33    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
34        Some(self.cmp(other))
35    }
36}
37
38impl Ord for OrdF32 {
39    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
40        self.0
41            .partial_cmp(&other.0)
42            .unwrap_or(std::cmp::Ordering::Greater)
43    }
44}
45
46// ── EntityId ──────────────────────────────────────────────────────────────────
47
48/// Stable identifier for a graph entity.
49#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
50pub struct EntityId(pub String);
51
52impl EntityId {
53    /// Create a new `EntityId` from any string-like value.
54    pub fn new(id: impl Into<String>) -> Self {
55        Self(id.into())
56    }
57
58    /// Return the inner ID string as a `&str`.
59    pub fn as_str(&self) -> &str {
60        &self.0
61    }
62}
63
64impl AsRef<str> for EntityId {
65    fn as_ref(&self) -> &str {
66        &self.0
67    }
68}
69
70impl std::fmt::Display for EntityId {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(f, "{}", self.0)
73    }
74}
75
76// ── Entity ────────────────────────────────────────────────────────────────────
77
78/// A node in the knowledge graph.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct Entity {
81    /// Unique identifier.
82    pub id: EntityId,
83    /// Human-readable label (e.g. "Person", "Concept").
84    pub label: String,
85    /// Arbitrary key-value properties.
86    pub properties: HashMap<String, Value>,
87}
88
89impl Entity {
90    /// Construct a new entity with no properties.
91    pub fn new(id: impl Into<String>, label: impl Into<String>) -> Self {
92        Self {
93            id: EntityId::new(id),
94            label: label.into(),
95            properties: HashMap::new(),
96        }
97    }
98
99    /// Construct a new entity with the given properties.
100    pub fn with_properties(
101        id: impl Into<String>,
102        label: impl Into<String>,
103        properties: HashMap<String, Value>,
104    ) -> Self {
105        Self {
106            id: EntityId::new(id),
107            label: label.into(),
108            properties,
109        }
110    }
111}
112
113// ── Relationship ──────────────────────────────────────────────────────────────
114
115/// A directed, typed edge between two entities.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct Relationship {
118    /// Source entity.
119    pub from: EntityId,
120    /// Target entity.
121    pub to: EntityId,
122    /// Relationship type label (e.g. "KNOWS", "PART_OF").
123    pub kind: String,
124    /// Optional weight for weighted-graph use cases.
125    pub weight: f32,
126}
127
128impl Relationship {
129    /// Construct a new relationship with the given kind and weight.
130    pub fn new(
131        from: impl Into<String>,
132        to: impl Into<String>,
133        kind: impl Into<String>,
134        weight: f32,
135    ) -> Self {
136        Self {
137            from: EntityId::new(from),
138            to: EntityId::new(to),
139            kind: kind.into(),
140            weight,
141        }
142    }
143}
144
145// ── MemGraphError (mirrors upstream) ─────────────────────────────────────────
146
147/// Graph-specific errors, mirrors `mem-graph::MemGraphError`.
148#[derive(Debug, thiserror::Error)]
149pub enum MemGraphError {
150    /// The requested entity was not found.
151    #[error("Entity '{0}' not found")]
152    EntityNotFound(String),
153
154    /// A relationship between the two entities already exists with the same kind.
155    #[error("Relationship '{kind}' from '{from}' to '{to}' already exists")]
156    DuplicateRelationship {
157        /// Source entity ID.
158        from: String,
159        /// Target entity ID.
160        to: String,
161        /// Relationship kind label.
162        kind: String,
163    },
164
165    /// Generic internal error.
166    #[error("Graph internal error: {0}")]
167    Internal(String),
168}
169
170impl From<MemGraphError> for AgentRuntimeError {
171    fn from(e: MemGraphError) -> Self {
172        AgentRuntimeError::Graph(e.to_string())
173    }
174}
175
176// ── GraphStore ────────────────────────────────────────────────────────────────
177
178/// In-memory knowledge graph supporting entities, relationships, BFS/DFS,
179/// shortest-path, weighted shortest-path, and graph analytics.
180///
181/// ## Guarantees
182/// - Thread-safe via `Arc<Mutex<_>>`
183/// - BFS/DFS are non-recursive (stack-safe)
184/// - Shortest-path is hop-count based (BFS)
185/// - Weighted shortest-path uses Dijkstra's algorithm
186#[derive(Debug, Clone)]
187pub struct GraphStore {
188    inner: Arc<Mutex<GraphInner>>,
189}
190
191#[derive(Debug)]
192struct GraphInner {
193    entities: HashMap<EntityId, Entity>,
194    relationships: Vec<Relationship>,
195    /// Cached result of cycle detection. Invalidated on any mutation.
196    cycle_cache: Option<bool>,
197}
198
199impl GraphStore {
200    /// Create a new, empty graph store.
201    pub fn new() -> Self {
202        Self {
203            inner: Arc::new(Mutex::new(GraphInner {
204                entities: HashMap::new(),
205                relationships: Vec::new(),
206                cycle_cache: None,
207            })),
208        }
209    }
210
211    /// Add an entity to the graph.
212    ///
213    /// If an entity with the same ID already exists, it is replaced.
214    pub fn add_entity(&self, entity: Entity) -> Result<(), AgentRuntimeError> {
215        let mut inner = recover_lock(self.inner.lock(), "add_entity");
216        inner.cycle_cache = None;
217        inner.entities.insert(entity.id.clone(), entity);
218        Ok(())
219    }
220
221    /// Retrieve an entity by ID.
222    pub fn get_entity(&self, id: &EntityId) -> Result<Entity, AgentRuntimeError> {
223        let inner = recover_lock(self.inner.lock(), "get_entity");
224        inner
225            .entities
226            .get(id)
227            .cloned()
228            .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))
229    }
230
231    /// Add a directed relationship between two existing entities.
232    ///
233    /// Both source and target entities must already exist in the graph.
234    pub fn add_relationship(&self, rel: Relationship) -> Result<(), AgentRuntimeError> {
235        let mut inner = recover_lock(self.inner.lock(), "add_relationship");
236
237        if !inner.entities.contains_key(&rel.from) {
238            return Err(AgentRuntimeError::Graph(format!(
239                "source entity '{}' not found",
240                rel.from.0
241            )));
242        }
243        if !inner.entities.contains_key(&rel.to) {
244            return Err(AgentRuntimeError::Graph(format!(
245                "target entity '{}' not found",
246                rel.to.0
247            )));
248        }
249
250        // Reject duplicate (from, to, kind) triples — the DuplicateRelationship
251        // error variant existed but was never raised, silently allowing duplicate
252        // edges that corrupt relationship_count() and BFS/DFS result counts.
253        let duplicate = inner
254            .relationships
255            .iter()
256            .any(|r| r.from == rel.from && r.to == rel.to && r.kind == rel.kind);
257        if duplicate {
258            return Err(AgentRuntimeError::Graph(
259                MemGraphError::DuplicateRelationship {
260                    from: rel.from.0.clone(),
261                    to: rel.to.0.clone(),
262                    kind: rel.kind.clone(),
263                }
264                .to_string(),
265            ));
266        }
267
268        inner.cycle_cache = None;
269        inner.relationships.push(rel);
270        Ok(())
271    }
272
273    /// Remove an entity and all relationships involving it.
274    pub fn remove_entity(&self, id: &EntityId) -> Result<(), AgentRuntimeError> {
275        let mut inner = recover_lock(self.inner.lock(), "remove_entity");
276
277        if inner.entities.remove(id).is_none() {
278            return Err(AgentRuntimeError::Graph(format!(
279                "entity '{}' not found",
280                id.0
281            )));
282        }
283        inner.cycle_cache = None;
284        inner.relationships.retain(|r| &r.from != id && &r.to != id);
285        Ok(())
286    }
287
288    /// Return all direct neighbours of the given entity (BFS layer 1).
289    fn neighbours(relationships: &[Relationship], id: &EntityId) -> Vec<EntityId> {
290        relationships
291            .iter()
292            .filter(|r| &r.from == id)
293            .map(|r| r.to.clone())
294            .collect()
295    }
296
297    /// Breadth-first search starting from `start`.
298    ///
299    /// Returns entity IDs in BFS discovery order (not including the start node).
300    #[tracing::instrument(skip(self))]
301    pub fn bfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
302        let inner = recover_lock(self.inner.lock(), "bfs");
303
304        if !inner.entities.contains_key(start) {
305            return Err(AgentRuntimeError::Graph(format!(
306                "start entity '{}' not found",
307                start.0
308            )));
309        }
310
311        let mut visited: HashSet<EntityId> = HashSet::new();
312        let mut queue: VecDeque<EntityId> = VecDeque::new();
313        let mut result: Vec<EntityId> = Vec::new();
314
315        visited.insert(start.clone());
316        queue.push_back(start.clone());
317
318        while let Some(current) = queue.pop_front() {
319            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
320            for neighbour in neighbours {
321                if visited.insert(neighbour.clone()) {
322                    result.push(neighbour.clone());
323                    queue.push_back(neighbour);
324                }
325            }
326        }
327
328        tracing::debug!("BFS visited {} nodes", result.len());
329        Ok(result)
330    }
331
332    /// Depth-first search starting from `start`.
333    ///
334    /// Returns entity IDs in DFS discovery order (not including the start node).
335    #[tracing::instrument(skip(self))]
336    pub fn dfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
337        let inner = recover_lock(self.inner.lock(), "dfs");
338
339        if !inner.entities.contains_key(start) {
340            return Err(AgentRuntimeError::Graph(format!(
341                "start entity '{}' not found",
342                start.0
343            )));
344        }
345
346        let mut visited: HashSet<EntityId> = HashSet::new();
347        let mut stack: Vec<EntityId> = Vec::new();
348        let mut result: Vec<EntityId> = Vec::new();
349
350        visited.insert(start.clone());
351        stack.push(start.clone());
352
353        while let Some(current) = stack.pop() {
354            let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, &current);
355            for neighbour in neighbours {
356                if visited.insert(neighbour.clone()) {
357                    result.push(neighbour.clone());
358                    stack.push(neighbour);
359                }
360            }
361        }
362
363        tracing::debug!("DFS visited {} nodes", result.len());
364        Ok(result)
365    }
366
367    /// Find the shortest path (by hop count) between `from` and `to`.
368    ///
369    /// # Returns
370    /// - `Some(path)` — ordered list of `EntityId`s from `from` to `to` (inclusive)
371    /// - `None` — no path exists
372    #[tracing::instrument(skip(self))]
373    pub fn shortest_path(
374        &self,
375        from: &EntityId,
376        to: &EntityId,
377    ) -> Result<Option<Vec<EntityId>>, AgentRuntimeError> {
378        let inner = recover_lock(self.inner.lock(), "shortest_path");
379
380        if !inner.entities.contains_key(from) {
381            return Err(AgentRuntimeError::Graph(format!(
382                "source entity '{}' not found",
383                from.0
384            )));
385        }
386        if !inner.entities.contains_key(to) {
387            return Err(AgentRuntimeError::Graph(format!(
388                "target entity '{}' not found",
389                to.0
390            )));
391        }
392
393        if from == to {
394            return Ok(Some(vec![from.clone()]));
395        }
396
397        // Item 5 — predecessor-map BFS; O(1) per enqueue instead of O(path_len).
398        let mut visited: HashSet<EntityId> = HashSet::new();
399        let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
400        let mut queue: VecDeque<EntityId> = VecDeque::new();
401
402        visited.insert(from.clone());
403        queue.push_back(from.clone());
404
405        while let Some(current) = queue.pop_front() {
406            for neighbour in Self::neighbours(&inner.relationships, &current) {
407                if &neighbour == to {
408                    // Reconstruct path by following prev back from current.
409                    let mut path = vec![neighbour, current.clone()];
410                    let mut node = current;
411                    while let Some(p) = prev.get(&node) {
412                        path.push(p.clone());
413                        node = p.clone();
414                    }
415                    path.reverse();
416                    return Ok(Some(path));
417                }
418                if visited.insert(neighbour.clone()) {
419                    prev.insert(neighbour.clone(), current.clone());
420                    queue.push_back(neighbour);
421                }
422            }
423        }
424
425        Ok(None)
426    }
427
428    /// Find the shortest weighted path between `from` and `to` using Dijkstra's algorithm.
429    ///
430    /// Uses `Relationship::weight` as edge cost. Negative weights are not supported
431    /// and will cause this method to return an error.
432    ///
433    /// # Returns
434    /// - `Ok(Some((path, total_weight)))` — the shortest path and its total weight
435    /// - `Ok(None)` — no path exists between `from` and `to`
436    /// - `Err(...)` — either entity not found, or a negative edge weight was encountered
437    pub fn shortest_path_weighted(
438        &self,
439        from: &EntityId,
440        to: &EntityId,
441    ) -> Result<Option<(Vec<EntityId>, f32)>, AgentRuntimeError> {
442        let inner = recover_lock(self.inner.lock(), "shortest_path_weighted");
443
444        if !inner.entities.contains_key(from) {
445            return Err(AgentRuntimeError::Graph(format!(
446                "source entity '{}' not found",
447                from.0
448            )));
449        }
450        if !inner.entities.contains_key(to) {
451            return Err(AgentRuntimeError::Graph(format!(
452                "target entity '{}' not found",
453                to.0
454            )));
455        }
456
457        // Validate: no negative weights
458        for rel in &inner.relationships {
459            if rel.weight < 0.0 {
460                return Err(AgentRuntimeError::Graph(format!(
461                    "negative weight {:.4} on edge '{}' -> '{}'",
462                    rel.weight, rel.from.0, rel.to.0
463                )));
464            }
465        }
466
467        if from == to {
468            return Ok(Some((vec![from.clone()], 0.0)));
469        }
470
471        // Dijkstra using a max-heap with negated weights to simulate a min-heap.
472        // Heap entries: (negated_cost, node_id)
473        let mut dist: HashMap<EntityId, f32> = HashMap::new();
474        let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
475        // BinaryHeap is a max-heap; negate weights to get min-heap behaviour.
476        let mut heap: BinaryHeap<(OrdF32, EntityId)> = BinaryHeap::new();
477
478        dist.insert(from.clone(), 0.0);
479        heap.push((OrdF32(-0.0), from.clone()));
480
481        while let Some((OrdF32(neg_cost), current)) = heap.pop() {
482            let cost = -neg_cost;
483
484            // Skip stale entries.
485            if let Some(&best) = dist.get(&current) {
486                if cost > best {
487                    continue;
488                }
489            }
490
491            if &current == to {
492                // Reconstruct path in reverse.
493                let mut path = vec![to.clone()];
494                let mut node = to.clone();
495                while let Some(p) = prev.get(&node) {
496                    path.push(p.clone());
497                    node = p.clone();
498                }
499                path.reverse();
500                return Ok(Some((path, cost)));
501            }
502
503            for rel in inner.relationships.iter().filter(|r| &r.from == &current) {
504                let next_cost = cost + rel.weight;
505                let entry = dist.entry(rel.to.clone()).or_insert(f32::INFINITY);
506                if next_cost < *entry {
507                    *entry = next_cost;
508                    prev.insert(rel.to.clone(), current.clone());
509                    heap.push((OrdF32(-next_cost), rel.to.clone()));
510                }
511            }
512        }
513
514        Ok(None)
515    }
516
517    /// BFS that builds a `HashSet` directly (including `start`).
518    ///
519    /// Operates on a pre-locked `GraphInner` to avoid acquiring the mutex twice
520    /// and to skip the intermediate `Vec` allocation that `bfs()` produces.
521    fn bfs_into_set(inner: &GraphInner, start: &EntityId) -> HashSet<EntityId> {
522        let mut visited: HashSet<EntityId> = HashSet::new();
523        let mut queue: VecDeque<EntityId> = VecDeque::new();
524        visited.insert(start.clone());
525        queue.push_back(start.clone());
526        while let Some(current) = queue.pop_front() {
527            for neighbour in Self::neighbours(&inner.relationships, &current) {
528                if visited.insert(neighbour.clone()) {
529                    queue.push_back(neighbour);
530                }
531            }
532        }
533        visited
534    }
535
536    /// Compute the transitive closure: all entities reachable from `start`
537    /// (including `start` itself).
538    ///
539    /// Uses a single lock acquisition and builds the result as a `HashSet`
540    /// directly, avoiding the intermediate `Vec` that would otherwise be
541    /// produced by delegating to [`bfs`].
542    ///
543    /// [`bfs`]: GraphStore::bfs
544    pub fn transitive_closure(
545        &self,
546        start: &EntityId,
547    ) -> Result<HashSet<EntityId>, AgentRuntimeError> {
548        let inner = recover_lock(self.inner.lock(), "transitive_closure");
549        if !inner.entities.contains_key(start) {
550            return Err(AgentRuntimeError::Graph(format!(
551                "start entity '{}' not found",
552                start.0
553            )));
554        }
555        Ok(Self::bfs_into_set(&inner, start))
556    }
557
558    /// Return the number of entities in the graph.
559    pub fn entity_count(&self) -> Result<usize, AgentRuntimeError> {
560        let inner = recover_lock(self.inner.lock(), "entity_count");
561        Ok(inner.entities.len())
562    }
563
564    /// Return the number of relationships in the graph.
565    pub fn relationship_count(&self) -> Result<usize, AgentRuntimeError> {
566        let inner = recover_lock(self.inner.lock(), "relationship_count");
567        Ok(inner.relationships.len())
568    }
569
570    /// Compute normalized degree centrality for each entity.
571    /// Degree = (in_degree + out_degree) / (n - 1), or 0.0 if n <= 1.
572    pub fn degree_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
573        let inner = recover_lock(self.inner.lock(), "degree_centrality");
574        let n = inner.entities.len();
575
576        let mut out_degree: HashMap<EntityId, usize> = HashMap::new();
577        let mut in_degree: HashMap<EntityId, usize> = HashMap::new();
578
579        for id in inner.entities.keys() {
580            out_degree.insert(id.clone(), 0);
581            in_degree.insert(id.clone(), 0);
582        }
583
584        for rel in &inner.relationships {
585            *out_degree.entry(rel.from.clone()).or_insert(0) += 1;
586            *in_degree.entry(rel.to.clone()).or_insert(0) += 1;
587        }
588
589        let denom = if n <= 1 { 1.0 } else { (n - 1) as f32 };
590        let mut result = HashMap::new();
591        for id in inner.entities.keys() {
592            let od = *out_degree.get(id).unwrap_or(&0);
593            let id_ = *in_degree.get(id).unwrap_or(&0);
594            let centrality = if n <= 1 {
595                0.0
596            } else {
597                (od + id_) as f32 / denom
598            };
599            result.insert(id.clone(), centrality);
600        }
601
602        Ok(result)
603    }
604
605    /// Compute normalized betweenness centrality for each entity.
606    /// Uses Brandes' algorithm with hop-count BFS.
607    ///
608    /// # Complexity
609    /// O(V * E) time. Not suitable for very large graphs (>1000 nodes).
610    pub fn betweenness_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
611        let inner = recover_lock(self.inner.lock(), "betweenness_centrality");
612        let n = inner.entities.len();
613        let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
614
615        let mut centrality: HashMap<EntityId, f32> =
616            nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
617
618        for source in &nodes {
619            // BFS to find shortest path counts and predecessors.
620            let mut stack: Vec<EntityId> = Vec::new();
621            let mut predecessors: HashMap<EntityId, Vec<EntityId>> =
622                nodes.iter().map(|id| (id.clone(), vec![])).collect();
623            let mut sigma: HashMap<EntityId, f32> =
624                nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
625            let mut dist: HashMap<EntityId, i64> =
626                nodes.iter().map(|id| (id.clone(), -1i64)).collect();
627
628            *sigma.entry(source.clone()).or_insert(0.0) = 1.0;
629            *dist.entry(source.clone()).or_insert(-1) = 0;
630
631            let mut queue: VecDeque<EntityId> = VecDeque::new();
632            queue.push_back(source.clone());
633
634            while let Some(v) = queue.pop_front() {
635                stack.push(v.clone());
636                let d_v = *dist.get(&v).unwrap_or(&0);
637                let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
638                for rel in inner.relationships.iter().filter(|r| &r.from == &v) {
639                    let w = &rel.to;
640                    let d_w = dist.get(w).copied().unwrap_or(-1);
641                    if d_w < 0 {
642                        queue.push_back(w.clone());
643                        *dist.entry(w.clone()).or_insert(-1) = d_v + 1;
644                    }
645                    if dist.get(w).copied().unwrap_or(-1) == d_v + 1 {
646                        *sigma.entry(w.clone()).or_insert(0.0) += sigma_v;
647                        predecessors.entry(w.clone()).or_default().push(v.clone());
648                    }
649                }
650            }
651
652            let mut delta: HashMap<EntityId, f32> =
653                nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
654
655            while let Some(w) = stack.pop() {
656                let delta_w = *delta.get(&w).unwrap_or(&0.0);
657                let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
658                // Item 5 — iterate predecessors by reference to avoid cloning the Vec.
659                for v in predecessors
660                    .get(&w)
661                    .map(|ps| ps.as_slice())
662                    .unwrap_or_default()
663                {
664                    let sigma_v = *sigma.get(v).unwrap_or(&1.0);
665                    let contribution = (sigma_v / sigma_w) * (1.0 + delta_w);
666                    *delta.entry(v.clone()).or_insert(0.0) += contribution;
667                }
668                if &w != source {
669                    *centrality.entry(w.clone()).or_insert(0.0) += delta_w;
670                }
671            }
672        }
673
674        // Normalize by 2 / ((n-1) * (n-2)) for directed graphs.
675        if n > 2 {
676            let norm = 2.0 / (((n - 1) * (n - 2)) as f32);
677            for v in centrality.values_mut() {
678                *v *= norm;
679            }
680        } else {
681            for v in centrality.values_mut() {
682                *v = 0.0;
683            }
684        }
685
686        Ok(centrality)
687    }
688
689    /// Detect communities using label propagation.
690    /// Each entity starts as its own community. In each iteration, each entity
691    /// adopts the most frequent label among its neighbours.
692    /// Returns a map of entity ID → community ID (usize).
693    pub fn label_propagation_communities(
694        &self,
695        max_iterations: usize,
696    ) -> Result<HashMap<EntityId, usize>, AgentRuntimeError> {
697        let inner = recover_lock(self.inner.lock(), "label_propagation_communities");
698        let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
699
700        // Assign each node a unique initial label (index in nodes vec).
701        let mut labels: HashMap<EntityId, usize> = nodes
702            .iter()
703            .enumerate()
704            .map(|(i, id)| (id.clone(), i))
705            .collect();
706
707        for _ in 0..max_iterations {
708            let mut changed = false;
709            // Iterate in a stable order.
710            for node in &nodes {
711                // Collect neighbour labels.
712                let neighbour_labels: Vec<usize> = inner
713                    .relationships
714                    .iter()
715                    .filter(|r| &r.from == node || &r.to == node)
716                    .map(|r| {
717                        if &r.from == node {
718                            labels.get(&r.to).copied().unwrap_or(0)
719                        } else {
720                            labels.get(&r.from).copied().unwrap_or(0)
721                        }
722                    })
723                    .collect();
724
725                if neighbour_labels.is_empty() {
726                    continue;
727                }
728
729                // Find the most frequent label.
730                let mut freq: HashMap<usize, usize> = HashMap::new();
731                for &lbl in &neighbour_labels {
732                    *freq.entry(lbl).or_insert(0) += 1;
733                }
734                let best = freq
735                    .into_iter()
736                    .max_by_key(|&(_, count)| count)
737                    .map(|(lbl, _)| lbl);
738
739                if let Some(new_label) = best {
740                    let current = labels.entry(node.clone()).or_insert(0);
741                    if *current != new_label {
742                        *current = new_label;
743                        changed = true;
744                    }
745                }
746            }
747
748            if !changed {
749                break;
750            }
751        }
752
753        Ok(labels)
754    }
755
756    /// Detect whether the directed graph contains any cycles.
757    ///
758    /// Uses iterative DFS with a three-color marking scheme.  The result is
759    /// cached until the next mutation (`add_entity`, `add_relationship`, or
760    /// `remove_entity`).
761    ///
762    /// # Returns
763    /// - `Ok(true)` — at least one cycle exists
764    /// - `Ok(false)` — the graph is acyclic (a DAG)
765    pub fn detect_cycles(&self) -> Result<bool, AgentRuntimeError> {
766        let mut inner = recover_lock(self.inner.lock(), "detect_cycles");
767
768        if let Some(cached) = inner.cycle_cache {
769            return Ok(cached);
770        }
771
772        // Iterative DFS with three-color marking: 0=white, 1=gray, 2=black.
773        let mut color: HashMap<&EntityId, u8> =
774            inner.entities.keys().map(|id| (id, 0u8)).collect();
775
776        let has_cycle = 'outer: {
777            for start in inner.entities.keys() {
778                if *color.get(start).unwrap_or(&0) != 0 {
779                    continue;
780                }
781
782                // Stack holds (node_id, iterator position for adjacency).
783                let mut stack: Vec<(&EntityId, usize)> = vec![(start, 0)];
784                *color.entry(start).or_insert(0) = 1;
785
786                while let Some((node, idx)) = stack.last_mut() {
787                    // Collect neighbors on first visit.
788                    let neighbors: Vec<&EntityId> = inner
789                        .relationships
790                        .iter()
791                        .filter(|r| &r.from == *node)
792                        .map(|r| &r.to)
793                        .collect();
794
795                    if *idx < neighbors.len() {
796                        let next = neighbors[*idx];
797                        *idx += 1;
798                        match color.get(next).copied().unwrap_or(0) {
799                            1 => break 'outer true, // back edge → cycle
800                            0 => {
801                                *color.entry(next).or_insert(0) = 1;
802                                stack.push((next, 0));
803                            }
804                            _ => {} // already fully processed
805                        }
806                    } else {
807                        // All neighbors processed; color black.
808                        *color.entry(*node).or_insert(0) = 2;
809                        stack.pop();
810                    }
811                }
812            }
813            false
814        };
815
816        inner.cycle_cache = Some(has_cycle);
817        Ok(has_cycle)
818    }
819
820    /// Return `true` if there is any path from `from` to `to`.
821    ///
822    /// Both nodes must exist or returns `Err`. Uses BFS internally.
823    /// Returns `Ok(false)` if nodes exist but are not connected.
824    pub fn path_exists(&self, from: &str, to: &str) -> Result<bool, AgentRuntimeError> {
825        let from_id = EntityId::new(from);
826        let to_id = EntityId::new(to);
827        match self.shortest_path(&from_id, &to_id) {
828            Ok(Some(_)) => Ok(true),
829            Ok(None) => Ok(false),
830            Err(e) => Err(e),
831        }
832    }
833
834    /// BFS limited by maximum depth and maximum node count.
835    ///
836    /// Returns the subset of nodes visited within those limits (including start).
837    pub fn bfs_bounded(
838        &self,
839        start: &str,
840        max_depth: usize,
841        max_nodes: usize,
842    ) -> Result<Vec<EntityId>, AgentRuntimeError> {
843        let inner = recover_lock(self.inner.lock(), "bfs_bounded");
844        let start_id = EntityId::new(start);
845        if !inner.entities.contains_key(&start_id) {
846            return Err(AgentRuntimeError::Graph(format!(
847                "start entity '{start}' not found"
848            )));
849        }
850
851        let mut visited: std::collections::HashMap<EntityId, usize> = std::collections::HashMap::new();
852        let mut queue: VecDeque<(EntityId, usize)> = VecDeque::new();
853        let mut result: Vec<EntityId> = Vec::new();
854
855        visited.insert(start_id.clone(), 0);
856        queue.push_back((start_id.clone(), 0));
857        result.push(start_id);
858
859        while let Some((current, depth)) = queue.pop_front() {
860            if result.len() >= max_nodes {
861                break;
862            }
863            if depth >= max_depth {
864                continue;
865            }
866            for neighbour in Self::neighbours(&inner.relationships, &current) {
867                if !visited.contains_key(&neighbour) {
868                    let new_depth = depth + 1;
869                    visited.insert(neighbour.clone(), new_depth);
870                    result.push(neighbour.clone());
871                    if result.len() >= max_nodes {
872                        break;
873                    }
874                    queue.push_back((neighbour, new_depth));
875                }
876            }
877        }
878
879        Ok(result)
880    }
881
882    /// DFS limited by maximum depth and maximum node count.
883    ///
884    /// Returns the subset of nodes visited within those limits (including start).
885    pub fn dfs_bounded(
886        &self,
887        start: &str,
888        max_depth: usize,
889        max_nodes: usize,
890    ) -> Result<Vec<EntityId>, AgentRuntimeError> {
891        let inner = recover_lock(self.inner.lock(), "dfs_bounded");
892        let start_id = EntityId::new(start);
893        if !inner.entities.contains_key(&start_id) {
894            return Err(AgentRuntimeError::Graph(format!(
895                "start entity '{start}' not found"
896            )));
897        }
898
899        let mut visited: HashSet<EntityId> = HashSet::new();
900        let mut stack: Vec<(EntityId, usize)> = Vec::new();
901        let mut result: Vec<EntityId> = Vec::new();
902
903        visited.insert(start_id.clone());
904        stack.push((start_id.clone(), 0));
905        result.push(start_id);
906
907        while let Some((current, depth)) = stack.pop() {
908            if result.len() >= max_nodes {
909                break;
910            }
911            if depth >= max_depth {
912                continue;
913            }
914            for neighbour in Self::neighbours(&inner.relationships, &current) {
915                if visited.insert(neighbour.clone()) {
916                    result.push(neighbour.clone());
917                    if result.len() >= max_nodes {
918                        break;
919                    }
920                    stack.push((neighbour, depth + 1));
921                }
922            }
923        }
924
925        Ok(result)
926    }
927
928    /// Extract a subgraph containing only the specified entities and the
929    /// relationships between them.
930    pub fn subgraph(&self, node_ids: &[EntityId]) -> Result<GraphStore, AgentRuntimeError> {
931        let inner = recover_lock(self.inner.lock(), "subgraph");
932        let id_set: HashSet<&EntityId> = node_ids.iter().collect();
933
934        let new_store = GraphStore::new();
935
936        for id in node_ids {
937            let entity = inner
938                .entities
939                .get(id)
940                .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))?
941                .clone();
942            // We hold inner lock; call directly on the new store's inner.
943            let mut new_inner = recover_lock(new_store.inner.lock(), "subgraph:add_entity");
944            new_inner.entities.insert(entity.id.clone(), entity);
945        }
946
947        for rel in inner.relationships.iter() {
948            if id_set.contains(&rel.from) && id_set.contains(&rel.to) {
949                let mut new_inner =
950                    recover_lock(new_store.inner.lock(), "subgraph:add_relationship");
951                new_inner.relationships.push(rel.clone());
952            }
953        }
954
955        Ok(new_store)
956    }
957}
958
959impl Default for GraphStore {
960    fn default() -> Self {
961        Self::new()
962    }
963}
964
965// ── Tests ─────────────────────────────────────────────────────────────────────
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970
971    fn make_graph() -> GraphStore {
972        GraphStore::new()
973    }
974
975    fn add(g: &GraphStore, id: &str) {
976        g.add_entity(Entity::new(id, "Node")).unwrap();
977    }
978
979    fn link(g: &GraphStore, from: &str, to: &str) {
980        g.add_relationship(Relationship::new(from, to, "CONNECTS", 1.0))
981            .unwrap();
982    }
983
984    fn link_w(g: &GraphStore, from: &str, to: &str, weight: f32) {
985        g.add_relationship(Relationship::new(from, to, "CONNECTS", weight))
986            .unwrap();
987    }
988
989    // ── EntityId ──────────────────────────────────────────────────────────────
990
991    #[test]
992    fn test_entity_id_equality() {
993        assert_eq!(EntityId::new("a"), EntityId::new("a"));
994        assert_ne!(EntityId::new("a"), EntityId::new("b"));
995    }
996
997    #[test]
998    fn test_entity_id_display() {
999        let id = EntityId::new("hello");
1000        assert_eq!(id.to_string(), "hello");
1001    }
1002
1003    // ── Entity ────────────────────────────────────────────────────────────────
1004
1005    #[test]
1006    fn test_entity_new_has_empty_properties() {
1007        let e = Entity::new("e1", "Person");
1008        assert!(e.properties.is_empty());
1009    }
1010
1011    #[test]
1012    fn test_entity_with_properties_stores_props() {
1013        let mut props = HashMap::new();
1014        props.insert("age".into(), Value::Number(42.into()));
1015        let e = Entity::with_properties("e1", "Person", props);
1016        assert!(e.properties.contains_key("age"));
1017    }
1018
1019    // ── GraphStore basic ops ──────────────────────────────────────────────────
1020
1021    #[test]
1022    fn test_graph_add_entity_increments_count() {
1023        let g = make_graph();
1024        add(&g, "a");
1025        assert_eq!(g.entity_count().unwrap(), 1);
1026    }
1027
1028    #[test]
1029    fn test_graph_get_entity_returns_entity() {
1030        let g = make_graph();
1031        g.add_entity(Entity::new("e1", "Person")).unwrap();
1032        let e = g.get_entity(&EntityId::new("e1")).unwrap();
1033        assert_eq!(e.label, "Person");
1034    }
1035
1036    #[test]
1037    fn test_graph_get_entity_missing_returns_error() {
1038        let g = make_graph();
1039        assert!(g.get_entity(&EntityId::new("ghost")).is_err());
1040    }
1041
1042    #[test]
1043    fn test_graph_add_relationship_increments_count() {
1044        let g = make_graph();
1045        add(&g, "a");
1046        add(&g, "b");
1047        link(&g, "a", "b");
1048        assert_eq!(g.relationship_count().unwrap(), 1);
1049    }
1050
1051    #[test]
1052    fn test_graph_add_relationship_missing_source_fails() {
1053        let g = make_graph();
1054        add(&g, "b");
1055        let result = g.add_relationship(Relationship::new("ghost", "b", "X", 1.0));
1056        assert!(result.is_err());
1057    }
1058
1059    #[test]
1060    fn test_graph_add_relationship_missing_target_fails() {
1061        let g = make_graph();
1062        add(&g, "a");
1063        let result = g.add_relationship(Relationship::new("a", "ghost", "X", 1.0));
1064        assert!(result.is_err());
1065    }
1066
1067    #[test]
1068    fn test_graph_remove_entity_removes_relationships() {
1069        let g = make_graph();
1070        add(&g, "a");
1071        add(&g, "b");
1072        link(&g, "a", "b");
1073        g.remove_entity(&EntityId::new("a")).unwrap();
1074        assert_eq!(g.entity_count().unwrap(), 1);
1075        assert_eq!(g.relationship_count().unwrap(), 0);
1076    }
1077
1078    #[test]
1079    fn test_graph_remove_entity_missing_returns_error() {
1080        let g = make_graph();
1081        assert!(g.remove_entity(&EntityId::new("ghost")).is_err());
1082    }
1083
1084    // ── BFS ───────────────────────────────────────────────────────────────────
1085
1086    #[test]
1087    fn test_bfs_finds_direct_neighbours() {
1088        let g = make_graph();
1089        add(&g, "a");
1090        add(&g, "b");
1091        add(&g, "c");
1092        link(&g, "a", "b");
1093        link(&g, "a", "c");
1094        let visited = g.bfs(&EntityId::new("a")).unwrap();
1095        assert_eq!(visited.len(), 2);
1096    }
1097
1098    #[test]
1099    fn test_bfs_traverses_chain() {
1100        let g = make_graph();
1101        add(&g, "a");
1102        add(&g, "b");
1103        add(&g, "c");
1104        add(&g, "d");
1105        link(&g, "a", "b");
1106        link(&g, "b", "c");
1107        link(&g, "c", "d");
1108        let visited = g.bfs(&EntityId::new("a")).unwrap();
1109        assert_eq!(visited.len(), 3);
1110        assert_eq!(visited[0], EntityId::new("b"));
1111    }
1112
1113    #[test]
1114    fn test_bfs_handles_isolated_node() {
1115        let g = make_graph();
1116        add(&g, "a");
1117        let visited = g.bfs(&EntityId::new("a")).unwrap();
1118        assert!(visited.is_empty());
1119    }
1120
1121    #[test]
1122    fn test_bfs_missing_start_returns_error() {
1123        let g = make_graph();
1124        assert!(g.bfs(&EntityId::new("ghost")).is_err());
1125    }
1126
1127    // ── DFS ───────────────────────────────────────────────────────────────────
1128
1129    #[test]
1130    fn test_dfs_visits_all_reachable_nodes() {
1131        let g = make_graph();
1132        add(&g, "a");
1133        add(&g, "b");
1134        add(&g, "c");
1135        add(&g, "d");
1136        link(&g, "a", "b");
1137        link(&g, "a", "c");
1138        link(&g, "b", "d");
1139        let visited = g.dfs(&EntityId::new("a")).unwrap();
1140        assert_eq!(visited.len(), 3);
1141    }
1142
1143    #[test]
1144    fn test_dfs_handles_isolated_node() {
1145        let g = make_graph();
1146        add(&g, "a");
1147        let visited = g.dfs(&EntityId::new("a")).unwrap();
1148        assert!(visited.is_empty());
1149    }
1150
1151    #[test]
1152    fn test_dfs_missing_start_returns_error() {
1153        let g = make_graph();
1154        assert!(g.dfs(&EntityId::new("ghost")).is_err());
1155    }
1156
1157    // ── Shortest path ─────────────────────────────────────────────────────────
1158
1159    #[test]
1160    fn test_shortest_path_direct_connection() {
1161        let g = make_graph();
1162        add(&g, "a");
1163        add(&g, "b");
1164        link(&g, "a", "b");
1165        let path = g
1166            .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
1167            .unwrap();
1168        assert_eq!(path, Some(vec![EntityId::new("a"), EntityId::new("b")]));
1169    }
1170
1171    #[test]
1172    fn test_shortest_path_multi_hop() {
1173        let g = make_graph();
1174        add(&g, "a");
1175        add(&g, "b");
1176        add(&g, "c");
1177        link(&g, "a", "b");
1178        link(&g, "b", "c");
1179        let path = g
1180            .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1181            .unwrap();
1182        assert_eq!(path.as_ref().map(|p| p.len()), Some(3));
1183    }
1184
1185    #[test]
1186    fn test_shortest_path_returns_none_for_disconnected() {
1187        let g = make_graph();
1188        add(&g, "a");
1189        add(&g, "b");
1190        let path = g
1191            .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
1192            .unwrap();
1193        assert_eq!(path, None);
1194    }
1195
1196    #[test]
1197    fn test_shortest_path_same_node_returns_single_element() {
1198        let g = make_graph();
1199        add(&g, "a");
1200        let path = g
1201            .shortest_path(&EntityId::new("a"), &EntityId::new("a"))
1202            .unwrap();
1203        assert_eq!(path, Some(vec![EntityId::new("a")]));
1204    }
1205
1206    #[test]
1207    fn test_shortest_path_missing_source_returns_error() {
1208        let g = make_graph();
1209        add(&g, "b");
1210        assert!(g
1211            .shortest_path(&EntityId::new("ghost"), &EntityId::new("b"))
1212            .is_err());
1213    }
1214
1215    #[test]
1216    fn test_shortest_path_missing_target_returns_error() {
1217        let g = make_graph();
1218        add(&g, "a");
1219        assert!(g
1220            .shortest_path(&EntityId::new("a"), &EntityId::new("ghost"))
1221            .is_err());
1222    }
1223
1224    // ── Transitive closure ────────────────────────────────────────────────────
1225
1226    #[test]
1227    fn test_transitive_closure_includes_start() {
1228        let g = make_graph();
1229        add(&g, "a");
1230        add(&g, "b");
1231        link(&g, "a", "b");
1232        let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1233        assert!(closure.contains(&EntityId::new("a")));
1234        assert!(closure.contains(&EntityId::new("b")));
1235    }
1236
1237    #[test]
1238    fn test_transitive_closure_isolated_node_contains_only_self() {
1239        let g = make_graph();
1240        add(&g, "a");
1241        let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1242        assert_eq!(closure.len(), 1);
1243    }
1244
1245    // ── MemGraphError conversion ──────────────────────────────────────────────
1246
1247    #[test]
1248    fn test_mem_graph_error_converts_to_runtime_error() {
1249        let e = MemGraphError::EntityNotFound("x".into());
1250        let re: AgentRuntimeError = e.into();
1251        assert!(matches!(re, AgentRuntimeError::Graph(_)));
1252    }
1253
1254    // ── Weighted shortest path ────────────────────────────────────────────────
1255
1256    #[test]
1257    fn test_shortest_path_weighted_simple() {
1258        // a --(1.0)--> b --(2.0)--> c
1259        // a --(10.0)--> c  (direct but heavier)
1260        let g = make_graph();
1261        add(&g, "a");
1262        add(&g, "b");
1263        add(&g, "c");
1264        link_w(&g, "a", "b", 1.0);
1265        link_w(&g, "b", "c", 2.0);
1266        g.add_relationship(Relationship::new("a", "c", "DIRECT", 10.0))
1267            .unwrap();
1268
1269        let result = g
1270            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("c"))
1271            .unwrap();
1272        assert!(result.is_some());
1273        let (path, weight) = result.unwrap();
1274        // The cheapest path is a -> b -> c with total weight 3.0
1275        assert_eq!(
1276            path,
1277            vec![EntityId::new("a"), EntityId::new("b"), EntityId::new("c")]
1278        );
1279        assert!((weight - 3.0).abs() < 1e-5);
1280    }
1281
1282    #[test]
1283    fn test_shortest_path_weighted_returns_none_for_disconnected() {
1284        let g = make_graph();
1285        add(&g, "a");
1286        add(&g, "b");
1287        let result = g
1288            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"))
1289            .unwrap();
1290        assert!(result.is_none());
1291    }
1292
1293    #[test]
1294    fn test_shortest_path_weighted_same_node() {
1295        let g = make_graph();
1296        add(&g, "a");
1297        let result = g
1298            .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("a"))
1299            .unwrap();
1300        assert_eq!(result, Some((vec![EntityId::new("a")], 0.0)));
1301    }
1302
1303    #[test]
1304    fn test_shortest_path_weighted_negative_weight_errors() {
1305        let g = make_graph();
1306        add(&g, "a");
1307        add(&g, "b");
1308        g.add_relationship(Relationship::new("a", "b", "NEG", -1.0))
1309            .unwrap();
1310        let result = g.shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"));
1311        assert!(result.is_err());
1312    }
1313
1314    // ── Degree centrality ─────────────────────────────────────────────────────
1315
1316    #[test]
1317    fn test_degree_centrality_basic() {
1318        // Star graph: a -> b, a -> c, a -> d
1319        // a: out=3, in=0 => (3+0)/(4-1) = 1.0
1320        // b: out=0, in=1 => (0+1)/3 = 0.333...
1321        let g = make_graph();
1322        add(&g, "a");
1323        add(&g, "b");
1324        add(&g, "c");
1325        add(&g, "d");
1326        link(&g, "a", "b");
1327        link(&g, "a", "c");
1328        link(&g, "a", "d");
1329
1330        let centrality = g.degree_centrality().unwrap();
1331        let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1332        let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1333
1334        assert!((a_cent - 1.0).abs() < 1e-5, "a centrality was {a_cent}");
1335        assert!(
1336            (b_cent - 1.0 / 3.0).abs() < 1e-5,
1337            "b centrality was {b_cent}"
1338        );
1339    }
1340
1341    // ── Betweenness centrality ────────────────────────────────────────────────
1342
1343    #[test]
1344    fn test_betweenness_centrality_chain() {
1345        // Chain: a -> b -> c -> d
1346        // b and c are on all paths from a to c, a to d, b to d
1347        // Node b and c should have higher centrality than a and d.
1348        let g = make_graph();
1349        add(&g, "a");
1350        add(&g, "b");
1351        add(&g, "c");
1352        add(&g, "d");
1353        link(&g, "a", "b");
1354        link(&g, "b", "c");
1355        link(&g, "c", "d");
1356
1357        let centrality = g.betweenness_centrality().unwrap();
1358        let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1359        let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1360        let c_cent = *centrality.get(&EntityId::new("c")).unwrap();
1361        let d_cent = *centrality.get(&EntityId::new("d")).unwrap();
1362
1363        // Endpoints should have 0 centrality; intermediate nodes should be > 0.
1364        assert!((a_cent).abs() < 1e-5, "expected a_cent ~ 0, got {a_cent}");
1365        assert!(b_cent > 0.0, "expected b_cent > 0, got {b_cent}");
1366        assert!(c_cent > 0.0, "expected c_cent > 0, got {c_cent}");
1367        assert!((d_cent).abs() < 1e-5, "expected d_cent ~ 0, got {d_cent}");
1368    }
1369
1370    // ── Label propagation communities ─────────────────────────────────────────
1371
1372    #[test]
1373    fn test_label_propagation_communities_two_clusters() {
1374        // Cluster 1: a <-> b <-> c (fully connected via bidirectional edges)
1375        // Cluster 2: x <-> y <-> z
1376        // No edges between clusters.
1377        let g = make_graph();
1378        for id in &["a", "b", "c", "x", "y", "z"] {
1379            add(&g, id);
1380        }
1381        // Cluster 1 (bidirectional via two directed edges each)
1382        link(&g, "a", "b");
1383        link(&g, "b", "a");
1384        link(&g, "b", "c");
1385        link(&g, "c", "b");
1386        link(&g, "a", "c");
1387        link(&g, "c", "a");
1388        // Cluster 2
1389        link(&g, "x", "y");
1390        link(&g, "y", "x");
1391        link(&g, "y", "z");
1392        link(&g, "z", "y");
1393        link(&g, "x", "z");
1394        link(&g, "z", "x");
1395
1396        let communities = g.label_propagation_communities(100).unwrap();
1397
1398        let label_a = communities[&EntityId::new("a")];
1399        let label_b = communities[&EntityId::new("b")];
1400        let label_c = communities[&EntityId::new("c")];
1401        let label_x = communities[&EntityId::new("x")];
1402        let label_y = communities[&EntityId::new("y")];
1403        let label_z = communities[&EntityId::new("z")];
1404
1405        // All nodes in cluster 1 share a label, all in cluster 2 share a label,
1406        // and the two clusters have different labels.
1407        assert_eq!(label_a, label_b, "a and b should be in same community");
1408        assert_eq!(label_b, label_c, "b and c should be in same community");
1409        assert_eq!(label_x, label_y, "x and y should be in same community");
1410        assert_eq!(label_y, label_z, "y and z should be in same community");
1411        assert_ne!(
1412            label_a, label_x,
1413            "cluster 1 and cluster 2 should be different communities"
1414        );
1415    }
1416
1417    // ── Subgraph extraction ───────────────────────────────────────────────────
1418
1419    #[test]
1420    fn test_subgraph_extracts_correct_nodes_and_edges() {
1421        // Full graph: a -> b -> c -> d
1422        // Subgraph of {a, b, c} should contain edges a->b and b->c but not c->d.
1423        let g = make_graph();
1424        add(&g, "a");
1425        add(&g, "b");
1426        add(&g, "c");
1427        add(&g, "d");
1428        link(&g, "a", "b");
1429        link(&g, "b", "c");
1430        link(&g, "c", "d");
1431
1432        let sub = g
1433            .subgraph(&[EntityId::new("a"), EntityId::new("b"), EntityId::new("c")])
1434            .unwrap();
1435
1436        assert_eq!(sub.entity_count().unwrap(), 3);
1437        assert_eq!(sub.relationship_count().unwrap(), 2);
1438
1439        // d should not be present in the subgraph.
1440        assert!(sub.get_entity(&EntityId::new("d")).is_err());
1441
1442        // a -> b and b -> c should be present; c -> d should not.
1443        let path = sub
1444            .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1445            .unwrap();
1446        assert!(path.is_some());
1447        assert_eq!(path.unwrap().len(), 3);
1448    }
1449
1450    // ── detect_cycles ──────────────────────────────────────────────────────────
1451
1452    #[test]
1453    fn test_detect_cycles_dag_returns_false() {
1454        let g = make_graph();
1455        add(&g, "a");
1456        add(&g, "b");
1457        add(&g, "c");
1458        link(&g, "a", "b");
1459        link(&g, "b", "c");
1460        assert_eq!(g.detect_cycles().unwrap(), false);
1461    }
1462
1463    #[test]
1464    fn test_detect_cycles_self_loop_returns_true() {
1465        let g = make_graph();
1466        add(&g, "a");
1467        // Use a different kind to avoid duplicate-relationship rejection.
1468        g.add_relationship(Relationship::new("a", "a", "SELF", 1.0))
1469            .unwrap();
1470        assert_eq!(g.detect_cycles().unwrap(), true);
1471    }
1472
1473    #[test]
1474    fn test_detect_cycles_simple_cycle_returns_true() {
1475        let g = make_graph();
1476        add(&g, "a");
1477        add(&g, "b");
1478        link(&g, "a", "b");
1479        g.add_relationship(Relationship::new("b", "a", "BACK", 1.0))
1480            .unwrap();
1481        assert_eq!(g.detect_cycles().unwrap(), true);
1482    }
1483
1484    #[test]
1485    fn test_detect_cycles_empty_graph_returns_false() {
1486        let g = make_graph();
1487        assert_eq!(g.detect_cycles().unwrap(), false);
1488    }
1489
1490    #[test]
1491    fn test_detect_cycles_result_is_cached() {
1492        let g = make_graph();
1493        add(&g, "x");
1494        add(&g, "y");
1495        link(&g, "x", "y");
1496        // First call.
1497        let r1 = g.detect_cycles().unwrap();
1498        // Second call should return the cached value.
1499        let r2 = g.detect_cycles().unwrap();
1500        assert_eq!(r1, r2);
1501    }
1502
1503    #[test]
1504    fn test_detect_cycles_cache_invalidated_on_mutation() {
1505        let g = make_graph();
1506        add(&g, "a");
1507        add(&g, "b");
1508        link(&g, "a", "b");
1509        assert_eq!(g.detect_cycles().unwrap(), false);
1510
1511        // Add a back edge to create a cycle — cache must be invalidated.
1512        g.add_relationship(Relationship::new("b", "a", "BACK", 1.0))
1513            .unwrap();
1514        assert_eq!(
1515            g.detect_cycles().unwrap(),
1516            true,
1517            "cache should be invalidated after adding a back edge"
1518        );
1519    }
1520
1521    // ── bfs_bounded / dfs_bounded ─────────────────────────────────────────────
1522
1523    #[test]
1524    fn test_bfs_bounded_respects_max_depth() {
1525        // Chain: a -> b -> c -> d
1526        let g = make_graph();
1527        add(&g, "a");
1528        add(&g, "b");
1529        add(&g, "c");
1530        add(&g, "d");
1531        link(&g, "a", "b");
1532        link(&g, "b", "c");
1533        link(&g, "c", "d");
1534
1535        // max_depth=1 should only visit a and b (depth 0 and 1)
1536        let visited = g.bfs_bounded("a", 1, 100).unwrap();
1537        assert!(visited.contains(&EntityId::new("a")));
1538        assert!(visited.contains(&EntityId::new("b")));
1539        assert!(!visited.contains(&EntityId::new("c")), "c is at depth 2, should not be visited");
1540    }
1541
1542    // ── #5/#35 path_exists ────────────────────────────────────────────────────
1543
1544    #[test]
1545    fn test_path_exists_returns_true() {
1546        let g = make_graph();
1547        add(&g, "a");
1548        add(&g, "b");
1549        add(&g, "c");
1550        link(&g, "a", "b");
1551        link(&g, "b", "c");
1552        assert_eq!(g.path_exists("a", "c").unwrap(), true);
1553    }
1554
1555    #[test]
1556    fn test_path_exists_returns_false() {
1557        let g = make_graph();
1558        add(&g, "a");
1559        add(&g, "b");
1560        assert_eq!(g.path_exists("a", "b").unwrap(), false);
1561    }
1562
1563    // ── #13 EntityId::as_str ──────────────────────────────────────────────────
1564
1565    #[test]
1566    fn test_entity_id_as_str() {
1567        let id = EntityId::new("my-entity");
1568        assert_eq!(id.as_str(), "my-entity");
1569    }
1570
1571    // ── #38 EntityId AsRef<str> ───────────────────────────────────────────────
1572
1573    #[test]
1574    fn test_entity_id_as_ref_str() {
1575        let id = EntityId::new("asref-test");
1576        let s: &str = id.as_ref();
1577        assert_eq!(s, "asref-test");
1578    }
1579
1580    #[test]
1581    fn test_dfs_bounded_respects_max_nodes() {
1582        // Chain: a -> b -> c -> d
1583        let g = make_graph();
1584        add(&g, "a");
1585        add(&g, "b");
1586        add(&g, "c");
1587        add(&g, "d");
1588        link(&g, "a", "b");
1589        link(&g, "b", "c");
1590        link(&g, "c", "d");
1591
1592        // max_nodes=2 means only 2 nodes total
1593        let visited = g.dfs_bounded("a", 100, 2).unwrap();
1594        assert_eq!(visited.len(), 2, "should stop at 2 nodes");
1595    }
1596}