use std::collections::{HashMap, HashSet};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::trajectory::graph::{NodeId, TrajectoryGraph};
use super::operations::{Branch, BranchId, BranchOperation, BranchError, ForkPoint};
#[inline]
fn current_timestamp() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
#[derive(Debug, Clone)]
pub struct BranchContext {
pub current_branch: BranchId,
pub current_node: NodeId,
pub depth: u32,
pub path: Vec<NodeId>,
}
#[derive(Debug, Clone)]
pub struct SplitResult {
pub original_branch: BranchId,
pub new_branch: BranchId,
pub split_point: NodeId,
pub moved_nodes: Vec<NodeId>,
}
#[derive(Debug, Clone)]
pub struct MergeResult {
pub target_branch: BranchId,
pub merged_branch: BranchId,
pub merge_point: NodeId,
}
#[derive(Debug, Clone)]
pub struct BranchStateMachine {
branches: HashMap<BranchId, Branch>,
current_branch: BranchId,
history: Vec<BranchOperation>,
fork_points: HashMap<NodeId, ForkPoint>,
next_branch_id: BranchId,
node_to_branch: HashMap<NodeId, BranchId>,
graph: TrajectoryGraph,
context: BranchContext,
}
impl BranchStateMachine {
pub fn from_graph(graph: TrajectoryGraph) -> Self {
let mut machine = Self {
branches: HashMap::new(),
current_branch: 0,
history: Vec::new(),
fork_points: HashMap::new(),
next_branch_id: 1,
node_to_branch: HashMap::new(),
graph,
context: BranchContext {
current_branch: 0,
current_node: 0,
depth: 0,
path: Vec::new(),
},
};
machine.initialize_from_graph();
machine
}
fn initialize_from_graph(&mut self) {
let roots: Vec<NodeId> = self.graph.roots().to_vec();
if roots.is_empty() {
return;
}
let root_node = roots[0];
let root_branch = Branch::root(0, root_node);
self.branches.insert(0, root_branch);
self.node_to_branch.insert(root_node, 0);
self.build_branches_from_root(root_node, 0, 0);
if let Some(head) = self.branches.get(&0).map(|b| b.head) {
self.context = BranchContext {
current_branch: 0,
current_node: head,
depth: self.compute_depth(head),
path: self.compute_path(head),
};
}
}
fn build_branches_from_root(&mut self, node_id: NodeId, branch_id: BranchId, depth: u32) {
if let Some(episode) = self.graph.get_node(node_id) {
let children = episode.children.clone();
if children.len() > 1 {
let child_branch_ids: Vec<BranchId> = children.iter()
.map(|_| {
let bid = self.next_branch_id;
self.next_branch_id += 1;
bid
})
.collect();
let fork = ForkPoint::new(node_id, child_branch_ids.clone(), depth);
self.fork_points.insert(node_id, fork);
for (i, &child_id) in children.iter().enumerate() {
let child_branch_id = child_branch_ids[i];
let mut child_branch = Branch::new(child_branch_id, node_id, child_id);
child_branch.parent_branch = Some(branch_id);
if let Some(parent) = self.branches.get_mut(&branch_id) {
parent.child_branches.push(child_branch_id);
}
self.branches.insert(child_branch_id, child_branch);
self.node_to_branch.insert(child_id, child_branch_id);
self.build_branches_from_root(child_id, child_branch_id, depth + 1);
}
} else if children.len() == 1 {
let child_id = children[0];
if let Some(branch) = self.branches.get_mut(&branch_id) {
branch.add_node(child_id);
}
self.node_to_branch.insert(child_id, branch_id);
self.build_branches_from_root(child_id, branch_id, depth + 1);
}
}
}
pub fn split(&mut self, node_id: NodeId) -> Result<SplitResult, BranchError> {
if self.graph.get_node(node_id).is_none() {
return Err(BranchError::NodeNotFound(node_id));
}
if self.graph.roots().contains(&node_id) {
return Err(BranchError::CannotSplitRoot);
}
let source_branch_id = *self.node_to_branch.get(&node_id)
.ok_or(BranchError::NodeNotFound(node_id))?;
let new_branch_id = self.next_branch_id;
self.next_branch_id += 1;
let fork_point = self.graph.get_node(node_id)
.and_then(|e| e.parent)
.ok_or(BranchError::NoParent(source_branch_id))?;
let moved_nodes = self.collect_subtree(node_id);
let mut new_branch = Branch::new(new_branch_id, fork_point, node_id);
new_branch.parent_branch = Some(source_branch_id);
new_branch.nodes = moved_nodes.clone();
if let Some(&head) = moved_nodes.iter()
.filter(|&&n| self.graph.get_node(n).map_or(false, |e| e.is_leaf()))
.max_by_key(|&&n| self.compute_depth(n))
{
new_branch.head = head;
}
for &n in &moved_nodes {
self.node_to_branch.insert(n, new_branch_id);
}
if let Some(source) = self.branches.get_mut(&source_branch_id) {
source.nodes.retain(|n| !moved_nodes.contains(n));
source.child_branches.push(new_branch_id);
}
self.branches.insert(new_branch_id, new_branch);
let operation = BranchOperation::split(source_branch_id, new_branch_id, node_id);
self.history.push(operation);
Ok(SplitResult {
original_branch: source_branch_id,
new_branch: new_branch_id,
split_point: node_id,
moved_nodes,
})
}
pub fn merge(&mut self, from_branch: BranchId, into_branch: BranchId) -> Result<MergeResult, BranchError> {
if from_branch == into_branch {
return Err(BranchError::SelfMerge);
}
if !self.branches.contains_key(&from_branch) {
return Err(BranchError::BranchNotFound(from_branch));
}
if !self.branches.contains_key(&into_branch) {
return Err(BranchError::BranchNotFound(into_branch));
}
if self.branches.get(&from_branch).map_or(false, |b| b.is_merged()) {
return Err(BranchError::AlreadyMerged(from_branch));
}
let merge_point = self.branches.get(&from_branch)
.map(|b| b.fork_point)
.ok_or(BranchError::BranchNotFound(from_branch))?;
let moved_nodes: Vec<NodeId> = self.branches.get(&from_branch)
.map(|b| b.nodes.clone())
.unwrap_or_default();
for &node in &moved_nodes {
self.node_to_branch.insert(node, into_branch);
}
if let Some(target) = self.branches.get_mut(&into_branch) {
target.nodes.extend(moved_nodes);
target.updated_at = current_timestamp();
}
if let Some(source) = self.branches.get_mut(&from_branch) {
source.mark_merged();
}
let operation = BranchOperation::merge(from_branch, into_branch, merge_point);
self.history.push(operation);
Ok(MergeResult {
target_branch: into_branch,
merged_branch: from_branch,
merge_point,
})
}
pub fn traverse(&mut self, target_branch: BranchId) -> Result<(), BranchError> {
if !self.branches.contains_key(&target_branch) {
return Err(BranchError::BranchNotFound(target_branch));
}
let from_branch = self.current_branch;
if let Some(branch) = self.branches.get(&target_branch) {
self.current_branch = target_branch;
self.context = BranchContext {
current_branch: target_branch,
current_node: branch.head,
depth: self.compute_depth(branch.head),
path: self.compute_path(branch.head),
};
}
let operation = BranchOperation::traverse(from_branch, target_branch);
self.history.push(operation);
Ok(())
}
pub fn archive(&mut self, branch_id: BranchId, reason: Option<String>) -> Result<(), BranchError> {
let branch = self.branches.get_mut(&branch_id)
.ok_or(BranchError::BranchNotFound(branch_id))?;
branch.archive();
let operation = BranchOperation::archive(branch_id, reason);
self.history.push(operation);
Ok(())
}
pub fn current(&self) -> Option<&Branch> {
self.branches.get(&self.current_branch)
}
pub fn get_branch(&self, branch_id: BranchId) -> Option<&Branch> {
self.branches.get(&branch_id)
}
pub fn active_branches(&self) -> impl Iterator<Item = &Branch> {
self.branches.values().filter(|b| b.is_active())
}
pub fn all_branches(&self) -> impl Iterator<Item = &Branch> {
self.branches.values()
}
pub fn branch_count(&self) -> usize {
self.branches.len()
}
pub fn fork_points(&self) -> impl Iterator<Item = &ForkPoint> {
self.fork_points.values()
}
pub fn history(&self) -> &[BranchOperation] {
&self.history
}
pub fn context(&self) -> &BranchContext {
&self.context
}
pub fn find_branch_for_node(&self, node_id: NodeId) -> Option<BranchId> {
self.node_to_branch.get(&node_id).copied()
}
pub fn child_branches(&self, branch_id: BranchId) -> Vec<BranchId> {
self.branches.get(&branch_id)
.map(|b| b.child_branches.clone())
.unwrap_or_default()
}
fn collect_subtree(&self, root: NodeId) -> Vec<NodeId> {
let mut nodes = Vec::new();
let mut stack = vec![root];
let mut visited = HashSet::new();
while let Some(node_id) = stack.pop() {
if visited.contains(&node_id) {
continue;
}
visited.insert(node_id);
nodes.push(node_id);
if let Some(episode) = self.graph.get_node(node_id) {
for &child in &episode.children {
stack.push(child);
}
}
}
nodes
}
fn compute_depth(&self, node_id: NodeId) -> u32 {
self.graph.depth(node_id).unwrap_or(0) as u32
}
fn compute_path(&self, node_id: NodeId) -> Vec<NodeId> {
self.graph.find_path_to(node_id).unwrap_or_default()
}
pub fn graph(&self) -> &TrajectoryGraph {
&self.graph
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::operations::BranchStatus;
use crate::trajectory::graph::{Edge, EdgeType};
fn make_branching_graph() -> TrajectoryGraph {
let edges = vec![
Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
];
TrajectoryGraph::from_edges(edges.into_iter())
}
#[test]
fn test_state_machine_creation() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph);
assert!(machine.branch_count() > 0);
}
#[test]
fn test_split_operation() {
let graph = make_branching_graph();
let mut machine = BranchStateMachine::from_graph(graph);
let initial_count = machine.branch_count();
if let Ok(result) = machine.split(3) {
assert!(machine.branch_count() >= initial_count);
assert!(machine.get_branch(result.new_branch).is_some());
}
}
#[test]
fn test_traverse() {
let graph = make_branching_graph();
let mut machine = BranchStateMachine::from_graph(graph);
let original_branch = machine.current_branch;
for branch in machine.all_branches().map(|b| b.id).collect::<Vec<_>>() {
if branch != original_branch {
assert!(machine.traverse(branch).is_ok());
assert_eq!(machine.current_branch, branch);
break;
}
}
}
#[test]
fn test_archive() {
let graph = make_branching_graph();
let mut machine = BranchStateMachine::from_graph(graph);
let branch_id = machine.current_branch;
assert!(machine.archive(branch_id, Some("test".to_string())).is_ok());
let branch = machine.get_branch(branch_id).unwrap();
assert_eq!(branch.status, BranchStatus::Archived);
}
#[test]
fn test_cannot_split_root() {
let graph = make_branching_graph();
let root = *graph.roots().first().unwrap();
let mut machine = BranchStateMachine::from_graph(graph);
let result = machine.split(root);
assert!(matches!(result, Err(BranchError::CannotSplitRoot)));
}
#[test]
fn test_operation_history() {
let graph = make_branching_graph();
let mut machine = BranchStateMachine::from_graph(graph);
let _ = machine.archive(0, None);
assert!(!machine.history().is_empty());
}
#[test]
fn test_find_branch_for_node() {
let graph = make_branching_graph();
let machine = BranchStateMachine::from_graph(graph);
for node_id in [1, 2, 3, 4, 5] {
let branch = machine.find_branch_for_node(node_id);
if machine.graph().get_node(node_id).is_some() {
}
}
}
}