Skip to main content

nexus_memory_vectors/
graph.rs

1//! Graph tree structure for hierarchical memory organization
2//!
3//! This module implements a graph tree that organizes memories hierarchically
4//! for efficient resource management and improved semantic search.
5//!
6//! ## Tree Structure
7//! - Root nodes: Category containers
8//! - Lane type nodes: Optional intermediate organization
9//! - Leaf nodes: Actual memory items
10//!
11//! ## Relevance Boosting
12//! - Priority weights: High (1.5), Medium (1.2), Low (1.0)
13//! - Depth penalty: Slight reduction for deeper nodes
14//! - Ancestor boost: Aggregated parent weights
15
16use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet, VecDeque};
18
19/// Unique identifier for tree nodes
20pub type NodeId = i64;
21
22/// Graph tree node representing a memory or category
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GraphNode {
25    /// Unique node identifier (memory ID or category node ID)
26    pub id: NodeId,
27
28    /// Node type
29    pub node_type: NodeType,
30
31    /// Parent node ID (None for root)
32    pub parent_id: Option<NodeId>,
33
34    /// Child node IDs
35    pub children: Vec<NodeId>,
36
37    /// Depth in tree (0 for root)
38    pub depth: u32,
39
40    /// Node weight for relevance boosting
41    pub weight: f32,
42
43    /// Category this node belongs to
44    pub category: String,
45
46    /// Optional memory lane type
47    pub memory_lane_type: Option<String>,
48}
49
50impl GraphNode {
51    /// Create a new graph node
52    pub fn new(id: NodeId, node_type: NodeType, category: String) -> Self {
53        Self {
54            id,
55            node_type,
56            parent_id: None,
57            children: Vec::new(),
58            depth: 0,
59            weight: 1.0,
60            category,
61            memory_lane_type: None,
62        }
63    }
64
65    /// Check if this is a leaf node (has no children)
66    pub fn is_leaf(&self) -> bool {
67        self.children.is_empty()
68    }
69
70    /// Check if this is a root node (has no parent)
71    pub fn is_root(&self) -> bool {
72        self.parent_id.is_none()
73    }
74
75    /// Add a child node
76    pub fn add_child(&mut self, child_id: NodeId) {
77        if !self.children.contains(&child_id) {
78            self.children.push(child_id);
79        }
80    }
81
82    /// Remove a child node
83    pub fn remove_child(&mut self, child_id: NodeId) -> bool {
84        if let Some(pos) = self.children.iter().position(|&id| id == child_id) {
85            self.children.remove(pos);
86            true
87        } else {
88            false
89        }
90    }
91
92    /// Set weight based on priority level
93    pub fn set_priority_weight(&mut self, priority: u8) {
94        self.weight = match priority {
95            1 => 1.5, // High priority
96            2 => 1.2, // Medium priority
97            _ => 1.0, // Low/default priority
98        };
99    }
100}
101
102/// Type of graph node
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub enum NodeType {
105    /// Category root node
106    CategoryRoot,
107
108    /// Memory lane type node
109    LaneTypeNode,
110
111    /// Actual memory leaf node
112    MemoryLeaf,
113
114    /// Time-based cluster node
115    TimeCluster,
116}
117
118/// Tree node for traversal operations
119#[derive(Debug, Clone)]
120pub struct TreeNode {
121    /// The graph node
122    pub node: GraphNode,
123
124    /// Cumulative path weight from root
125    pub path_weight: f32,
126
127    /// Distance from root
128    pub distance: u32,
129}
130
131impl TreeNode {
132    /// Create a new tree node
133    pub fn new(node: GraphNode) -> Self {
134        let weight = node.weight;
135        Self {
136            node,
137            path_weight: weight,
138            distance: 0,
139        }
140    }
141
142    /// Create a child tree node
143    pub fn child(node: GraphNode, parent: &TreeNode) -> Self {
144        Self {
145            path_weight: parent.path_weight * node.weight,
146            distance: parent.distance + 1,
147            node,
148        }
149    }
150}
151
152/// Graph tree structure for hierarchical memory organization
153#[derive(Debug, Clone, Default)]
154pub struct GraphTree {
155    /// All nodes indexed by ID
156    nodes: HashMap<NodeId, GraphNode>,
157
158    /// Root node IDs
159    roots: Vec<NodeId>,
160
161    /// Category to root node mapping
162    category_roots: HashMap<String, NodeId>,
163
164    /// Next synthetic node ID (for non-memory nodes)
165    next_synthetic_id: NodeId,
166}
167
168impl GraphTree {
169    /// Create a new empty graph tree
170    pub fn new() -> Self {
171        Self {
172            nodes: HashMap::new(),
173            roots: Vec::new(),
174            category_roots: HashMap::new(),
175            next_synthetic_id: -1, // Synthetic IDs are negative
176        }
177    }
178
179    /// Add a memory node to the tree
180    pub fn add_memory(
181        &mut self,
182        memory_id: NodeId,
183        category: &str,
184        memory_lane_type: Option<&str>,
185        priority: Option<u8>,
186    ) {
187        // Ensure category root exists
188        let category_root_id = self.get_or_create_category_root(category);
189
190        // Create memory node
191        let mut node = GraphNode::new(memory_id, NodeType::MemoryLeaf, category.to_string());
192        node.parent_id = Some(category_root_id);
193        node.memory_lane_type = memory_lane_type.map(|s| s.to_string());
194
195        if let Some(p) = priority {
196            node.set_priority_weight(p);
197        }
198
199        // Add to category root's children
200        if let Some(root) = self.nodes.get_mut(&category_root_id) {
201            root.add_child(memory_id);
202            node.depth = root.depth + 1;
203        }
204
205        self.nodes.insert(memory_id, node);
206    }
207
208    /// Remove a memory node from the tree
209    pub fn remove_memory(&mut self, memory_id: NodeId) -> bool {
210        if let Some(node) = self.nodes.remove(&memory_id) {
211            // Remove from parent's children
212            if let Some(parent_id) = node.parent_id {
213                if let Some(parent) = self.nodes.get_mut(&parent_id) {
214                    parent.remove_child(memory_id);
215                }
216            }
217            true
218        } else {
219            false
220        }
221    }
222
223    /// Get a node by ID
224    pub fn get(&self, id: NodeId) -> Option<&GraphNode> {
225        self.nodes.get(&id)
226    }
227
228    /// Get all memory IDs in a category
229    pub fn get_memories_by_category(&self, category: &str) -> Vec<NodeId> {
230        let mut result = Vec::new();
231
232        if let Some(&root_id) = self.category_roots.get(category) {
233            self.collect_leaf_ids(root_id, &mut result);
234        }
235
236        result
237    }
238
239    /// Get all memory IDs with a specific lane type
240    pub fn get_memories_by_lane_type(&self, lane_type: &str) -> Vec<NodeId> {
241        self.nodes
242            .values()
243            .filter(|n| {
244                n.node_type == NodeType::MemoryLeaf
245                    && n.memory_lane_type.as_deref() == Some(lane_type)
246            })
247            .map(|n| n.id)
248            .collect()
249    }
250
251    /// Get ancestors of a node (path to root)
252    pub fn get_ancestors(&self, node_id: NodeId) -> Vec<NodeId> {
253        let mut ancestors = Vec::new();
254        let mut current = self.nodes.get(&node_id);
255
256        while let Some(node) = current {
257            if let Some(parent_id) = node.parent_id {
258                ancestors.push(parent_id);
259                current = self.nodes.get(&parent_id);
260            } else {
261                break;
262            }
263        }
264
265        ancestors
266    }
267
268    /// Get descendants of a node (BFS)
269    pub fn get_descendants(&self, node_id: NodeId) -> Vec<NodeId> {
270        let mut result = Vec::new();
271        let mut queue = vec![node_id];
272        let mut visited = HashSet::new();
273
274        while let Some(id) = queue.pop() {
275            if visited.contains(&id) {
276                continue;
277            }
278            visited.insert(id);
279
280            if let Some(node) = self.nodes.get(&id) {
281                for &child_id in &node.children {
282                    if !visited.contains(&child_id) {
283                        result.push(child_id);
284                        queue.push(child_id);
285                    }
286                }
287            }
288        }
289
290        result
291    }
292
293    /// Calculate boosted relevance score based on tree structure
294    pub fn calculate_boosted_score(&self, memory_id: NodeId, base_similarity: f32) -> f32 {
295        if let Some(node) = self.nodes.get(&memory_id) {
296            // Apply weight from node
297            let weight = node.weight;
298
299            // Apply depth penalty (deeper nodes get slightly lower scores)
300            let depth_factor = 1.0 - (node.depth as f32 * 0.02);
301
302            // Apply ancestor weight aggregation
303            let ancestor_boost = self.calculate_ancestor_boost(memory_id);
304
305            base_similarity * weight * depth_factor.max(0.8) * ancestor_boost
306        } else {
307            base_similarity
308        }
309    }
310
311    /// Get tree statistics
312    pub fn stats(&self) -> TreeStats {
313        let memory_count = self
314            .nodes
315            .values()
316            .filter(|node| node.node_type == NodeType::MemoryLeaf)
317            .count();
318        let max_depth = self
319            .nodes
320            .values()
321            .map(|node| node.depth)
322            .max()
323            .unwrap_or(0);
324
325        TreeStats {
326            total_nodes: self.nodes.len(),
327            root_count: self.roots.len(),
328            category_count: self.category_roots.len(),
329            memory_count,
330            max_depth,
331        }
332    }
333
334    // Private methods
335
336    fn get_or_create_category_root(&mut self, category: &str) -> NodeId {
337        if let Some(&id) = self.category_roots.get(category) {
338            return id;
339        }
340
341        let root_id = self.next_synthetic_id;
342        self.next_synthetic_id -= 1;
343
344        let mut root = GraphNode::new(root_id, NodeType::CategoryRoot, category.to_string());
345        root.depth = 0;
346
347        self.nodes.insert(root_id, root.clone());
348        self.roots.push(root_id);
349        self.category_roots.insert(category.to_string(), root_id);
350
351        root_id
352    }
353
354    fn collect_leaf_ids(&self, node_id: NodeId, result: &mut Vec<NodeId>) {
355        if let Some(node) = self.nodes.get(&node_id) {
356            if node.node_type == NodeType::MemoryLeaf {
357                result.push(node_id);
358            }
359            for &child_id in &node.children {
360                self.collect_leaf_ids(child_id, result);
361            }
362        }
363    }
364
365    fn calculate_ancestor_boost(&self, node_id: NodeId) -> f32 {
366        let ancestors = self.get_ancestors(node_id);
367        if ancestors.is_empty() {
368            return 1.0;
369        }
370
371        let total_weight: f32 = ancestors
372            .iter()
373            .filter_map(|id| self.nodes.get(id))
374            .map(|n| n.weight)
375            .product();
376
377        // Normalize to reasonable range
378        (total_weight / ancestors.len() as f32).clamp(0.8, 1.2)
379    }
380
381    // === Advanced Tree Traversal Algorithms ===
382
383    /// Breadth-first traversal from a starting node
384    pub fn traverse_bfs(&self, start_id: NodeId) -> Vec<NodeId> {
385        let mut result = Vec::new();
386        let mut visited = HashSet::new();
387        let mut queue = VecDeque::new();
388
389        if self.nodes.contains_key(&start_id) {
390            queue.push_back(start_id);
391            visited.insert(start_id);
392        }
393
394        while let Some(node_id) = queue.pop_front() {
395            result.push(node_id);
396
397            if let Some(node) = self.nodes.get(&node_id) {
398                for &child_id in &node.children {
399                    if !visited.contains(&child_id) {
400                        visited.insert(child_id);
401                        queue.push_back(child_id);
402                    }
403                }
404            }
405        }
406
407        result
408    }
409
410    /// Depth-first traversal (pre-order) from a starting node
411    pub fn traverse_dfs_preorder(&self, start_id: NodeId) -> Vec<NodeId> {
412        let mut result = Vec::new();
413        let mut visited = HashSet::new();
414        self.dfs_preorder_helper(start_id, &mut visited, &mut result);
415        result
416    }
417
418    fn dfs_preorder_helper(
419        &self,
420        node_id: NodeId,
421        visited: &mut HashSet<NodeId>,
422        result: &mut Vec<NodeId>,
423    ) {
424        if visited.contains(&node_id) || !self.nodes.contains_key(&node_id) {
425            return;
426        }
427
428        visited.insert(node_id);
429        result.push(node_id);
430
431        if let Some(node) = self.nodes.get(&node_id) {
432            for &child_id in &node.children {
433                self.dfs_preorder_helper(child_id, visited, result);
434            }
435        }
436    }
437
438    /// Depth-first traversal (post-order) from a starting node
439    pub fn traverse_dfs_postorder(&self, start_id: NodeId) -> Vec<NodeId> {
440        let mut result = Vec::new();
441        let mut visited = HashSet::new();
442        self.dfs_postorder_helper(start_id, &mut visited, &mut result);
443        result
444    }
445
446    fn dfs_postorder_helper(
447        &self,
448        node_id: NodeId,
449        visited: &mut HashSet<NodeId>,
450        result: &mut Vec<NodeId>,
451    ) {
452        if visited.contains(&node_id) || !self.nodes.contains_key(&node_id) {
453            return;
454        }
455
456        visited.insert(node_id);
457
458        if let Some(node) = self.nodes.get(&node_id) {
459            for &child_id in &node.children {
460                self.dfs_postorder_helper(child_id, visited, result);
461            }
462        }
463
464        result.push(node_id);
465    }
466
467    /// Get nodes at a specific depth level
468    pub fn get_nodes_at_depth(&self, depth: u32) -> Vec<NodeId> {
469        self.nodes
470            .values()
471            .filter(|n| n.depth == depth)
472            .map(|n| n.id)
473            .collect()
474    }
475
476    /// Get all leaf nodes (memory entries)
477    pub fn get_all_leaves(&self) -> Vec<NodeId> {
478        self.nodes
479            .values()
480            .filter(|n| n.is_leaf() && n.node_type == NodeType::MemoryLeaf)
481            .map(|n| n.id)
482            .collect()
483    }
484
485    /// Get path from root to a specific node
486    pub fn get_path(&self, node_id: NodeId) -> Vec<NodeId> {
487        let mut path = Vec::new();
488        let mut current = node_id;
489
490        while let Some(node) = self.nodes.get(&current) {
491            path.push(current);
492            if let Some(parent_id) = node.parent_id {
493                current = parent_id;
494            } else {
495                break;
496            }
497        }
498
499        path.reverse(); // Root to leaf order
500        path
501    }
502
503    /// Find lowest common ancestor of two nodes
504    pub fn find_lca(&self, node_a: NodeId, node_b: NodeId) -> Option<NodeId> {
505        let mut path_a: HashSet<NodeId> = self.get_ancestors(node_a).into_iter().collect();
506        path_a.insert(node_a);
507
508        // Check node_b and its ancestors
509        if path_a.contains(&node_b) {
510            return Some(node_b);
511        }
512
513        let mut current = node_b;
514        loop {
515            if path_a.contains(&current) {
516                return Some(current);
517            }
518
519            if let Some(node) = self.nodes.get(&current) {
520                if let Some(parent_id) = node.parent_id {
521                    current = parent_id;
522                } else {
523                    break;
524                }
525            } else {
526                break;
527            }
528        }
529
530        None
531    }
532
533    /// Calculate the distance between two nodes
534    pub fn distance(&self, node_a: NodeId, node_b: NodeId) -> Option<u32> {
535        let lca = self.find_lca(node_a, node_b)?;
536
537        let dist_to_lca = |node_id: NodeId| -> u32 {
538            let mut dist = 0;
539            let mut current = node_id;
540
541            while current != lca {
542                if let Some(node) = self.nodes.get(&current) {
543                    if let Some(parent_id) = node.parent_id {
544                        current = parent_id;
545                        dist += 1;
546                    } else {
547                        break;
548                    }
549                } else {
550                    break;
551                }
552            }
553
554            dist
555        };
556
557        Some(dist_to_lca(node_a) + dist_to_lca(node_b))
558    }
559
560    /// Get subtree size for a node (including itself)
561    pub fn subtree_size(&self, node_id: NodeId) -> usize {
562        let descendants = self.get_descendants(node_id);
563        descendants.len() + 1 // +1 for the node itself
564    }
565
566    /// Prune nodes below a certain depth
567    pub fn prune_below_depth(&mut self, max_depth: u32) -> Vec<NodeId> {
568        let to_remove: Vec<NodeId> = self
569            .nodes
570            .values()
571            .filter(|n| n.depth > max_depth)
572            .map(|n| n.id)
573            .collect();
574
575        let mut removed = Vec::new();
576        for id in to_remove {
577            if self.remove_memory(id) {
578                removed.push(id);
579            }
580        }
581
582        removed
583    }
584
585    /// Find all nodes matching a predicate
586    pub fn find_matching<F>(&self, predicate: F) -> Vec<NodeId>
587    where
588        F: Fn(&GraphNode) -> bool,
589    {
590        self.nodes
591            .values()
592            .filter(|n| predicate(n))
593            .map(|n| n.id)
594            .collect()
595    }
596
597    /// Get siblings of a node
598    pub fn get_siblings(&self, node_id: NodeId) -> Vec<NodeId> {
599        let node = match self.nodes.get(&node_id) {
600            Some(n) => n,
601            None => return Vec::new(),
602        };
603
604        let parent_id = match node.parent_id {
605            Some(id) => id,
606            None => return Vec::new(),
607        };
608
609        let parent = match self.nodes.get(&parent_id) {
610            Some(p) => p,
611            None => return Vec::new(),
612        };
613
614        parent
615            .children
616            .iter()
617            .filter(|&&id| id != node_id)
618            .copied()
619            .collect()
620    }
621
622    /// Rebalance weights in the tree
623    pub fn rebalance_weights(&mut self) {
624        // Calculate average weight per level and normalize
625        let mut level_weights: HashMap<u32, Vec<f32>> = HashMap::new();
626
627        for node in self.nodes.values() {
628            level_weights
629                .entry(node.depth)
630                .or_default()
631                .push(node.weight);
632        }
633
634        let mut level_avgs: HashMap<u32, f32> = HashMap::new();
635        for (depth, weights) in level_weights {
636            let avg = weights.iter().sum::<f32>() / weights.len() as f32;
637            level_avgs.insert(depth, avg);
638        }
639
640        // Normalize weights around average
641        for node in self.nodes.values_mut() {
642            if let Some(&avg) = level_avgs.get(&node.depth) {
643                if avg > 0.0 {
644                    node.weight = (node.weight / avg).clamp(0.5, 2.0);
645                }
646            }
647        }
648    }
649}
650
651/// Statistics about the graph tree
652#[derive(Debug, Clone, Default, Serialize, Deserialize)]
653pub struct TreeStats {
654    /// Total number of nodes
655    pub total_nodes: usize,
656
657    /// Number of root nodes
658    pub root_count: usize,
659
660    /// Number of categories
661    pub category_count: usize,
662
663    /// Number of memory leaf nodes
664    pub memory_count: usize,
665
666    /// Maximum tree depth
667    pub max_depth: u32,
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[test]
675    fn test_graph_node_new() {
676        let node = GraphNode::new(1, NodeType::MemoryLeaf, "general".to_string());
677        assert_eq!(node.id, 1);
678        assert!(node.is_leaf());
679        assert!(node.is_root());
680        assert_eq!(node.weight, 1.0);
681    }
682
683    #[test]
684    fn test_graph_node_add_remove_child() {
685        let mut parent = GraphNode::new(1, NodeType::CategoryRoot, "general".to_string());
686        parent.add_child(2);
687        assert_eq!(parent.children.len(), 1);
688        assert!(!parent.is_leaf());
689
690        parent.add_child(2); // Duplicate should not be added
691        assert_eq!(parent.children.len(), 1);
692
693        assert!(parent.remove_child(2));
694        assert!(parent.is_leaf());
695        assert!(!parent.remove_child(999)); // Non-existent
696    }
697
698    #[test]
699    fn test_graph_node_priority_weight() {
700        let mut node = GraphNode::new(1, NodeType::MemoryLeaf, "general".to_string());
701
702        node.set_priority_weight(1);
703        assert!((node.weight - 1.5).abs() < 0.01);
704
705        node.set_priority_weight(2);
706        assert!((node.weight - 1.2).abs() < 0.01);
707
708        node.set_priority_weight(3);
709        assert!((node.weight - 1.0).abs() < 0.01);
710    }
711
712    #[test]
713    fn test_tree_node_creation() {
714        let node = GraphNode::new(1, NodeType::MemoryLeaf, "general".to_string());
715        let tree_node = TreeNode::new(node.clone());
716
717        assert_eq!(tree_node.path_weight, 1.0);
718        assert_eq!(tree_node.distance, 0);
719    }
720
721    #[test]
722    fn test_graph_tree_add_memory() {
723        let mut tree = GraphTree::new();
724        tree.add_memory(100, "general", None, None);
725
726        assert!(tree.get(100).is_some());
727        let node = tree.get(100).unwrap();
728        assert_eq!(node.node_type, NodeType::MemoryLeaf);
729        assert!(node.parent_id.is_some());
730    }
731
732    #[test]
733    fn test_graph_tree_remove_memory() {
734        let mut tree = GraphTree::new();
735        tree.add_memory(100, "general", None, None);
736
737        assert!(tree.remove_memory(100));
738        assert!(tree.get(100).is_none());
739        assert!(!tree.remove_memory(100)); // Already removed
740    }
741
742    #[test]
743    fn test_graph_tree_get_by_category() {
744        let mut tree = GraphTree::new();
745        tree.add_memory(100, "general", None, None);
746        tree.add_memory(101, "general", None, None);
747        tree.add_memory(102, "facts", None, None);
748
749        let general = tree.get_memories_by_category("general");
750        assert_eq!(general.len(), 2);
751        assert!(general.contains(&100));
752        assert!(general.contains(&101));
753
754        let facts = tree.get_memories_by_category("facts");
755        assert_eq!(facts.len(), 1);
756        assert!(facts.contains(&102));
757    }
758
759    #[test]
760    fn test_graph_tree_get_by_lane_type() {
761        let mut tree = GraphTree::new();
762        tree.add_memory(100, "general", Some("correction"), None);
763        tree.add_memory(101, "general", Some("insight"), None);
764        tree.add_memory(102, "facts", Some("correction"), None);
765
766        let corrections = tree.get_memories_by_lane_type("correction");
767        assert_eq!(corrections.len(), 2);
768    }
769
770    #[test]
771    fn test_graph_tree_ancestors() {
772        let mut tree = GraphTree::new();
773        tree.add_memory(100, "general", None, None);
774
775        let ancestors = tree.get_ancestors(100);
776        assert_eq!(ancestors.len(), 1); // Category root
777    }
778
779    #[test]
780    fn test_graph_tree_boosted_score() {
781        let mut tree = GraphTree::new();
782        tree.add_memory(100, "general", Some("correction"), Some(1)); // High priority
783
784        let score = tree.calculate_boosted_score(100, 0.8);
785        // Score should be boosted by priority weight (1.5)
786        assert!(score > 0.8);
787    }
788
789    #[test]
790    fn test_graph_tree_stats() {
791        let mut tree = GraphTree::new();
792        tree.add_memory(100, "general", None, None);
793        tree.add_memory(101, "facts", None, None);
794
795        let stats = tree.stats();
796        assert_eq!(stats.memory_count, 2);
797        assert_eq!(stats.category_count, 2);
798        assert!(stats.total_nodes >= 4); // 2 memories + 2 category roots
799    }
800
801    #[test]
802    fn test_traverse_bfs() {
803        let mut tree = GraphTree::new();
804        tree.add_memory(100, "general", None, None);
805        tree.add_memory(101, "general", None, None);
806
807        // BFS from root should visit category root first, then leaves
808        let root_id = tree.category_roots.get("general").copied().unwrap();
809        let bfs_order = tree.traverse_bfs(root_id);
810
811        assert!(!bfs_order.is_empty());
812        assert_eq!(bfs_order[0], root_id); // Root first
813    }
814
815    #[test]
816    fn test_traverse_dfs_preorder() {
817        let mut tree = GraphTree::new();
818        tree.add_memory(100, "general", None, None);
819        tree.add_memory(101, "general", None, None);
820
821        let root_id = tree.category_roots.get("general").copied().unwrap();
822        let dfs_order = tree.traverse_dfs_preorder(root_id);
823
824        assert!(!dfs_order.is_empty());
825        assert_eq!(dfs_order[0], root_id);
826    }
827
828    #[test]
829    fn test_get_path() {
830        let mut tree = GraphTree::new();
831        tree.add_memory(100, "general", None, None);
832
833        let path = tree.get_path(100);
834        assert_eq!(path.len(), 2); // Category root -> memory
835        assert_eq!(path[path.len() - 1], 100);
836    }
837
838    #[test]
839    fn test_find_lca() {
840        let mut tree = GraphTree::new();
841        tree.add_memory(100, "general", None, None);
842        tree.add_memory(101, "general", None, None);
843
844        let lca = tree.find_lca(100, 101);
845        assert!(lca.is_some());
846
847        // LCA should be the category root
848        let root_id = tree.category_roots.get("general").copied();
849        assert_eq!(lca, root_id);
850    }
851
852    #[test]
853    fn test_distance() {
854        let mut tree = GraphTree::new();
855        tree.add_memory(100, "general", None, None);
856        tree.add_memory(101, "general", None, None);
857
858        let dist = tree.distance(100, 101);
859        // Both are children of same parent, distance = 2
860        assert_eq!(dist, Some(2));
861    }
862
863    #[test]
864    fn test_subtree_size() {
865        let mut tree = GraphTree::new();
866        tree.add_memory(100, "general", None, None);
867        tree.add_memory(101, "general", None, None);
868
869        let root_id = tree.category_roots.get("general").copied().unwrap();
870        let size = tree.subtree_size(root_id);
871
872        assert_eq!(size, 3); // Root + 2 memories
873    }
874
875    #[test]
876    fn test_get_siblings() {
877        let mut tree = GraphTree::new();
878        tree.add_memory(100, "general", None, None);
879        tree.add_memory(101, "general", None, None);
880        tree.add_memory(102, "general", None, None);
881
882        let siblings = tree.get_siblings(100);
883        assert_eq!(siblings.len(), 2);
884        assert!(siblings.contains(&101));
885        assert!(siblings.contains(&102));
886    }
887
888    #[test]
889    fn test_get_all_leaves() {
890        let mut tree = GraphTree::new();
891        tree.add_memory(100, "general", None, None);
892        tree.add_memory(101, "facts", None, None);
893
894        let leaves = tree.get_all_leaves();
895        assert_eq!(leaves.len(), 2);
896    }
897
898    #[test]
899    fn test_get_nodes_at_depth() {
900        let mut tree = GraphTree::new();
901        tree.add_memory(100, "general", None, None);
902        tree.add_memory(101, "general", None, None);
903
904        // Memories should be at depth 1
905        let depth_1 = tree.get_nodes_at_depth(1);
906        assert_eq!(depth_1.len(), 2);
907    }
908
909    #[test]
910    fn test_find_matching() {
911        let mut tree = GraphTree::new();
912        tree.add_memory(100, "general", Some("correction"), Some(1));
913        tree.add_memory(101, "general", None, None);
914
915        let high_priority = tree.find_matching(|n| n.weight > 1.0);
916        assert_eq!(high_priority.len(), 1);
917        assert!(high_priority.contains(&100));
918    }
919}