use crate::types::Rank;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, PartialEq, Default)]
pub enum TopologyStrategy {
#[default]
FullMesh,
KRegular { degree: usize },
Hypercube,
}
#[derive(Debug, Clone)]
pub struct RoutingTable {
pub neighbors: HashSet<Rank>,
pub next_hop: HashMap<Rank, Rank>,
}
impl RoutingTable {
pub fn route(&self, dest: Rank) -> Option<Rank> {
if self.neighbors.contains(&dest) {
Some(dest)
} else {
self.next_hop.get(&dest).copied()
}
}
pub fn is_neighbor(&self, dest: Rank) -> bool {
self.neighbors.contains(&dest)
}
}
pub struct SpanningTree {
pub parent: HashMap<Rank, Rank>,
pub children: HashMap<Rank, Vec<Rank>>,
}
pub fn build_neighbors(strategy: &TopologyStrategy, rank: Rank, world_size: u32) -> HashSet<Rank> {
match strategy {
TopologyStrategy::FullMesh => (0..world_size).filter(|&r| r != rank).collect(),
TopologyStrategy::KRegular { degree } => {
let half = (*degree / 2) as u32;
let mut neighbors = HashSet::new();
for d in 1..=half {
neighbors.insert((rank + d) % world_size);
neighbors.insert((rank + world_size - d) % world_size);
}
neighbors
}
TopologyStrategy::Hypercube => {
assert!(
world_size.is_power_of_two(),
"Hypercube requires power-of-2 world_size, got {world_size}"
);
let bits = world_size.trailing_zeros();
(0..bits).map(|b| rank ^ (1 << b)).collect()
}
}
}
fn ring_distance(from: Rank, to: Rank, world_size: u32) -> u32 {
let cw = (to + world_size - from) % world_size;
let ccw = (from + world_size - to) % world_size;
cw.min(ccw)
}
pub fn build_routing_table(
strategy: &TopologyStrategy,
rank: Rank,
world_size: u32,
) -> RoutingTable {
let neighbors = build_neighbors(strategy, rank, world_size);
if matches!(strategy, TopologyStrategy::FullMesh) {
return RoutingTable {
neighbors,
next_hop: HashMap::new(),
};
}
let mut next_hop = HashMap::new();
for dest in 0..world_size {
if dest == rank || neighbors.contains(&dest) {
continue;
}
let hop = match strategy {
TopologyStrategy::FullMesh => unreachable!(),
TopologyStrategy::KRegular { .. } => {
*neighbors
.iter()
.min_by_key(|&&n| ring_distance(n, dest, world_size))
.expect("KRegular always has neighbors")
}
TopologyStrategy::Hypercube => {
let diff = rank ^ dest;
let highest_bit = 31 - diff.leading_zeros();
rank ^ (1 << highest_bit)
}
};
next_hop.insert(dest, hop);
}
RoutingTable {
neighbors,
next_hop,
}
}
pub fn optimal_ring_order(world_size: u32) -> Vec<Rank> {
(0..world_size).collect()
}
pub fn build_spanning_tree(
strategy: &TopologyStrategy,
root: Rank,
world_size: u32,
) -> SpanningTree {
let mut parent = HashMap::new();
let mut children: HashMap<Rank, Vec<Rank>> = HashMap::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
visited.insert(root);
queue.push_back(root);
children.insert(root, Vec::new());
while let Some(node) = queue.pop_front() {
let node_neighbors = build_neighbors(strategy, node, world_size);
let mut node_children: Vec<Rank> = node_neighbors
.into_iter()
.filter(|n| visited.insert(*n))
.collect();
node_children.sort();
for &child in &node_children {
parent.insert(child, node);
children.insert(child, Vec::new());
queue.push_back(child);
}
children
.get_mut(&node)
.expect("BFS tree invariant: all visited nodes have children entry")
.extend(node_children);
}
SpanningTree { parent, children }
}
pub fn find_alternative_hop(
strategy: &TopologyStrategy,
rank: Rank,
dest: Rank,
failed_hop: Rank,
world_size: u32,
) -> Option<Rank> {
let neighbors = build_neighbors(strategy, rank, world_size);
let candidates: Vec<Rank> = neighbors.into_iter().filter(|&n| n != failed_hop).collect();
if candidates.is_empty() {
return None;
}
match strategy {
TopologyStrategy::FullMesh => {
unreachable!(
"FullMesh has no routing table, so relay/alternative-hop logic is never invoked"
)
}
TopologyStrategy::KRegular { .. } => {
candidates
.into_iter()
.min_by_key(|&n| ring_distance(n, dest, world_size))
}
TopologyStrategy::Hypercube => {
candidates
.into_iter()
.min_by_key(|&n| (n ^ dest).count_ones())
}
}
}
pub fn recompute_routing_table(
topology: &TopologyStrategy,
rank: Rank,
world_size: u32,
) -> Option<std::sync::Arc<RoutingTable>> {
if matches!(topology, TopologyStrategy::FullMesh) {
None
} else {
Some(std::sync::Arc::new(build_routing_table(
topology, rank, world_size,
)))
}
}
pub fn parse_topology(s: &str) -> Option<TopologyStrategy> {
let s = s.trim().to_lowercase();
if s == "full_mesh" {
Some(TopologyStrategy::FullMesh)
} else if s == "hypercube" {
Some(TopologyStrategy::Hypercube)
} else if let Some(rest) = s.strip_prefix("k_regular:") {
rest.parse::<usize>().ok().map(|degree| {
assert!(
degree >= 2 && degree % 2 == 0,
"KRegular degree must be even and >= 2"
);
TopologyStrategy::KRegular { degree }
})
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_full_mesh_neighbors() {
let n = build_neighbors(&TopologyStrategy::FullMesh, 0, 4);
assert_eq!(n, HashSet::from([1, 2, 3]));
}
#[test]
fn test_kregular_neighbors() {
let n = build_neighbors(&TopologyStrategy::KRegular { degree: 4 }, 0, 8);
assert_eq!(n, HashSet::from([1, 2, 6, 7]));
}
#[test]
fn test_kregular_neighbors_wrap() {
let n = build_neighbors(&TopologyStrategy::KRegular { degree: 4 }, 7, 8);
assert_eq!(n, HashSet::from([5, 6, 0, 1]));
}
#[test]
fn test_hypercube_neighbors() {
let n = build_neighbors(&TopologyStrategy::Hypercube, 0, 8);
assert_eq!(n, HashSet::from([1, 2, 4]));
let n = build_neighbors(&TopologyStrategy::Hypercube, 5, 8);
assert_eq!(n, HashSet::from([4, 7, 1]));
}
#[test]
fn test_kregular_routing() {
let rt = build_routing_table(&TopologyStrategy::KRegular { degree: 4 }, 0, 16);
assert_eq!(rt.neighbors, HashSet::from([14, 15, 1, 2]));
let hop = rt.route(5).unwrap();
assert!(rt.neighbors.contains(&hop));
}
#[test]
fn test_hypercube_routing() {
let rt = build_routing_table(&TopologyStrategy::Hypercube, 0, 8);
assert_eq!(rt.route(7), Some(4));
assert_eq!(rt.route(3), Some(2));
}
#[test]
fn test_full_mesh_no_next_hop() {
let rt = build_routing_table(&TopologyStrategy::FullMesh, 0, 4);
assert!(rt.next_hop.is_empty());
assert_eq!(rt.route(3), Some(3));
}
#[test]
fn test_spanning_tree() {
let tree = build_spanning_tree(&TopologyStrategy::KRegular { degree: 4 }, 0, 8);
assert_eq!(tree.children.len(), 8);
assert!(!tree.parent.contains_key(&0));
for r in 1..8 {
assert!(tree.parent.contains_key(&r));
}
}
#[test]
fn test_parse_topology() {
assert_eq!(
parse_topology("full_mesh"),
Some(TopologyStrategy::FullMesh)
);
assert_eq!(
parse_topology("hypercube"),
Some(TopologyStrategy::Hypercube)
);
assert_eq!(
parse_topology("k_regular:8"),
Some(TopologyStrategy::KRegular { degree: 8 })
);
assert_eq!(parse_topology("invalid"), None);
}
#[test]
fn test_find_alternative_hop() {
let alt = find_alternative_hop(&TopologyStrategy::KRegular { degree: 4 }, 0, 5, 2, 16);
assert!(alt.is_some());
let hop = alt.unwrap();
assert_ne!(hop, 2);
let neighbors = build_neighbors(&TopologyStrategy::KRegular { degree: 4 }, 0, 16);
assert!(
neighbors.contains(&hop),
"hop {hop} not in neighbors {neighbors:?}"
);
let alt = find_alternative_hop(&TopologyStrategy::Hypercube, 0, 7, 4, 8);
assert!(alt.is_some());
let hop = alt.unwrap();
assert_ne!(hop, 4);
let neighbors = build_neighbors(&TopologyStrategy::Hypercube, 0, 8);
assert!(
neighbors.contains(&hop),
"hop {hop} not in neighbors {neighbors:?}"
);
}
#[test]
fn test_find_alternative_hop_degree_2() {
let alt = find_alternative_hop(&TopologyStrategy::KRegular { degree: 2 }, 0, 3, 1, 8);
assert_eq!(alt, Some(7));
}
#[test]
fn test_find_alternative_hop_no_candidate() {
let alt = find_alternative_hop(&TopologyStrategy::KRegular { degree: 2 }, 0, 5, 7, 8);
assert_eq!(alt, Some(1));
}
#[test]
fn test_ring_distance() {
assert_eq!(ring_distance(0, 3, 8), 3);
assert_eq!(ring_distance(0, 5, 8), 3); assert_eq!(ring_distance(7, 0, 8), 1);
}
}