1use std::collections::{HashMap, HashSet, VecDeque};
14
15pub type NodeId = u64;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct Edge {
21 pub parent: NodeId,
22 pub child: NodeId,
23 pub edge_type: EdgeType,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
28pub enum EdgeType {
29 #[default]
31 Continuation,
32 Regeneration,
34 Branch,
36}
37
38#[derive(Debug, Clone)]
43pub struct Episode {
44 pub id: NodeId,
45 pub parent: Option<NodeId>,
46 pub children: Vec<NodeId>,
47 pub weight: f32,
49 pub has_thumbs_up: bool,
50 pub has_thumbs_down: bool,
51 pub content_length: usize,
52 pub has_error: bool,
53 pub created_at: i64,
54}
55
56impl Episode {
57 pub fn new(id: NodeId) -> Self {
58 Self {
59 id,
60 parent: None,
61 children: Vec::new(),
62 weight: 1.0,
63 has_thumbs_up: false,
64 has_thumbs_down: false,
65 content_length: 0,
66 has_error: false,
67 created_at: 0,
68 }
69 }
70
71 #[inline]
73 pub fn is_branch_point(&self) -> bool {
74 self.children.len() > 1
75 }
76
77 #[inline]
79 pub fn is_leaf(&self) -> bool {
80 self.children.is_empty()
81 }
82
83 #[inline]
85 pub fn is_root(&self) -> bool {
86 self.parent.is_none()
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct BranchInfo {
93 pub branch_point: NodeId,
95 pub children: Vec<NodeId>,
97 pub branch_type: EdgeType,
99 pub selected_child_idx: Option<usize>,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
105pub enum PathSelectionPolicy {
106 #[default]
108 FeedbackFirst,
109 FirstByTime,
111 LongestContent,
113 HighestWeight,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
119pub enum TraversalOrder {
120 #[default]
122 DepthFirst,
123 BreadthFirst,
125 Topological,
127 ReverseTopological,
129}
130
131#[derive(Debug, Clone)]
133pub struct PathResult {
134 pub nodes: Vec<NodeId>,
136 pub branch_points: Vec<BranchInfo>,
138 pub total_weight: f32,
140}
141
142#[derive(Debug, Clone)]
148pub struct TrajectoryGraph {
149 nodes: HashMap<NodeId, Episode>,
150 roots: Vec<NodeId>,
151 leaves: Vec<NodeId>,
152}
153
154impl TrajectoryGraph {
155 pub fn new() -> Self {
157 Self {
158 nodes: HashMap::new(),
159 roots: Vec::new(),
160 leaves: Vec::new(),
161 }
162 }
163
164 pub fn from_edges(edges: impl IntoIterator<Item = Edge>) -> Self {
186 let mut graph = Self::new();
187
188 for edge in edges {
189 graph.nodes.entry(edge.parent).or_insert_with(|| Episode::new(edge.parent));
191 graph.nodes.entry(edge.child).or_insert_with(|| Episode::new(edge.child));
192
193 if let Some(parent) = graph.nodes.get_mut(&edge.parent) {
195 if !parent.children.contains(&edge.child) {
196 parent.children.push(edge.child);
197 }
198 }
199 if let Some(child) = graph.nodes.get_mut(&edge.child) {
200 child.parent = Some(edge.parent);
201 }
202 }
203
204 graph.update_roots_and_leaves();
205 graph
206 }
207
208 pub fn add_node(&mut self, node: Episode) {
210 self.nodes.insert(node.id, node);
211 }
212
213 #[inline]
215 pub fn get_node(&self, id: NodeId) -> Option<&Episode> {
216 self.nodes.get(&id)
217 }
218
219 #[inline]
221 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Episode> {
222 self.nodes.get_mut(&id)
223 }
224
225 #[inline]
227 pub fn node_count(&self) -> usize {
228 self.nodes.len()
229 }
230
231 #[inline]
233 pub fn roots(&self) -> &[NodeId] {
234 &self.roots
235 }
236
237 #[inline]
239 pub fn leaves(&self) -> &[NodeId] {
240 &self.leaves
241 }
242
243 #[inline]
245 pub fn is_branch_point(&self, id: NodeId) -> bool {
246 self.nodes.get(&id).map_or(false, |n| n.is_branch_point())
247 }
248
249 pub fn find_branch_points(&self) -> Vec<BranchInfo> {
251 self.nodes
252 .values()
253 .filter(|n| n.is_branch_point())
254 .map(|n| BranchInfo {
255 branch_point: n.id,
256 children: n.children.clone(),
257 branch_type: if n.children.len() > 1 {
258 EdgeType::Regeneration
259 } else {
260 EdgeType::Continuation
261 },
262 selected_child_idx: None,
263 })
264 .collect()
265 }
266
267 fn update_roots_and_leaves(&mut self) {
269 self.roots = self.nodes.values()
270 .filter(|n| n.is_root())
271 .map(|n| n.id)
272 .collect();
273
274 self.leaves = self.nodes.values()
275 .filter(|n| n.is_leaf())
276 .map(|n| n.id)
277 .collect();
278 }
279
280 pub fn traverse<F>(&self, order: TraversalOrder, mut visitor: F)
291 where
292 F: FnMut(&Episode),
293 {
294 match order {
295 TraversalOrder::DepthFirst => self.traverse_dfs(&mut visitor),
296 TraversalOrder::BreadthFirst => self.traverse_bfs(&mut visitor),
297 TraversalOrder::Topological => self.traverse_topological(&mut visitor),
298 TraversalOrder::ReverseTopological => self.traverse_reverse_topological(&mut visitor),
299 }
300 }
301
302 fn traverse_dfs<F>(&self, visitor: &mut F)
303 where
304 F: FnMut(&Episode),
305 {
306 let mut visited = HashSet::new();
307 let mut stack: Vec<NodeId> = self.roots.clone();
308
309 while let Some(id) = stack.pop() {
310 if visited.contains(&id) {
311 continue;
312 }
313 visited.insert(id);
314
315 if let Some(node) = self.nodes.get(&id) {
316 visitor(node);
317 for &child_id in node.children.iter().rev() {
319 if !visited.contains(&child_id) {
320 stack.push(child_id);
321 }
322 }
323 }
324 }
325 }
326
327 fn traverse_bfs<F>(&self, visitor: &mut F)
328 where
329 F: FnMut(&Episode),
330 {
331 let mut visited = HashSet::new();
332 let mut queue: VecDeque<NodeId> = self.roots.iter().copied().collect();
333
334 while let Some(id) = queue.pop_front() {
335 if visited.contains(&id) {
336 continue;
337 }
338 visited.insert(id);
339
340 if let Some(node) = self.nodes.get(&id) {
341 visitor(node);
342 for &child_id in &node.children {
343 if !visited.contains(&child_id) {
344 queue.push_back(child_id);
345 }
346 }
347 }
348 }
349 }
350
351 fn traverse_topological<F>(&self, visitor: &mut F)
352 where
353 F: FnMut(&Episode),
354 {
355 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
357 for node in self.nodes.values() {
358 in_degree.entry(node.id).or_insert(0);
359 for &child in &node.children {
360 *in_degree.entry(child).or_insert(0) += 1;
361 }
362 }
363
364 let mut queue: VecDeque<NodeId> = in_degree
365 .iter()
366 .filter(|(_, °)| deg == 0)
367 .map(|(&id, _)| id)
368 .collect();
369
370 while let Some(id) = queue.pop_front() {
371 if let Some(node) = self.nodes.get(&id) {
372 visitor(node);
373 for &child in &node.children {
374 if let Some(deg) = in_degree.get_mut(&child) {
375 *deg -= 1;
376 if *deg == 0 {
377 queue.push_back(child);
378 }
379 }
380 }
381 }
382 }
383 }
384
385 fn traverse_reverse_topological<F>(&self, visitor: &mut F)
386 where
387 F: FnMut(&Episode),
388 {
389 let mut order = Vec::with_capacity(self.nodes.len());
390 self.traverse_topological(&mut |node| order.push(node.id));
391
392 for id in order.into_iter().rev() {
393 if let Some(node) = self.nodes.get(&id) {
394 visitor(node);
395 }
396 }
397 }
398
399 pub fn find_primary_path(&self, policy: PathSelectionPolicy) -> Option<PathResult> {
416 if self.roots.is_empty() {
417 return None;
418 }
419
420 let start = self.roots[0];
422 let mut path = Vec::new();
423 let mut branch_points = Vec::new();
424 let mut total_weight = 0.0;
425 let mut current = start;
426
427 loop {
428 let node = self.nodes.get(¤t)?;
429 path.push(current);
430 total_weight += node.weight;
431
432 if node.children.is_empty() {
433 break;
434 }
435
436 let (next_idx, next) = self.select_child(node, policy)?;
438
439 if node.is_branch_point() {
440 branch_points.push(BranchInfo {
441 branch_point: current,
442 children: node.children.clone(),
443 branch_type: EdgeType::Regeneration,
444 selected_child_idx: Some(next_idx),
445 });
446 }
447
448 current = next;
449 }
450
451 Some(PathResult {
452 nodes: path,
453 branch_points,
454 total_weight,
455 })
456 }
457
458 fn select_child(&self, parent: &Episode, policy: PathSelectionPolicy) -> Option<(usize, NodeId)> {
460 if parent.children.is_empty() {
461 return None;
462 }
463
464 let children: Vec<&Episode> = parent.children
465 .iter()
466 .filter_map(|&id| self.nodes.get(&id))
467 .collect();
468
469 if children.is_empty() {
470 return Some((0, parent.children[0]));
471 }
472
473 let selected_idx = match policy {
474 PathSelectionPolicy::FeedbackFirst => {
475 children.iter().enumerate()
477 .max_by(|(_, a), (_, b)| {
478 match (a.has_thumbs_up, b.has_thumbs_up) {
480 (true, false) => return std::cmp::Ordering::Greater,
481 (false, true) => return std::cmp::Ordering::Less,
482 _ => {}
483 }
484 match (a.has_thumbs_down, b.has_thumbs_down) {
486 (false, true) => return std::cmp::Ordering::Greater,
487 (true, false) => return std::cmp::Ordering::Less,
488 _ => {}
489 }
490 match a.content_length.cmp(&b.content_length) {
492 std::cmp::Ordering::Equal => {}
493 other => return other,
494 }
495 a.created_at.cmp(&b.created_at).reverse()
497 })
498 .map(|(i, _)| i)
499 .unwrap_or(0)
500 }
501 PathSelectionPolicy::FirstByTime => {
502 children.iter().enumerate()
503 .min_by_key(|(_, n)| n.created_at)
504 .map(|(i, _)| i)
505 .unwrap_or(0)
506 }
507 PathSelectionPolicy::LongestContent => {
508 children.iter().enumerate()
509 .max_by_key(|(_, n)| n.content_length)
510 .map(|(i, _)| i)
511 .unwrap_or(0)
512 }
513 PathSelectionPolicy::HighestWeight => {
514 children.iter().enumerate()
515 .max_by(|(_, a), (_, b)| a.weight.partial_cmp(&b.weight).unwrap_or(std::cmp::Ordering::Equal))
516 .map(|(i, _)| i)
517 .unwrap_or(0)
518 }
519 };
520
521 Some((selected_idx, parent.children[selected_idx]))
522 }
523
524 pub fn find_all_paths_from(&self, start: NodeId) -> Vec<Vec<NodeId>> {
530 let mut paths = Vec::new();
531 let mut current_path = vec![start];
532 self.find_paths_recursive(start, &mut current_path, &mut paths);
533 paths
534 }
535
536 fn find_paths_recursive(
537 &self,
538 current: NodeId,
539 path: &mut Vec<NodeId>,
540 paths: &mut Vec<Vec<NodeId>>,
541 ) {
542 if let Some(node) = self.nodes.get(¤t) {
543 if node.is_leaf() {
544 paths.push(path.clone());
545 } else {
546 for &child in &node.children {
547 path.push(child);
548 self.find_paths_recursive(child, path, paths);
549 path.pop();
550 }
551 }
552 }
553 }
554
555 pub fn find_path_to(&self, target: NodeId) -> Option<Vec<NodeId>> {
557 let mut path = Vec::new();
558 let mut current = target;
559
560 loop {
561 path.push(current);
562 match self.nodes.get(¤t)?.parent {
563 Some(parent) => current = parent,
564 None => break,
565 }
566 }
567
568 path.reverse();
569 Some(path)
570 }
571
572 pub fn depth(&self, node: NodeId) -> Option<usize> {
574 self.find_path_to(node).map(|p| p.len() - 1)
575 }
576
577 pub fn lowest_common_ancestor(&self, a: NodeId, b: NodeId) -> Option<NodeId> {
579 let path_a = self.find_path_to(a)?;
580 let path_b = self.find_path_to(b)?;
581
582 let path_a_set: HashSet<_> = path_a.iter().copied().collect();
583
584 for &node in path_b.iter().rev() {
586 if path_a_set.contains(&node) {
587 return Some(node);
588 }
589 }
590
591 None
592 }
593}
594
595impl Default for TrajectoryGraph {
596 fn default() -> Self {
597 Self::new()
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 fn make_linear_graph() -> TrajectoryGraph {
606 let edges = vec![
608 Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
609 Edge { parent: 2, child: 3, edge_type: EdgeType::Continuation },
610 Edge { parent: 3, child: 4, edge_type: EdgeType::Continuation },
611 ];
612 TrajectoryGraph::from_edges(edges.into_iter())
613 }
614
615 fn make_branching_graph() -> TrajectoryGraph {
616 let edges = vec![
620 Edge { parent: 1, child: 2, edge_type: EdgeType::Continuation },
621 Edge { parent: 2, child: 3, edge_type: EdgeType::Regeneration },
622 Edge { parent: 2, child: 4, edge_type: EdgeType::Regeneration },
623 Edge { parent: 1, child: 5, edge_type: EdgeType::Branch },
624 ];
625 TrajectoryGraph::from_edges(edges.into_iter())
626 }
627
628 #[test]
629 fn test_linear_graph() {
630 let graph = make_linear_graph();
631 assert_eq!(graph.node_count(), 4);
632 assert_eq!(graph.roots(), &[1]);
633 assert_eq!(graph.leaves(), &[4]);
634 assert!(!graph.is_branch_point(1));
635 }
636
637 #[test]
638 fn test_branching_graph() {
639 let graph = make_branching_graph();
640 assert_eq!(graph.node_count(), 5);
641 assert!(graph.is_branch_point(1));
642 assert!(graph.is_branch_point(2));
643
644 let branches = graph.find_branch_points();
645 assert_eq!(branches.len(), 2);
646 }
647
648 #[test]
649 fn test_find_path_to() {
650 let graph = make_linear_graph();
651 let path = graph.find_path_to(4).unwrap();
652 assert_eq!(path, vec![1, 2, 3, 4]);
653 }
654
655 #[test]
656 fn test_primary_path() {
657 let graph = make_linear_graph();
658 let result = graph.find_primary_path(PathSelectionPolicy::FirstByTime).unwrap();
659 assert_eq!(result.nodes, vec![1, 2, 3, 4]);
660 assert!(result.branch_points.is_empty());
661 }
662
663 #[test]
664 fn test_dfs_traversal() {
665 let graph = make_linear_graph();
666 let mut visited = Vec::new();
667 graph.traverse(TraversalOrder::DepthFirst, |node| {
668 visited.push(node.id);
669 });
670 assert_eq!(visited, vec![1, 2, 3, 4]);
671 }
672
673 #[test]
674 fn test_depth() {
675 let graph = make_linear_graph();
676 assert_eq!(graph.depth(1), Some(0));
677 assert_eq!(graph.depth(4), Some(3));
678 }
679}