use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use zeph_config::OrchestrationConfig;
use super::graph::{TaskGraph, TaskId, TaskNode};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Topology {
AllParallel,
LinearChain,
FanOut,
FanIn,
Hierarchical,
Mixed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DispatchStrategy {
FullParallel,
Sequential,
LevelBarrier,
Adaptive,
TreeOptimized,
CascadeAware,
}
#[derive(Debug, Clone)]
pub struct TopologyAnalysis {
pub topology: Topology,
pub strategy: DispatchStrategy,
pub max_parallel: usize,
pub depth: usize,
pub depths: HashMap<TaskId, usize>,
}
pub struct TopologyClassifier;
impl TopologyClassifier {
#[must_use]
pub fn classify(graph: &TaskGraph) -> Topology {
let tasks = &graph.tasks;
if tasks.is_empty() {
return Topology::AllParallel;
}
let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
if edge_count == 0 {
return Topology::AllParallel;
}
let (longest, depths) = compute_longest_path_and_depths(tasks);
Self::classify_with_depths(graph, longest, &depths)
}
#[must_use]
pub fn classify_with_depths(
graph: &TaskGraph,
longest_path: usize,
_depths: &HashMap<TaskId, usize>,
) -> Topology {
let tasks = &graph.tasks;
let n = tasks.len();
if n == 0 {
return Topology::AllParallel;
}
let edge_count: usize = tasks.iter().map(|t| t.depends_on.len()).sum();
if edge_count == 0 {
return Topology::AllParallel;
}
if edge_count == n - 1 && longest_path == n - 1 {
return Topology::LinearChain;
}
let roots_count = tasks.iter().filter(|t| t.depends_on.is_empty()).count();
if roots_count == 1 && longest_path == 1 {
return Topology::FanOut;
}
let non_roots_count = tasks.iter().filter(|t| !t.depends_on.is_empty()).count();
if roots_count >= 2 && non_roots_count == 1 && longest_path == 1 {
let sink_dep_count = tasks
.iter()
.filter(|t| !t.depends_on.is_empty())
.map(|t| t.depends_on.len())
.next()
.unwrap_or(0);
if sink_dep_count >= 2 {
return Topology::FanIn;
}
}
if roots_count == 1 && longest_path >= 2 {
let max_dep_count = tasks.iter().map(|t| t.depends_on.len()).max().unwrap_or(0);
if max_dep_count <= 1 {
return Topology::Hierarchical;
}
}
Topology::Mixed
}
#[must_use]
pub fn compute_max_parallel(topology: Topology, base: usize) -> usize {
match topology {
Topology::AllParallel | Topology::FanOut | Topology::FanIn | Topology::Hierarchical => {
base
}
Topology::LinearChain => 1,
Topology::Mixed => (base / 2 + 1).min(base).max(1),
}
}
#[must_use]
pub fn strategy(topology: Topology, config: &OrchestrationConfig) -> DispatchStrategy {
match topology {
Topology::FanOut | Topology::FanIn if config.tree_optimized_dispatch => {
DispatchStrategy::TreeOptimized
}
Topology::Mixed if config.cascade_routing => DispatchStrategy::CascadeAware,
Topology::AllParallel | Topology::FanOut | Topology::FanIn => {
DispatchStrategy::FullParallel
}
Topology::LinearChain => DispatchStrategy::Sequential,
Topology::Hierarchical => DispatchStrategy::LevelBarrier,
Topology::Mixed => DispatchStrategy::Adaptive,
}
}
#[must_use]
pub fn analyze(graph: &TaskGraph, config: &OrchestrationConfig) -> TopologyAnalysis {
let tasks = &graph.tasks;
let n = tasks.len();
if !config.topology_selection || n == 0 {
return TopologyAnalysis {
topology: Topology::AllParallel,
strategy: DispatchStrategy::FullParallel,
max_parallel: config.max_parallel as usize,
depth: 0,
depths: HashMap::new(),
};
}
let (longest, depths) = compute_longest_path_and_depths(tasks);
let topology = Self::classify_with_depths(graph, longest, &depths);
let strategy = Self::strategy(topology, config);
let base = config.max_parallel as usize;
let max_parallel = Self::compute_max_parallel(topology, base);
TopologyAnalysis {
topology,
strategy,
max_parallel,
depth: longest,
depths,
}
}
}
pub(crate) fn compute_depths_for_scheduler(
graph: &TaskGraph,
) -> (usize, std::collections::HashMap<TaskId, usize>) {
compute_longest_path_and_depths(&graph.tasks)
}
fn compute_longest_path_and_depths(tasks: &[TaskNode]) -> (usize, HashMap<TaskId, usize>) {
let n = tasks.len();
if n == 0 {
return (0, HashMap::new());
}
let mut in_degree = vec![0usize; n];
let mut dependents: Vec<Vec<usize>> = vec![Vec::new(); n];
for task in tasks {
let i = task.id.index();
in_degree[i] = task.depends_on.len();
for dep in &task.depends_on {
dependents[dep.index()].push(i);
}
}
let mut queue: std::collections::VecDeque<usize> = in_degree
.iter()
.enumerate()
.filter(|(_, d)| **d == 0)
.map(|(i, _)| i)
.collect();
let mut dist = vec![0usize; n];
let mut max_dist = 0usize;
while let Some(u) = queue.pop_front() {
for &v in &dependents[u] {
let new_dist = dist[u] + 1;
if new_dist > dist[v] {
dist[v] = new_dist;
}
if dist[v] > max_dist {
max_dist = dist[v];
}
in_degree[v] -= 1;
if in_degree[v] == 0 {
queue.push_back(v);
}
}
}
let depths: HashMap<TaskId, usize> = tasks.iter().map(|t| (t.id, dist[t.id.index()])).collect();
(max_dist, depths)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::{TaskGraph, TaskId, TaskNode};
fn make_node(id: u32, deps: &[u32]) -> TaskNode {
let mut n = TaskNode::new(id, format!("t{id}"), "desc");
n.depends_on = deps.iter().map(|&d| TaskId(d)).collect();
n
}
fn graph_from(nodes: Vec<TaskNode>) -> TaskGraph {
let mut g = TaskGraph::new("test");
g.tasks = nodes;
g
}
fn default_config() -> zeph_config::OrchestrationConfig {
zeph_config::OrchestrationConfig {
topology_selection: true,
max_parallel: 4,
..zeph_config::OrchestrationConfig::default()
}
}
#[test]
fn classify_empty_graph() {
let g = graph_from(vec![]);
assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
}
#[test]
fn classify_single_task() {
let g = graph_from(vec![make_node(0, &[])]);
assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
}
#[test]
fn classify_all_parallel() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::AllParallel);
}
#[test]
fn classify_two_task_chain() {
let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
}
#[test]
fn classify_linear_chain() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[1]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::LinearChain);
}
#[test]
fn classify_fan_out() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[0]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::FanOut);
}
#[test]
fn classify_fan_in() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[]),
make_node(3, &[0, 1, 2]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
}
#[test]
fn classify_fan_in_two_roots() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[0, 1]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::FanIn);
}
#[test]
fn classify_hierarchical() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1]),
make_node(4, &[2]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
}
#[test]
fn classify_hierarchical_three_levels() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
}
#[test]
fn classify_diamond_is_mixed() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1, 2]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::Mixed);
}
#[test]
fn classify_fan_out_with_chain_on_branch_is_hierarchical() {
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1]),
]);
assert_eq!(TopologyClassifier::classify(&g), Topology::Hierarchical);
}
fn no_overrides_config() -> zeph_config::OrchestrationConfig {
zeph_config::OrchestrationConfig {
topology_selection: true,
max_parallel: 4,
cascade_routing: false,
tree_optimized_dispatch: false,
..zeph_config::OrchestrationConfig::default()
}
}
#[test]
fn strategy_all_parallel_is_full_parallel() {
assert_eq!(
TopologyClassifier::strategy(Topology::AllParallel, &no_overrides_config()),
DispatchStrategy::FullParallel
);
}
#[test]
fn strategy_fan_out_is_full_parallel() {
assert_eq!(
TopologyClassifier::strategy(Topology::FanOut, &no_overrides_config()),
DispatchStrategy::FullParallel
);
}
#[test]
fn strategy_fan_in_is_full_parallel() {
assert_eq!(
TopologyClassifier::strategy(Topology::FanIn, &no_overrides_config()),
DispatchStrategy::FullParallel
);
}
#[test]
fn strategy_linear_chain_is_sequential() {
assert_eq!(
TopologyClassifier::strategy(Topology::LinearChain, &no_overrides_config()),
DispatchStrategy::Sequential
);
}
#[test]
fn strategy_hierarchical_is_level_barrier() {
assert_eq!(
TopologyClassifier::strategy(Topology::Hierarchical, &no_overrides_config()),
DispatchStrategy::LevelBarrier
);
}
#[test]
fn strategy_mixed_is_adaptive() {
assert_eq!(
TopologyClassifier::strategy(Topology::Mixed, &no_overrides_config()),
DispatchStrategy::Adaptive
);
}
#[test]
fn strategy_fan_out_tree_optimized_when_enabled() {
let mut cfg = no_overrides_config();
cfg.tree_optimized_dispatch = true;
assert_eq!(
TopologyClassifier::strategy(Topology::FanOut, &cfg),
DispatchStrategy::TreeOptimized
);
assert_eq!(
TopologyClassifier::strategy(Topology::FanIn, &cfg),
DispatchStrategy::TreeOptimized
);
}
#[test]
fn strategy_mixed_cascade_aware_when_enabled() {
let mut cfg = no_overrides_config();
cfg.cascade_routing = true;
assert_eq!(
TopologyClassifier::strategy(Topology::Mixed, &cfg),
DispatchStrategy::CascadeAware
);
}
#[test]
fn strategy_tree_optimized_does_not_affect_non_fan_topologies() {
let mut cfg = no_overrides_config();
cfg.tree_optimized_dispatch = true;
assert_eq!(
TopologyClassifier::strategy(Topology::Hierarchical, &cfg),
DispatchStrategy::LevelBarrier
);
assert_eq!(
TopologyClassifier::strategy(Topology::LinearChain, &cfg),
DispatchStrategy::Sequential
);
assert_eq!(
TopologyClassifier::strategy(Topology::Mixed, &cfg),
DispatchStrategy::Adaptive
);
}
#[test]
fn analyze_disabled_returns_full_parallel() {
let mut cfg = default_config();
cfg.topology_selection = false;
let g = graph_from(vec![make_node(0, &[]), make_node(1, &[0])]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
assert_eq!(analysis.max_parallel, 4);
assert_eq!(analysis.topology, Topology::AllParallel);
}
#[test]
fn analyze_linear_chain_returns_sequential() {
let cfg = default_config();
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[1]),
]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.topology, Topology::LinearChain);
assert_eq!(analysis.strategy, DispatchStrategy::Sequential);
assert_eq!(analysis.max_parallel, 1);
assert_eq!(analysis.depth, 2);
}
#[test]
fn analyze_hierarchical_returns_level_barrier() {
let cfg = default_config();
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1]),
]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.topology, Topology::Hierarchical);
assert_eq!(analysis.strategy, DispatchStrategy::LevelBarrier);
assert_eq!(analysis.max_parallel, 4);
assert_eq!(analysis.depth, 2);
assert_eq!(analysis.depths[&TaskId(0)], 0);
assert_eq!(analysis.depths[&TaskId(1)], 1);
assert_eq!(analysis.depths[&TaskId(2)], 1);
assert_eq!(analysis.depths[&TaskId(3)], 2);
}
#[test]
fn analyze_fan_in_returns_full_parallel() {
let cfg = default_config();
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[]),
make_node(3, &[0, 1, 2]),
]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.topology, Topology::FanIn);
assert_eq!(analysis.strategy, DispatchStrategy::FullParallel);
assert_eq!(analysis.max_parallel, 4);
}
#[test]
fn analyze_mixed_is_conservative() {
let cfg = default_config(); let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1, 2]),
]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.topology, Topology::Mixed);
assert_eq!(analysis.strategy, DispatchStrategy::Adaptive);
assert_eq!(analysis.max_parallel, 3);
}
#[test]
fn analyze_depths_correct_for_fan_out() {
let cfg = default_config();
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[0]),
]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.depths[&TaskId(0)], 0);
assert_eq!(analysis.depths[&TaskId(1)], 1);
assert_eq!(analysis.depths[&TaskId(2)], 1);
assert_eq!(analysis.depths[&TaskId(3)], 1);
}
#[test]
fn analyze_mixed_respects_max_parallel_one() {
let mut cfg = default_config();
cfg.max_parallel = 1;
let g = graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1, 2]),
]);
let analysis = TopologyClassifier::analyze(&g, &cfg);
assert_eq!(analysis.max_parallel, 1);
}
#[test]
fn classify_with_depths_matches_classify_for_all_variants() {
let graphs = vec![
graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[]),
]),
graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[1]),
]),
graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[0]),
]),
graph_from(vec![
make_node(0, &[]),
make_node(1, &[]),
make_node(2, &[]),
make_node(3, &[0, 1, 2]),
]),
graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1]),
]),
graph_from(vec![
make_node(0, &[]),
make_node(1, &[0]),
make_node(2, &[0]),
make_node(3, &[1, 2]),
]),
];
for g in &graphs {
let expected = TopologyClassifier::classify(g);
let tasks = &g.tasks;
let (longest, depths) = if tasks.is_empty() {
(0, std::collections::HashMap::new())
} else {
let cfg = default_config();
let analysis = TopologyClassifier::analyze(g, &cfg);
(analysis.depth, analysis.depths)
};
let actual = TopologyClassifier::classify_with_depths(g, longest, &depths);
assert_eq!(
actual,
expected,
"classify_with_depths mismatch for graph with {} tasks",
g.tasks.len()
);
}
}
#[test]
fn compute_max_parallel_all_parallel_returns_base() {
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::AllParallel, 8),
8
);
}
#[test]
fn compute_max_parallel_fan_out_returns_base() {
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::FanOut, 6),
6
);
}
#[test]
fn compute_max_parallel_fan_in_returns_base() {
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::FanIn, 4),
4
);
}
#[test]
fn compute_max_parallel_hierarchical_returns_base() {
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::Hierarchical, 10),
10
);
}
#[test]
fn compute_max_parallel_linear_chain_returns_one() {
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::LinearChain, 8),
1
);
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::LinearChain, 1),
1
);
}
#[test]
fn compute_max_parallel_mixed_is_half_plus_one() {
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::Mixed, 4),
3
);
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::Mixed, 2),
2
);
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::Mixed, 1),
1
);
assert_eq!(
TopologyClassifier::compute_max_parallel(Topology::Mixed, 8),
5
);
}
}