use anyhow::Result;
use std::collections::HashMap;
use trustformers_core::parallel::CommunicationBackend;
use trustformers_core::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AggregationStrategy {
BinaryTree,
Ring,
Butterfly,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct HierarchicalConfig {
pub num_nodes: usize,
pub devices_per_node: usize,
pub node_rank: usize,
pub local_rank: usize,
pub global_rank: usize,
pub strategy: AggregationStrategy,
pub comm_backend: CommunicationBackend,
pub enable_compression: bool,
pub compression_threshold: f32,
pub enable_fault_tolerance: bool,
pub comm_timeout_ms: u64,
}
impl Default for HierarchicalConfig {
fn default() -> Self {
Self {
num_nodes: 1,
devices_per_node: 1,
node_rank: 0,
local_rank: 0,
global_rank: 0,
strategy: AggregationStrategy::Adaptive,
comm_backend: CommunicationBackend::Mpi,
enable_compression: true,
compression_threshold: 0.1,
enable_fault_tolerance: true,
comm_timeout_ms: 30000,
}
}
}
impl HierarchicalConfig {
pub fn new(
num_nodes: usize,
devices_per_node: usize,
node_rank: usize,
local_rank: usize,
) -> Self {
let global_rank = node_rank * devices_per_node + local_rank;
Self {
num_nodes,
devices_per_node,
node_rank,
local_rank,
global_rank,
..Default::default()
}
}
pub fn world_size(&self) -> usize {
self.num_nodes * self.devices_per_node
}
pub fn is_master(&self) -> bool {
self.global_rank == 0
}
pub fn is_node_master(&self) -> bool {
self.local_rank == 0
}
}
pub struct HierarchicalAggregator {
config: HierarchicalConfig,
#[allow(dead_code)]
node_topology: NodeTopology,
communication_groups: CommunicationGroups,
aggregation_stats: AggregationStats,
#[allow(dead_code)]
fault_detector: Option<FaultDetector>,
}
#[derive(Debug, Clone)]
pub struct NodeTopology {
pub node_adjacency: Vec<Vec<bool>>,
pub node_bandwidth: Vec<Vec<f32>>,
pub node_latency: Vec<Vec<f32>>,
pub intra_node_bandwidth: f32,
pub intra_node_latency: f32,
}
#[derive(Debug, Clone)]
pub struct CommunicationGroups {
pub node_local_group: Vec<usize>,
pub cross_node_group: Vec<usize>,
pub tree_structure: TreeStructure,
pub ring_structure: RingStructure,
pub butterfly_structure: ButterflyStructure,
}
#[derive(Debug, Clone)]
pub struct TreeStructure {
pub parent: Option<usize>,
pub children: Vec<usize>,
pub depth: usize,
pub height: usize,
}
#[derive(Debug, Clone)]
pub struct RingStructure {
pub next_rank: usize,
pub prev_rank: usize,
pub ring_size: usize,
}
#[derive(Debug, Clone)]
pub struct ButterflyStructure {
pub connections: Vec<Vec<usize>>,
pub num_stages: usize,
}
#[derive(Debug, Clone)]
pub struct AggregationStats {
pub total_operations: usize,
pub avg_aggregation_time: f32,
pub total_bytes_transferred: usize,
pub compression_ratio: f32,
pub failed_operations: usize,
pub strategy_history: HashMap<AggregationStrategy, usize>,
}
#[derive(Debug)]
pub struct FaultDetector {
pub failed_nodes: Vec<usize>,
pub timeout_threshold: u64,
pub recovery_strategy: RecoveryStrategy,
}
#[derive(Debug, Clone)]
pub enum RecoveryStrategy {
Skip,
Retry,
Abort,
}
impl Default for AggregationStats {
fn default() -> Self {
Self {
total_operations: 0,
avg_aggregation_time: 0.0,
total_bytes_transferred: 0,
compression_ratio: 1.0,
failed_operations: 0,
strategy_history: HashMap::new(),
}
}
}
impl HierarchicalAggregator {
pub fn new(config: HierarchicalConfig) -> Result<Self> {
let node_topology = Self::detect_network_topology(&config)?;
let communication_groups = Self::build_communication_groups(&config, &node_topology)?;
let aggregation_stats = AggregationStats::default();
let fault_detector = if config.enable_fault_tolerance {
Some(FaultDetector {
failed_nodes: Vec::new(),
timeout_threshold: config.comm_timeout_ms,
recovery_strategy: RecoveryStrategy::Skip,
})
} else {
None
};
Ok(Self {
config,
node_topology,
communication_groups,
aggregation_stats,
fault_detector,
})
}
fn detect_network_topology(config: &HierarchicalConfig) -> Result<NodeTopology> {
let num_nodes = config.num_nodes;
let mut node_adjacency = vec![vec![false; num_nodes]; num_nodes];
let mut node_bandwidth = vec![vec![0.0; num_nodes]; num_nodes];
let mut node_latency = vec![vec![0.0; num_nodes]; num_nodes];
for i in 0..num_nodes {
for j in 0..num_nodes {
if i != j {
node_adjacency[i][j] = true;
node_bandwidth[i][j] = if (i as i32 - j as i32).abs() == 1 {
10000.0 } else {
1000.0 };
node_latency[i][j] = if (i as i32 - j as i32).abs() == 1 {
0.1 } else {
1.0 };
} else {
node_adjacency[i][j] = false;
node_bandwidth[i][j] = f32::INFINITY;
node_latency[i][j] = 0.0;
}
}
}
Ok(NodeTopology {
node_adjacency,
node_bandwidth,
node_latency,
intra_node_bandwidth: 80000.0, intra_node_latency: 0.01, })
}
fn build_communication_groups(
config: &HierarchicalConfig,
topology: &NodeTopology,
) -> Result<CommunicationGroups> {
let node_local_group: Vec<usize> = (0..config.devices_per_node)
.map(|i| config.node_rank * config.devices_per_node + i)
.collect();
let cross_node_group: Vec<usize> =
(0..config.num_nodes).map(|i| i * config.devices_per_node).collect();
let tree_structure = Self::build_tree_structure(config, topology)?;
let ring_structure = Self::build_ring_structure(config)?;
let butterfly_structure = Self::build_butterfly_structure(config)?;
Ok(CommunicationGroups {
node_local_group,
cross_node_group,
tree_structure,
ring_structure,
butterfly_structure,
})
}
fn build_tree_structure(
config: &HierarchicalConfig,
_topology: &NodeTopology,
) -> Result<TreeStructure> {
let world_size = config.world_size();
let rank = config.global_rank;
let parent = if rank == 0 { None } else { Some((rank - 1) / 2) };
let mut children = Vec::new();
let left_child = 2 * rank + 1;
let right_child = 2 * rank + 2;
if left_child < world_size {
children.push(left_child);
}
if right_child < world_size {
children.push(right_child);
}
let depth = (rank as f32).log2().floor() as usize;
let height = (world_size as f32).log2().ceil() as usize;
Ok(TreeStructure {
parent,
children,
depth,
height,
})
}
fn build_ring_structure(config: &HierarchicalConfig) -> Result<RingStructure> {
let world_size = config.world_size();
let rank = config.global_rank;
let next_rank = (rank + 1) % world_size;
let prev_rank = (rank + world_size - 1) % world_size;
Ok(RingStructure {
next_rank,
prev_rank,
ring_size: world_size,
})
}
fn build_butterfly_structure(config: &HierarchicalConfig) -> Result<ButterflyStructure> {
let world_size = config.world_size();
let rank = config.global_rank;
let num_stages = (world_size as f32).log2().ceil() as usize;
let mut connections = Vec::new();
for stage in 0..num_stages {
let mut stage_connections = Vec::new();
let distance = 1 << stage;
let partner = rank ^ distance;
if partner < world_size {
stage_connections.push(partner);
}
connections.push(stage_connections);
}
Ok(ButterflyStructure {
connections,
num_stages,
})
}
pub fn hierarchical_all_reduce(
&mut self,
gradients: &mut HashMap<String, Tensor>,
) -> Result<()> {
let start_time = std::time::Instant::now();
let strategy = self.select_optimal_strategy(gradients)?;
match strategy {
AggregationStrategy::BinaryTree => {
self.tree_based_all_reduce(gradients)?;
},
AggregationStrategy::Ring => {
self.ring_based_all_reduce(gradients)?;
},
AggregationStrategy::Butterfly => {
self.butterfly_based_all_reduce(gradients)?;
},
AggregationStrategy::Adaptive => {
let optimal_strategy = self.adaptive_strategy_selection(gradients)?;
match optimal_strategy {
AggregationStrategy::BinaryTree => self.tree_based_all_reduce(gradients)?,
AggregationStrategy::Ring => self.ring_based_all_reduce(gradients)?,
AggregationStrategy::Butterfly => self.butterfly_based_all_reduce(gradients)?,
AggregationStrategy::Adaptive => {
return Err(anyhow::anyhow!(
"Invalid adaptive strategy selection: recursive Adaptive strategy returned"
));
},
}
},
}
let elapsed = start_time.elapsed().as_millis() as f32;
self.update_aggregation_stats(strategy, elapsed, gradients)?;
Ok(())
}
fn select_optimal_strategy(
&self,
gradients: &HashMap<String, Tensor>,
) -> Result<AggregationStrategy> {
match self.config.strategy {
AggregationStrategy::Adaptive => self.adaptive_strategy_selection(gradients),
strategy => Ok(strategy),
}
}
fn adaptive_strategy_selection(
&self,
gradients: &HashMap<String, Tensor>,
) -> Result<AggregationStrategy> {
let world_size = self.config.world_size();
let num_nodes = self.config.num_nodes;
let total_data_size: usize = gradients.values().map(|tensor| tensor.memory_usage()).sum();
if world_size <= 8 {
Ok(AggregationStrategy::BinaryTree)
} else if total_data_size > 100 * 1024 * 1024 {
Ok(AggregationStrategy::Ring)
} else if num_nodes > 16 {
Ok(AggregationStrategy::Butterfly)
} else {
Ok(AggregationStrategy::BinaryTree)
}
}
fn tree_based_all_reduce(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
let tree = self.communication_groups.tree_structure.clone();
self.tree_reduce_up(gradients, &tree)?;
self.tree_broadcast_down(gradients, &tree)?;
Ok(())
}
fn ring_based_all_reduce(&mut self, gradients: &mut HashMap<String, Tensor>) -> Result<()> {
let ring = self.communication_groups.ring_structure.clone();
self.ring_reduce_scatter(gradients, &ring)?;
self.ring_all_gather(gradients, &ring)?;
Ok(())
}
fn butterfly_based_all_reduce(
&mut self,
gradients: &mut HashMap<String, Tensor>,
) -> Result<()> {
let butterfly = self.communication_groups.butterfly_structure.clone();
for stage in 0..butterfly.num_stages {
self.butterfly_stage_operation(gradients, &butterfly, stage)?;
}
Ok(())
}
fn tree_reduce_up(
&mut self,
gradients: &mut HashMap<String, Tensor>,
tree: &TreeStructure,
) -> Result<()> {
for &child_rank in &tree.children {
for (name, gradient) in gradients.iter_mut() {
let child_gradient = self.simulate_receive_gradient(child_rank, name)?;
*gradient = gradient.add(&child_gradient)?;
}
}
if let Some(parent_rank) = tree.parent {
for (name, gradient) in gradients.iter() {
self.simulate_send_gradient(parent_rank, name, gradient)?;
}
}
Ok(())
}
fn tree_broadcast_down(
&mut self,
gradients: &mut HashMap<String, Tensor>,
tree: &TreeStructure,
) -> Result<()> {
if let Some(parent_rank) = tree.parent {
for (name, gradient) in gradients.iter_mut() {
*gradient = self.simulate_receive_gradient(parent_rank, name)?;
}
}
for &child_rank in &tree.children {
for (name, gradient) in gradients.iter() {
self.simulate_send_gradient(child_rank, name, gradient)?;
}
}
Ok(())
}
fn ring_reduce_scatter(
&mut self,
gradients: &mut HashMap<String, Tensor>,
ring: &RingStructure,
) -> Result<()> {
let num_chunks = ring.ring_size;
let rank = self.config.global_rank;
for step in 0..num_chunks - 1 {
let _send_chunk = (rank + ring.ring_size - step) % ring.ring_size;
let _recv_chunk = (rank + ring.ring_size - step - 1) % ring.ring_size;
for (name, gradient) in gradients.iter_mut() {
let chunk_gradient = self.simulate_receive_gradient(ring.prev_rank, name)?;
*gradient = gradient.add(&chunk_gradient)?;
self.simulate_send_gradient(ring.next_rank, name, gradient)?;
}
}
Ok(())
}
fn ring_all_gather(
&mut self,
gradients: &mut HashMap<String, Tensor>,
ring: &RingStructure,
) -> Result<()> {
let num_chunks = ring.ring_size;
for _step in 0..num_chunks - 1 {
for (name, gradient) in gradients.iter_mut() {
let chunk_gradient = self.simulate_receive_gradient(ring.prev_rank, name)?;
*gradient = gradient.add(&chunk_gradient)?;
self.simulate_send_gradient(ring.next_rank, name, gradient)?;
}
}
Ok(())
}
fn butterfly_stage_operation(
&mut self,
gradients: &mut HashMap<String, Tensor>,
butterfly: &ButterflyStructure,
stage: usize,
) -> Result<()> {
if stage < butterfly.connections.len() {
for &partner_rank in &butterfly.connections[stage] {
for (name, gradient) in gradients.iter_mut() {
let partner_gradient = self.simulate_receive_gradient(partner_rank, name)?;
*gradient = gradient.add(&partner_gradient)?;
self.simulate_send_gradient(partner_rank, name, gradient)?;
}
}
}
Ok(())
}
fn simulate_receive_gradient(&self, _from_rank: usize, _name: &str) -> Result<Tensor> {
Ok(Tensor::zeros(&[1])?)
}
fn simulate_send_gradient(
&self,
_to_rank: usize,
_name: &str,
_gradient: &Tensor,
) -> Result<()> {
Ok(())
}
fn update_aggregation_stats(
&mut self,
strategy: AggregationStrategy,
elapsed_ms: f32,
gradients: &HashMap<String, Tensor>,
) -> Result<()> {
let stats = &mut self.aggregation_stats;
stats.total_operations += 1;
stats.avg_aggregation_time =
(stats.avg_aggregation_time * (stats.total_operations - 1) as f32 + elapsed_ms)
/ stats.total_operations as f32;
let bytes_transferred: usize = gradients.values().map(|tensor| tensor.memory_usage()).sum();
stats.total_bytes_transferred += bytes_transferred;
*stats.strategy_history.entry(strategy).or_insert(0) += 1;
Ok(())
}
pub fn get_stats(&self) -> &AggregationStats {
&self.aggregation_stats
}
pub fn reset_stats(&mut self) {
self.aggregation_stats = AggregationStats::default();
}
pub fn get_recommended_strategy(&self) -> AggregationStrategy {
let world_size = self.config.world_size();
let num_nodes = self.config.num_nodes;
if world_size <= 8 {
AggregationStrategy::BinaryTree
} else if num_nodes > 16 {
AggregationStrategy::Butterfly
} else {
AggregationStrategy::Ring
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hierarchical_config() {
let config = HierarchicalConfig::new(4, 8, 2, 3);
assert_eq!(config.num_nodes, 4);
assert_eq!(config.devices_per_node, 8);
assert_eq!(config.node_rank, 2);
assert_eq!(config.local_rank, 3);
assert_eq!(config.global_rank, 19);
assert_eq!(config.world_size(), 32);
assert!(!config.is_master());
assert!(!config.is_node_master());
}
#[test]
fn test_tree_structure_building() {
let config = HierarchicalConfig::new(2, 4, 0, 0);
let topology = HierarchicalAggregator::detect_network_topology(&config)
.expect("Operation failed in test");
let tree = HierarchicalAggregator::build_tree_structure(&config, &topology)
.expect("Operation failed in test");
assert_eq!(tree.parent, None); assert_eq!(tree.children, vec![1, 2]);
assert_eq!(tree.depth, 0);
}
#[test]
fn test_ring_structure_building() {
let config = HierarchicalConfig::new(2, 4, 0, 1);
let ring = HierarchicalAggregator::build_ring_structure(&config)
.expect("Operation failed in test");
assert_eq!(ring.next_rank, 2);
assert_eq!(ring.prev_rank, 0);
assert_eq!(ring.ring_size, 8);
}
#[test]
fn test_adaptive_strategy_selection() {
let config = HierarchicalConfig::new(4, 4, 0, 0);
let aggregator = HierarchicalAggregator::new(config).expect("Construction failed");
let mut gradients = HashMap::new();
gradients.insert(
"param1".to_string(),
Tensor::zeros(&[8000, 8000]).expect("Failed to create tensor"),
);
let strategy = aggregator
.adaptive_strategy_selection(&gradients)
.expect("Operation failed in test");
assert!(matches!(strategy, AggregationStrategy::Ring));
}
#[test]
fn test_aggregation_stats_update() {
let config = HierarchicalConfig::new(2, 2, 0, 0);
let mut aggregator = HierarchicalAggregator::new(config).expect("Construction failed");
let mut gradients = HashMap::new();
gradients.insert(
"param1".to_string(),
Tensor::zeros(&[10, 10]).expect("Failed to create tensor"),
);
aggregator
.update_aggregation_stats(AggregationStrategy::BinaryTree, 100.0, &gradients)
.expect("Operation failed in test");
let stats = aggregator.get_stats();
assert_eq!(stats.total_operations, 1);
assert_eq!(stats.avg_aggregation_time, 100.0);
assert_eq!(
stats.strategy_history.get(&AggregationStrategy::BinaryTree),
Some(&1)
);
}
#[test]
fn test_recommended_strategy() {
let small_config = HierarchicalConfig::new(2, 2, 0, 0);
let small_aggregator =
HierarchicalAggregator::new(small_config).expect("Construction failed");
assert!(matches!(
small_aggregator.get_recommended_strategy(),
AggregationStrategy::BinaryTree
));
let large_config = HierarchicalConfig::new(20, 1, 0, 0);
let large_aggregator =
HierarchicalAggregator::new(large_config).expect("Construction failed");
assert!(matches!(
large_aggregator.get_recommended_strategy(),
AggregationStrategy::Butterfly
));
}
#[test]
fn test_butterfly_structure() {
let config = HierarchicalConfig::new(1, 8, 0, 0);
let butterfly = HierarchicalAggregator::build_butterfly_structure(&config)
.expect("Operation failed in test");
assert_eq!(butterfly.num_stages, 3); assert_eq!(butterfly.connections.len(), 3);
}
#[test]
fn test_network_topology_detection() {
let config = HierarchicalConfig::new(3, 2, 0, 0);
let topology = HierarchicalAggregator::detect_network_topology(&config)
.expect("Operation failed in test");
assert_eq!(topology.node_adjacency.len(), 3);
assert_eq!(topology.node_bandwidth.len(), 3);
assert_eq!(topology.node_latency.len(), 3);
assert!(topology.intra_node_bandwidth > 0.0);
assert!(topology.intra_node_latency > 0.0);
}
}