rag_plusplus_core/trajectory/branch/
resolution.rs1use std::collections::{HashSet, VecDeque};
8use crate::trajectory::graph::{NodeId, TrajectoryGraph};
9use super::operations::{BranchId, BranchStatus, BranchError};
10use super::state_machine::BranchStateMachine;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum RecoveryStrategy {
15 Reactivate,
17 Copy,
19 MergeInto(BranchId),
21 SplitIndependent,
23}
24
25#[derive(Debug, Clone)]
27pub struct RecoverableBranch {
28 pub branch_id: Option<BranchId>,
30 pub fork_point: NodeId,
32 pub entry_node: NodeId,
34 pub nodes: Vec<NodeId>,
36 pub head: NodeId,
38 pub depth: u32,
40 pub lost_reason: LostReason,
42 pub recovery_score: f32,
44 pub suggested_strategy: RecoveryStrategy,
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum LostReason {
51 Archived,
53 Untracked,
55 OrphanedByDeletion,
57 UnselectedRegeneration,
59 Abandoned,
61 ExplorationDivergence,
63}
64
65pub struct BranchResolver<'a> {
95 machine: &'a BranchStateMachine,
96}
97
98impl<'a> BranchResolver<'a> {
99 pub fn new(machine: &'a BranchStateMachine) -> Self {
101 Self { machine }
102 }
103
104 pub fn find_recoverable_branches(&self) -> Vec<RecoverableBranch> {
109 let mut recoverable = Vec::new();
110 let graph = self.machine.graph();
111
112 let tracked_nodes: HashSet<NodeId> = self.machine.all_branches()
114 .flat_map(|b| b.nodes.iter().copied())
115 .collect();
116
117 let fork_points: Vec<NodeId> = graph.find_branch_points()
119 .iter()
120 .map(|bp| bp.branch_point)
121 .collect();
122
123 for fork_point in fork_points {
125 if let Some(episode) = graph.get_node(fork_point) {
126 for &child_id in &episode.children {
127 let subtree = self.collect_subtree(graph, child_id);
128
129 let untracked: Vec<NodeId> = subtree.iter()
131 .filter(|n| !tracked_nodes.contains(n))
132 .copied()
133 .collect();
134
135 if !untracked.is_empty() {
136 let branch = self.create_recoverable_branch(
138 graph,
139 fork_point,
140 child_id,
141 untracked,
142 );
143 recoverable.push(branch);
144 }
145 }
146 }
147 }
148
149 for branch in self.machine.all_branches() {
151 if branch.status == BranchStatus::Archived {
152 let recoverable_branch = RecoverableBranch {
153 branch_id: Some(branch.id),
154 fork_point: branch.fork_point,
155 entry_node: branch.nodes.first().copied().unwrap_or(branch.fork_point),
156 nodes: branch.nodes.clone(),
157 head: branch.head,
158 depth: self.compute_depth(graph, branch.head),
159 lost_reason: LostReason::Archived,
160 recovery_score: self.compute_recovery_score(graph, &branch.nodes),
161 suggested_strategy: RecoveryStrategy::Reactivate,
162 };
163 recoverable.push(recoverable_branch);
164 }
165 }
166
167 recoverable.sort_by(|a, b| {
169 b.recovery_score.partial_cmp(&a.recovery_score).unwrap_or(std::cmp::Ordering::Equal)
170 });
171
172 recoverable
173 }
174
175 pub fn find_unselected_regenerations(&self) -> Vec<RecoverableBranch> {
177 let graph = self.machine.graph();
178 let mut unselected = Vec::new();
179
180 for fork in self.machine.fork_points() {
182 let selected = fork.selected_child;
183
184 for &child_id in &fork.children {
185 if Some(child_id) == selected {
187 continue;
188 }
189
190 if let Some(branch) = self.machine.get_branch(child_id) {
192 if branch.is_active() {
193 continue;
194 }
195 }
196
197 let subtree = self.collect_subtree(graph, fork.node_id);
199 let child_nodes: Vec<NodeId> = subtree.into_iter()
200 .filter(|&n| {
201 self.is_descendant_of(graph, n, child_id) || n == child_id
202 })
203 .collect();
204
205 if !child_nodes.is_empty() {
206 let recoverable = RecoverableBranch {
207 branch_id: None,
208 fork_point: fork.node_id,
209 entry_node: child_id,
210 nodes: child_nodes.clone(),
211 head: self.find_deepest_leaf(graph, &child_nodes),
212 depth: fork.depth + 1,
213 lost_reason: LostReason::UnselectedRegeneration,
214 recovery_score: self.compute_recovery_score(graph, &child_nodes),
215 suggested_strategy: RecoveryStrategy::SplitIndependent,
216 };
217 unselected.push(recoverable);
218 }
219 }
220 }
221
222 unselected
223 }
224
225 pub fn recover(
227 &self,
228 machine: &mut BranchStateMachine,
229 recoverable: &RecoverableBranch,
230 ) -> Result<BranchId, BranchError> {
231 match &recoverable.suggested_strategy {
232 RecoveryStrategy::Reactivate => {
233 if let Some(branch_id) = recoverable.branch_id {
234 self.reactivate_branch(machine, branch_id)
235 } else {
236 Err(BranchError::InvalidState("No branch ID for reactivation".to_string()))
237 }
238 }
239 RecoveryStrategy::Copy => {
240 self.copy_as_new_branch(machine, recoverable)
241 }
242 RecoveryStrategy::MergeInto(target) => {
243 if let Some(branch_id) = recoverable.branch_id {
244 machine.merge(branch_id, *target)?;
245 Ok(*target)
246 } else {
247 Err(BranchError::InvalidState("No branch ID for merge".to_string()))
248 }
249 }
250 RecoveryStrategy::SplitIndependent => {
251 self.create_independent_branch(machine, recoverable)
252 }
253 }
254 }
255
256 fn reactivate_branch(
261 &self,
262 machine: &mut BranchStateMachine,
263 branch_id: BranchId,
264 ) -> Result<BranchId, BranchError> {
265 let _branch = machine.get_branch(branch_id)
269 .ok_or(BranchError::BranchNotFound(branch_id))?;
270
271 Ok(branch_id)
275 }
276
277 fn copy_as_new_branch(
278 &self,
279 machine: &mut BranchStateMachine,
280 recoverable: &RecoverableBranch,
281 ) -> Result<BranchId, BranchError> {
282 let result = machine.split(recoverable.entry_node)?;
284 Ok(result.new_branch)
285 }
286
287 fn create_independent_branch(
288 &self,
289 machine: &mut BranchStateMachine,
290 recoverable: &RecoverableBranch,
291 ) -> Result<BranchId, BranchError> {
292 let result = machine.split(recoverable.entry_node)?;
294 Ok(result.new_branch)
295 }
296
297 fn create_recoverable_branch(
302 &self,
303 graph: &TrajectoryGraph,
304 fork_point: NodeId,
305 entry_node: NodeId,
306 nodes: Vec<NodeId>,
307 ) -> RecoverableBranch {
308 let head = self.find_deepest_leaf(graph, &nodes);
309 let depth = self.compute_depth(graph, head);
310 let score = self.compute_recovery_score(graph, &nodes);
311
312 RecoverableBranch {
313 branch_id: None,
314 fork_point,
315 entry_node,
316 nodes,
317 head,
318 depth,
319 lost_reason: LostReason::Untracked,
320 recovery_score: score,
321 suggested_strategy: RecoveryStrategy::SplitIndependent,
322 }
323 }
324
325 fn collect_subtree(&self, graph: &TrajectoryGraph, root: NodeId) -> Vec<NodeId> {
326 let mut nodes = Vec::new();
327 let mut stack = vec![root];
328 let mut visited = HashSet::new();
329
330 while let Some(node_id) = stack.pop() {
331 if visited.contains(&node_id) {
332 continue;
333 }
334 visited.insert(node_id);
335 nodes.push(node_id);
336
337 if let Some(episode) = graph.get_node(node_id) {
338 for &child in &episode.children {
339 stack.push(child);
340 }
341 }
342 }
343
344 nodes
345 }
346
347 fn compute_depth(&self, graph: &TrajectoryGraph, node_id: NodeId) -> u32 {
348 graph.depth(node_id).unwrap_or(0) as u32
349 }
350
351 fn find_deepest_leaf(&self, graph: &TrajectoryGraph, nodes: &[NodeId]) -> NodeId {
352 nodes.iter()
353 .filter(|&&n| graph.get_node(n).map_or(false, |e| e.is_leaf()))
354 .max_by_key(|&&n| self.compute_depth(graph, n))
355 .copied()
356 .unwrap_or_else(|| nodes.first().copied().unwrap_or(0))
357 }
358
359 fn is_descendant_of(&self, graph: &TrajectoryGraph, node: NodeId, ancestor: NodeId) -> bool {
360 if node == ancestor {
361 return true;
362 }
363
364 let mut queue = VecDeque::new();
366 queue.push_back(ancestor);
367 let mut visited = HashSet::new();
368
369 while let Some(current) = queue.pop_front() {
370 if current == node {
371 return true;
372 }
373 if visited.contains(¤t) {
374 continue;
375 }
376 visited.insert(current);
377
378 if let Some(episode) = graph.get_node(current) {
379 for &child in &episode.children {
380 queue.push_back(child);
381 }
382 }
383 }
384
385 false
386 }
387
388 fn compute_recovery_score(&self, graph: &TrajectoryGraph, nodes: &[NodeId]) -> f32 {
396 let length_factor = (nodes.len() as f32).ln_1p();
397
398 let max_depth = nodes.iter()
399 .map(|&n| self.compute_depth(graph, n))
400 .max()
401 .unwrap_or(0);
402 let depth_factor = (max_depth as f32).sqrt();
403
404 let content_factor: f32 = nodes.iter()
406 .filter_map(|&n| graph.get_node(n))
407 .map(|e| (e.content_length as f32).ln_1p())
408 .sum::<f32>()
409 / nodes.len().max(1) as f32;
410
411 let feedback_factor: f32 = nodes.iter()
413 .filter_map(|&n| graph.get_node(n))
414 .filter(|e| e.has_thumbs_up)
415 .count() as f32;
416
417 0.3 * length_factor + 0.3 * depth_factor + 0.2 * content_factor + 0.2 * feedback_factor
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::trajectory::graph::{Edge, EdgeType};
426
427 fn make_branching_graph() -> TrajectoryGraph {
428 let edges = vec![
432 Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
433 Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
434 Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
435 Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
436 ];
437 TrajectoryGraph::from_edges(edges.into_iter())
438 }
439
440 #[test]
441 fn test_resolver_creation() {
442 let graph = make_branching_graph();
443 let machine = BranchStateMachine::from_graph(graph);
444 let resolver = BranchResolver::new(&machine);
445
446 let recoverable = resolver.find_recoverable_branches();
448 assert!(recoverable.len() >= 0);
450 }
451
452 #[test]
453 fn test_recovery_score() {
454 let graph = make_branching_graph();
455 let machine = BranchStateMachine::from_graph(graph.clone());
456 let resolver = BranchResolver::new(&machine);
457
458 let nodes = vec![1, 2, 3];
460 let score = resolver.compute_recovery_score(&graph, &nodes);
461
462 assert!(score >= 0.0);
464 }
465
466 #[test]
467 fn test_collect_subtree() {
468 let graph = make_branching_graph();
469 let machine = BranchStateMachine::from_graph(graph.clone());
470 let resolver = BranchResolver::new(&machine);
471
472 let subtree = resolver.collect_subtree(&graph, 2);
474
475 assert!(subtree.contains(&2));
477 assert!(subtree.contains(&3));
478 assert!(subtree.contains(&4));
479 }
480
481 #[test]
482 fn test_is_descendant() {
483 let graph = make_branching_graph();
484 let machine = BranchStateMachine::from_graph(graph.clone());
485 let resolver = BranchResolver::new(&machine);
486
487 assert!(resolver.is_descendant_of(&graph, 3, 1));
489 assert!(resolver.is_descendant_of(&graph, 3, 2));
490
491 assert!(!resolver.is_descendant_of(&graph, 1, 3));
493 }
494
495 #[test]
496 fn test_recovery_strategy() {
497 let graph = make_branching_graph();
498 let machine = BranchStateMachine::from_graph(graph);
499 let resolver = BranchResolver::new(&machine);
500
501 let recoverable = resolver.find_recoverable_branches();
502
503 for branch in recoverable {
505 match branch.suggested_strategy {
506 RecoveryStrategy::Reactivate |
507 RecoveryStrategy::Copy |
508 RecoveryStrategy::SplitIndependent |
509 RecoveryStrategy::MergeInto(_) => {
510 }
512 }
513 }
514 }
515}