#![allow(dead_code)]
use crate::gradient_compression::{CompressionConfig, CompressionMethod};
use crate::gradient_compression_enhanced::{CompressionMetrics, EnhancedGradientCompressor};
use crate::{TorshDistributedError, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use torsh_tensor::Tensor;
use tracing::{debug, info};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkMetrics {
pub bandwidth_mbps: f32,
pub latency_ms: f32,
pub packet_loss: f32,
pub congestion_factor: f32,
pub stability_score: f32,
pub timestamp_ms: u64,
}
impl Default for NetworkMetrics {
fn default() -> Self {
Self {
bandwidth_mbps: 1000.0, latency_ms: 1.0,
packet_loss: 0.0,
congestion_factor: 0.0,
stability_score: 1.0,
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
}
}
}
impl NetworkMetrics {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveCompressionConfig {
pub target_bandwidth_utilization: f32,
pub min_compression_ratio: f32,
pub max_compression_ratio: f32,
pub monitoring_interval: Duration,
pub history_window_size: usize,
pub adaptation_sensitivity: f32,
pub convergence_quality_weight: f32,
pub communication_efficiency_weight: f32,
}
impl Default for AdaptiveCompressionConfig {
fn default() -> Self {
Self {
target_bandwidth_utilization: 0.8,
min_compression_ratio: 0.01,
max_compression_ratio: 0.9,
monitoring_interval: Duration::from_millis(100),
history_window_size: 50,
adaptation_sensitivity: 0.3,
convergence_quality_weight: 0.6,
communication_efficiency_weight: 0.4,
}
}
}
#[derive(Debug)]
pub struct NetworkProfiler {
metrics_history: VecDeque<NetworkMetrics>,
bandwidth_samples: VecDeque<f32>,
latency_samples: VecDeque<f32>,
last_measurement: Instant,
config: AdaptiveCompressionConfig,
}
impl NetworkProfiler {
pub fn new(config: AdaptiveCompressionConfig) -> Self {
Self {
metrics_history: VecDeque::with_capacity(config.history_window_size),
bandwidth_samples: VecDeque::with_capacity(config.history_window_size),
latency_samples: VecDeque::with_capacity(config.history_window_size),
last_measurement: Instant::now(),
config,
}
}
pub fn measure_network_performance(&mut self) -> TorshResult<NetworkMetrics> {
let now = Instant::now();
let bandwidth = self.estimate_bandwidth()?;
let latency = self.estimate_latency()?;
let packet_loss = self.estimate_packet_loss()?;
let congestion = self.estimate_congestion_factor(bandwidth, latency)?;
let stability = self.calculate_stability_score()?;
let metrics = NetworkMetrics {
bandwidth_mbps: bandwidth,
latency_ms: latency,
packet_loss,
congestion_factor: congestion,
stability_score: stability,
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
};
self.metrics_history.push_back(metrics.clone());
if self.metrics_history.len() > self.config.history_window_size {
self.metrics_history.pop_front();
}
self.bandwidth_samples.push_back(bandwidth);
if self.bandwidth_samples.len() > self.config.history_window_size {
self.bandwidth_samples.pop_front();
}
self.latency_samples.push_back(latency);
if self.latency_samples.len() > self.config.history_window_size {
self.latency_samples.pop_front();
}
self.last_measurement = now;
Ok(metrics)
}
fn estimate_bandwidth(&self) -> TorshResult<f32> {
let base_bandwidth = 1000.0; let variation = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis()
% 100) as f32
/ 100.0;
let bandwidth = base_bandwidth * (0.8 + 0.4 * variation);
Ok(bandwidth.max(10.0)) }
fn estimate_latency(&self) -> TorshResult<f32> {
let base_latency = 1.0; let variation = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis()
% 50) as f32
/ 50.0;
let latency = base_latency * (0.5 + variation);
Ok(latency.max(0.1)) }
fn estimate_packet_loss(&self) -> TorshResult<f32> {
let base_loss = 0.001; let variation = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_millis()
% 20) as f32
/ 20.0;
Ok(base_loss * variation)
}
fn estimate_congestion_factor(&self, bandwidth: f32, latency: f32) -> TorshResult<f32> {
let normalized_latency = (latency / 10.0).min(1.0); let normalized_bandwidth = (1000.0 / bandwidth.max(100.0)).min(1.0); let congestion = (normalized_latency + normalized_bandwidth) / 2.0;
Ok(congestion.min(1.0))
}
fn calculate_stability_score(&self) -> TorshResult<f32> {
if self.bandwidth_samples.len() < 5 {
return Ok(1.0); }
let mean_bandwidth: f32 =
self.bandwidth_samples.iter().sum::<f32>() / self.bandwidth_samples.len() as f32;
let variance: f32 = self
.bandwidth_samples
.iter()
.map(|&x| (x - mean_bandwidth).powi(2))
.sum::<f32>()
/ self.bandwidth_samples.len() as f32;
let std_dev = variance.sqrt();
let cv_bandwidth = if mean_bandwidth > 0.0 {
std_dev / mean_bandwidth
} else {
0.0
};
let mean_latency: f32 =
self.latency_samples.iter().sum::<f32>() / self.latency_samples.len() as f32;
let latency_variance: f32 = self
.latency_samples
.iter()
.map(|&x| (x - mean_latency).powi(2))
.sum::<f32>()
/ self.latency_samples.len() as f32;
let latency_std_dev = latency_variance.sqrt();
let cv_latency = if mean_latency > 0.0 {
latency_std_dev / mean_latency
} else {
0.0
};
let stability = 1.0 - ((cv_bandwidth + cv_latency) / 2.0).min(1.0);
Ok(stability.max(0.0))
}
pub fn get_average_metrics(&self) -> Option<NetworkMetrics> {
if self.metrics_history.is_empty() {
return None;
}
let count = self.metrics_history.len() as f32;
let avg_bandwidth = self
.metrics_history
.iter()
.map(|m| m.bandwidth_mbps)
.sum::<f32>()
/ count;
let avg_latency = self
.metrics_history
.iter()
.map(|m| m.latency_ms)
.sum::<f32>()
/ count;
let avg_packet_loss = self
.metrics_history
.iter()
.map(|m| m.packet_loss)
.sum::<f32>()
/ count;
let avg_congestion = self
.metrics_history
.iter()
.map(|m| m.congestion_factor)
.sum::<f32>()
/ count;
let avg_stability = self
.metrics_history
.iter()
.map(|m| m.stability_score)
.sum::<f32>()
/ count;
Some(NetworkMetrics {
bandwidth_mbps: avg_bandwidth,
latency_ms: avg_latency,
packet_loss: avg_packet_loss,
congestion_factor: avg_congestion,
stability_score: avg_stability,
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
})
}
}
pub struct NetworkAwareCompressor {
enhanced_compressor: EnhancedGradientCompressor,
network_profiler: Arc<Mutex<NetworkProfiler>>,
current_config: Arc<Mutex<CompressionConfig>>,
adaptive_config: AdaptiveCompressionConfig,
compression_history: Arc<Mutex<VecDeque<CompressionMetrics>>>,
convergence_tracker: Arc<Mutex<ConvergenceTracker>>,
}
#[derive(Debug)]
struct ConvergenceTracker {
loss_history: VecDeque<f32>,
gradient_norm_history: VecDeque<f32>,
learning_rate_history: VecDeque<f32>,
convergence_rate: f32,
}
impl ConvergenceTracker {
fn new(history_size: usize) -> Self {
Self {
loss_history: VecDeque::with_capacity(history_size),
gradient_norm_history: VecDeque::with_capacity(history_size),
learning_rate_history: VecDeque::with_capacity(history_size),
convergence_rate: 0.0,
}
}
fn update_convergence_metrics(&mut self, loss: f32, gradient_norm: f32, learning_rate: f32) {
self.loss_history.push_back(loss);
if self.loss_history.len() > 100 {
self.loss_history.pop_front();
}
self.gradient_norm_history.push_back(gradient_norm);
if self.gradient_norm_history.len() > 100 {
self.gradient_norm_history.pop_front();
}
self.learning_rate_history.push_back(learning_rate);
if self.learning_rate_history.len() > 100 {
self.learning_rate_history.pop_front();
}
self.estimate_convergence_rate();
}
fn estimate_convergence_rate(&mut self) {
if self.loss_history.len() < 10 {
self.convergence_rate = 0.5; return;
}
let recent_losses: Vec<f32> = self.loss_history.iter().rev().take(10).cloned().collect();
let old_losses: Vec<f32> = if self.loss_history.len() >= 20 {
self.loss_history
.iter()
.rev()
.skip(10)
.take(10)
.cloned()
.collect()
} else {
recent_losses.clone()
};
let recent_avg: f32 = recent_losses.iter().sum::<f32>() / recent_losses.len() as f32;
let old_avg: f32 = old_losses.iter().sum::<f32>() / old_losses.len() as f32;
if old_avg > recent_avg && old_avg > 0.0 {
let improvement_rate = (old_avg - recent_avg) / old_avg;
self.convergence_rate = improvement_rate.clamp(0.0, 1.0);
} else {
self.convergence_rate = 0.1; }
}
fn get_convergence_quality(&self) -> f32 {
self.convergence_rate
}
}
impl NetworkAwareCompressor {
pub fn new(
base_config: CompressionConfig,
adaptive_config: AdaptiveCompressionConfig,
) -> TorshResult<Self> {
let enhanced_compressor = EnhancedGradientCompressor::new(base_config.clone())?;
let network_profiler = Arc::new(Mutex::new(NetworkProfiler::new(adaptive_config.clone())));
let current_config = Arc::new(Mutex::new(base_config));
let compression_history = Arc::new(Mutex::new(VecDeque::with_capacity(
adaptive_config.history_window_size,
)));
let convergence_tracker = Arc::new(Mutex::new(ConvergenceTracker::new(
adaptive_config.history_window_size,
)));
Ok(Self {
enhanced_compressor,
network_profiler,
current_config,
adaptive_config,
compression_history,
convergence_tracker,
})
}
pub fn compress_gradient_adaptive(
&mut self,
gradient: &Tensor,
training_metrics: Option<TrainingMetrics>,
) -> TorshResult<(
crate::gradient_compression::CompressedGradient,
CompressionMetrics,
)> {
if let Some(metrics) = training_metrics {
let mut tracker = self.convergence_tracker.lock().map_err(|e| {
TorshDistributedError::communication_error(
"convergence_tracker",
format!("Lock error: {}", e),
)
})?;
tracker.update_convergence_metrics(
metrics.loss,
metrics.gradient_norm,
metrics.learning_rate,
);
}
let network_metrics = {
let mut profiler = self.network_profiler.lock().map_err(|e| {
TorshDistributedError::communication_error(
"network_profiler",
format!("Lock error: {}", e),
)
})?;
profiler.measure_network_performance()?
};
let optimal_config = self.calculate_optimal_compression_config(&network_metrics)?;
{
let mut config = self.current_config.lock().map_err(|e| {
TorshDistributedError::communication_error(
"current_config",
format!("Lock error: {}", e),
)
})?;
*config = optimal_config.clone();
}
self.enhanced_compressor = EnhancedGradientCompressor::new(optimal_config.clone())?;
let start_time = Instant::now();
let (compressed_gradient, metrics) = self
.enhanced_compressor
.compress_gradient_enhanced(gradient, "adaptive_gradient")?;
let _compression_time = start_time.elapsed();
{
let mut history = self.compression_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"compression_history",
format!("Lock error: {}", e),
)
})?;
history.push_back(metrics.clone());
if history.len() > self.adaptive_config.history_window_size {
history.pop_front();
}
}
info!(
"Network-aware compression: ratio={:.3}, bandwidth={:.1}MB/s, latency={:.2}ms, stability={:.3}, throughput={:.1}MB/s",
metrics.compression_ratio,
network_metrics.bandwidth_mbps,
network_metrics.latency_ms,
network_metrics.stability_score,
metrics.throughput_mbps
);
Ok((compressed_gradient, metrics))
}
fn calculate_optimal_compression_config(
&self,
network_metrics: &NetworkMetrics,
) -> TorshResult<CompressionConfig> {
let convergence_quality = {
let tracker = self.convergence_tracker.lock().map_err(|e| {
TorshDistributedError::communication_error(
"convergence_tracker",
format!("Lock error: {}", e),
)
})?;
tracker.get_convergence_quality()
};
let optimal_ratio =
self.calculate_optimal_compression_ratio(network_metrics, convergence_quality)?;
let optimal_method =
self.select_optimal_compression_method(network_metrics, optimal_ratio)?;
let base_config = self.current_config.lock().map_err(|e| {
TorshDistributedError::communication_error(
"current_config",
format!("Lock error: {}", e),
)
})?;
let mut optimal_config = base_config.clone();
optimal_config.compression_ratio = optimal_ratio;
optimal_config.method = optimal_method.clone();
optimal_config.error_feedback = network_metrics.stability_score > 0.7;
optimal_config.error_feedback_momentum = if network_metrics.stability_score > 0.8 {
0.9 } else {
0.7 };
debug!(
"Optimal compression config: ratio={:.3}, method={:?}, error_feedback={}",
optimal_ratio, optimal_method, optimal_config.error_feedback
);
Ok(optimal_config)
}
fn calculate_optimal_compression_ratio(
&self,
network_metrics: &NetworkMetrics,
convergence_quality: f32,
) -> TorshResult<f32> {
let target_bandwidth =
self.adaptive_config.target_bandwidth_utilization * network_metrics.bandwidth_mbps;
let bandwidth_pressure = (1000.0 / target_bandwidth).min(1.0);
let latency_factor = (network_metrics.latency_ms / 10.0).min(1.0);
let loss_factor = network_metrics.packet_loss * 10.0;
let congestion_factor = network_metrics.congestion_factor;
let network_pressure =
(bandwidth_pressure + latency_factor + loss_factor + congestion_factor) / 4.0;
let efficiency_weight = self.adaptive_config.communication_efficiency_weight;
let convergence_weight = self.adaptive_config.convergence_quality_weight;
let network_compression_ratio = network_pressure * 0.8; let convergence_compression_ratio = (1.0 - convergence_quality) * 0.5;
let optimal_ratio = efficiency_weight * network_compression_ratio
+ convergence_weight * convergence_compression_ratio;
let clamped_ratio = optimal_ratio
.max(self.adaptive_config.min_compression_ratio)
.min(self.adaptive_config.max_compression_ratio);
debug!(
"Compression ratio calculation: network_pressure={:.3}, convergence_quality={:.3}, optimal={:.3}",
network_pressure, convergence_quality, clamped_ratio
);
Ok(clamped_ratio)
}
fn select_optimal_compression_method(
&self,
network_metrics: &NetworkMetrics,
compression_ratio: f32,
) -> TorshResult<CompressionMethod> {
if compression_ratio < 0.1 {
if network_metrics.stability_score > 0.8 {
Ok(CompressionMethod::Quantization { bits: 4 }) } else {
Ok(CompressionMethod::SignSGD) }
} else if compression_ratio < 0.3 {
if network_metrics.latency_ms < 2.0 {
Ok(CompressionMethod::TopK {
k: compression_ratio,
}) } else {
Ok(CompressionMethod::Quantization { bits: 8 }) }
} else if compression_ratio < 0.7 {
if network_metrics.bandwidth_mbps > 500.0 {
Ok(CompressionMethod::TopK {
k: compression_ratio,
}) } else {
Ok(CompressionMethod::Threshold { threshold: 0.01 }) }
} else {
if network_metrics.stability_score > 0.9 {
Ok(CompressionMethod::RandomK {
k: compression_ratio,
}) } else {
Ok(CompressionMethod::TopK {
k: compression_ratio,
}) }
}
}
pub fn get_network_metrics(&self) -> TorshResult<Option<NetworkMetrics>> {
let profiler = self.network_profiler.lock().map_err(|e| {
TorshDistributedError::communication_error(
"network_profiler",
format!("Lock error: {}", e),
)
})?;
Ok(profiler.get_average_metrics())
}
pub fn get_compression_statistics(&self) -> TorshResult<CompressionStatistics> {
let history = self.compression_history.lock().map_err(|e| {
TorshDistributedError::communication_error(
"compression_history",
format!("Lock error: {}", e),
)
})?;
if history.is_empty() {
return Ok(CompressionStatistics::default());
}
let count = history.len() as f32;
let avg_ratio = history.iter().map(|m| m.compression_ratio).sum::<f32>() / count;
let avg_time_us =
history.iter().map(|m| m.compression_time_us).sum::<u64>() / history.len() as u64;
let avg_throughput = history.iter().map(|m| m.throughput_mbps).sum::<f32>() / count;
let avg_error = history.iter().map(|m| m.compression_error).sum::<f32>() / count;
Ok(CompressionStatistics {
average_compression_ratio: avg_ratio,
average_compression_time_us: avg_time_us,
average_throughput_mbps: avg_throughput,
average_compression_error: avg_error,
total_compressions: history.len(),
})
}
pub fn update_adaptive_config(&mut self, config: AdaptiveCompressionConfig) -> TorshResult<()> {
self.adaptive_config = config.clone();
let mut profiler = self.network_profiler.lock().map_err(|e| {
TorshDistributedError::communication_error(
"network_profiler",
format!("Lock error: {}", e),
)
})?;
profiler.config = config;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TrainingMetrics {
pub loss: f32,
pub gradient_norm: f32,
pub learning_rate: f32,
}
#[derive(Debug, Clone, Default)]
pub struct CompressionStatistics {
pub average_compression_ratio: f32,
pub average_compression_time_us: u64,
pub average_throughput_mbps: f32,
pub average_compression_error: f32,
pub total_compressions: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[tokio::test]
async fn test_network_profiler() -> TorshResult<()> {
let config = AdaptiveCompressionConfig::default();
let mut profiler = NetworkProfiler::new(config);
let metrics = profiler.measure_network_performance()?;
assert!(metrics.bandwidth_mbps > 0.0);
assert!(metrics.latency_ms > 0.0);
assert!(metrics.packet_loss >= 0.0);
assert!(metrics.congestion_factor >= 0.0 && metrics.congestion_factor <= 1.0);
assert!(metrics.stability_score >= 0.0 && metrics.stability_score <= 1.0);
Ok(())
}
#[tokio::test]
async fn test_network_aware_compression() -> TorshResult<()> {
let base_config = CompressionConfig::default();
let adaptive_config = AdaptiveCompressionConfig::default();
let mut compressor = NetworkAwareCompressor::new(base_config, adaptive_config)?;
let gradient = ones::<f32>(&[1000, 1000])?;
let training_metrics = TrainingMetrics {
loss: 0.5,
gradient_norm: 1.0,
learning_rate: 0.001,
};
let (compressed, metrics) =
compressor.compress_gradient_adaptive(&gradient, Some(training_metrics))?;
assert!(metrics.compression_ratio > 0.0);
assert!(metrics.compression_ratio <= 1.0);
assert!(compressed.original_shape == vec![1000, 1000]);
Ok(())
}
#[tokio::test]
async fn test_adaptive_compression_ratio_calculation() -> TorshResult<()> {
let base_config = CompressionConfig::default();
let adaptive_config = AdaptiveCompressionConfig::default();
let compressor = NetworkAwareCompressor::new(base_config, adaptive_config)?;
let high_bandwidth_metrics = NetworkMetrics {
bandwidth_mbps: 2000.0,
latency_ms: 0.5,
packet_loss: 0.0,
congestion_factor: 0.1,
stability_score: 0.95,
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
};
let low_bandwidth_metrics = NetworkMetrics {
bandwidth_mbps: 100.0,
latency_ms: 5.0,
packet_loss: 0.01,
congestion_factor: 0.8,
stability_score: 0.6,
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
};
let high_bandwidth_ratio =
compressor.calculate_optimal_compression_ratio(&high_bandwidth_metrics, 0.8)?;
let low_bandwidth_ratio =
compressor.calculate_optimal_compression_ratio(&low_bandwidth_metrics, 0.8)?;
assert!(low_bandwidth_ratio > high_bandwidth_ratio);
Ok(())
}
#[tokio::test]
async fn test_compression_method_selection() -> TorshResult<()> {
let base_config = CompressionConfig::default();
let adaptive_config = AdaptiveCompressionConfig::default();
let compressor = NetworkAwareCompressor::new(base_config, adaptive_config)?;
let stable_network = NetworkMetrics {
bandwidth_mbps: 1000.0,
latency_ms: 1.0,
packet_loss: 0.0,
congestion_factor: 0.1,
stability_score: 0.95,
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("time should be after UNIX_EPOCH")
.as_millis() as u64,
};
let high_compression_method =
compressor.select_optimal_compression_method(&stable_network, 0.05)?;
let moderate_compression_method =
compressor.select_optimal_compression_method(&stable_network, 0.5)?;
match high_compression_method {
CompressionMethod::Quantization { bits: 4 } | CompressionMethod::SignSGD => {}
_ => panic!("Unexpected method for high compression"),
}
match moderate_compression_method {
CompressionMethod::TopK { .. } | CompressionMethod::Threshold { .. } => {}
_ => panic!("Unexpected method for moderate compression"),
}
Ok(())
}
#[tokio::test]
async fn test_convergence_tracking() -> TorshResult<()> {
let mut tracker = ConvergenceTracker::new(50);
for i in 0..20 {
let loss = 1.0 - (i as f32 * 0.05); let gradient_norm = 1.0;
let learning_rate = 0.001;
tracker.update_convergence_metrics(loss, gradient_norm, learning_rate);
}
let quality = tracker.get_convergence_quality();
assert!(quality > 0.0);
assert!(quality <= 1.0);
Ok(())
}
}