rag_plusplus_core/trajectory/
graph.rs

1//! DAG Traversal Algorithms for Trajectory Structure
2//!
3//! Trajectories are DAGs (Directed Acyclic Graphs) where:
4//! - Each node is an episode (unit of experience)
5//! - Edges represent parent→child relationships
6//! - Multiple children = regenerations or branching
7//!
8//! This module provides high-performance algorithms for:
9//! - Path finding (root to leaf traversal)
10//! - Branch detection (identifying decision points)
11//! - Primary path selection (choosing the "best" linear path through the DAG)
12
13use std::collections::{HashMap, HashSet, VecDeque};
14
15/// Unique identifier for a node in the trajectory DAG.
16pub type NodeId = u64;
17
18/// Edge in the trajectory DAG.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Edge {
21    pub parent: NodeId,
22    pub child: NodeId,
23    pub edge_type: EdgeType,
24}
25
26/// Type of edge in the DAG.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
28pub enum EdgeType {
29    /// Normal continuation (single child)
30    #[default]
31    Continuation,
32    /// Regeneration (multiple children from same parent)
33    Regeneration,
34    /// Branch (explicit user-initiated fork)
35    Branch,
36}
37
38/// An episode in the trajectory DAG.
39///
40/// An episode represents a unit of experience - a message turn, interaction,
41/// or decision point within a trajectory.
42#[derive(Debug, Clone)]
43pub struct Episode {
44    pub id: NodeId,
45    pub parent: Option<NodeId>,
46    pub children: Vec<NodeId>,
47    /// Metadata for primary path selection
48    pub weight: f32,
49    pub has_thumbs_up: bool,
50    pub has_thumbs_down: bool,
51    pub content_length: usize,
52    pub has_error: bool,
53    pub created_at: i64,
54}
55
56impl Episode {
57    pub fn new(id: NodeId) -> Self {
58        Self {
59            id,
60            parent: None,
61            children: Vec::new(),
62            weight: 1.0,
63            has_thumbs_up: false,
64            has_thumbs_down: false,
65            content_length: 0,
66            has_error: false,
67            created_at: 0,
68        }
69    }
70
71    /// Check if this episode is a branching point (multiple children).
72    #[inline]
73    pub fn is_branch_point(&self) -> bool {
74        self.children.len() > 1
75    }
76
77    /// Check if this episode is a leaf (no children).
78    #[inline]
79    pub fn is_leaf(&self) -> bool {
80        self.children.is_empty()
81    }
82
83    /// Check if this episode is a root (no parent).
84    #[inline]
85    pub fn is_root(&self) -> bool {
86        self.parent.is_none()
87    }
88}
89
90/// Information about a branch point in the DAG.
91#[derive(Debug, Clone)]
92pub struct BranchInfo {
93    /// The episode where branching occurs
94    pub branch_point: NodeId,
95    /// All children at this branch
96    pub children: Vec<NodeId>,
97    /// Type of branching
98    pub branch_type: EdgeType,
99    /// Index of selected child for primary path
100    pub selected_child_idx: Option<usize>,
101}
102
103/// Policy for selecting which child to follow at branch points.
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
105pub enum PathSelectionPolicy {
106    /// Prefer child with thumbs_up, then longest content, then first by time
107    #[default]
108    FeedbackFirst,
109    /// Always pick first child by creation time
110    FirstByTime,
111    /// Always pick child with longest content
112    LongestContent,
113    /// Pick child with highest weight
114    HighestWeight,
115}
116
117/// Traversal order for DAG walking.
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
119pub enum TraversalOrder {
120    /// Depth-first, root to leaves
121    #[default]
122    DepthFirst,
123    /// Breadth-first, level by level
124    BreadthFirst,
125    /// Topological order (respects dependencies)
126    Topological,
127    /// Reverse topological (leaves to roots)
128    ReverseTopological,
129}
130
131/// Result of finding a path through the DAG.
132#[derive(Debug, Clone)]
133pub struct PathResult {
134    /// Ordered list of node IDs from root to leaf
135    pub nodes: Vec<NodeId>,
136    /// Branch points encountered
137    pub branch_points: Vec<BranchInfo>,
138    /// Total weight along path
139    pub total_weight: f32,
140}
141
142/// Trajectory DAG structure optimized for traversal operations.
143///
144/// A trajectory represents a sequence of experiences (episodes) forming
145/// a path through time. The DAG structure captures branching and
146/// regeneration points.
147#[derive(Debug, Clone)]
148pub struct TrajectoryGraph {
149    nodes: HashMap<NodeId, Episode>,
150    roots: Vec<NodeId>,
151    leaves: Vec<NodeId>,
152}
153
154impl TrajectoryGraph {
155    /// Create a new empty graph.
156    pub fn new() -> Self {
157        Self {
158            nodes: HashMap::new(),
159            roots: Vec::new(),
160            leaves: Vec::new(),
161        }
162    }
163
164    /// Build graph from a list of edges.
165    ///
166    /// # Arguments
167    ///
168    /// * `edges` - Iterator of (parent_id, child_id) pairs
169    ///
170    /// # Example
171    ///
172    /// ```
173    /// use rag_plusplus_core::trajectory::{TrajectoryGraph, Edge, EdgeType};
174    ///
175    /// let edges = vec![
176    ///     Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
177    ///     Edge { parent: 2, child: 3, edge_type: EdgeType::Continuation },
178    ///     Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration }, // Branch
179    /// ];
180    ///
181    /// let graph = TrajectoryGraph::from_edges(edges.iter().copied());
182    /// assert_eq!(graph.node_count(), 4);
183    /// assert!(graph.is_branch_point(2));
184    /// ```
185    pub fn from_edges(edges: impl IntoIterator<Item = Edge>) -> Self {
186        let mut graph = Self::new();
187
188        for edge in edges {
189            // Ensure nodes exist
190            graph.nodes.entry(edge.parent).or_insert_with(|| Episode::new(edge.parent));
191            graph.nodes.entry(edge.child).or_insert_with(|| Episode::new(edge.child));
192
193            // Add edge
194            if let Some(parent) = graph.nodes.get_mut(&edge.parent) {
195                if !parent.children.contains(&edge.child) {
196                    parent.children.push(edge.child);
197                }
198            }
199            if let Some(child) = graph.nodes.get_mut(&edge.child) {
200                child.parent = Some(edge.parent);
201            }
202        }
203
204        graph.update_roots_and_leaves();
205        graph
206    }
207
208    /// Add a single episode.
209    pub fn add_node(&mut self, node: Episode) {
210        self.nodes.insert(node.id, node);
211    }
212
213    /// Get an episode by ID.
214    #[inline]
215    pub fn get_node(&self, id: NodeId) -> Option<&Episode> {
216        self.nodes.get(&id)
217    }
218
219    /// Get a mutable reference to an episode.
220    #[inline]
221    pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Episode> {
222        self.nodes.get_mut(&id)
223    }
224
225    /// Number of episodes in the graph.
226    #[inline]
227    pub fn node_count(&self) -> usize {
228        self.nodes.len()
229    }
230
231    /// Get all root episodes (episodes without parents).
232    #[inline]
233    pub fn roots(&self) -> &[NodeId] {
234        &self.roots
235    }
236
237    /// Get all leaf episodes (episodes without children).
238    #[inline]
239    pub fn leaves(&self) -> &[NodeId] {
240        &self.leaves
241    }
242
243    /// Check if an episode is a branch point.
244    #[inline]
245    pub fn is_branch_point(&self, id: NodeId) -> bool {
246        self.nodes.get(&id).map_or(false, |n| n.is_branch_point())
247    }
248
249    /// Find all branch points in the graph.
250    pub fn find_branch_points(&self) -> Vec<BranchInfo> {
251        self.nodes
252            .values()
253            .filter(|n| n.is_branch_point())
254            .map(|n| BranchInfo {
255                branch_point: n.id,
256                children: n.children.clone(),
257                branch_type: if n.children.len() > 1 {
258                    EdgeType::Regeneration
259                } else {
260                    EdgeType::Continuation
261                },
262                selected_child_idx: None,
263            })
264            .collect()
265    }
266
267    /// Update roots and leaves lists (call after modifications).
268    fn update_roots_and_leaves(&mut self) {
269        self.roots = self.nodes.values()
270            .filter(|n| n.is_root())
271            .map(|n| n.id)
272            .collect();
273
274        self.leaves = self.nodes.values()
275            .filter(|n| n.is_leaf())
276            .map(|n| n.id)
277            .collect();
278    }
279
280    // =========================================================================
281    // TRAVERSAL ALGORITHMS
282    // =========================================================================
283
284    /// Traverse the graph in specified order, calling visitor for each episode.
285    ///
286    /// # Arguments
287    ///
288    /// * `order` - Traversal order
289    /// * `visitor` - Callback function for each episode
290    pub fn traverse<F>(&self, order: TraversalOrder, mut visitor: F)
291    where
292        F: FnMut(&Episode),
293    {
294        match order {
295            TraversalOrder::DepthFirst => self.traverse_dfs(&mut visitor),
296            TraversalOrder::BreadthFirst => self.traverse_bfs(&mut visitor),
297            TraversalOrder::Topological => self.traverse_topological(&mut visitor),
298            TraversalOrder::ReverseTopological => self.traverse_reverse_topological(&mut visitor),
299        }
300    }
301
302    fn traverse_dfs<F>(&self, visitor: &mut F)
303    where
304        F: FnMut(&Episode),
305    {
306        let mut visited = HashSet::new();
307        let mut stack: Vec<NodeId> = self.roots.clone();
308
309        while let Some(id) = stack.pop() {
310            if visited.contains(&id) {
311                continue;
312            }
313            visited.insert(id);
314
315            if let Some(node) = self.nodes.get(&id) {
316                visitor(node);
317                // Push children in reverse order for correct DFS order
318                for &child_id in node.children.iter().rev() {
319                    if !visited.contains(&child_id) {
320                        stack.push(child_id);
321                    }
322                }
323            }
324        }
325    }
326
327    fn traverse_bfs<F>(&self, visitor: &mut F)
328    where
329        F: FnMut(&Episode),
330    {
331        let mut visited = HashSet::new();
332        let mut queue: VecDeque<NodeId> = self.roots.iter().copied().collect();
333
334        while let Some(id) = queue.pop_front() {
335            if visited.contains(&id) {
336                continue;
337            }
338            visited.insert(id);
339
340            if let Some(node) = self.nodes.get(&id) {
341                visitor(node);
342                for &child_id in &node.children {
343                    if !visited.contains(&child_id) {
344                        queue.push_back(child_id);
345                    }
346                }
347            }
348        }
349    }
350
351    fn traverse_topological<F>(&self, visitor: &mut F)
352    where
353        F: FnMut(&Episode),
354    {
355        // Kahn's algorithm
356        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
357        for node in self.nodes.values() {
358            in_degree.entry(node.id).or_insert(0);
359            for &child in &node.children {
360                *in_degree.entry(child).or_insert(0) += 1;
361            }
362        }
363
364        let mut queue: VecDeque<NodeId> = in_degree
365            .iter()
366            .filter(|(_, &deg)| deg == 0)
367            .map(|(&id, _)| id)
368            .collect();
369
370        while let Some(id) = queue.pop_front() {
371            if let Some(node) = self.nodes.get(&id) {
372                visitor(node);
373                for &child in &node.children {
374                    if let Some(deg) = in_degree.get_mut(&child) {
375                        *deg -= 1;
376                        if *deg == 0 {
377                            queue.push_back(child);
378                        }
379                    }
380                }
381            }
382        }
383    }
384
385    fn traverse_reverse_topological<F>(&self, visitor: &mut F)
386    where
387        F: FnMut(&Episode),
388    {
389        let mut order = Vec::with_capacity(self.nodes.len());
390        self.traverse_topological(&mut |node| order.push(node.id));
391
392        for id in order.into_iter().rev() {
393            if let Some(node) = self.nodes.get(&id) {
394                visitor(node);
395            }
396        }
397    }
398
399    // =========================================================================
400    // PRIMARY PATH SELECTION
401    // =========================================================================
402
403    /// Find the primary path through the DAG using the specified selection policy.
404    ///
405    /// The primary path is a linear sequence from a root to a leaf that represents
406    /// the "best" version of the trajectory (e.g., after regenerations).
407    ///
408    /// # Arguments
409    ///
410    /// * `policy` - How to choose at branch points
411    ///
412    /// # Returns
413    ///
414    /// PathResult containing the selected path and branch information.
415    pub fn find_primary_path(&self, policy: PathSelectionPolicy) -> Option<PathResult> {
416        if self.roots.is_empty() {
417            return None;
418        }
419
420        // Start from first root
421        let start = self.roots[0];
422        let mut path = Vec::new();
423        let mut branch_points = Vec::new();
424        let mut total_weight = 0.0;
425        let mut current = start;
426
427        loop {
428            let node = self.nodes.get(&current)?;
429            path.push(current);
430            total_weight += node.weight;
431
432            if node.children.is_empty() {
433                break;
434            }
435
436            // Select next node based on policy
437            let (next_idx, next) = self.select_child(node, policy)?;
438
439            if node.is_branch_point() {
440                branch_points.push(BranchInfo {
441                    branch_point: current,
442                    children: node.children.clone(),
443                    branch_type: EdgeType::Regeneration,
444                    selected_child_idx: Some(next_idx),
445                });
446            }
447
448            current = next;
449        }
450
451        Some(PathResult {
452            nodes: path,
453            branch_points,
454            total_weight,
455        })
456    }
457
458    /// Select which child to follow at a branch point.
459    fn select_child(&self, parent: &Episode, policy: PathSelectionPolicy) -> Option<(usize, NodeId)> {
460        if parent.children.is_empty() {
461            return None;
462        }
463
464        let children: Vec<&Episode> = parent.children
465            .iter()
466            .filter_map(|&id| self.nodes.get(&id))
467            .collect();
468
469        if children.is_empty() {
470            return Some((0, parent.children[0]));
471        }
472
473        let selected_idx = match policy {
474            PathSelectionPolicy::FeedbackFirst => {
475                // Priority: thumbs_up > no_thumbs_down > longest > first_by_time
476                children.iter().enumerate()
477                    .max_by(|(_, a), (_, b)| {
478                        // First: thumbs_up wins
479                        match (a.has_thumbs_up, b.has_thumbs_up) {
480                            (true, false) => return std::cmp::Ordering::Greater,
481                            (false, true) => return std::cmp::Ordering::Less,
482                            _ => {}
483                        }
484                        // Second: no thumbs_down is better
485                        match (a.has_thumbs_down, b.has_thumbs_down) {
486                            (false, true) => return std::cmp::Ordering::Greater,
487                            (true, false) => return std::cmp::Ordering::Less,
488                            _ => {}
489                        }
490                        // Third: longer content
491                        match a.content_length.cmp(&b.content_length) {
492                            std::cmp::Ordering::Equal => {}
493                            other => return other,
494                        }
495                        // Fourth: earlier creation time
496                        a.created_at.cmp(&b.created_at).reverse()
497                    })
498                    .map(|(i, _)| i)
499                    .unwrap_or(0)
500            }
501            PathSelectionPolicy::FirstByTime => {
502                children.iter().enumerate()
503                    .min_by_key(|(_, n)| n.created_at)
504                    .map(|(i, _)| i)
505                    .unwrap_or(0)
506            }
507            PathSelectionPolicy::LongestContent => {
508                children.iter().enumerate()
509                    .max_by_key(|(_, n)| n.content_length)
510                    .map(|(i, _)| i)
511                    .unwrap_or(0)
512            }
513            PathSelectionPolicy::HighestWeight => {
514                children.iter().enumerate()
515                    .max_by(|(_, a), (_, b)| a.weight.partial_cmp(&b.weight).unwrap_or(std::cmp::Ordering::Equal))
516                    .map(|(i, _)| i)
517                    .unwrap_or(0)
518            }
519        };
520
521        Some((selected_idx, parent.children[selected_idx]))
522    }
523
524    // =========================================================================
525    // PATH FINDING
526    // =========================================================================
527
528    /// Find all paths from an episode to all reachable leaves.
529    pub fn find_all_paths_from(&self, start: NodeId) -> Vec<Vec<NodeId>> {
530        let mut paths = Vec::new();
531        let mut current_path = vec![start];
532        self.find_paths_recursive(start, &mut current_path, &mut paths);
533        paths
534    }
535
536    fn find_paths_recursive(
537        &self,
538        current: NodeId,
539        path: &mut Vec<NodeId>,
540        paths: &mut Vec<Vec<NodeId>>,
541    ) {
542        if let Some(node) = self.nodes.get(&current) {
543            if node.is_leaf() {
544                paths.push(path.clone());
545            } else {
546                for &child in &node.children {
547                    path.push(child);
548                    self.find_paths_recursive(child, path, paths);
549                    path.pop();
550                }
551            }
552        }
553    }
554
555    /// Find the path from root to a specific episode.
556    pub fn find_path_to(&self, target: NodeId) -> Option<Vec<NodeId>> {
557        let mut path = Vec::new();
558        let mut current = target;
559
560        loop {
561            path.push(current);
562            match self.nodes.get(&current)?.parent {
563                Some(parent) => current = parent,
564                None => break,
565            }
566        }
567
568        path.reverse();
569        Some(path)
570    }
571
572    /// Compute the depth of an episode (distance from root).
573    pub fn depth(&self, node: NodeId) -> Option<usize> {
574        self.find_path_to(node).map(|p| p.len() - 1)
575    }
576
577    /// Find the lowest common ancestor of two episodes.
578    pub fn lowest_common_ancestor(&self, a: NodeId, b: NodeId) -> Option<NodeId> {
579        let path_a = self.find_path_to(a)?;
580        let path_b = self.find_path_to(b)?;
581
582        let path_a_set: HashSet<_> = path_a.iter().copied().collect();
583
584        // Walk up from b until we find a node in a's path
585        for &node in path_b.iter().rev() {
586            if path_a_set.contains(&node) {
587                return Some(node);
588            }
589        }
590
591        None
592    }
593}
594
595impl Default for TrajectoryGraph {
596    fn default() -> Self {
597        Self::new()
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604
605    fn make_linear_graph() -> TrajectoryGraph {
606        // 1 -> 2 -> 3 -> 4
607        let edges = vec![
608            Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
609            Edge { parent: 2, child: 3, edge_type: EdgeType::Continuation },
610            Edge { parent: 3, child: 4, edge_type: EdgeType::Continuation },
611        ];
612        TrajectoryGraph::from_edges(edges.into_iter())
613    }
614
615    fn make_branching_graph() -> TrajectoryGraph {
616        // 1 -> 2 -> 3
617        //        -> 4 (regeneration)
618        //   -> 5 (separate branch)
619        let edges = vec![
620            Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
621            Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
622            Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
623            Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
624        ];
625        TrajectoryGraph::from_edges(edges.into_iter())
626    }
627
628    #[test]
629    fn test_linear_graph() {
630        let graph = make_linear_graph();
631        assert_eq!(graph.node_count(), 4);
632        assert_eq!(graph.roots(), &[1]);
633        assert_eq!(graph.leaves(), &[4]);
634        assert!(!graph.is_branch_point(1));
635    }
636
637    #[test]
638    fn test_branching_graph() {
639        let graph = make_branching_graph();
640        assert_eq!(graph.node_count(), 5);
641        assert!(graph.is_branch_point(1));
642        assert!(graph.is_branch_point(2));
643
644        let branches = graph.find_branch_points();
645        assert_eq!(branches.len(), 2);
646    }
647
648    #[test]
649    fn test_find_path_to() {
650        let graph = make_linear_graph();
651        let path = graph.find_path_to(4).unwrap();
652        assert_eq!(path, vec![1, 2, 3, 4]);
653    }
654
655    #[test]
656    fn test_primary_path() {
657        let graph = make_linear_graph();
658        let result = graph.find_primary_path(PathSelectionPolicy::FirstByTime).unwrap();
659        assert_eq!(result.nodes, vec![1, 2, 3, 4]);
660        assert!(result.branch_points.is_empty());
661    }
662
663    #[test]
664    fn test_dfs_traversal() {
665        let graph = make_linear_graph();
666        let mut visited = Vec::new();
667        graph.traverse(TraversalOrder::DepthFirst, |node| {
668            visited.push(node.id);
669        });
670        assert_eq!(visited, vec![1, 2, 3, 4]);
671    }
672
673    #[test]
674    fn test_depth() {
675        let graph = make_linear_graph();
676        assert_eq!(graph.depth(1), Some(0));
677        assert_eq!(graph.depth(4), Some(3));
678    }
679}