1use serde::{Deserialize, Serialize};
17use std::collections::{HashMap, HashSet, VecDeque};
18
19pub type NodeId = i64;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GraphNode {
25 pub id: NodeId,
27
28 pub node_type: NodeType,
30
31 pub parent_id: Option<NodeId>,
33
34 pub children: Vec<NodeId>,
36
37 pub depth: u32,
39
40 pub weight: f32,
42
43 pub category: String,
45
46 pub memory_lane_type: Option<String>,
48}
49
50impl GraphNode {
51 pub fn new(id: NodeId, node_type: NodeType, category: String) -> Self {
53 Self {
54 id,
55 node_type,
56 parent_id: None,
57 children: Vec::new(),
58 depth: 0,
59 weight: 1.0,
60 category,
61 memory_lane_type: None,
62 }
63 }
64
65 pub fn is_leaf(&self) -> bool {
67 self.children.is_empty()
68 }
69
70 pub fn is_root(&self) -> bool {
72 self.parent_id.is_none()
73 }
74
75 pub fn add_child(&mut self, child_id: NodeId) {
77 if !self.children.contains(&child_id) {
78 self.children.push(child_id);
79 }
80 }
81
82 pub fn remove_child(&mut self, child_id: NodeId) -> bool {
84 if let Some(pos) = self.children.iter().position(|&id| id == child_id) {
85 self.children.remove(pos);
86 true
87 } else {
88 false
89 }
90 }
91
92 pub fn set_priority_weight(&mut self, priority: u8) {
94 self.weight = match priority {
95 1 => 1.5, 2 => 1.2, _ => 1.0, };
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
104pub enum NodeType {
105 CategoryRoot,
107
108 LaneTypeNode,
110
111 MemoryLeaf,
113
114 TimeCluster,
116}
117
118#[derive(Debug, Clone)]
120pub struct TreeNode {
121 pub node: GraphNode,
123
124 pub path_weight: f32,
126
127 pub distance: u32,
129}
130
131impl TreeNode {
132 pub fn new(node: GraphNode) -> Self {
134 let weight = node.weight;
135 Self {
136 node,
137 path_weight: weight,
138 distance: 0,
139 }
140 }
141
142 pub fn child(node: GraphNode, parent: &TreeNode) -> Self {
144 Self {
145 path_weight: parent.path_weight * node.weight,
146 distance: parent.distance + 1,
147 node,
148 }
149 }
150}
151
152#[derive(Debug, Clone, Default)]
154pub struct GraphTree {
155 nodes: HashMap<NodeId, GraphNode>,
157
158 roots: Vec<NodeId>,
160
161 category_roots: HashMap<String, NodeId>,
163
164 next_synthetic_id: NodeId,
166}
167
168impl GraphTree {
169 pub fn new() -> Self {
171 Self {
172 nodes: HashMap::new(),
173 roots: Vec::new(),
174 category_roots: HashMap::new(),
175 next_synthetic_id: -1, }
177 }
178
179 pub fn add_memory(
181 &mut self,
182 memory_id: NodeId,
183 category: &str,
184 memory_lane_type: Option<&str>,
185 priority: Option<u8>,
186 ) {
187 let category_root_id = self.get_or_create_category_root(category);
189
190 let mut node = GraphNode::new(memory_id, NodeType::MemoryLeaf, category.to_string());
192 node.parent_id = Some(category_root_id);
193 node.memory_lane_type = memory_lane_type.map(|s| s.to_string());
194
195 if let Some(p) = priority {
196 node.set_priority_weight(p);
197 }
198
199 if let Some(root) = self.nodes.get_mut(&category_root_id) {
201 root.add_child(memory_id);
202 node.depth = root.depth + 1;
203 }
204
205 self.nodes.insert(memory_id, node);
206 }
207
208 pub fn remove_memory(&mut self, memory_id: NodeId) -> bool {
210 if let Some(node) = self.nodes.remove(&memory_id) {
211 if let Some(parent_id) = node.parent_id {
213 if let Some(parent) = self.nodes.get_mut(&parent_id) {
214 parent.remove_child(memory_id);
215 }
216 }
217 true
218 } else {
219 false
220 }
221 }
222
223 pub fn get(&self, id: NodeId) -> Option<&GraphNode> {
225 self.nodes.get(&id)
226 }
227
228 pub fn get_memories_by_category(&self, category: &str) -> Vec<NodeId> {
230 let mut result = Vec::new();
231
232 if let Some(&root_id) = self.category_roots.get(category) {
233 self.collect_leaf_ids(root_id, &mut result);
234 }
235
236 result
237 }
238
239 pub fn get_memories_by_lane_type(&self, lane_type: &str) -> Vec<NodeId> {
241 self.nodes
242 .values()
243 .filter(|n| {
244 n.node_type == NodeType::MemoryLeaf
245 && n.memory_lane_type.as_deref() == Some(lane_type)
246 })
247 .map(|n| n.id)
248 .collect()
249 }
250
251 pub fn get_ancestors(&self, node_id: NodeId) -> Vec<NodeId> {
253 let mut ancestors = Vec::new();
254 let mut current = self.nodes.get(&node_id);
255
256 while let Some(node) = current {
257 if let Some(parent_id) = node.parent_id {
258 ancestors.push(parent_id);
259 current = self.nodes.get(&parent_id);
260 } else {
261 break;
262 }
263 }
264
265 ancestors
266 }
267
268 pub fn get_descendants(&self, node_id: NodeId) -> Vec<NodeId> {
270 let mut result = Vec::new();
271 let mut queue = vec![node_id];
272 let mut visited = HashSet::new();
273
274 while let Some(id) = queue.pop() {
275 if visited.contains(&id) {
276 continue;
277 }
278 visited.insert(id);
279
280 if let Some(node) = self.nodes.get(&id) {
281 for &child_id in &node.children {
282 if !visited.contains(&child_id) {
283 result.push(child_id);
284 queue.push(child_id);
285 }
286 }
287 }
288 }
289
290 result
291 }
292
293 pub fn calculate_boosted_score(&self, memory_id: NodeId, base_similarity: f32) -> f32 {
295 if let Some(node) = self.nodes.get(&memory_id) {
296 let weight = node.weight;
298
299 let depth_factor = 1.0 - (node.depth as f32 * 0.02);
301
302 let ancestor_boost = self.calculate_ancestor_boost(memory_id);
304
305 base_similarity * weight * depth_factor.max(0.8) * ancestor_boost
306 } else {
307 base_similarity
308 }
309 }
310
311 pub fn stats(&self) -> TreeStats {
313 let memory_count = self
314 .nodes
315 .values()
316 .filter(|node| node.node_type == NodeType::MemoryLeaf)
317 .count();
318 let max_depth = self
319 .nodes
320 .values()
321 .map(|node| node.depth)
322 .max()
323 .unwrap_or(0);
324
325 TreeStats {
326 total_nodes: self.nodes.len(),
327 root_count: self.roots.len(),
328 category_count: self.category_roots.len(),
329 memory_count,
330 max_depth,
331 }
332 }
333
334 fn get_or_create_category_root(&mut self, category: &str) -> NodeId {
337 if let Some(&id) = self.category_roots.get(category) {
338 return id;
339 }
340
341 let root_id = self.next_synthetic_id;
342 self.next_synthetic_id -= 1;
343
344 let mut root = GraphNode::new(root_id, NodeType::CategoryRoot, category.to_string());
345 root.depth = 0;
346
347 self.nodes.insert(root_id, root.clone());
348 self.roots.push(root_id);
349 self.category_roots.insert(category.to_string(), root_id);
350
351 root_id
352 }
353
354 fn collect_leaf_ids(&self, node_id: NodeId, result: &mut Vec<NodeId>) {
355 if let Some(node) = self.nodes.get(&node_id) {
356 if node.node_type == NodeType::MemoryLeaf {
357 result.push(node_id);
358 }
359 for &child_id in &node.children {
360 self.collect_leaf_ids(child_id, result);
361 }
362 }
363 }
364
365 fn calculate_ancestor_boost(&self, node_id: NodeId) -> f32 {
366 let ancestors = self.get_ancestors(node_id);
367 if ancestors.is_empty() {
368 return 1.0;
369 }
370
371 let total_weight: f32 = ancestors
372 .iter()
373 .filter_map(|id| self.nodes.get(id))
374 .map(|n| n.weight)
375 .product();
376
377 (total_weight / ancestors.len() as f32).clamp(0.8, 1.2)
379 }
380
381 pub fn traverse_bfs(&self, start_id: NodeId) -> Vec<NodeId> {
385 let mut result = Vec::new();
386 let mut visited = HashSet::new();
387 let mut queue = VecDeque::new();
388
389 if self.nodes.contains_key(&start_id) {
390 queue.push_back(start_id);
391 visited.insert(start_id);
392 }
393
394 while let Some(node_id) = queue.pop_front() {
395 result.push(node_id);
396
397 if let Some(node) = self.nodes.get(&node_id) {
398 for &child_id in &node.children {
399 if !visited.contains(&child_id) {
400 visited.insert(child_id);
401 queue.push_back(child_id);
402 }
403 }
404 }
405 }
406
407 result
408 }
409
410 pub fn traverse_dfs_preorder(&self, start_id: NodeId) -> Vec<NodeId> {
412 let mut result = Vec::new();
413 let mut visited = HashSet::new();
414 self.dfs_preorder_helper(start_id, &mut visited, &mut result);
415 result
416 }
417
418 fn dfs_preorder_helper(
419 &self,
420 node_id: NodeId,
421 visited: &mut HashSet<NodeId>,
422 result: &mut Vec<NodeId>,
423 ) {
424 if visited.contains(&node_id) || !self.nodes.contains_key(&node_id) {
425 return;
426 }
427
428 visited.insert(node_id);
429 result.push(node_id);
430
431 if let Some(node) = self.nodes.get(&node_id) {
432 for &child_id in &node.children {
433 self.dfs_preorder_helper(child_id, visited, result);
434 }
435 }
436 }
437
438 pub fn traverse_dfs_postorder(&self, start_id: NodeId) -> Vec<NodeId> {
440 let mut result = Vec::new();
441 let mut visited = HashSet::new();
442 self.dfs_postorder_helper(start_id, &mut visited, &mut result);
443 result
444 }
445
446 fn dfs_postorder_helper(
447 &self,
448 node_id: NodeId,
449 visited: &mut HashSet<NodeId>,
450 result: &mut Vec<NodeId>,
451 ) {
452 if visited.contains(&node_id) || !self.nodes.contains_key(&node_id) {
453 return;
454 }
455
456 visited.insert(node_id);
457
458 if let Some(node) = self.nodes.get(&node_id) {
459 for &child_id in &node.children {
460 self.dfs_postorder_helper(child_id, visited, result);
461 }
462 }
463
464 result.push(node_id);
465 }
466
467 pub fn get_nodes_at_depth(&self, depth: u32) -> Vec<NodeId> {
469 self.nodes
470 .values()
471 .filter(|n| n.depth == depth)
472 .map(|n| n.id)
473 .collect()
474 }
475
476 pub fn get_all_leaves(&self) -> Vec<NodeId> {
478 self.nodes
479 .values()
480 .filter(|n| n.is_leaf() && n.node_type == NodeType::MemoryLeaf)
481 .map(|n| n.id)
482 .collect()
483 }
484
485 pub fn get_path(&self, node_id: NodeId) -> Vec<NodeId> {
487 let mut path = Vec::new();
488 let mut current = node_id;
489
490 while let Some(node) = self.nodes.get(¤t) {
491 path.push(current);
492 if let Some(parent_id) = node.parent_id {
493 current = parent_id;
494 } else {
495 break;
496 }
497 }
498
499 path.reverse(); path
501 }
502
503 pub fn find_lca(&self, node_a: NodeId, node_b: NodeId) -> Option<NodeId> {
505 let mut path_a: HashSet<NodeId> = self.get_ancestors(node_a).into_iter().collect();
506 path_a.insert(node_a);
507
508 if path_a.contains(&node_b) {
510 return Some(node_b);
511 }
512
513 let mut current = node_b;
514 loop {
515 if path_a.contains(¤t) {
516 return Some(current);
517 }
518
519 if let Some(node) = self.nodes.get(¤t) {
520 if let Some(parent_id) = node.parent_id {
521 current = parent_id;
522 } else {
523 break;
524 }
525 } else {
526 break;
527 }
528 }
529
530 None
531 }
532
533 pub fn distance(&self, node_a: NodeId, node_b: NodeId) -> Option<u32> {
535 let lca = self.find_lca(node_a, node_b)?;
536
537 let dist_to_lca = |node_id: NodeId| -> u32 {
538 let mut dist = 0;
539 let mut current = node_id;
540
541 while current != lca {
542 if let Some(node) = self.nodes.get(¤t) {
543 if let Some(parent_id) = node.parent_id {
544 current = parent_id;
545 dist += 1;
546 } else {
547 break;
548 }
549 } else {
550 break;
551 }
552 }
553
554 dist
555 };
556
557 Some(dist_to_lca(node_a) + dist_to_lca(node_b))
558 }
559
560 pub fn subtree_size(&self, node_id: NodeId) -> usize {
562 let descendants = self.get_descendants(node_id);
563 descendants.len() + 1 }
565
566 pub fn prune_below_depth(&mut self, max_depth: u32) -> Vec<NodeId> {
568 let to_remove: Vec<NodeId> = self
569 .nodes
570 .values()
571 .filter(|n| n.depth > max_depth)
572 .map(|n| n.id)
573 .collect();
574
575 let mut removed = Vec::new();
576 for id in to_remove {
577 if self.remove_memory(id) {
578 removed.push(id);
579 }
580 }
581
582 removed
583 }
584
585 pub fn find_matching<F>(&self, predicate: F) -> Vec<NodeId>
587 where
588 F: Fn(&GraphNode) -> bool,
589 {
590 self.nodes
591 .values()
592 .filter(|n| predicate(n))
593 .map(|n| n.id)
594 .collect()
595 }
596
597 pub fn get_siblings(&self, node_id: NodeId) -> Vec<NodeId> {
599 let node = match self.nodes.get(&node_id) {
600 Some(n) => n,
601 None => return Vec::new(),
602 };
603
604 let parent_id = match node.parent_id {
605 Some(id) => id,
606 None => return Vec::new(),
607 };
608
609 let parent = match self.nodes.get(&parent_id) {
610 Some(p) => p,
611 None => return Vec::new(),
612 };
613
614 parent
615 .children
616 .iter()
617 .filter(|&&id| id != node_id)
618 .copied()
619 .collect()
620 }
621
622 pub fn rebalance_weights(&mut self) {
624 let mut level_weights: HashMap<u32, Vec<f32>> = HashMap::new();
626
627 for node in self.nodes.values() {
628 level_weights
629 .entry(node.depth)
630 .or_default()
631 .push(node.weight);
632 }
633
634 let mut level_avgs: HashMap<u32, f32> = HashMap::new();
635 for (depth, weights) in level_weights {
636 let avg = weights.iter().sum::<f32>() / weights.len() as f32;
637 level_avgs.insert(depth, avg);
638 }
639
640 for node in self.nodes.values_mut() {
642 if let Some(&avg) = level_avgs.get(&node.depth) {
643 if avg > 0.0 {
644 node.weight = (node.weight / avg).clamp(0.5, 2.0);
645 }
646 }
647 }
648 }
649}
650
651#[derive(Debug, Clone, Default, Serialize, Deserialize)]
653pub struct TreeStats {
654 pub total_nodes: usize,
656
657 pub root_count: usize,
659
660 pub category_count: usize,
662
663 pub memory_count: usize,
665
666 pub max_depth: u32,
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[test]
675 fn test_graph_node_new() {
676 let node = GraphNode::new(1, NodeType::MemoryLeaf, "general".to_string());
677 assert_eq!(node.id, 1);
678 assert!(node.is_leaf());
679 assert!(node.is_root());
680 assert_eq!(node.weight, 1.0);
681 }
682
683 #[test]
684 fn test_graph_node_add_remove_child() {
685 let mut parent = GraphNode::new(1, NodeType::CategoryRoot, "general".to_string());
686 parent.add_child(2);
687 assert_eq!(parent.children.len(), 1);
688 assert!(!parent.is_leaf());
689
690 parent.add_child(2); assert_eq!(parent.children.len(), 1);
692
693 assert!(parent.remove_child(2));
694 assert!(parent.is_leaf());
695 assert!(!parent.remove_child(999)); }
697
698 #[test]
699 fn test_graph_node_priority_weight() {
700 let mut node = GraphNode::new(1, NodeType::MemoryLeaf, "general".to_string());
701
702 node.set_priority_weight(1);
703 assert!((node.weight - 1.5).abs() < 0.01);
704
705 node.set_priority_weight(2);
706 assert!((node.weight - 1.2).abs() < 0.01);
707
708 node.set_priority_weight(3);
709 assert!((node.weight - 1.0).abs() < 0.01);
710 }
711
712 #[test]
713 fn test_tree_node_creation() {
714 let node = GraphNode::new(1, NodeType::MemoryLeaf, "general".to_string());
715 let tree_node = TreeNode::new(node.clone());
716
717 assert_eq!(tree_node.path_weight, 1.0);
718 assert_eq!(tree_node.distance, 0);
719 }
720
721 #[test]
722 fn test_graph_tree_add_memory() {
723 let mut tree = GraphTree::new();
724 tree.add_memory(100, "general", None, None);
725
726 assert!(tree.get(100).is_some());
727 let node = tree.get(100).unwrap();
728 assert_eq!(node.node_type, NodeType::MemoryLeaf);
729 assert!(node.parent_id.is_some());
730 }
731
732 #[test]
733 fn test_graph_tree_remove_memory() {
734 let mut tree = GraphTree::new();
735 tree.add_memory(100, "general", None, None);
736
737 assert!(tree.remove_memory(100));
738 assert!(tree.get(100).is_none());
739 assert!(!tree.remove_memory(100)); }
741
742 #[test]
743 fn test_graph_tree_get_by_category() {
744 let mut tree = GraphTree::new();
745 tree.add_memory(100, "general", None, None);
746 tree.add_memory(101, "general", None, None);
747 tree.add_memory(102, "facts", None, None);
748
749 let general = tree.get_memories_by_category("general");
750 assert_eq!(general.len(), 2);
751 assert!(general.contains(&100));
752 assert!(general.contains(&101));
753
754 let facts = tree.get_memories_by_category("facts");
755 assert_eq!(facts.len(), 1);
756 assert!(facts.contains(&102));
757 }
758
759 #[test]
760 fn test_graph_tree_get_by_lane_type() {
761 let mut tree = GraphTree::new();
762 tree.add_memory(100, "general", Some("correction"), None);
763 tree.add_memory(101, "general", Some("insight"), None);
764 tree.add_memory(102, "facts", Some("correction"), None);
765
766 let corrections = tree.get_memories_by_lane_type("correction");
767 assert_eq!(corrections.len(), 2);
768 }
769
770 #[test]
771 fn test_graph_tree_ancestors() {
772 let mut tree = GraphTree::new();
773 tree.add_memory(100, "general", None, None);
774
775 let ancestors = tree.get_ancestors(100);
776 assert_eq!(ancestors.len(), 1); }
778
779 #[test]
780 fn test_graph_tree_boosted_score() {
781 let mut tree = GraphTree::new();
782 tree.add_memory(100, "general", Some("correction"), Some(1)); let score = tree.calculate_boosted_score(100, 0.8);
785 assert!(score > 0.8);
787 }
788
789 #[test]
790 fn test_graph_tree_stats() {
791 let mut tree = GraphTree::new();
792 tree.add_memory(100, "general", None, None);
793 tree.add_memory(101, "facts", None, None);
794
795 let stats = tree.stats();
796 assert_eq!(stats.memory_count, 2);
797 assert_eq!(stats.category_count, 2);
798 assert!(stats.total_nodes >= 4); }
800
801 #[test]
802 fn test_traverse_bfs() {
803 let mut tree = GraphTree::new();
804 tree.add_memory(100, "general", None, None);
805 tree.add_memory(101, "general", None, None);
806
807 let root_id = tree.category_roots.get("general").copied().unwrap();
809 let bfs_order = tree.traverse_bfs(root_id);
810
811 assert!(!bfs_order.is_empty());
812 assert_eq!(bfs_order[0], root_id); }
814
815 #[test]
816 fn test_traverse_dfs_preorder() {
817 let mut tree = GraphTree::new();
818 tree.add_memory(100, "general", None, None);
819 tree.add_memory(101, "general", None, None);
820
821 let root_id = tree.category_roots.get("general").copied().unwrap();
822 let dfs_order = tree.traverse_dfs_preorder(root_id);
823
824 assert!(!dfs_order.is_empty());
825 assert_eq!(dfs_order[0], root_id);
826 }
827
828 #[test]
829 fn test_get_path() {
830 let mut tree = GraphTree::new();
831 tree.add_memory(100, "general", None, None);
832
833 let path = tree.get_path(100);
834 assert_eq!(path.len(), 2); assert_eq!(path[path.len() - 1], 100);
836 }
837
838 #[test]
839 fn test_find_lca() {
840 let mut tree = GraphTree::new();
841 tree.add_memory(100, "general", None, None);
842 tree.add_memory(101, "general", None, None);
843
844 let lca = tree.find_lca(100, 101);
845 assert!(lca.is_some());
846
847 let root_id = tree.category_roots.get("general").copied();
849 assert_eq!(lca, root_id);
850 }
851
852 #[test]
853 fn test_distance() {
854 let mut tree = GraphTree::new();
855 tree.add_memory(100, "general", None, None);
856 tree.add_memory(101, "general", None, None);
857
858 let dist = tree.distance(100, 101);
859 assert_eq!(dist, Some(2));
861 }
862
863 #[test]
864 fn test_subtree_size() {
865 let mut tree = GraphTree::new();
866 tree.add_memory(100, "general", None, None);
867 tree.add_memory(101, "general", None, None);
868
869 let root_id = tree.category_roots.get("general").copied().unwrap();
870 let size = tree.subtree_size(root_id);
871
872 assert_eq!(size, 3); }
874
875 #[test]
876 fn test_get_siblings() {
877 let mut tree = GraphTree::new();
878 tree.add_memory(100, "general", None, None);
879 tree.add_memory(101, "general", None, None);
880 tree.add_memory(102, "general", None, None);
881
882 let siblings = tree.get_siblings(100);
883 assert_eq!(siblings.len(), 2);
884 assert!(siblings.contains(&101));
885 assert!(siblings.contains(&102));
886 }
887
888 #[test]
889 fn test_get_all_leaves() {
890 let mut tree = GraphTree::new();
891 tree.add_memory(100, "general", None, None);
892 tree.add_memory(101, "facts", None, None);
893
894 let leaves = tree.get_all_leaves();
895 assert_eq!(leaves.len(), 2);
896 }
897
898 #[test]
899 fn test_get_nodes_at_depth() {
900 let mut tree = GraphTree::new();
901 tree.add_memory(100, "general", None, None);
902 tree.add_memory(101, "general", None, None);
903
904 let depth_1 = tree.get_nodes_at_depth(1);
906 assert_eq!(depth_1.len(), 2);
907 }
908
909 #[test]
910 fn test_find_matching() {
911 let mut tree = GraphTree::new();
912 tree.add_memory(100, "general", Some("correction"), Some(1));
913 tree.add_memory(101, "general", None, None);
914
915 let high_priority = tree.find_matching(|n| n.weight > 1.0);
916 assert_eq!(high_priority.len(), 1);
917 assert!(high_priority.contains(&100));
918 }
919}