#![allow(dead_code)]
use crate::TorshResult;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use torsh_tensor::Tensor;
use super::{
config::{CommunicationStrategy, RankMapping, ThreeDParallelismConfig},
model_shards::ModelShards,
};
pub struct GradientSynchronizer {
config: ThreeDParallelismConfig,
rank_mapping: RankMapping,
gradient_buffers: Arc<Mutex<HashMap<String, GradientBuffer>>>,
sync_stats: Arc<Mutex<SyncStatistics>>,
compression_config: GradientCompressionConfig,
bucket_config: GradientBucketingConfig,
}
impl GradientSynchronizer {
pub fn new(config: &ThreeDParallelismConfig, rank_mapping: &RankMapping) -> TorshResult<Self> {
let gradient_buffers = Arc::new(Mutex::new(HashMap::new()));
let sync_stats = Arc::new(Mutex::new(SyncStatistics::new()));
let compression_config = GradientCompressionConfig {
enable_compression: true,
compression_ratio: 0.1,
error_feedback: true,
quantization_bits: 8,
};
let bucket_config = GradientBucketingConfig {
bucket_size_mb: 25.0,
max_buckets: 16,
overlap_communication: true,
};
Ok(Self {
config: config.clone(),
rank_mapping: rank_mapping.clone(),
gradient_buffers,
sync_stats,
compression_config,
bucket_config,
})
}
pub async fn synchronize_gradients(&self, model_shards: &ModelShards) -> TorshResult<()> {
let start_time = Instant::now();
self.synchronize_tensor_parallel_gradients(model_shards)
.await?;
self.synchronize_pipeline_parallel_gradients(model_shards)
.await?;
self.synchronize_data_parallel_gradients(model_shards)
.await?;
let mut stats = self.sync_stats.lock().expect("lock should not be poisoned");
stats.total_sync_operations += 1;
stats.total_sync_time += start_time.elapsed();
Ok(())
}
async fn synchronize_tensor_parallel_gradients(
&self,
model_shards: &ModelShards,
) -> TorshResult<()> {
if self.config.tp_size <= 1 {
return Ok(());
}
let _tp_ranks = self.rank_mapping.tp_group_ranks();
for (stage_idx, stage_layers) in model_shards.pipeline_stages.iter().enumerate() {
if stage_idx != self.rank_mapping.pp_rank {
continue; }
for layer_shard in stage_layers {
if let Some(ref grad_weight) = layer_shard.grad_weight {
self.all_reduce_tensor_parallel(grad_weight).await?;
}
if let Some(ref grad_bias) = layer_shard.grad_bias {
self.all_reduce_tensor_parallel(grad_bias).await?;
}
}
}
Ok(())
}
async fn synchronize_pipeline_parallel_gradients(
&self,
model_shards: &ModelShards,
) -> TorshResult<()> {
if self.config.pp_size <= 1 {
return Ok(());
}
if self.rank_mapping.is_pp_head() || self.rank_mapping.is_pp_tail() {
self.synchronize_shared_embeddings(model_shards).await?;
}
Ok(())
}
async fn synchronize_data_parallel_gradients(
&self,
model_shards: &ModelShards,
) -> TorshResult<()> {
if self.config.dp_size <= 1 {
return Ok(());
}
let gradient_buckets = self.create_gradient_buckets(model_shards).await?;
for bucket in gradient_buckets {
self.synchronize_gradient_bucket(&bucket).await?;
}
Ok(())
}
async fn all_reduce_tensor_parallel(&self, gradient: &Tensor<f32>) -> TorshResult<()> {
match self.config.comm_strategy {
CommunicationStrategy::AllReduce => self.standard_all_reduce_tp(gradient).await,
CommunicationStrategy::HierarchicalAllReduce => {
self.hierarchical_all_reduce_tp(gradient).await
}
CommunicationStrategy::RingAllReduce => self.ring_all_reduce_tp(gradient).await,
CommunicationStrategy::TreeAllReduce => self.tree_all_reduce_tp(gradient).await,
CommunicationStrategy::Adaptive => {
let gradient_size = gradient.numel() * std::mem::size_of::<f32>();
if gradient_size < 1024 * 1024 {
self.tree_all_reduce_tp(gradient).await
} else {
self.ring_all_reduce_tp(gradient).await
}
}
}
}
async fn standard_all_reduce_tp(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(Duration::from_micros(50)).await;
Ok(())
}
async fn hierarchical_all_reduce_tp(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(Duration::from_micros(40)).await;
Ok(())
}
async fn ring_all_reduce_tp(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(Duration::from_micros(60)).await;
Ok(())
}
async fn tree_all_reduce_tp(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(Duration::from_micros(30)).await;
Ok(())
}
async fn synchronize_shared_embeddings(&self, model_shards: &ModelShards) -> TorshResult<()> {
for stage_layers in model_shards.pipeline_stages.iter() {
for layer_shard in stage_layers {
if matches!(
layer_shard.layer_type,
super::model_shards::LayerType::Embedding
) {
if let Some(ref grad_weight) = layer_shard.grad_weight {
self.all_reduce_pipeline_parallel(grad_weight).await?;
}
}
}
}
Ok(())
}
async fn all_reduce_pipeline_parallel(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(Duration::from_micros(100)).await;
Ok(())
}
async fn create_gradient_buckets(
&self,
model_shards: &ModelShards,
) -> TorshResult<Vec<GradientBucket>> {
let mut buckets = Vec::new();
let mut current_bucket = GradientBucket::new();
let bucket_size_bytes = (self.bucket_config.bucket_size_mb * 1024.0 * 1024.0) as usize;
for (stage_idx, stage_layers) in model_shards.pipeline_stages.iter().enumerate() {
if stage_idx != self.rank_mapping.pp_rank {
continue; }
for layer_shard in stage_layers {
if let Some(ref grad_weight) = layer_shard.grad_weight {
let gradient_size = grad_weight.numel() * std::mem::size_of::<f32>();
if current_bucket.size_bytes + gradient_size > bucket_size_bytes
&& !current_bucket.gradients.is_empty()
{
buckets.push(current_bucket);
current_bucket = GradientBucket::new();
}
current_bucket.add_gradient(
format!("layer_{}_weight", layer_shard.layer_id),
grad_weight.clone(),
);
}
if let Some(ref grad_bias) = layer_shard.grad_bias {
let gradient_size = grad_bias.numel() * std::mem::size_of::<f32>();
if current_bucket.size_bytes + gradient_size > bucket_size_bytes
&& !current_bucket.gradients.is_empty()
{
buckets.push(current_bucket);
current_bucket = GradientBucket::new();
}
current_bucket.add_gradient(
format!("layer_{}_bias", layer_shard.layer_id),
grad_bias.clone(),
);
}
}
}
if !current_bucket.gradients.is_empty() {
buckets.push(current_bucket);
}
Ok(buckets)
}
async fn synchronize_gradient_bucket(&self, bucket: &GradientBucket) -> TorshResult<()> {
let start_time = Instant::now();
let compressed_gradients = if self.compression_config.enable_compression {
self.compress_gradients(&bucket.gradients).await?
} else {
bucket.gradients.clone()
};
for gradient in compressed_gradients.values() {
self.all_reduce_data_parallel(gradient).await?;
}
if self.compression_config.enable_compression {
self.decompress_gradients(&compressed_gradients).await?;
}
let mut stats = self.sync_stats.lock().expect("lock should not be poisoned");
stats.total_buckets_synced += 1;
stats.total_bucket_sync_time += start_time.elapsed();
Ok(())
}
async fn all_reduce_data_parallel(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
match self.config.comm_strategy {
CommunicationStrategy::AllReduce => {
tokio::time::sleep(Duration::from_micros(80)).await;
}
CommunicationStrategy::HierarchicalAllReduce => {
tokio::time::sleep(Duration::from_micros(60)).await;
}
CommunicationStrategy::RingAllReduce => {
tokio::time::sleep(Duration::from_micros(100)).await;
}
CommunicationStrategy::TreeAllReduce => {
tokio::time::sleep(Duration::from_micros(50)).await;
}
CommunicationStrategy::Adaptive => {
tokio::time::sleep(Duration::from_micros(70)).await;
}
}
Ok(())
}
async fn compress_gradients(
&self,
gradients: &HashMap<String, Tensor<f32>>,
) -> TorshResult<HashMap<String, Tensor<f32>>> {
let mut compressed = HashMap::new();
for (name, gradient) in gradients {
let compressed_gradient = self.quantize_gradient(gradient).await?;
compressed.insert(name.clone(), compressed_gradient);
}
Ok(compressed)
}
async fn decompress_gradients(
&self,
gradients: &HashMap<String, Tensor<f32>>,
) -> TorshResult<()> {
for gradient in gradients.values() {
self.dequantize_gradient(gradient).await?;
}
Ok(())
}
async fn quantize_gradient(&self, gradient: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
tokio::time::sleep(Duration::from_micros(10)).await;
Ok(gradient.clone()) }
async fn dequantize_gradient(&self, _gradient: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(Duration::from_micros(5)).await;
Ok(()) }
pub fn get_sync_stats(&self) -> SyncStatistics {
self.sync_stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn update_compression_config(&mut self, config: GradientCompressionConfig) {
self.compression_config = config;
}
pub fn clear_buffers(&self) {
let mut buffers = self
.gradient_buffers
.lock()
.expect("lock should not be poisoned");
buffers.clear();
}
}
#[derive(Debug, Clone)]
struct GradientBucket {
gradients: HashMap<String, Tensor<f32>>,
size_bytes: usize,
}
impl GradientBucket {
fn new() -> Self {
Self {
gradients: HashMap::new(),
size_bytes: 0,
}
}
fn add_gradient(&mut self, name: String, gradient: Tensor<f32>) {
let gradient_size = gradient.numel() * std::mem::size_of::<f32>();
self.size_bytes += gradient_size;
self.gradients.insert(name, gradient);
}
}
#[derive(Debug, Clone)]
struct GradientBuffer {
accumulated_gradient: Tensor<f32>,
accumulation_count: usize,
}
#[derive(Debug, Clone)]
pub struct SyncStatistics {
pub total_sync_operations: u64,
pub total_sync_time: Duration,
pub total_buckets_synced: u64,
pub total_bucket_sync_time: Duration,
pub average_sync_time_ms: f64,
pub communication_efficiency: f64,
}
impl SyncStatistics {
fn new() -> Self {
Self {
total_sync_operations: 0,
total_sync_time: Duration::ZERO,
total_buckets_synced: 0,
total_bucket_sync_time: Duration::ZERO,
average_sync_time_ms: 0.0,
communication_efficiency: 1.0,
}
}
pub fn update_average_sync_time(&mut self) {
if self.total_sync_operations > 0 {
self.average_sync_time_ms =
self.total_sync_time.as_secs_f64() * 1000.0 / self.total_sync_operations as f64;
}
}
}
#[derive(Debug, Clone)]
pub struct GradientCompressionConfig {
pub enable_compression: bool,
pub compression_ratio: f32,
pub error_feedback: bool,
pub quantization_bits: u32,
}
#[derive(Debug, Clone)]
pub struct GradientBucketingConfig {
pub bucket_size_mb: f32,
pub max_buckets: usize,
pub overlap_communication: bool,
}