1use crate::error::AgentRuntimeError;
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20use std::sync::{Arc, Mutex};
21
22fn recover_lock<'a, T>(
28 result: std::sync::LockResult<std::sync::MutexGuard<'a, T>>,
29 ctx: &str,
30) -> std::sync::MutexGuard<'a, T>
31where
32 T: ?Sized,
33{
34 match result {
35 Ok(guard) => guard,
36 Err(poisoned) => {
37 tracing::warn!("mutex poisoned in {ctx}, recovering inner value");
38 poisoned.into_inner()
39 }
40 }
41}
42
43#[derive(Debug, Clone, Copy, PartialEq)]
48struct OrdF32(f32);
49
50impl Eq for OrdF32 {}
51
52impl PartialOrd for OrdF32 {
53 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
54 Some(self.cmp(other))
55 }
56}
57
58impl Ord for OrdF32 {
59 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
60 self.0
61 .partial_cmp(&other.0)
62 .unwrap_or(std::cmp::Ordering::Greater)
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
70pub struct EntityId(pub String);
71
72impl EntityId {
73 pub fn new(id: impl Into<String>) -> Self {
75 Self(id.into())
76 }
77}
78
79impl std::fmt::Display for EntityId {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 write!(f, "{}", self.0)
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Entity {
90 pub id: EntityId,
92 pub label: String,
94 pub properties: HashMap<String, Value>,
96}
97
98impl Entity {
99 pub fn new(id: impl Into<String>, label: impl Into<String>) -> Self {
101 Self {
102 id: EntityId::new(id),
103 label: label.into(),
104 properties: HashMap::new(),
105 }
106 }
107
108 pub fn with_properties(
110 id: impl Into<String>,
111 label: impl Into<String>,
112 properties: HashMap<String, Value>,
113 ) -> Self {
114 Self {
115 id: EntityId::new(id),
116 label: label.into(),
117 properties,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct Relationship {
127 pub from: EntityId,
129 pub to: EntityId,
131 pub kind: String,
133 pub weight: f32,
135}
136
137impl Relationship {
138 pub fn new(
140 from: impl Into<String>,
141 to: impl Into<String>,
142 kind: impl Into<String>,
143 weight: f32,
144 ) -> Self {
145 Self {
146 from: EntityId::new(from),
147 to: EntityId::new(to),
148 kind: kind.into(),
149 weight,
150 }
151 }
152}
153
154#[derive(Debug, thiserror::Error)]
158pub enum MemGraphError {
159 #[error("Entity '{0}' not found")]
161 EntityNotFound(String),
162
163 #[error("Relationship '{kind}' from '{from}' to '{to}' already exists")]
165 DuplicateRelationship {
166 from: String,
168 to: String,
170 kind: String,
172 },
173
174 #[error("Graph internal error: {0}")]
176 Internal(String),
177}
178
179impl From<MemGraphError> for AgentRuntimeError {
180 fn from(e: MemGraphError) -> Self {
181 AgentRuntimeError::Graph(e.to_string())
182 }
183}
184
185#[derive(Debug, Clone)]
196pub struct GraphStore {
197 inner: Arc<Mutex<GraphInner>>,
198}
199
200#[derive(Debug)]
201struct GraphInner {
202 entities: HashMap<EntityId, Entity>,
203 relationships: Vec<Relationship>,
204}
205
206impl GraphStore {
207 pub fn new() -> Self {
209 Self {
210 inner: Arc::new(Mutex::new(GraphInner {
211 entities: HashMap::new(),
212 relationships: Vec::new(),
213 })),
214 }
215 }
216
217 pub fn add_entity(&self, entity: Entity) -> Result<(), AgentRuntimeError> {
221 let mut inner = recover_lock(self.inner.lock(), "add_entity");
222 inner.entities.insert(entity.id.clone(), entity);
223 Ok(())
224 }
225
226 pub fn get_entity(&self, id: &EntityId) -> Result<Entity, AgentRuntimeError> {
228 let inner = recover_lock(self.inner.lock(), "get_entity");
229 inner
230 .entities
231 .get(id)
232 .cloned()
233 .ok_or_else(|| AgentRuntimeError::Graph(format!("entity '{}' not found", id.0)))
234 }
235
236 pub fn add_relationship(&self, rel: Relationship) -> Result<(), AgentRuntimeError> {
240 let mut inner = recover_lock(self.inner.lock(), "add_relationship");
241
242 if !inner.entities.contains_key(&rel.from) {
243 return Err(AgentRuntimeError::Graph(format!(
244 "source entity '{}' not found",
245 rel.from.0
246 )));
247 }
248 if !inner.entities.contains_key(&rel.to) {
249 return Err(AgentRuntimeError::Graph(format!(
250 "target entity '{}' not found",
251 rel.to.0
252 )));
253 }
254
255 let duplicate = inner
259 .relationships
260 .iter()
261 .any(|r| r.from == rel.from && r.to == rel.to && r.kind == rel.kind);
262 if duplicate {
263 return Err(AgentRuntimeError::Graph(
264 MemGraphError::DuplicateRelationship {
265 from: rel.from.0.clone(),
266 to: rel.to.0.clone(),
267 kind: rel.kind.clone(),
268 }
269 .to_string(),
270 ));
271 }
272
273 inner.relationships.push(rel);
274 Ok(())
275 }
276
277 pub fn remove_entity(&self, id: &EntityId) -> Result<(), AgentRuntimeError> {
279 let mut inner = recover_lock(self.inner.lock(), "remove_entity");
280
281 if inner.entities.remove(id).is_none() {
282 return Err(AgentRuntimeError::Graph(format!(
283 "entity '{}' not found",
284 id.0
285 )));
286 }
287 inner.relationships.retain(|r| &r.from != id && &r.to != id);
288 Ok(())
289 }
290
291 fn neighbours(relationships: &[Relationship], id: &EntityId) -> Vec<EntityId> {
293 relationships
294 .iter()
295 .filter(|r| &r.from == id)
296 .map(|r| r.to.clone())
297 .collect()
298 }
299
300 #[tracing::instrument(skip(self))]
304 pub fn bfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
305 let inner = recover_lock(self.inner.lock(), "bfs");
306
307 if !inner.entities.contains_key(start) {
308 return Err(AgentRuntimeError::Graph(format!(
309 "start entity '{}' not found",
310 start.0
311 )));
312 }
313
314 let mut visited: HashSet<EntityId> = HashSet::new();
315 let mut queue: VecDeque<EntityId> = VecDeque::new();
316 let mut result: Vec<EntityId> = Vec::new();
317
318 visited.insert(start.clone());
319 queue.push_back(start.clone());
320
321 while let Some(current) = queue.pop_front() {
322 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
323 for neighbour in neighbours {
324 if visited.insert(neighbour.clone()) {
325 result.push(neighbour.clone());
326 queue.push_back(neighbour);
327 }
328 }
329 }
330
331 tracing::debug!("BFS visited {} nodes", result.len());
332 Ok(result)
333 }
334
335 #[tracing::instrument(skip(self))]
339 pub fn dfs(&self, start: &EntityId) -> Result<Vec<EntityId>, AgentRuntimeError> {
340 let inner = recover_lock(self.inner.lock(), "dfs");
341
342 if !inner.entities.contains_key(start) {
343 return Err(AgentRuntimeError::Graph(format!(
344 "start entity '{}' not found",
345 start.0
346 )));
347 }
348
349 let mut visited: HashSet<EntityId> = HashSet::new();
350 let mut stack: Vec<EntityId> = Vec::new();
351 let mut result: Vec<EntityId> = Vec::new();
352
353 visited.insert(start.clone());
354 stack.push(start.clone());
355
356 while let Some(current) = stack.pop() {
357 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
358 for neighbour in neighbours {
359 if visited.insert(neighbour.clone()) {
360 result.push(neighbour.clone());
361 stack.push(neighbour);
362 }
363 }
364 }
365
366 tracing::debug!("DFS visited {} nodes", result.len());
367 Ok(result)
368 }
369
370 #[tracing::instrument(skip(self))]
376 pub fn shortest_path(
377 &self,
378 from: &EntityId,
379 to: &EntityId,
380 ) -> Result<Option<Vec<EntityId>>, AgentRuntimeError> {
381 let inner = recover_lock(self.inner.lock(), "shortest_path");
382
383 if !inner.entities.contains_key(from) {
384 return Err(AgentRuntimeError::Graph(format!(
385 "source entity '{}' not found",
386 from.0
387 )));
388 }
389 if !inner.entities.contains_key(to) {
390 return Err(AgentRuntimeError::Graph(format!(
391 "target entity '{}' not found",
392 to.0
393 )));
394 }
395
396 if from == to {
397 return Ok(Some(vec![from.clone()]));
398 }
399
400 let mut visited: HashSet<EntityId> = HashSet::new();
401 let mut queue: VecDeque<Vec<EntityId>> = VecDeque::new();
402
403 visited.insert(from.clone());
404 queue.push_back(vec![from.clone()]);
405
406 while let Some(path) = queue.pop_front() {
407 let current = match path.last() {
408 Some(c) => c.clone(),
409 None => continue,
410 };
411
412 let neighbours: Vec<EntityId> = Self::neighbours(&inner.relationships, ¤t);
413
414 for neighbour in neighbours {
415 if &neighbour == to {
416 let mut full_path = path.clone();
417 full_path.push(neighbour);
418 return Ok(Some(full_path));
419 }
420 if visited.insert(neighbour.clone()) {
421 let mut new_path = path.clone();
422 new_path.push(neighbour.clone());
423 queue.push_back(new_path);
424 }
425 }
426 }
427
428 Ok(None)
429 }
430
431 pub fn shortest_path_weighted(
441 &self,
442 from: &EntityId,
443 to: &EntityId,
444 ) -> Result<Option<(Vec<EntityId>, f32)>, AgentRuntimeError> {
445 let inner = recover_lock(self.inner.lock(), "shortest_path_weighted");
446
447 if !inner.entities.contains_key(from) {
448 return Err(AgentRuntimeError::Graph(format!(
449 "source entity '{}' not found",
450 from.0
451 )));
452 }
453 if !inner.entities.contains_key(to) {
454 return Err(AgentRuntimeError::Graph(format!(
455 "target entity '{}' not found",
456 to.0
457 )));
458 }
459
460 for rel in &inner.relationships {
462 if rel.weight < 0.0 {
463 return Err(AgentRuntimeError::Graph(format!(
464 "negative weight {:.4} on edge '{}' -> '{}'",
465 rel.weight, rel.from.0, rel.to.0
466 )));
467 }
468 }
469
470 if from == to {
471 return Ok(Some((vec![from.clone()], 0.0)));
472 }
473
474 let mut dist: HashMap<EntityId, f32> = HashMap::new();
477 let mut prev: HashMap<EntityId, EntityId> = HashMap::new();
478 let mut heap: BinaryHeap<(OrdF32, EntityId)> = BinaryHeap::new();
480
481 dist.insert(from.clone(), 0.0);
482 heap.push((OrdF32(-0.0), from.clone()));
483
484 while let Some((OrdF32(neg_cost), current)) = heap.pop() {
485 let cost = -neg_cost;
486
487 if let Some(&best) = dist.get(¤t) {
489 if cost > best {
490 continue;
491 }
492 }
493
494 if ¤t == to {
495 let mut path = vec![to.clone()];
497 let mut node = to.clone();
498 while let Some(p) = prev.get(&node) {
499 path.push(p.clone());
500 node = p.clone();
501 }
502 path.reverse();
503 return Ok(Some((path, cost)));
504 }
505
506 for rel in inner.relationships.iter().filter(|r| &r.from == ¤t) {
507 let next_cost = cost + rel.weight;
508 let entry = dist.entry(rel.to.clone()).or_insert(f32::INFINITY);
509 if next_cost < *entry {
510 *entry = next_cost;
511 prev.insert(rel.to.clone(), current.clone());
512 heap.push((OrdF32(-next_cost), rel.to.clone()));
513 }
514 }
515 }
516
517 Ok(None)
518 }
519
520 pub fn transitive_closure(
522 &self,
523 start: &EntityId,
524 ) -> Result<HashSet<EntityId>, AgentRuntimeError> {
525 let reachable = self.bfs(start)?;
526 let mut set: HashSet<EntityId> = reachable.into_iter().collect();
527 set.insert(start.clone());
528 Ok(set)
529 }
530
531 pub fn entity_count(&self) -> Result<usize, AgentRuntimeError> {
533 let inner = recover_lock(self.inner.lock(), "entity_count");
534 Ok(inner.entities.len())
535 }
536
537 pub fn relationship_count(&self) -> Result<usize, AgentRuntimeError> {
539 let inner = recover_lock(self.inner.lock(), "relationship_count");
540 Ok(inner.relationships.len())
541 }
542
543 pub fn degree_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
546 let inner = recover_lock(self.inner.lock(), "degree_centrality");
547 let n = inner.entities.len();
548
549 let mut out_degree: HashMap<EntityId, usize> = HashMap::new();
550 let mut in_degree: HashMap<EntityId, usize> = HashMap::new();
551
552 for id in inner.entities.keys() {
553 out_degree.insert(id.clone(), 0);
554 in_degree.insert(id.clone(), 0);
555 }
556
557 for rel in &inner.relationships {
558 *out_degree.entry(rel.from.clone()).or_insert(0) += 1;
559 *in_degree.entry(rel.to.clone()).or_insert(0) += 1;
560 }
561
562 let denom = if n <= 1 { 1.0 } else { (n - 1) as f32 };
563 let mut result = HashMap::new();
564 for id in inner.entities.keys() {
565 let od = *out_degree.get(id).unwrap_or(&0);
566 let id_ = *in_degree.get(id).unwrap_or(&0);
567 let centrality = if n <= 1 {
568 0.0
569 } else {
570 (od + id_) as f32 / denom
571 };
572 result.insert(id.clone(), centrality);
573 }
574
575 Ok(result)
576 }
577
578 pub fn betweenness_centrality(&self) -> Result<HashMap<EntityId, f32>, AgentRuntimeError> {
584 let inner = recover_lock(self.inner.lock(), "betweenness_centrality");
585 let n = inner.entities.len();
586 let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
587
588 let mut centrality: HashMap<EntityId, f32> =
589 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
590
591 for source in &nodes {
592 let mut stack: Vec<EntityId> = Vec::new();
594 let mut predecessors: HashMap<EntityId, Vec<EntityId>> =
595 nodes.iter().map(|id| (id.clone(), vec![])).collect();
596 let mut sigma: HashMap<EntityId, f32> =
597 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
598 let mut dist: HashMap<EntityId, i64> =
599 nodes.iter().map(|id| (id.clone(), -1i64)).collect();
600
601 *sigma.entry(source.clone()).or_insert(0.0) = 1.0;
602 *dist.entry(source.clone()).or_insert(-1) = 0;
603
604 let mut queue: VecDeque<EntityId> = VecDeque::new();
605 queue.push_back(source.clone());
606
607 while let Some(v) = queue.pop_front() {
608 stack.push(v.clone());
609 let d_v = *dist.get(&v).unwrap_or(&0);
610 let sigma_v = *sigma.get(&v).unwrap_or(&0.0);
611 for rel in inner.relationships.iter().filter(|r| &r.from == &v) {
612 let w = &rel.to;
613 let d_w = dist.get(w).copied().unwrap_or(-1);
614 if d_w < 0 {
615 queue.push_back(w.clone());
616 *dist.entry(w.clone()).or_insert(-1) = d_v + 1;
617 }
618 if dist.get(w).copied().unwrap_or(-1) == d_v + 1 {
619 *sigma.entry(w.clone()).or_insert(0.0) += sigma_v;
620 predecessors.entry(w.clone()).or_default().push(v.clone());
621 }
622 }
623 }
624
625 let mut delta: HashMap<EntityId, f32> =
626 nodes.iter().map(|id| (id.clone(), 0.0f32)).collect();
627
628 while let Some(w) = stack.pop() {
629 let delta_w = *delta.get(&w).unwrap_or(&0.0);
630 let sigma_w = *sigma.get(&w).unwrap_or(&1.0);
631 for v in predecessors.get(&w).cloned().unwrap_or_default() {
632 let sigma_v = *sigma.get(&v).unwrap_or(&1.0);
633 let contribution = (sigma_v / sigma_w) * (1.0 + delta_w);
634 *delta.entry(v.clone()).or_insert(0.0) += contribution;
635 }
636 if &w != source {
637 *centrality.entry(w.clone()).or_insert(0.0) += delta_w;
638 }
639 }
640 }
641
642 if n > 2 {
644 let norm = 2.0 / (((n - 1) * (n - 2)) as f32);
645 for v in centrality.values_mut() {
646 *v *= norm;
647 }
648 } else {
649 for v in centrality.values_mut() {
650 *v = 0.0;
651 }
652 }
653
654 Ok(centrality)
655 }
656
657 pub fn label_propagation_communities(
662 &self,
663 max_iterations: usize,
664 ) -> Result<HashMap<EntityId, usize>, AgentRuntimeError> {
665 let inner = recover_lock(self.inner.lock(), "label_propagation_communities");
666 let nodes: Vec<EntityId> = inner.entities.keys().cloned().collect();
667
668 let mut labels: HashMap<EntityId, usize> = nodes
670 .iter()
671 .enumerate()
672 .map(|(i, id)| (id.clone(), i))
673 .collect();
674
675 for _ in 0..max_iterations {
676 let mut changed = false;
677 for node in &nodes {
679 let neighbour_labels: Vec<usize> = inner
681 .relationships
682 .iter()
683 .filter(|r| &r.from == node || &r.to == node)
684 .map(|r| {
685 if &r.from == node {
686 labels.get(&r.to).copied().unwrap_or(0)
687 } else {
688 labels.get(&r.from).copied().unwrap_or(0)
689 }
690 })
691 .collect();
692
693 if neighbour_labels.is_empty() {
694 continue;
695 }
696
697 let mut freq: HashMap<usize, usize> = HashMap::new();
699 for &lbl in &neighbour_labels {
700 *freq.entry(lbl).or_insert(0) += 1;
701 }
702 let best = freq
703 .into_iter()
704 .max_by_key(|&(_, count)| count)
705 .map(|(lbl, _)| lbl);
706
707 if let Some(new_label) = best {
708 let current = labels.entry(node.clone()).or_insert(0);
709 if *current != new_label {
710 *current = new_label;
711 changed = true;
712 }
713 }
714 }
715
716 if !changed {
717 break;
718 }
719 }
720
721 Ok(labels)
722 }
723
724 pub fn subgraph(&self, node_ids: &[EntityId]) -> Result<GraphStore, AgentRuntimeError> {
727 let inner = recover_lock(self.inner.lock(), "subgraph");
728 let id_set: HashSet<&EntityId> = node_ids.iter().collect();
729
730 let new_store = GraphStore::new();
731
732 for id in node_ids {
733 let entity = inner
734 .entities
735 .get(id)
736 .ok_or_else(|| {
737 AgentRuntimeError::Graph(format!("entity '{}' not found", id.0))
738 })?
739 .clone();
740 let mut new_inner = recover_lock(new_store.inner.lock(), "subgraph:add_entity");
742 new_inner.entities.insert(entity.id.clone(), entity);
743 }
744
745 for rel in inner.relationships.iter() {
746 if id_set.contains(&rel.from) && id_set.contains(&rel.to) {
747 let mut new_inner =
748 recover_lock(new_store.inner.lock(), "subgraph:add_relationship");
749 new_inner.relationships.push(rel.clone());
750 }
751 }
752
753 Ok(new_store)
754 }
755}
756
757impl Default for GraphStore {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763#[cfg(test)]
766mod tests {
767 use super::*;
768
769 fn make_graph() -> GraphStore {
770 GraphStore::new()
771 }
772
773 fn add(g: &GraphStore, id: &str) {
774 g.add_entity(Entity::new(id, "Node")).unwrap();
775 }
776
777 fn link(g: &GraphStore, from: &str, to: &str) {
778 g.add_relationship(Relationship::new(from, to, "CONNECTS", 1.0))
779 .unwrap();
780 }
781
782 fn link_w(g: &GraphStore, from: &str, to: &str, weight: f32) {
783 g.add_relationship(Relationship::new(from, to, "CONNECTS", weight))
784 .unwrap();
785 }
786
787 #[test]
790 fn test_entity_id_equality() {
791 assert_eq!(EntityId::new("a"), EntityId::new("a"));
792 assert_ne!(EntityId::new("a"), EntityId::new("b"));
793 }
794
795 #[test]
796 fn test_entity_id_display() {
797 let id = EntityId::new("hello");
798 assert_eq!(id.to_string(), "hello");
799 }
800
801 #[test]
804 fn test_entity_new_has_empty_properties() {
805 let e = Entity::new("e1", "Person");
806 assert!(e.properties.is_empty());
807 }
808
809 #[test]
810 fn test_entity_with_properties_stores_props() {
811 let mut props = HashMap::new();
812 props.insert("age".into(), Value::Number(42.into()));
813 let e = Entity::with_properties("e1", "Person", props);
814 assert!(e.properties.contains_key("age"));
815 }
816
817 #[test]
820 fn test_graph_add_entity_increments_count() {
821 let g = make_graph();
822 add(&g, "a");
823 assert_eq!(g.entity_count().unwrap(), 1);
824 }
825
826 #[test]
827 fn test_graph_get_entity_returns_entity() {
828 let g = make_graph();
829 g.add_entity(Entity::new("e1", "Person")).unwrap();
830 let e = g.get_entity(&EntityId::new("e1")).unwrap();
831 assert_eq!(e.label, "Person");
832 }
833
834 #[test]
835 fn test_graph_get_entity_missing_returns_error() {
836 let g = make_graph();
837 assert!(g.get_entity(&EntityId::new("ghost")).is_err());
838 }
839
840 #[test]
841 fn test_graph_add_relationship_increments_count() {
842 let g = make_graph();
843 add(&g, "a");
844 add(&g, "b");
845 link(&g, "a", "b");
846 assert_eq!(g.relationship_count().unwrap(), 1);
847 }
848
849 #[test]
850 fn test_graph_add_relationship_missing_source_fails() {
851 let g = make_graph();
852 add(&g, "b");
853 let result = g.add_relationship(Relationship::new("ghost", "b", "X", 1.0));
854 assert!(result.is_err());
855 }
856
857 #[test]
858 fn test_graph_add_relationship_missing_target_fails() {
859 let g = make_graph();
860 add(&g, "a");
861 let result = g.add_relationship(Relationship::new("a", "ghost", "X", 1.0));
862 assert!(result.is_err());
863 }
864
865 #[test]
866 fn test_graph_remove_entity_removes_relationships() {
867 let g = make_graph();
868 add(&g, "a");
869 add(&g, "b");
870 link(&g, "a", "b");
871 g.remove_entity(&EntityId::new("a")).unwrap();
872 assert_eq!(g.entity_count().unwrap(), 1);
873 assert_eq!(g.relationship_count().unwrap(), 0);
874 }
875
876 #[test]
877 fn test_graph_remove_entity_missing_returns_error() {
878 let g = make_graph();
879 assert!(g.remove_entity(&EntityId::new("ghost")).is_err());
880 }
881
882 #[test]
885 fn test_bfs_finds_direct_neighbours() {
886 let g = make_graph();
887 add(&g, "a");
888 add(&g, "b");
889 add(&g, "c");
890 link(&g, "a", "b");
891 link(&g, "a", "c");
892 let visited = g.bfs(&EntityId::new("a")).unwrap();
893 assert_eq!(visited.len(), 2);
894 }
895
896 #[test]
897 fn test_bfs_traverses_chain() {
898 let g = make_graph();
899 add(&g, "a");
900 add(&g, "b");
901 add(&g, "c");
902 add(&g, "d");
903 link(&g, "a", "b");
904 link(&g, "b", "c");
905 link(&g, "c", "d");
906 let visited = g.bfs(&EntityId::new("a")).unwrap();
907 assert_eq!(visited.len(), 3);
908 assert_eq!(visited[0], EntityId::new("b"));
909 }
910
911 #[test]
912 fn test_bfs_handles_isolated_node() {
913 let g = make_graph();
914 add(&g, "a");
915 let visited = g.bfs(&EntityId::new("a")).unwrap();
916 assert!(visited.is_empty());
917 }
918
919 #[test]
920 fn test_bfs_missing_start_returns_error() {
921 let g = make_graph();
922 assert!(g.bfs(&EntityId::new("ghost")).is_err());
923 }
924
925 #[test]
928 fn test_dfs_visits_all_reachable_nodes() {
929 let g = make_graph();
930 add(&g, "a");
931 add(&g, "b");
932 add(&g, "c");
933 add(&g, "d");
934 link(&g, "a", "b");
935 link(&g, "a", "c");
936 link(&g, "b", "d");
937 let visited = g.dfs(&EntityId::new("a")).unwrap();
938 assert_eq!(visited.len(), 3);
939 }
940
941 #[test]
942 fn test_dfs_handles_isolated_node() {
943 let g = make_graph();
944 add(&g, "a");
945 let visited = g.dfs(&EntityId::new("a")).unwrap();
946 assert!(visited.is_empty());
947 }
948
949 #[test]
950 fn test_dfs_missing_start_returns_error() {
951 let g = make_graph();
952 assert!(g.dfs(&EntityId::new("ghost")).is_err());
953 }
954
955 #[test]
958 fn test_shortest_path_direct_connection() {
959 let g = make_graph();
960 add(&g, "a");
961 add(&g, "b");
962 link(&g, "a", "b");
963 let path = g
964 .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
965 .unwrap();
966 assert_eq!(path, Some(vec![EntityId::new("a"), EntityId::new("b")]));
967 }
968
969 #[test]
970 fn test_shortest_path_multi_hop() {
971 let g = make_graph();
972 add(&g, "a");
973 add(&g, "b");
974 add(&g, "c");
975 link(&g, "a", "b");
976 link(&g, "b", "c");
977 let path = g
978 .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
979 .unwrap();
980 assert_eq!(path.as_ref().map(|p| p.len()), Some(3));
981 }
982
983 #[test]
984 fn test_shortest_path_returns_none_for_disconnected() {
985 let g = make_graph();
986 add(&g, "a");
987 add(&g, "b");
988 let path = g
989 .shortest_path(&EntityId::new("a"), &EntityId::new("b"))
990 .unwrap();
991 assert_eq!(path, None);
992 }
993
994 #[test]
995 fn test_shortest_path_same_node_returns_single_element() {
996 let g = make_graph();
997 add(&g, "a");
998 let path = g
999 .shortest_path(&EntityId::new("a"), &EntityId::new("a"))
1000 .unwrap();
1001 assert_eq!(path, Some(vec![EntityId::new("a")]));
1002 }
1003
1004 #[test]
1005 fn test_shortest_path_missing_source_returns_error() {
1006 let g = make_graph();
1007 add(&g, "b");
1008 assert!(g
1009 .shortest_path(&EntityId::new("ghost"), &EntityId::new("b"))
1010 .is_err());
1011 }
1012
1013 #[test]
1014 fn test_shortest_path_missing_target_returns_error() {
1015 let g = make_graph();
1016 add(&g, "a");
1017 assert!(g
1018 .shortest_path(&EntityId::new("a"), &EntityId::new("ghost"))
1019 .is_err());
1020 }
1021
1022 #[test]
1025 fn test_transitive_closure_includes_start() {
1026 let g = make_graph();
1027 add(&g, "a");
1028 add(&g, "b");
1029 link(&g, "a", "b");
1030 let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1031 assert!(closure.contains(&EntityId::new("a")));
1032 assert!(closure.contains(&EntityId::new("b")));
1033 }
1034
1035 #[test]
1036 fn test_transitive_closure_isolated_node_contains_only_self() {
1037 let g = make_graph();
1038 add(&g, "a");
1039 let closure = g.transitive_closure(&EntityId::new("a")).unwrap();
1040 assert_eq!(closure.len(), 1);
1041 }
1042
1043 #[test]
1046 fn test_mem_graph_error_converts_to_runtime_error() {
1047 let e = MemGraphError::EntityNotFound("x".into());
1048 let re: AgentRuntimeError = e.into();
1049 assert!(matches!(re, AgentRuntimeError::Graph(_)));
1050 }
1051
1052 #[test]
1055 fn test_shortest_path_weighted_simple() {
1056 let g = make_graph();
1059 add(&g, "a");
1060 add(&g, "b");
1061 add(&g, "c");
1062 link_w(&g, "a", "b", 1.0);
1063 link_w(&g, "b", "c", 2.0);
1064 g.add_relationship(Relationship::new("a", "c", "DIRECT", 10.0))
1065 .unwrap();
1066
1067 let result = g
1068 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("c"))
1069 .unwrap();
1070 assert!(result.is_some());
1071 let (path, weight) = result.unwrap();
1072 assert_eq!(
1074 path,
1075 vec![EntityId::new("a"), EntityId::new("b"), EntityId::new("c")]
1076 );
1077 assert!((weight - 3.0).abs() < 1e-5);
1078 }
1079
1080 #[test]
1081 fn test_shortest_path_weighted_returns_none_for_disconnected() {
1082 let g = make_graph();
1083 add(&g, "a");
1084 add(&g, "b");
1085 let result = g
1086 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"))
1087 .unwrap();
1088 assert!(result.is_none());
1089 }
1090
1091 #[test]
1092 fn test_shortest_path_weighted_same_node() {
1093 let g = make_graph();
1094 add(&g, "a");
1095 let result = g
1096 .shortest_path_weighted(&EntityId::new("a"), &EntityId::new("a"))
1097 .unwrap();
1098 assert_eq!(result, Some((vec![EntityId::new("a")], 0.0)));
1099 }
1100
1101 #[test]
1102 fn test_shortest_path_weighted_negative_weight_errors() {
1103 let g = make_graph();
1104 add(&g, "a");
1105 add(&g, "b");
1106 g.add_relationship(Relationship::new("a", "b", "NEG", -1.0))
1107 .unwrap();
1108 let result = g.shortest_path_weighted(&EntityId::new("a"), &EntityId::new("b"));
1109 assert!(result.is_err());
1110 }
1111
1112 #[test]
1115 fn test_degree_centrality_basic() {
1116 let g = make_graph();
1120 add(&g, "a");
1121 add(&g, "b");
1122 add(&g, "c");
1123 add(&g, "d");
1124 link(&g, "a", "b");
1125 link(&g, "a", "c");
1126 link(&g, "a", "d");
1127
1128 let centrality = g.degree_centrality().unwrap();
1129 let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1130 let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1131
1132 assert!((a_cent - 1.0).abs() < 1e-5, "a centrality was {a_cent}");
1133 assert!(
1134 (b_cent - 1.0 / 3.0).abs() < 1e-5,
1135 "b centrality was {b_cent}"
1136 );
1137 }
1138
1139 #[test]
1142 fn test_betweenness_centrality_chain() {
1143 let g = make_graph();
1147 add(&g, "a");
1148 add(&g, "b");
1149 add(&g, "c");
1150 add(&g, "d");
1151 link(&g, "a", "b");
1152 link(&g, "b", "c");
1153 link(&g, "c", "d");
1154
1155 let centrality = g.betweenness_centrality().unwrap();
1156 let a_cent = *centrality.get(&EntityId::new("a")).unwrap();
1157 let b_cent = *centrality.get(&EntityId::new("b")).unwrap();
1158 let c_cent = *centrality.get(&EntityId::new("c")).unwrap();
1159 let d_cent = *centrality.get(&EntityId::new("d")).unwrap();
1160
1161 assert!(
1163 (a_cent).abs() < 1e-5,
1164 "expected a_cent ~ 0, got {a_cent}"
1165 );
1166 assert!(b_cent > 0.0, "expected b_cent > 0, got {b_cent}");
1167 assert!(c_cent > 0.0, "expected c_cent > 0, got {c_cent}");
1168 assert!(
1169 (d_cent).abs() < 1e-5,
1170 "expected d_cent ~ 0, got {d_cent}"
1171 );
1172 }
1173
1174 #[test]
1177 fn test_label_propagation_communities_two_clusters() {
1178 let g = make_graph();
1182 for id in &["a", "b", "c", "x", "y", "z"] {
1183 add(&g, id);
1184 }
1185 link(&g, "a", "b");
1187 link(&g, "b", "a");
1188 link(&g, "b", "c");
1189 link(&g, "c", "b");
1190 link(&g, "a", "c");
1191 link(&g, "c", "a");
1192 link(&g, "x", "y");
1194 link(&g, "y", "x");
1195 link(&g, "y", "z");
1196 link(&g, "z", "y");
1197 link(&g, "x", "z");
1198 link(&g, "z", "x");
1199
1200 let communities = g.label_propagation_communities(100).unwrap();
1201
1202 let label_a = communities[&EntityId::new("a")];
1203 let label_b = communities[&EntityId::new("b")];
1204 let label_c = communities[&EntityId::new("c")];
1205 let label_x = communities[&EntityId::new("x")];
1206 let label_y = communities[&EntityId::new("y")];
1207 let label_z = communities[&EntityId::new("z")];
1208
1209 assert_eq!(label_a, label_b, "a and b should be in same community");
1212 assert_eq!(label_b, label_c, "b and c should be in same community");
1213 assert_eq!(label_x, label_y, "x and y should be in same community");
1214 assert_eq!(label_y, label_z, "y and z should be in same community");
1215 assert_ne!(
1216 label_a, label_x,
1217 "cluster 1 and cluster 2 should be different communities"
1218 );
1219 }
1220
1221 #[test]
1224 fn test_subgraph_extracts_correct_nodes_and_edges() {
1225 let g = make_graph();
1228 add(&g, "a");
1229 add(&g, "b");
1230 add(&g, "c");
1231 add(&g, "d");
1232 link(&g, "a", "b");
1233 link(&g, "b", "c");
1234 link(&g, "c", "d");
1235
1236 let sub = g
1237 .subgraph(&[
1238 EntityId::new("a"),
1239 EntityId::new("b"),
1240 EntityId::new("c"),
1241 ])
1242 .unwrap();
1243
1244 assert_eq!(sub.entity_count().unwrap(), 3);
1245 assert_eq!(sub.relationship_count().unwrap(), 2);
1246
1247 assert!(sub.get_entity(&EntityId::new("d")).is_err());
1249
1250 let path = sub
1252 .shortest_path(&EntityId::new("a"), &EntityId::new("c"))
1253 .unwrap();
1254 assert!(path.is_some());
1255 assert_eq!(path.unwrap().len(), 3);
1256 }
1257}