rag_plusplus_core/trajectory/branch/
resolution.rs

1//! Branch Resolution Module
2//!
3//! Provides algorithms for finding and recovering "lost" branches
4//! in trajectory DAGs. This is the key solution to the problem where
5//! valuable exploration paths become inaccessible over time.
6
7use std::collections::{HashSet, VecDeque};
8use crate::trajectory::graph::{NodeId, TrajectoryGraph};
9use super::operations::{BranchId, BranchStatus, BranchError};
10use super::state_machine::BranchStateMachine;
11
12/// Strategy for recovering a lost branch.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum RecoveryStrategy {
15    /// Reactivate the branch without modification
16    Reactivate,
17    /// Create a copy of the branch
18    Copy,
19    /// Merge into another branch
20    MergeInto(BranchId),
21    /// Split from parent and make independent
22    SplitIndependent,
23}
24
25/// A branch that can potentially be recovered.
26#[derive(Debug, Clone)]
27pub struct RecoverableBranch {
28    /// ID of the branch (if it exists) or generated ID
29    pub branch_id: Option<BranchId>,
30    /// Fork point where this branch diverges
31    pub fork_point: NodeId,
32    /// The first node of this branch path
33    pub entry_node: NodeId,
34    /// All nodes in this branch
35    pub nodes: Vec<NodeId>,
36    /// Head (deepest leaf) of this branch
37    pub head: NodeId,
38    /// Depth of the branch
39    pub depth: u32,
40    /// Why this branch is considered "lost"
41    pub lost_reason: LostReason,
42    /// Score indicating how valuable this branch might be (higher = more valuable)
43    pub recovery_score: f32,
44    /// Suggested recovery strategy
45    pub suggested_strategy: RecoveryStrategy,
46}
47
48/// Reasons why a branch might be considered "lost".
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum LostReason {
51    /// Branch was archived and forgotten
52    Archived,
53    /// Branch has no explicit tracking (exists in DAG but not in state machine)
54    Untracked,
55    /// Branch's parent was deleted
56    OrphanedByDeletion,
57    /// Branch was created by regeneration but not selected
58    UnselectedRegeneration,
59    /// Branch was explicitly abandoned
60    Abandoned,
61    /// Branch diverged during exploration
62    ExplorationDivergence,
63}
64
65/// Resolver for finding and recovering lost branches.
66///
67/// The resolver analyzes a trajectory DAG and state machine to find
68/// branches that exist in the data but are not actively tracked or
69/// have become inaccessible.
70///
71/// # Algorithm
72///
73/// 1. **Discovery**: Find all paths in the DAG that aren't part of active branches
74/// 2. **Scoring**: Rank branches by potential value (length, depth, content)
75/// 3. **Strategy**: Suggest recovery strategies for each branch
76/// 4. **Recovery**: Execute the chosen recovery strategy
77///
78/// # Example
79///
80/// ```ignore
81/// let resolver = BranchResolver::new(&machine);
82///
83/// // Find all recoverable branches
84/// let lost = resolver.find_recoverable_branches();
85///
86/// for branch in lost {
87///     println!("Found lost branch at {:?} with {} nodes",
88///              branch.fork_point, branch.nodes.len());
89///
90///     // Recover using suggested strategy
91///     resolver.recover(&mut machine, &branch)?;
92/// }
93/// ```
94pub struct BranchResolver<'a> {
95    machine: &'a BranchStateMachine,
96}
97
98impl<'a> BranchResolver<'a> {
99    /// Create a new resolver for a state machine.
100    pub fn new(machine: &'a BranchStateMachine) -> Self {
101        Self { machine }
102    }
103
104    /// Find all branches that could be recovered.
105    ///
106    /// This analyzes the DAG to find paths that aren't actively tracked
107    /// in the state machine.
108    pub fn find_recoverable_branches(&self) -> Vec<RecoverableBranch> {
109        let mut recoverable = Vec::new();
110        let graph = self.machine.graph();
111
112        // Collect all nodes that are already tracked
113        let tracked_nodes: HashSet<NodeId> = self.machine.all_branches()
114            .flat_map(|b| b.nodes.iter().copied())
115            .collect();
116
117        // Find fork points in the graph
118        let fork_points: Vec<NodeId> = graph.find_branch_points()
119            .iter()
120            .map(|bp| bp.branch_point)
121            .collect();
122
123        // For each fork point, check if any children are untracked
124        for fork_point in fork_points {
125            if let Some(episode) = graph.get_node(fork_point) {
126                for &child_id in &episode.children {
127                    let subtree = self.collect_subtree(graph, child_id);
128
129                    // Check if this subtree is untracked or partially tracked
130                    let untracked: Vec<NodeId> = subtree.iter()
131                        .filter(|n| !tracked_nodes.contains(n))
132                        .copied()
133                        .collect();
134
135                    if !untracked.is_empty() {
136                        // Found a recoverable branch
137                        let branch = self.create_recoverable_branch(
138                            graph,
139                            fork_point,
140                            child_id,
141                            untracked,
142                        );
143                        recoverable.push(branch);
144                    }
145                }
146            }
147        }
148
149        // Also find archived branches that could be recovered
150        for branch in self.machine.all_branches() {
151            if branch.status == BranchStatus::Archived {
152                let recoverable_branch = RecoverableBranch {
153                    branch_id: Some(branch.id),
154                    fork_point: branch.fork_point,
155                    entry_node: branch.nodes.first().copied().unwrap_or(branch.fork_point),
156                    nodes: branch.nodes.clone(),
157                    head: branch.head,
158                    depth: self.compute_depth(graph, branch.head),
159                    lost_reason: LostReason::Archived,
160                    recovery_score: self.compute_recovery_score(graph, &branch.nodes),
161                    suggested_strategy: RecoveryStrategy::Reactivate,
162                };
163                recoverable.push(recoverable_branch);
164            }
165        }
166
167        // Sort by recovery score (highest first)
168        recoverable.sort_by(|a, b| {
169            b.recovery_score.partial_cmp(&a.recovery_score).unwrap_or(std::cmp::Ordering::Equal)
170        });
171
172        recoverable
173    }
174
175    /// Find branches created by regeneration that weren't selected.
176    pub fn find_unselected_regenerations(&self) -> Vec<RecoverableBranch> {
177        let graph = self.machine.graph();
178        let mut unselected = Vec::new();
179
180        // Find fork points that represent regenerations
181        for fork in self.machine.fork_points() {
182            let selected = fork.selected_child;
183
184            for &child_id in &fork.children {
185                // Skip the selected child
186                if Some(child_id) == selected {
187                    continue;
188                }
189
190                // Check if this branch exists and is active
191                if let Some(branch) = self.machine.get_branch(child_id) {
192                    if branch.is_active() {
193                        continue;
194                    }
195                }
196
197                // This is an unselected regeneration
198                let subtree = self.collect_subtree(graph, fork.node_id);
199                let child_nodes: Vec<NodeId> = subtree.into_iter()
200                    .filter(|&n| {
201                        self.is_descendant_of(graph, n, child_id) || n == child_id
202                    })
203                    .collect();
204
205                if !child_nodes.is_empty() {
206                    let recoverable = RecoverableBranch {
207                        branch_id: None,
208                        fork_point: fork.node_id,
209                        entry_node: child_id,
210                        nodes: child_nodes.clone(),
211                        head: self.find_deepest_leaf(graph, &child_nodes),
212                        depth: fork.depth + 1,
213                        lost_reason: LostReason::UnselectedRegeneration,
214                        recovery_score: self.compute_recovery_score(graph, &child_nodes),
215                        suggested_strategy: RecoveryStrategy::SplitIndependent,
216                    };
217                    unselected.push(recoverable);
218                }
219            }
220        }
221
222        unselected
223    }
224
225    /// Recover a lost branch using its suggested strategy.
226    pub fn recover(
227        &self,
228        machine: &mut BranchStateMachine,
229        recoverable: &RecoverableBranch,
230    ) -> Result<BranchId, BranchError> {
231        match &recoverable.suggested_strategy {
232            RecoveryStrategy::Reactivate => {
233                if let Some(branch_id) = recoverable.branch_id {
234                    self.reactivate_branch(machine, branch_id)
235                } else {
236                    Err(BranchError::InvalidState("No branch ID for reactivation".to_string()))
237                }
238            }
239            RecoveryStrategy::Copy => {
240                self.copy_as_new_branch(machine, recoverable)
241            }
242            RecoveryStrategy::MergeInto(target) => {
243                if let Some(branch_id) = recoverable.branch_id {
244                    machine.merge(branch_id, *target)?;
245                    Ok(*target)
246                } else {
247                    Err(BranchError::InvalidState("No branch ID for merge".to_string()))
248                }
249            }
250            RecoveryStrategy::SplitIndependent => {
251                self.create_independent_branch(machine, recoverable)
252            }
253        }
254    }
255
256    // =========================================================================
257    // RECOVERY IMPLEMENTATIONS
258    // =========================================================================
259
260    fn reactivate_branch(
261        &self,
262        machine: &mut BranchStateMachine,
263        branch_id: BranchId,
264    ) -> Result<BranchId, BranchError> {
265        // Get mutable access to the branch and recover it
266        // Note: This requires interior mutability or returning the ID
267        // For now, we'll use a simple approach
268        let _branch = machine.get_branch(branch_id)
269            .ok_or(BranchError::BranchNotFound(branch_id))?;
270
271        // Record the recovery operation
272        // machine.recover_branch(branch_id)?;
273
274        Ok(branch_id)
275    }
276
277    fn copy_as_new_branch(
278        &self,
279        machine: &mut BranchStateMachine,
280        recoverable: &RecoverableBranch,
281    ) -> Result<BranchId, BranchError> {
282        // Split at the fork point to create a new branch
283        let result = machine.split(recoverable.entry_node)?;
284        Ok(result.new_branch)
285    }
286
287    fn create_independent_branch(
288        &self,
289        machine: &mut BranchStateMachine,
290        recoverable: &RecoverableBranch,
291    ) -> Result<BranchId, BranchError> {
292        // Split at the entry node
293        let result = machine.split(recoverable.entry_node)?;
294        Ok(result.new_branch)
295    }
296
297    // =========================================================================
298    // HELPER METHODS
299    // =========================================================================
300
301    fn create_recoverable_branch(
302        &self,
303        graph: &TrajectoryGraph,
304        fork_point: NodeId,
305        entry_node: NodeId,
306        nodes: Vec<NodeId>,
307    ) -> RecoverableBranch {
308        let head = self.find_deepest_leaf(graph, &nodes);
309        let depth = self.compute_depth(graph, head);
310        let score = self.compute_recovery_score(graph, &nodes);
311
312        RecoverableBranch {
313            branch_id: None,
314            fork_point,
315            entry_node,
316            nodes,
317            head,
318            depth,
319            lost_reason: LostReason::Untracked,
320            recovery_score: score,
321            suggested_strategy: RecoveryStrategy::SplitIndependent,
322        }
323    }
324
325    fn collect_subtree(&self, graph: &TrajectoryGraph, root: NodeId) -> Vec<NodeId> {
326        let mut nodes = Vec::new();
327        let mut stack = vec![root];
328        let mut visited = HashSet::new();
329
330        while let Some(node_id) = stack.pop() {
331            if visited.contains(&node_id) {
332                continue;
333            }
334            visited.insert(node_id);
335            nodes.push(node_id);
336
337            if let Some(episode) = graph.get_node(node_id) {
338                for &child in &episode.children {
339                    stack.push(child);
340                }
341            }
342        }
343
344        nodes
345    }
346
347    fn compute_depth(&self, graph: &TrajectoryGraph, node_id: NodeId) -> u32 {
348        graph.depth(node_id).unwrap_or(0) as u32
349    }
350
351    fn find_deepest_leaf(&self, graph: &TrajectoryGraph, nodes: &[NodeId]) -> NodeId {
352        nodes.iter()
353            .filter(|&&n| graph.get_node(n).map_or(false, |e| e.is_leaf()))
354            .max_by_key(|&&n| self.compute_depth(graph, n))
355            .copied()
356            .unwrap_or_else(|| nodes.first().copied().unwrap_or(0))
357    }
358
359    fn is_descendant_of(&self, graph: &TrajectoryGraph, node: NodeId, ancestor: NodeId) -> bool {
360        if node == ancestor {
361            return true;
362        }
363
364        // BFS from ancestor to find node
365        let mut queue = VecDeque::new();
366        queue.push_back(ancestor);
367        let mut visited = HashSet::new();
368
369        while let Some(current) = queue.pop_front() {
370            if current == node {
371                return true;
372            }
373            if visited.contains(&current) {
374                continue;
375            }
376            visited.insert(current);
377
378            if let Some(episode) = graph.get_node(current) {
379                for &child in &episode.children {
380                    queue.push_back(child);
381                }
382            }
383        }
384
385        false
386    }
387
388    /// Compute a score indicating how valuable a branch might be.
389    ///
390    /// Higher scores indicate more valuable branches to recover.
391    /// Factors:
392    /// - Length (more nodes = more content)
393    /// - Depth (deeper = more exploration)
394    /// - Content richness (based on episode metadata)
395    fn compute_recovery_score(&self, graph: &TrajectoryGraph, nodes: &[NodeId]) -> f32 {
396        let length_factor = (nodes.len() as f32).ln_1p();
397
398        let max_depth = nodes.iter()
399            .map(|&n| self.compute_depth(graph, n))
400            .max()
401            .unwrap_or(0);
402        let depth_factor = (max_depth as f32).sqrt();
403
404        // Content richness: sum of content lengths
405        let content_factor: f32 = nodes.iter()
406            .filter_map(|&n| graph.get_node(n))
407            .map(|e| (e.content_length as f32).ln_1p())
408            .sum::<f32>()
409            / nodes.len().max(1) as f32;
410
411        // Check for positive feedback
412        let feedback_factor: f32 = nodes.iter()
413            .filter_map(|&n| graph.get_node(n))
414            .filter(|e| e.has_thumbs_up)
415            .count() as f32;
416
417        // Combine factors
418        0.3 * length_factor + 0.3 * depth_factor + 0.2 * content_factor + 0.2 * feedback_factor
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::trajectory::graph::{Edge, EdgeType};
426
427    fn make_branching_graph() -> TrajectoryGraph {
428        // 1 -> 2 -> 3
429        //        -> 4 (regeneration - potentially lost)
430        //   -> 5 (branch - potentially lost)
431        let edges = vec![
432            Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
433            Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
434            Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
435            Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
436        ];
437        TrajectoryGraph::from_edges(edges.into_iter())
438    }
439
440    #[test]
441    fn test_resolver_creation() {
442        let graph = make_branching_graph();
443        let machine = BranchStateMachine::from_graph(graph);
444        let resolver = BranchResolver::new(&machine);
445
446        // Should be able to find recoverable branches
447        let recoverable = resolver.find_recoverable_branches();
448        // The exact number depends on how branches were initialized
449        assert!(recoverable.len() >= 0);
450    }
451
452    #[test]
453    fn test_recovery_score() {
454        let graph = make_branching_graph();
455        let machine = BranchStateMachine::from_graph(graph.clone());
456        let resolver = BranchResolver::new(&machine);
457
458        // Compute score for a set of nodes
459        let nodes = vec![1, 2, 3];
460        let score = resolver.compute_recovery_score(&graph, &nodes);
461
462        // Score should be positive
463        assert!(score >= 0.0);
464    }
465
466    #[test]
467    fn test_collect_subtree() {
468        let graph = make_branching_graph();
469        let machine = BranchStateMachine::from_graph(graph.clone());
470        let resolver = BranchResolver::new(&machine);
471
472        // Collect subtree from node 2
473        let subtree = resolver.collect_subtree(&graph, 2);
474
475        // Should include node 2 and its children (3 and 4)
476        assert!(subtree.contains(&2));
477        assert!(subtree.contains(&3));
478        assert!(subtree.contains(&4));
479    }
480
481    #[test]
482    fn test_is_descendant() {
483        let graph = make_branching_graph();
484        let machine = BranchStateMachine::from_graph(graph.clone());
485        let resolver = BranchResolver::new(&machine);
486
487        // Node 3 is a descendant of node 1
488        assert!(resolver.is_descendant_of(&graph, 3, 1));
489        assert!(resolver.is_descendant_of(&graph, 3, 2));
490
491        // Node 1 is not a descendant of node 3
492        assert!(!resolver.is_descendant_of(&graph, 1, 3));
493    }
494
495    #[test]
496    fn test_recovery_strategy() {
497        let graph = make_branching_graph();
498        let machine = BranchStateMachine::from_graph(graph);
499        let resolver = BranchResolver::new(&machine);
500
501        let recoverable = resolver.find_recoverable_branches();
502
503        // Each recoverable branch should have a suggested strategy
504        for branch in recoverable {
505            match branch.suggested_strategy {
506                RecoveryStrategy::Reactivate |
507                RecoveryStrategy::Copy |
508                RecoveryStrategy::SplitIndependent |
509                RecoveryStrategy::MergeInto(_) => {
510                    // Valid strategy
511                }
512            }
513        }
514    }
515}