use std::collections::VecDeque;
use thiserror::Error;
pub type NodeId = usize;
#[derive(Debug, Clone, Default)]
pub struct NodeLatency {
pub latency_ns: Option<u64>,
}
impl NodeLatency {
pub fn new(latency_ns: u64) -> Self {
Self {
latency_ns: Some(latency_ns),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct InferenceGraph {
pub nodes: Vec<NodeLatency>,
pub edges: Vec<(NodeId, NodeId)>,
}
impl InferenceGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, latency: NodeLatency) -> NodeId {
let id = self.nodes.len();
self.nodes.push(latency);
id
}
pub fn add_edge(&mut self, from: NodeId, to: NodeId) -> Result<(), CriticalPathError> {
let n = self.nodes.len();
if from >= n {
return Err(CriticalPathError::InvalidNode(from));
}
if to >= n {
return Err(CriticalPathError::InvalidNode(to));
}
self.edges.push((from, to));
Ok(())
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_edges(&self) -> usize {
self.edges.len()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CriticalPathReport {
pub nodes: Vec<NodeId>,
pub total_latency_ns: u64,
pub bottleneck: NodeId,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MissingCostWarning {
pub node_id: NodeId,
}
#[derive(Debug, Clone)]
pub struct CriticalPathResult {
pub report: CriticalPathReport,
pub warnings: Vec<MissingCostWarning>,
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum CriticalPathError {
#[error("Cycle detected; these nodes were not reachable via topological sort: {0}")]
CycleDetected(String),
#[error("Edge references out-of-range node id {0}")]
InvalidNode(NodeId),
}
pub fn critical_path(graph: &InferenceGraph) -> Result<CriticalPathResult, CriticalPathError> {
let n = graph.num_nodes();
if n == 0 {
return Ok(CriticalPathResult {
report: CriticalPathReport {
nodes: vec![],
total_latency_ns: 0,
bottleneck: 0,
},
warnings: vec![],
});
}
let mut succ: Vec<Vec<NodeId>> = vec![vec![]; n];
let mut pred: Vec<Vec<NodeId>> = vec![vec![]; n];
let mut in_degree: Vec<usize> = vec![0; n];
for &(from, to) in &graph.edges {
if from >= n || to >= n {
return Err(CriticalPathError::InvalidNode(if from >= n {
from
} else {
to
}));
}
succ[from].push(to);
pred[to].push(from);
in_degree[to] += 1;
}
let mut warnings: Vec<MissingCostWarning> = vec![];
let costs: Vec<u64> = graph
.nodes
.iter()
.enumerate()
.map(|(id, nl)| {
nl.latency_ns.unwrap_or_else(|| {
warnings.push(MissingCostWarning { node_id: id });
1
})
})
.collect();
let mut queue: VecDeque<NodeId> = VecDeque::new();
for v in 0..n {
if in_degree[v] == 0 {
queue.push_back(v);
}
}
let mut topo_order: Vec<NodeId> = Vec::with_capacity(n);
let mut remaining_in: Vec<usize> = in_degree.clone();
while let Some(u) = queue.pop_front() {
topo_order.push(u);
for &v in &succ[u] {
remaining_in[v] -= 1;
if remaining_in[v] == 0 {
queue.push_back(v);
}
}
}
if topo_order.len() != n {
let cyclic: Vec<String> = (0..n)
.filter(|&v| !topo_order.contains(&v))
.map(|v| v.to_string())
.collect();
return Err(CriticalPathError::CycleDetected(cyclic.join(", ")));
}
let mut dist: Vec<u64> = vec![0; n];
let mut best_pred: Vec<Option<NodeId>> = vec![None; n];
for &v in &topo_order {
dist[v] = costs[v];
best_pred[v] = None;
for &u in &pred[v] {
let candidate = dist[u].saturating_add(costs[v]);
if candidate > dist[v] {
dist[v] = candidate;
best_pred[v] = Some(u);
}
}
}
let (end_node, &max_dist) = dist
.iter()
.enumerate()
.max_by_key(|&(_, d)| d)
.unwrap_or((0, &0));
let mut path: Vec<NodeId> = vec![];
let mut current = end_node;
loop {
path.push(current);
match best_pred[current] {
Some(prev) => current = prev,
None => break,
}
}
path.reverse();
let bottleneck = path
.iter()
.copied()
.max_by_key(|&v| costs[v])
.unwrap_or(end_node);
Ok(CriticalPathResult {
report: CriticalPathReport {
nodes: path,
total_latency_ns: max_dist,
bottleneck,
},
warnings,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn build_graph(latencies: &[Option<u64>], edges: &[(usize, usize)]) -> InferenceGraph {
let mut g = InferenceGraph::new();
for &lat in latencies {
g.add_node(NodeLatency { latency_ns: lat });
}
for &(from, to) in edges {
g.add_edge(from, to).expect("valid edge");
}
g
}
#[test]
fn test_linear_chain() {
let g = build_graph(&[Some(10), Some(20), Some(5)], &[(0, 1), (1, 2)]);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.nodes, vec![0, 1, 2]);
assert_eq!(res.report.total_latency_ns, 35);
assert_eq!(res.report.bottleneck, 1); assert!(res.warnings.is_empty());
}
#[test]
fn test_diamond_longer_branch_wins() {
let g = build_graph(
&[Some(1), Some(100), Some(1), Some(1)],
&[(0, 1), (0, 2), (1, 3), (2, 3)],
);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.nodes, vec![0, 1, 3]);
assert_eq!(res.report.total_latency_ns, 102);
assert_eq!(res.report.bottleneck, 1); assert!(res.warnings.is_empty());
}
#[test]
fn test_single_node() {
let g = build_graph(&[Some(42)], &[]);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.nodes, vec![0]);
assert_eq!(res.report.total_latency_ns, 42);
assert_eq!(res.report.bottleneck, 0);
assert!(res.warnings.is_empty());
}
#[test]
fn test_missing_latency_warning() {
let g = build_graph(&[None, None, None], &[(0, 1), (1, 2)]);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.total_latency_ns, 3);
assert_eq!(res.warnings.len(), 3);
let warned_ids: Vec<NodeId> = res.warnings.iter().map(|w| w.node_id).collect();
assert!(warned_ids.contains(&0));
assert!(warned_ids.contains(&1));
assert!(warned_ids.contains(&2));
}
#[test]
fn test_empty_graph() {
let g = InferenceGraph::new();
let res = critical_path(&g).expect("no cycle");
assert!(res.report.nodes.is_empty());
assert_eq!(res.report.total_latency_ns, 0);
assert_eq!(res.report.bottleneck, 0);
assert!(res.warnings.is_empty());
}
#[test]
fn test_cycle_detected() {
let g = build_graph(&[Some(1), Some(1), Some(1)], &[(0, 1), (1, 2), (2, 0)]);
let err = critical_path(&g).expect_err("should detect cycle");
matches!(err, CriticalPathError::CycleDetected(_));
}
#[test]
fn test_parallel_branches() {
let g = build_graph(&[Some(5), Some(10), Some(1), Some(3)], &[(0, 1), (2, 3)]);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.total_latency_ns, 15);
assert_eq!(*res.report.nodes.last().expect("non-empty"), 1);
}
#[test]
fn test_fan_out_fan_in() {
let g = build_graph(
&[Some(1), Some(2), Some(5), Some(3), Some(1)],
&[(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)],
);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.total_latency_ns, 7);
assert_eq!(res.report.nodes, vec![0, 2, 4]);
assert_eq!(res.report.bottleneck, 2); }
#[test]
fn test_invalid_edge() {
let mut g = InferenceGraph::new();
g.add_node(NodeLatency::new(10));
let err = g.add_edge(0, 5).expect_err("node 5 does not exist");
matches!(err, CriticalPathError::InvalidNode(5));
}
#[test]
fn test_mixed_latencies() {
let g = build_graph(&[Some(10), None, Some(50)], &[(0, 1), (1, 2)]);
let res = critical_path(&g).expect("no cycle");
assert_eq!(res.report.total_latency_ns, 61);
assert_eq!(res.report.bottleneck, 2); assert_eq!(res.warnings.len(), 1);
assert_eq!(res.warnings[0].node_id, 1);
}
}