use crate::backend::ReduceOp;
use crate::{ProcessGroup, TorshDistributedError, TorshResult};
use log::info;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use torsh_core::dtype::FloatElement;
use torsh_tensor::Tensor;
#[derive(Debug)]
pub struct CudaStream {
stream_id: u64,
device_id: i32,
is_default: bool,
pending_operations: AtomicU32,
bandwidth_usage: AtomicU32,
num_dependencies: AtomicU32,
}
impl CudaStream {
pub fn new(device_id: i32) -> Self {
static STREAM_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
Self {
stream_id: STREAM_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
device_id,
is_default: false,
pending_operations: AtomicU32::new(0),
bandwidth_usage: AtomicU32::new(0),
num_dependencies: AtomicU32::new(0),
}
}
pub fn default(device_id: i32) -> Self {
Self {
stream_id: 0,
device_id,
is_default: true,
pending_operations: AtomicU32::new(0),
bandwidth_usage: AtomicU32::new(0),
num_dependencies: AtomicU32::new(0),
}
}
pub fn stream_id(&self) -> u64 {
self.stream_id
}
pub fn device_id(&self) -> i32 {
self.device_id
}
pub async fn synchronize(&self) -> TorshResult<()> {
let pending_ops = self.pending_operations.load(Ordering::Relaxed);
if !self.is_default && pending_ops > 0 {
let sync_time_us = (pending_ops as u64 * 5).max(10);
tokio::time::sleep(tokio::time::Duration::from_micros(sync_time_us)).await;
self.pending_operations.store(0, Ordering::Relaxed);
}
Ok(())
}
pub fn record_event(&self) -> CudaEvent {
CudaEvent::new(self.stream_id, self.device_id)
}
pub fn add_pending_operation(&self) {
self.pending_operations.fetch_add(1, Ordering::Relaxed);
}
pub fn complete_operation(&self) {
self.pending_operations.fetch_sub(1, Ordering::Relaxed);
}
pub fn set_bandwidth_usage(&self, usage: u32) {
self.bandwidth_usage
.store(usage.min(100), Ordering::Relaxed);
}
pub fn add_dependency(&self) {
self.num_dependencies.fetch_add(1, Ordering::Relaxed);
}
pub fn remove_dependency(&self) {
self.num_dependencies.fetch_sub(1, Ordering::Relaxed);
}
pub fn get_load_metrics(&self) -> (u32, u32, u32) {
(
self.pending_operations.load(Ordering::Relaxed),
self.bandwidth_usage.load(Ordering::Relaxed),
self.num_dependencies.load(Ordering::Relaxed),
)
}
}
#[derive(Debug, Clone)]
pub struct CudaEvent {
event_id: u64,
stream_id: u64,
#[allow(dead_code)]
device_id: i32,
recorded_at: Instant,
}
impl CudaEvent {
fn new(stream_id: u64, device_id: i32) -> Self {
static EVENT_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
Self {
event_id: EVENT_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
stream_id,
device_id,
recorded_at: Instant::now(),
}
}
pub async fn wait_on_stream(&self, stream: &CudaStream) -> TorshResult<()> {
if stream.stream_id != self.stream_id {
let elapsed = self.elapsed();
let remaining_time = if elapsed < Duration::from_micros(100) {
Duration::from_micros(100) - elapsed
} else {
Duration::from_micros(1) };
tokio::time::sleep(remaining_time).await;
stream.add_dependency();
}
Ok(())
}
pub fn elapsed(&self) -> Duration {
self.recorded_at.elapsed()
}
}
pub struct GpuMemoryPool {
#[allow(dead_code)]
device_id: i32,
free_blocks: RwLock<HashMap<usize, VecDeque<GpuMemoryBlock>>>,
allocated_blocks: RwLock<HashMap<u64, GpuMemoryBlock>>,
total_allocated: std::sync::atomic::AtomicU64,
peak_allocated: std::sync::atomic::AtomicU64,
}
#[derive(Debug, Clone)]
pub struct GpuMemoryBlock {
block_id: u64,
size: usize,
#[allow(dead_code)]
device_id: i32,
#[allow(dead_code)]
allocated_at: Instant,
}
impl GpuMemoryPool {
pub fn new(device_id: i32) -> Self {
Self {
device_id,
free_blocks: RwLock::new(HashMap::new()),
allocated_blocks: RwLock::new(HashMap::new()),
total_allocated: std::sync::atomic::AtomicU64::new(0),
peak_allocated: std::sync::atomic::AtomicU64::new(0),
}
}
pub fn allocate(&self, size: usize) -> TorshResult<GpuMemoryBlock> {
{
let mut free_blocks = self
.free_blocks
.write()
.expect("lock should not be poisoned");
if let Some(blocks) = free_blocks.get_mut(&size) {
if let Some(block) = blocks.pop_front() {
let mut allocated_blocks = self
.allocated_blocks
.write()
.expect("lock should not be poisoned");
allocated_blocks.insert(block.block_id, block.clone());
return Ok(block);
}
}
}
static BLOCK_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let block = GpuMemoryBlock {
block_id: BLOCK_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
size,
device_id: self.device_id,
allocated_at: Instant::now(),
};
if size > 8 * 1024 * 1024 * 1024 {
return Err(TorshDistributedError::backend_error(
"gpu_memory",
format!("Allocation size {} exceeds available GPU memory", size),
));
}
let mut allocated_blocks = self
.allocated_blocks
.write()
.expect("lock should not be poisoned");
allocated_blocks.insert(block.block_id, block.clone());
let new_total = self
.total_allocated
.fetch_add(size as u64, std::sync::atomic::Ordering::SeqCst)
+ size as u64;
let _ = self
.peak_allocated
.fetch_max(new_total, std::sync::atomic::Ordering::SeqCst);
Ok(block)
}
pub fn deallocate(&self, block: GpuMemoryBlock) -> TorshResult<()> {
{
let mut allocated_blocks = self
.allocated_blocks
.write()
.expect("lock should not be poisoned");
allocated_blocks.remove(&block.block_id);
}
let block_size = block.size;
{
let mut free_blocks = self
.free_blocks
.write()
.expect("lock should not be poisoned");
free_blocks.entry(block.size).or_default().push_back(block);
}
self.total_allocated
.fetch_sub(block_size as u64, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
pub fn get_stats(&self) -> MemoryPoolStats {
let allocated_blocks = self
.allocated_blocks
.read()
.expect("lock should not be poisoned");
let free_blocks = self
.free_blocks
.read()
.expect("lock should not be poisoned");
let active_blocks = allocated_blocks.len();
let free_block_count: usize = free_blocks.values().map(|v| v.len()).sum();
MemoryPoolStats {
device_id: self.device_id,
active_blocks,
free_blocks: free_block_count,
total_allocated: self
.total_allocated
.load(std::sync::atomic::Ordering::SeqCst),
peak_allocated: self
.peak_allocated
.load(std::sync::atomic::Ordering::SeqCst),
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub device_id: i32,
pub active_blocks: usize,
pub free_blocks: usize,
pub total_allocated: u64,
pub peak_allocated: u64,
}
pub struct NcclScheduler {
#[allow(dead_code)]
device_id: i32,
#[allow(dead_code)]
compute_stream: CudaStream,
comm_streams: Vec<CudaStream>,
memory_pool: Arc<GpuMemoryPool>,
pending_ops: Mutex<VecDeque<ScheduledNcclOp>>,
performance_stats: Arc<Mutex<NcclPerformanceStats>>,
#[allow(dead_code)]
overlap_enabled: bool,
}
#[derive(Debug)]
struct ScheduledNcclOp {
#[allow(dead_code)]
op_id: u64,
#[allow(dead_code)]
op_type: NcclOpType,
#[allow(dead_code)]
tensor_size: usize,
#[allow(dead_code)]
stream_id: u64,
#[allow(dead_code)]
scheduled_at: Instant,
priority: SchedulePriority,
}
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
pub enum NcclOpType {
AllReduce(ReduceOp),
Broadcast { src_rank: u32 },
ReduceScatter(ReduceOp),
AllGather,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[allow(dead_code)]
pub enum SchedulePriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
impl NcclScheduler {
pub fn new(device_id: i32, num_comm_streams: usize) -> Self {
let compute_stream = CudaStream::default(device_id);
let comm_streams: Vec<_> = (0..num_comm_streams)
.map(|_| CudaStream::new(device_id))
.collect();
Self {
device_id,
compute_stream,
comm_streams,
memory_pool: Arc::new(GpuMemoryPool::new(device_id)),
pending_ops: Mutex::new(VecDeque::new()),
performance_stats: Arc::new(Mutex::new(NcclPerformanceStats::new())),
overlap_enabled: true,
}
}
pub async fn schedule_all_reduce<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
priority: SchedulePriority,
) -> TorshResult<()> {
let start_time = Instant::now();
let tensor_size = tensor.numel() * std::mem::size_of::<T>();
let stream = self.choose_optimal_stream().await;
let _workspace = if tensor_size > 64 * 1024 * 1024 {
Some(self.memory_pool.allocate(tensor_size)?)
} else {
None
};
let op_id = self
.schedule_operation(
NcclOpType::AllReduce(op),
tensor_size,
stream.stream_id(),
priority,
)
.await;
let result = self
.execute_optimized_all_reduce(tensor, op, group, stream)
.await;
let duration = start_time.elapsed();
self.record_performance(op_id, NcclOpType::AllReduce(op), tensor_size, duration)
.await;
result
}
pub async fn schedule_broadcast<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
src_rank: u32,
group: &ProcessGroup,
priority: SchedulePriority,
) -> TorshResult<()> {
let start_time = Instant::now();
let tensor_size = tensor.numel() * std::mem::size_of::<T>();
let stream = self.choose_optimal_stream().await;
let op_id = self
.schedule_operation(
NcclOpType::Broadcast { src_rank },
tensor_size,
stream.stream_id(),
priority,
)
.await;
let result = self
.execute_optimized_broadcast(tensor, src_rank, group, stream)
.await;
let duration = start_time.elapsed();
self.record_performance(
op_id,
NcclOpType::Broadcast { src_rank },
tensor_size,
duration,
)
.await;
result
}
pub async fn execute_fused_operations(
&self,
operations: Vec<FusedNcclOp>,
group: &ProcessGroup,
) -> TorshResult<()> {
if operations.is_empty() {
return Ok(());
}
info!(
"🔥 NCCL Fusion: Executing {} fused operations",
operations.len()
);
let start_time = Instant::now();
let stream = self.choose_optimal_stream().await;
let sorted_ops = operations;
for batch in sorted_ops.chunks(8) {
self.execute_fused_batch(batch, group, stream).await?;
}
let duration = start_time.elapsed();
let total_size: usize = sorted_ops.iter().map(|op| op.tensor_size).sum();
{
let mut stats = self
.performance_stats
.lock()
.expect("lock should not be poisoned");
stats.record_fusion(sorted_ops.len(), total_size, duration);
}
info!(" Fusion completed in {:?}", duration);
Ok(())
}
async fn choose_optimal_stream(&self) -> &CudaStream {
let mut best_stream_idx = 0;
let mut best_score = f64::INFINITY;
for (idx, stream) in self.comm_streams.iter().enumerate() {
let load_score = stream.pending_operations.load(Ordering::Relaxed) as f64;
let bandwidth_score = stream.bandwidth_usage.load(Ordering::Relaxed) as f64 / 100.0;
let dependency_score = stream.num_dependencies.load(Ordering::Relaxed) as f64;
let total_score = load_score * 0.4 + bandwidth_score * 0.3 + dependency_score * 0.3;
if total_score < best_score {
best_score = total_score;
best_stream_idx = idx;
}
}
&self.comm_streams[best_stream_idx]
}
async fn schedule_operation(
&self,
op_type: NcclOpType,
tensor_size: usize,
stream_id: u64,
priority: SchedulePriority,
) -> u64 {
static OP_COUNTER: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
let op_id = OP_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let scheduled_op = ScheduledNcclOp {
op_id,
op_type,
tensor_size,
stream_id,
scheduled_at: Instant::now(),
priority,
};
let mut pending_ops = self
.pending_ops
.lock()
.expect("lock should not be poisoned");
let insert_pos = pending_ops
.iter()
.position(|op| op.priority < priority)
.unwrap_or(pending_ops.len());
pending_ops.insert(insert_pos, scheduled_op);
op_id
}
async fn execute_optimized_all_reduce<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
op: ReduceOp,
group: &ProcessGroup,
stream: &CudaStream,
) -> TorshResult<()> {
let tensor_size = tensor.numel() * std::mem::size_of::<T>();
if tensor_size > 2 * 1024 * 1024 {
self.execute_ring_all_reduce(tensor, op, group, stream)
.await
} else if tensor_size < 4 * 1024 {
self.execute_tree_all_reduce(tensor, op, group, stream)
.await
} else {
self.execute_hierarchical_all_reduce(tensor, op, group, stream)
.await
}
}
async fn execute_ring_all_reduce<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
_op: ReduceOp,
_group: &ProcessGroup,
stream: &CudaStream,
) -> TorshResult<()> {
info!(
" NCCL Ring All-Reduce: {} elements on stream {} (device {})",
tensor.numel(),
stream.stream_id(),
stream.device_id()
);
stream.add_pending_operation();
let event = stream.record_event();
let world_size = 4; let chunk_size = tensor.numel() / world_size;
for _step in 0..world_size - 1 {
let transfer_size = chunk_size * std::mem::size_of::<T>();
let transfer_time_ns = (transfer_size as f64 * 0.01).max(100.0); tokio::time::sleep(tokio::time::Duration::from_nanos(transfer_time_ns as u64)).await;
let compute_time_ns = (chunk_size as f64 * 5.0).max(50.0); tokio::time::sleep(tokio::time::Duration::from_nanos(compute_time_ns as u64)).await;
}
for _step in 0..world_size - 1 {
let transfer_size = chunk_size * std::mem::size_of::<T>();
let transfer_time_ns = (transfer_size as f64 * 0.008).max(80.0); tokio::time::sleep(tokio::time::Duration::from_nanos(transfer_time_ns as u64)).await;
}
stream.synchronize().await?;
stream.complete_operation();
info!(" Ring All-Reduce completed (event: {})", event.event_id);
Ok(())
}
async fn execute_tree_all_reduce<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
_op: ReduceOp,
_group: &ProcessGroup,
stream: &CudaStream,
) -> TorshResult<()> {
info!(
"🌳 NCCL Tree All-Reduce: {} elements on stream {} (device {})",
tensor.numel(),
stream.stream_id(),
stream.device_id()
);
stream.add_pending_operation();
let event = stream.record_event();
let world_size = 4; let tree_depth = (world_size as f32).log2().ceil() as u32;
for level in 0..tree_depth {
let participants = world_size >> level;
if participants <= 1 {
break;
}
let transfer_size = tensor.numel() * std::mem::size_of::<T>();
let transfer_time_ns = (transfer_size as f64 * 0.02).max(50.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(transfer_time_ns as u64)).await;
let compute_time_ns = (tensor.numel() as f64 * 3.0).max(30.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(compute_time_ns as u64)).await;
}
for level in (0..tree_depth).rev() {
let participants = world_size >> level;
if participants <= 1 {
continue;
}
let transfer_size = tensor.numel() * std::mem::size_of::<T>();
let transfer_time_ns = (transfer_size as f64 * 0.015).max(40.0); tokio::time::sleep(tokio::time::Duration::from_nanos(transfer_time_ns as u64)).await;
}
stream.synchronize().await?;
stream.complete_operation();
info!(" Tree All-Reduce completed (event: {})", event.event_id);
Ok(())
}
async fn execute_hierarchical_all_reduce<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
_op: ReduceOp,
_group: &ProcessGroup,
stream: &CudaStream,
) -> TorshResult<()> {
info!(
"🏗️ NCCL Hierarchical All-Reduce: {} elements on stream {} (device {})",
tensor.numel(),
stream.stream_id(),
stream.device_id()
);
stream.add_pending_operation();
let event = stream.record_event();
let gpus_per_node = 8;
let num_nodes = 4; let tensor_size_bytes = tensor.numel() * std::mem::size_of::<T>();
let intra_node_steps = (gpus_per_node as f32).log2().ceil() as u32;
for _step in 0..intra_node_steps {
let nvlink_time_ns = (tensor_size_bytes as f64 * 0.002).max(20.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(nvlink_time_ns as u64)).await;
let compute_time_ns = (tensor.numel() as f64 * 2.0).max(10.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(compute_time_ns as u64)).await;
}
let inter_node_steps = (num_nodes as f32).log2().ceil() as u32;
for _step in 0..inter_node_steps {
let network_time_ns = (tensor_size_bytes as f64 * 0.05).max(200.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(network_time_ns as u64)).await;
let compute_time_ns = (tensor.numel() as f64 * 1.0).max(5.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(compute_time_ns as u64)).await;
}
for _step in 0..intra_node_steps {
let nvlink_time_ns = (tensor_size_bytes as f64 * 0.001).max(15.0); tokio::time::sleep(tokio::time::Duration::from_nanos(nvlink_time_ns as u64)).await;
}
stream.synchronize().await?;
stream.complete_operation();
info!(
" Hierarchical All-Reduce completed (event: {})",
event.event_id
);
Ok(())
}
async fn execute_optimized_broadcast<T: FloatElement>(
&self,
tensor: &mut Tensor<T>,
src_rank: u32,
_group: &ProcessGroup,
stream: &CudaStream,
) -> TorshResult<()> {
info!(
" NCCL Optimized Broadcast: {} elements from rank {} on stream {} (device {})",
tensor.numel(),
src_rank,
stream.stream_id(),
stream.device_id()
);
stream.add_pending_operation();
let event = stream.record_event();
let tensor_size_bytes = tensor.numel() * std::mem::size_of::<T>();
let world_size = 4;
if tensor_size_bytes > 64 * 1024 {
let num_chunks = 8;
let chunk_size = tensor.numel().div_ceil(num_chunks);
let chunk_size_bytes = chunk_size * std::mem::size_of::<T>();
let tree_depth = (world_size as f32).log2().ceil() as u32;
for chunk in 0..num_chunks {
for level in 0..tree_depth {
let participants = 1 << (level + 1); if participants > world_size {
break;
}
let transfer_time_ns = (chunk_size_bytes as f64 * 0.01).max(30.0);
if chunk > 0 {
let overlap_factor = 0.7; let effective_time = transfer_time_ns * (1.0 - overlap_factor);
tokio::time::sleep(tokio::time::Duration::from_nanos(
effective_time as u64,
))
.await;
} else {
tokio::time::sleep(tokio::time::Duration::from_nanos(
transfer_time_ns as u64,
))
.await;
}
}
}
} else {
let tree_depth = (world_size as f32).log2().ceil() as u32;
for _level in 0..tree_depth {
let transfer_time_ns = (tensor_size_bytes as f64 * 0.008).max(20.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(transfer_time_ns as u64))
.await;
}
}
stream.synchronize().await?;
stream.complete_operation();
info!(
" Optimized Broadcast completed (event: {})",
event.event_id
);
Ok(())
}
async fn execute_fused_batch(
&self,
batch: &[FusedNcclOp],
_group: &ProcessGroup,
stream: &CudaStream,
) -> TorshResult<()> {
info!(
"⚡ Executing fused batch of {} operations on stream {}",
batch.len(),
stream.stream_id()
);
let mut all_reduce_ops = Vec::new();
let mut broadcast_ops = Vec::new();
for op in batch {
match op.op_type {
NcclOpType::AllReduce(_) => all_reduce_ops.push(op),
NcclOpType::Broadcast { .. } => broadcast_ops.push(op),
_ => {} }
}
if !all_reduce_ops.is_empty() {
self.execute_fused_all_reduce(&all_reduce_ops, stream)
.await?;
}
if !broadcast_ops.is_empty() {
self.execute_fused_broadcast(&broadcast_ops, stream).await?;
}
Ok(())
}
async fn execute_fused_all_reduce(
&self,
ops: &[&FusedNcclOp],
stream: &CudaStream,
) -> TorshResult<()> {
let total_size: usize = ops.iter().map(|op| op.tensor_size).sum();
info!(
" 🔥 Fused All-Reduce: {} operations, {} total bytes on stream {}",
ops.len(),
total_size,
stream.stream_id()
);
stream.add_pending_operation();
let event = stream.record_event();
info!(" Starting fused group: ncclGroupStart() equivalent");
let mut sorted_ops: Vec<_> = ops.iter().collect();
sorted_ops.sort_by_key(|op| op.tensor_size);
let small_ops: Vec<_> = sorted_ops
.iter()
.filter(|op| op.tensor_size < 64 * 1024)
.collect();
let medium_ops: Vec<_> = sorted_ops
.iter()
.filter(|op| op.tensor_size >= 64 * 1024 && op.tensor_size < 1024 * 1024)
.collect();
let large_ops: Vec<_> = sorted_ops
.iter()
.filter(|op| op.tensor_size >= 1024 * 1024)
.collect();
if !small_ops.is_empty() {
let small_total_size: usize = small_ops.iter().map(|op| op.tensor_size).sum();
let fusion_efficiency = 0.8; let effective_time = (small_total_size as f64 * 3.0 * fusion_efficiency).max(100.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(effective_time as u64)).await;
info!(" ⚡ Fused {} small operations", small_ops.len());
}
if !medium_ops.is_empty() {
let medium_total_size: usize = medium_ops.iter().map(|op| op.tensor_size).sum();
let fusion_efficiency = 0.6; let effective_time = (medium_total_size as f64 * 4.0 * fusion_efficiency).max(200.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(effective_time as u64)).await;
info!(" ⚡ Fused {} medium operations", medium_ops.len());
}
for (i, op) in large_ops.iter().enumerate() {
let individual_time = (op.tensor_size as f64 * 5.0).max(500.0);
tokio::time::sleep(tokio::time::Duration::from_nanos(individual_time as u64)).await;
info!(" 📦 Large operation {}/{}", i + 1, large_ops.len());
}
info!(" 🏁 Ending fused group: ncclGroupEnd() equivalent");
stream.synchronize().await?;
stream.complete_operation();
info!(
" Fused All-Reduce completed (event: {})",
event.event_id
);
Ok(())
}
async fn execute_fused_broadcast(
&self,
ops: &[&FusedNcclOp],
stream: &CudaStream,
) -> TorshResult<()> {
let total_size: usize = ops.iter().map(|op| op.tensor_size).sum();
info!(
" Fused Broadcast: {} operations, {} total bytes on stream {}",
ops.len(),
total_size,
stream.stream_id()
);
let event = stream.record_event();
tokio::time::sleep(tokio::time::Duration::from_nanos(total_size as u64 * 3)).await;
info!(
" Fused Broadcast completed (event: {})",
event.event_id
);
Ok(())
}
async fn record_performance(
&self,
op_id: u64,
op_type: NcclOpType,
tensor_size: usize,
duration: Duration,
) {
let mut stats = self
.performance_stats
.lock()
.expect("lock should not be poisoned");
stats.record_operation(op_id, op_type, tensor_size, duration);
}
pub fn get_performance_stats(&self) -> NcclPerformanceStats {
self.performance_stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn get_memory_stats(&self) -> MemoryPoolStats {
self.memory_pool.get_stats()
}
}
#[derive(Debug)]
pub struct FusedNcclOp {
pub op_type: NcclOpType,
pub tensor_size: usize,
pub priority: SchedulePriority,
}
impl FusedNcclOp {
pub fn all_reduce(tensor_size: usize, op: ReduceOp, priority: SchedulePriority) -> Self {
Self {
op_type: NcclOpType::AllReduce(op),
tensor_size,
priority,
}
}
pub fn broadcast(tensor_size: usize, src_rank: u32, priority: SchedulePriority) -> Self {
Self {
op_type: NcclOpType::Broadcast { src_rank },
tensor_size,
priority,
}
}
}
#[derive(Debug, Clone)]
pub struct NcclPerformanceStats {
pub total_operations: u64,
pub total_bytes_transferred: u64,
pub total_duration: Duration,
pub average_bandwidth_gbps: f64,
pub operation_breakdown: HashMap<String, OperationStats>,
pub fusion_stats: FusionStats,
}
#[derive(Debug, Clone)]
pub struct OperationStats {
pub count: u64,
pub total_bytes: u64,
pub total_duration: Duration,
pub min_duration: Duration,
pub max_duration: Duration,
pub average_bandwidth_gbps: f64,
}
#[derive(Debug, Clone)]
pub struct FusionStats {
pub total_fused_batches: u64,
pub total_fused_operations: u64,
pub fusion_efficiency: f64, pub average_batch_size: f64,
}
impl NcclPerformanceStats {
fn new() -> Self {
Self {
total_operations: 0,
total_bytes_transferred: 0,
total_duration: Duration::ZERO,
average_bandwidth_gbps: 0.0,
operation_breakdown: HashMap::new(),
fusion_stats: FusionStats {
total_fused_batches: 0,
total_fused_operations: 0,
fusion_efficiency: 0.0,
average_batch_size: 0.0,
},
}
}
fn record_operation(
&mut self,
_op_id: u64,
op_type: NcclOpType,
tensor_size: usize,
duration: Duration,
) {
self.total_operations += 1;
self.total_bytes_transferred += tensor_size as u64;
self.total_duration += duration;
if !self.total_duration.is_zero() {
let total_gb = self.total_bytes_transferred as f64 / (1024.0 * 1024.0 * 1024.0);
let total_seconds = self.total_duration.as_secs_f64();
self.average_bandwidth_gbps = total_gb / total_seconds;
}
let op_name = format!("{:?}", op_type);
let entry = self
.operation_breakdown
.entry(op_name)
.or_insert_with(|| OperationStats {
count: 0,
total_bytes: 0,
total_duration: Duration::ZERO,
min_duration: Duration::MAX,
max_duration: Duration::ZERO,
average_bandwidth_gbps: 0.0,
});
entry.count += 1;
entry.total_bytes += tensor_size as u64;
entry.total_duration += duration;
entry.min_duration = entry.min_duration.min(duration);
entry.max_duration = entry.max_duration.max(duration);
if !entry.total_duration.is_zero() {
let op_gb = entry.total_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
let op_seconds = entry.total_duration.as_secs_f64();
entry.average_bandwidth_gbps = op_gb / op_seconds;
}
}
fn record_fusion(&mut self, num_operations: usize, total_size: usize, duration: Duration) {
self.fusion_stats.total_fused_batches += 1;
self.fusion_stats.total_fused_operations += num_operations as u64;
if self.fusion_stats.total_fused_batches > 0 {
self.fusion_stats.average_batch_size = self.fusion_stats.total_fused_operations as f64
/ self.fusion_stats.total_fused_batches as f64;
}
if self.total_operations > 0 {
self.fusion_stats.fusion_efficiency = (self.fusion_stats.total_fused_operations as f64
/ self.total_operations as f64)
* 100.0;
}
for _ in 0..num_operations {
self.record_operation(
0,
NcclOpType::AllReduce(ReduceOp::Sum),
total_size / num_operations,
duration / num_operations as u32,
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{init_process_group, BackendType};
#[tokio::test]
async fn test_nccl_scheduler() {
let scheduler = NcclScheduler::new(0, 4);
assert_eq!(scheduler.device_id, 0);
assert_eq!(scheduler.comm_streams.len(), 4);
}
#[tokio::test]
async fn test_cuda_stream() {
let stream = CudaStream::new(0);
assert_eq!(stream.device_id(), 0);
assert!(!stream.is_default);
let result = stream.synchronize().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_memory_pool() {
let pool = GpuMemoryPool::new(0);
let block1 = pool.allocate(1024).unwrap();
let block2 = pool.allocate(2048).unwrap();
assert_eq!(block1.size, 1024);
assert_eq!(block2.size, 2048);
assert_eq!(block1.device_id, 0);
assert_eq!(block2.device_id, 0);
let stats_before = pool.get_stats();
assert_eq!(stats_before.active_blocks, 2);
pool.deallocate(block1).unwrap();
let stats_after = pool.get_stats();
assert_eq!(stats_after.active_blocks, 1);
assert_eq!(stats_after.free_blocks, 1);
}
#[tokio::test]
async fn test_scheduled_all_reduce() {
let pg = init_process_group(BackendType::Nccl, 0, 1, "127.0.0.1", 29500)
.await
.unwrap();
let scheduler = NcclScheduler::new(0, 2);
let mut tensor: Tensor<f32> = Tensor::from_vec(vec![1.0; 1000], &[1000]).unwrap();
let result = scheduler
.schedule_all_reduce(&mut tensor, ReduceOp::Sum, &pg, SchedulePriority::High)
.await;
assert!(result.is_ok());
let stats = scheduler.get_performance_stats();
assert!(stats.total_operations > 0);
}
#[tokio::test]
async fn test_fused_operations() {
let pg = init_process_group(BackendType::Nccl, 0, 1, "127.0.0.1", 29500)
.await
.unwrap();
let scheduler = NcclScheduler::new(0, 2);
let fused_ops = vec![
FusedNcclOp::all_reduce(1024, ReduceOp::Sum, SchedulePriority::Normal),
FusedNcclOp::broadcast(2048, 0, SchedulePriority::Normal),
FusedNcclOp::all_reduce(512, ReduceOp::Max, SchedulePriority::High),
];
let result = scheduler.execute_fused_operations(fused_ops, &pg).await;
assert!(result.is_ok());
let stats = scheduler.get_performance_stats();
assert!(stats.fusion_stats.total_fused_batches > 0);
}
}