use std::collections::{HashSet, VecDeque};
use crate::trajectory::graph::{NodeId, TrajectoryGraph};
use super::operations::{BranchId, BranchStatus, BranchError};
use super::state_machine::BranchStateMachine;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecoveryStrategy {
Reactivate,
Copy,
MergeInto(BranchId),
SplitIndependent,
}
#[derive(Debug, Clone)]
pub struct RecoverableBranch {
pub branch_id: Option<BranchId>,
pub fork_point: NodeId,
pub entry_node: NodeId,
pub nodes: Vec<NodeId>,
pub head: NodeId,
pub depth: u32,
pub lost_reason: LostReason,
pub recovery_score: f32,
pub suggested_strategy: RecoveryStrategy,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LostReason {
Archived,
Untracked,
OrphanedByDeletion,
UnselectedRegeneration,
Abandoned,
ExplorationDivergence,
}
pub struct BranchResolver<'a> {
machine: &'a BranchStateMachine,
}
impl<'a> BranchResolver<'a> {
pub fn new(machine: &'a BranchStateMachine) -> Self {
Self { machine }
}
pub fn find_recoverable_branches(&self) -> Vec<RecoverableBranch> {
let mut recoverable = Vec::new();
let graph = self.machine.graph();
let tracked_nodes: HashSet<NodeId> = self.machine.all_branches()
.flat_map(|b| b.nodes.iter().copied())
.collect();
let fork_points: Vec<NodeId> = graph.find_branch_points()
.iter()
.map(|bp| bp.branch_point)
.collect();
for fork_point in fork_points {
if let Some(episode) = graph.get_node(fork_point) {
for &child_id in &episode.children {
let subtree = self.collect_subtree(graph, child_id);
let untracked: Vec<NodeId> = subtree.iter()
.filter(|n| !tracked_nodes.contains(n))
.copied()
.collect();
if !untracked.is_empty() {
let branch = self.create_recoverable_branch(
graph,
fork_point,
child_id,
untracked,
);
recoverable.push(branch);
}
}
}
}
for branch in self.machine.all_branches() {
if branch.status == BranchStatus::Archived {
let recoverable_branch = RecoverableBranch {
branch_id: Some(branch.id),
fork_point: branch.fork_point,
entry_node: branch.nodes.first().copied().unwrap_or(branch.fork_point),
nodes: branch.nodes.clone(),
head: branch.head,
depth: self.compute_depth(graph, branch.head),
lost_reason: LostReason::Archived,
recovery_score: self.compute_recovery_score(graph, &branch.nodes),
suggested_strategy: RecoveryStrategy::Reactivate,
};
recoverable.push(recoverable_branch);
}
}
recoverable.sort_by(|a, b| {
b.recovery_score.partial_cmp(&a.recovery_score).unwrap_or(std::cmp::Ordering::Equal)
});
recoverable
}
pub fn find_unselected_regenerations(&self) -> Vec<RecoverableBranch> {
let graph = self.machine.graph();
let mut unselected = Vec::new();
for fork in self.machine.fork_points() {
let selected = fork.selected_child;
for &child_id in &fork.children {
if Some(child_id) == selected {
continue;
}
if let Some(branch) = self.machine.get_branch(child_id) {
if branch.is_active() {
continue;
}
}
let subtree = self.collect_subtree(graph, fork.node_id);
let child_nodes: Vec<NodeId> = subtree.into_iter()
.filter(|&n| {
self.is_descendant_of(graph, n, child_id) || n == child_id
})
.collect();
if !child_nodes.is_empty() {
let recoverable = RecoverableBranch {
branch_id: None,
fork_point: fork.node_id,
entry_node: child_id,
nodes: child_nodes.clone(),
head: self.find_deepest_leaf(graph, &child_nodes),
depth: fork.depth + 1,
lost_reason: LostReason::UnselectedRegeneration,
recovery_score: self.compute_recovery_score(graph, &child_nodes),
suggested_strategy: RecoveryStrategy::SplitIndependent,
};
unselected.push(recoverable);
}
}
}
unselected
}
pub fn recover(
&self,
machine: &mut BranchStateMachine,
recoverable: &RecoverableBranch,
) -> Result<BranchId, BranchError> {
match &recoverable.suggested_strategy {
RecoveryStrategy::Reactivate => {
if let Some(branch_id) = recoverable.branch_id {
self.reactivate_branch(machine, branch_id)
} else {
Err(BranchError::InvalidState("No branch ID for reactivation".to_string()))
}
}
RecoveryStrategy::Copy => {
self.copy_as_new_branch(machine, recoverable)
}
RecoveryStrategy::MergeInto(target) => {
if let Some(branch_id) = recoverable.branch_id {
machine.merge(branch_id, *target)?;
Ok(*target)
} else {
Err(BranchError::InvalidState("No branch ID for merge".to_string()))
}
}
RecoveryStrategy::SplitIndependent => {
self.create_independent_branch(machine, recoverable)
}
}
}
fn reactivate_branch(
&self,
machine: &mut BranchStateMachine,
branch_id: BranchId,
) -> Result<BranchId, BranchError> {
let _branch = machine.get_branch(branch_id)
.ok_or(BranchError::BranchNotFound(branch_id))?;
Ok(branch_id)
}
fn copy_as_new_branch(
&self,
machine: &mut BranchStateMachine,
recoverable: &RecoverableBranch,
) -> Result<BranchId, BranchError> {
let result = machine.split(recoverable.entry_node)?;
Ok(result.new_branch)
}
fn create_independent_branch(
&self,
machine: &mut BranchStateMachine,
recoverable: &RecoverableBranch,
) -> Result<BranchId, BranchError> {
let result = machine.split(recoverable.entry_node)?;
Ok(result.new_branch)
}
fn create_recoverable_branch(
&self,
graph: &TrajectoryGraph,
fork_point: NodeId,
entry_node: NodeId,
nodes: Vec<NodeId>,
) -> RecoverableBranch {
let head = self.find_deepest_leaf(graph, &nodes);
let depth = self.compute_depth(graph, head);
let score = self.compute_recovery_score(graph, &nodes);
RecoverableBranch {
branch_id: None,
fork_point,
entry_node,
nodes,
head,
depth,
lost_reason: LostReason::Untracked,
recovery_score: score,
suggested_strategy: RecoveryStrategy::SplitIndependent,
}
}
fn collect_subtree(&self, graph: &TrajectoryGraph, root: NodeId) -> Vec<NodeId> {
let mut nodes = Vec::new();
let mut stack = vec![root];
let mut visited = HashSet::new();
while let Some(node_id) = stack.pop() {
if visited.contains(&node_id) {
continue;
}
visited.insert(node_id);
nodes.push(node_id);
if let Some(episode) = graph.get_node(node_id) {
for &child in &episode.children {
stack.push(child);
}
}
}
nodes
}
fn compute_depth(&self, graph: &TrajectoryGraph, node_id: NodeId) -> u32 {
graph.depth(node_id).unwrap_or(0) as u32
}
fn find_deepest_leaf(&self, graph: &TrajectoryGraph, nodes: &[NodeId]) -> NodeId {
nodes.iter()
.filter(|&&n| graph.get_node(n).map_or(false, |e| e.is_leaf()))
.max_by_key(|&&n| self.compute_depth(graph, n))
.copied()
.unwrap_or_else(|| nodes.first().copied().unwrap_or(0))
}
fn is_descendant_of(&self, graph: &TrajectoryGraph, node: NodeId, ancestor: NodeId) -> bool {
if node == ancestor {
return true;
}
let mut queue = VecDeque::new();
queue.push_back(ancestor);
let mut visited = HashSet::new();
while let Some(current) = queue.pop_front() {
if current == node {
return true;
}
if visited.contains(¤t) {
continue;
}
visited.insert(current);
if let Some(episode) = graph.get_node(current) {
for &child in &episode.children {
queue.push_back(child);
}
}
}
false
}
fn compute_recovery_score(&self, graph: &TrajectoryGraph, nodes: &[NodeId]) -> f32 {
let length_factor = (nodes.len() as f32).ln_1p();
let max_depth = nodes.iter()
.map(|&n| self.compute_depth(graph, n))
.max()
.unwrap_or(0);
let depth_factor = (max_depth as f32).sqrt();
let content_factor: f32 = nodes.iter()
.filter_map(|&n| graph.get_node(n))
.map(|e| (e.content_length as f32).ln_1p())
.sum::<f32>()
/ nodes.len().max(1) as f32;
let feedback_factor: f32 = nodes.iter()
.filter_map(|&n| graph.get_node(n))
.filter(|e| e.has_thumbs_up)
.count() as f32;
0.3 * length_factor + 0.3 * depth_factor + 0.2 * content_factor + 0.2 * feedback_factor
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trajectory::graph::{Edge, EdgeType};
fn make_branching_graph() -> TrajectoryGraph {
let edges = vec![
Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
];
TrajectoryGraph::from_edges(edges.into_iter())
}
#[test]
fn test_resolver_creation() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph);
let resolver = BranchResolver::new(&machine);
let recoverable = resolver.find_recoverable_branches();
assert!(recoverable.len() >= 0);
}
#[test]
fn test_recovery_score() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph.clone());
let resolver = BranchResolver::new(&machine);
let nodes = vec![1, 2, 3];
let score = resolver.compute_recovery_score(&graph, &nodes);
assert!(score >= 0.0);
}
#[test]
fn test_collect_subtree() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph.clone());
let resolver = BranchResolver::new(&machine);
let subtree = resolver.collect_subtree(&graph, 2);
assert!(subtree.contains(&2));
assert!(subtree.contains(&3));
assert!(subtree.contains(&4));
}
#[test]
fn test_is_descendant() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph.clone());
let resolver = BranchResolver::new(&machine);
assert!(resolver.is_descendant_of(&graph, 3, 1));
assert!(resolver.is_descendant_of(&graph, 3, 2));
assert!(!resolver.is_descendant_of(&graph, 1, 3));
}
#[test]
fn test_recovery_strategy() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph);
let resolver = BranchResolver::new(&machine);
let recoverable = resolver.find_recoverable_branches();
for branch in recoverable {
match branch.suggested_strategy {
RecoveryStrategy::Reactivate |
RecoveryStrategy::Copy |
RecoveryStrategy::SplitIndependent |
RecoveryStrategy::MergeInto(_) => {
}
}
}
}
}