use std::collections::{HashMap, HashSet, VecDeque};
pub type NodeId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Edge {
pub parent: NodeId,
pub child: NodeId,
pub edge_type: EdgeType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum EdgeType {
#[default]
Continuation,
Regeneration,
Branch,
}
#[derive(Debug, Clone)]
pub struct Episode {
pub id: NodeId,
pub parent: Option<NodeId>,
pub children: Vec<NodeId>,
pub weight: f32,
pub has_thumbs_up: bool,
pub has_thumbs_down: bool,
pub content_length: usize,
pub has_error: bool,
pub created_at: i64,
}
impl Episode {
pub fn new(id: NodeId) -> Self {
Self {
id,
parent: None,
children: Vec::new(),
weight: 1.0,
has_thumbs_up: false,
has_thumbs_down: false,
content_length: 0,
has_error: false,
created_at: 0,
}
}
#[inline]
pub fn is_branch_point(&self) -> bool {
self.children.len() > 1
}
#[inline]
pub fn is_leaf(&self) -> bool {
self.children.is_empty()
}
#[inline]
pub fn is_root(&self) -> bool {
self.parent.is_none()
}
}
#[derive(Debug, Clone)]
pub struct BranchInfo {
pub branch_point: NodeId,
pub children: Vec<NodeId>,
pub branch_type: EdgeType,
pub selected_child_idx: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PathSelectionPolicy {
#[default]
FeedbackFirst,
FirstByTime,
LongestContent,
HighestWeight,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TraversalOrder {
#[default]
DepthFirst,
BreadthFirst,
Topological,
ReverseTopological,
}
#[derive(Debug, Clone)]
pub struct PathResult {
pub nodes: Vec<NodeId>,
pub branch_points: Vec<BranchInfo>,
pub total_weight: f32,
}
#[derive(Debug, Clone)]
pub struct TrajectoryGraph {
nodes: HashMap<NodeId, Episode>,
roots: Vec<NodeId>,
leaves: Vec<NodeId>,
}
impl TrajectoryGraph {
pub fn new() -> Self {
Self {
nodes: HashMap::new(),
roots: Vec::new(),
leaves: Vec::new(),
}
}
pub fn from_edges(edges: impl IntoIterator<Item = Edge>) -> Self {
let mut graph = Self::new();
for edge in edges {
graph.nodes.entry(edge.parent).or_insert_with(|| Episode::new(edge.parent));
graph.nodes.entry(edge.child).or_insert_with(|| Episode::new(edge.child));
if let Some(parent) = graph.nodes.get_mut(&edge.parent) {
if !parent.children.contains(&edge.child) {
parent.children.push(edge.child);
}
}
if let Some(child) = graph.nodes.get_mut(&edge.child) {
child.parent = Some(edge.parent);
}
}
graph.update_roots_and_leaves();
graph
}
pub fn add_node(&mut self, node: Episode) {
self.nodes.insert(node.id, node);
}
#[inline]
pub fn get_node(&self, id: NodeId) -> Option<&Episode> {
self.nodes.get(&id)
}
#[inline]
pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Episode> {
self.nodes.get_mut(&id)
}
#[inline]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn roots(&self) -> &[NodeId] {
&self.roots
}
#[inline]
pub fn leaves(&self) -> &[NodeId] {
&self.leaves
}
#[inline]
pub fn is_branch_point(&self, id: NodeId) -> bool {
self.nodes.get(&id).map_or(false, |n| n.is_branch_point())
}
pub fn find_branch_points(&self) -> Vec<BranchInfo> {
self.nodes
.values()
.filter(|n| n.is_branch_point())
.map(|n| BranchInfo {
branch_point: n.id,
children: n.children.clone(),
branch_type: if n.children.len() > 1 {
EdgeType::Regeneration
} else {
EdgeType::Continuation
},
selected_child_idx: None,
})
.collect()
}
fn update_roots_and_leaves(&mut self) {
self.roots = self.nodes.values()
.filter(|n| n.is_root())
.map(|n| n.id)
.collect();
self.leaves = self.nodes.values()
.filter(|n| n.is_leaf())
.map(|n| n.id)
.collect();
}
pub fn traverse<F>(&self, order: TraversalOrder, mut visitor: F)
where
F: FnMut(&Episode),
{
match order {
TraversalOrder::DepthFirst => self.traverse_dfs(&mut visitor),
TraversalOrder::BreadthFirst => self.traverse_bfs(&mut visitor),
TraversalOrder::Topological => self.traverse_topological(&mut visitor),
TraversalOrder::ReverseTopological => self.traverse_reverse_topological(&mut visitor),
}
}
fn traverse_dfs<F>(&self, visitor: &mut F)
where
F: FnMut(&Episode),
{
let mut visited = HashSet::new();
let mut stack: Vec<NodeId> = self.roots.clone();
while let Some(id) = stack.pop() {
if visited.contains(&id) {
continue;
}
visited.insert(id);
if let Some(node) = self.nodes.get(&id) {
visitor(node);
for &child_id in node.children.iter().rev() {
if !visited.contains(&child_id) {
stack.push(child_id);
}
}
}
}
}
fn traverse_bfs<F>(&self, visitor: &mut F)
where
F: FnMut(&Episode),
{
let mut visited = HashSet::new();
let mut queue: VecDeque<NodeId> = self.roots.iter().copied().collect();
while let Some(id) = queue.pop_front() {
if visited.contains(&id) {
continue;
}
visited.insert(id);
if let Some(node) = self.nodes.get(&id) {
visitor(node);
for &child_id in &node.children {
if !visited.contains(&child_id) {
queue.push_back(child_id);
}
}
}
}
}
fn traverse_topological<F>(&self, visitor: &mut F)
where
F: FnMut(&Episode),
{
let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
for node in self.nodes.values() {
in_degree.entry(node.id).or_insert(0);
for &child in &node.children {
*in_degree.entry(child).or_insert(0) += 1;
}
}
let mut queue: VecDeque<NodeId> = in_degree
.iter()
.filter(|(_, °)| deg == 0)
.map(|(&id, _)| id)
.collect();
while let Some(id) = queue.pop_front() {
if let Some(node) = self.nodes.get(&id) {
visitor(node);
for &child in &node.children {
if let Some(deg) = in_degree.get_mut(&child) {
*deg -= 1;
if *deg == 0 {
queue.push_back(child);
}
}
}
}
}
}
fn traverse_reverse_topological<F>(&self, visitor: &mut F)
where
F: FnMut(&Episode),
{
let mut order = Vec::with_capacity(self.nodes.len());
self.traverse_topological(&mut |node| order.push(node.id));
for id in order.into_iter().rev() {
if let Some(node) = self.nodes.get(&id) {
visitor(node);
}
}
}
pub fn find_primary_path(&self, policy: PathSelectionPolicy) -> Option<PathResult> {
if self.roots.is_empty() {
return None;
}
let start = self.roots[0];
let mut path = Vec::new();
let mut branch_points = Vec::new();
let mut total_weight = 0.0;
let mut current = start;
loop {
let node = self.nodes.get(¤t)?;
path.push(current);
total_weight += node.weight;
if node.children.is_empty() {
break;
}
let (next_idx, next) = self.select_child(node, policy)?;
if node.is_branch_point() {
branch_points.push(BranchInfo {
branch_point: current,
children: node.children.clone(),
branch_type: EdgeType::Regeneration,
selected_child_idx: Some(next_idx),
});
}
current = next;
}
Some(PathResult {
nodes: path,
branch_points,
total_weight,
})
}
fn select_child(&self, parent: &Episode, policy: PathSelectionPolicy) -> Option<(usize, NodeId)> {
if parent.children.is_empty() {
return None;
}
let children: Vec<&Episode> = parent.children
.iter()
.filter_map(|&id| self.nodes.get(&id))
.collect();
if children.is_empty() {
return Some((0, parent.children[0]));
}
let selected_idx = match policy {
PathSelectionPolicy::FeedbackFirst => {
children.iter().enumerate()
.max_by(|(_, a), (_, b)| {
match (a.has_thumbs_up, b.has_thumbs_up) {
(true, false) => return std::cmp::Ordering::Greater,
(false, true) => return std::cmp::Ordering::Less,
_ => {}
}
match (a.has_thumbs_down, b.has_thumbs_down) {
(false, true) => return std::cmp::Ordering::Greater,
(true, false) => return std::cmp::Ordering::Less,
_ => {}
}
match a.content_length.cmp(&b.content_length) {
std::cmp::Ordering::Equal => {}
other => return other,
}
a.created_at.cmp(&b.created_at).reverse()
})
.map(|(i, _)| i)
.unwrap_or(0)
}
PathSelectionPolicy::FirstByTime => {
children.iter().enumerate()
.min_by_key(|(_, n)| n.created_at)
.map(|(i, _)| i)
.unwrap_or(0)
}
PathSelectionPolicy::LongestContent => {
children.iter().enumerate()
.max_by_key(|(_, n)| n.content_length)
.map(|(i, _)| i)
.unwrap_or(0)
}
PathSelectionPolicy::HighestWeight => {
children.iter().enumerate()
.max_by(|(_, a), (_, b)| a.weight.partial_cmp(&b.weight).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
};
Some((selected_idx, parent.children[selected_idx]))
}
pub fn find_all_paths_from(&self, start: NodeId) -> Vec<Vec<NodeId>> {
let mut paths = Vec::new();
let mut current_path = vec![start];
self.find_paths_recursive(start, &mut current_path, &mut paths);
paths
}
fn find_paths_recursive(
&self,
current: NodeId,
path: &mut Vec<NodeId>,
paths: &mut Vec<Vec<NodeId>>,
) {
if let Some(node) = self.nodes.get(¤t) {
if node.is_leaf() {
paths.push(path.clone());
} else {
for &child in &node.children {
path.push(child);
self.find_paths_recursive(child, path, paths);
path.pop();
}
}
}
}
pub fn find_path_to(&self, target: NodeId) -> Option<Vec<NodeId>> {
let mut path = Vec::new();
let mut current = target;
loop {
path.push(current);
match self.nodes.get(¤t)?.parent {
Some(parent) => current = parent,
None => break,
}
}
path.reverse();
Some(path)
}
pub fn depth(&self, node: NodeId) -> Option<usize> {
self.find_path_to(node).map(|p| p.len() - 1)
}
pub fn lowest_common_ancestor(&self, a: NodeId, b: NodeId) -> Option<NodeId> {
let path_a = self.find_path_to(a)?;
let path_b = self.find_path_to(b)?;
let path_a_set: HashSet<_> = path_a.iter().copied().collect();
for &node in path_b.iter().rev() {
if path_a_set.contains(&node) {
return Some(node);
}
}
None
}
}
impl Default for TrajectoryGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_linear_graph() -> TrajectoryGraph {
let edges = vec![
Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
Edge { parent: 2, child: 3, edge_type: EdgeType::Continuation },
Edge { parent: 3, child: 4, edge_type: EdgeType::Continuation },
];
TrajectoryGraph::from_edges(edges.into_iter())
}
fn make_branching_graph() -> TrajectoryGraph {
let edges = vec![
Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
];
TrajectoryGraph::from_edges(edges.into_iter())
}
#[test]
fn test_linear_graph() {
let graph = make_linear_graph();
assert_eq!(graph.node_count(), 4);
assert_eq!(graph.roots(), &[1]);
assert_eq!(graph.leaves(), &[4]);
assert!(!graph.is_branch_point(1));
}
#[test]
fn test_branching_graph() {
let graph = make_branching_graph();
assert_eq!(graph.node_count(), 5);
assert!(graph.is_branch_point(1));
assert!(graph.is_branch_point(2));
let branches = graph.find_branch_points();
assert_eq!(branches.len(), 2);
}
#[test]
fn test_find_path_to() {
let graph = make_linear_graph();
let path = graph.find_path_to(4).unwrap();
assert_eq!(path, vec![1, 2, 3, 4]);
}
#[test]
fn test_primary_path() {
let graph = make_linear_graph();
let result = graph.find_primary_path(PathSelectionPolicy::FirstByTime).unwrap();
assert_eq!(result.nodes, vec![1, 2, 3, 4]);
assert!(result.branch_points.is_empty());
}
#[test]
fn test_dfs_traversal() {
let graph = make_linear_graph();
let mut visited = Vec::new();
graph.traverse(TraversalOrder::DepthFirst, |node| {
visited.push(node.id);
});
assert_eq!(visited, vec![1, 2, 3, 4]);
}
#[test]
fn test_depth() {
let graph = make_linear_graph();
assert_eq!(graph.depth(1), Some(0));
assert_eq!(graph.depth(4), Some(3));
}
}