use super::{api, DistributedScalar, ReduceOp};
use crate::autograd::Variable;
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use std::collections::{HashMap, VecDeque};
use std::sync::{mpsc, Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
pub struct AsyncGradientSynchronizer<T: DistributedScalar> {
worker_handle: Option<JoinHandle<()>>,
request_sender: mpsc::Sender<GradientSyncRequest<T>>,
completion_receiver: mpsc::Receiver<GradientSyncCompletion>,
config: AsyncConfig,
bucket_manager: Arc<Mutex<GradientBucketManager<T>>>,
}
#[derive(Debug, Clone)]
pub struct AsyncConfig {
pub max_concurrent_ops: usize,
pub sync_timeout: Duration,
pub enable_compression: bool,
pub compression_threshold: usize,
pub enable_bucketing: bool,
pub bucket_size_mb: usize,
}
impl Default for AsyncConfig {
fn default() -> Self {
Self {
max_concurrent_ops: 4,
sync_timeout: Duration::from_secs(30),
enable_compression: false,
compression_threshold: 1024 * 1024, enable_bucketing: true,
bucket_size_mb: 25,
}
}
}
#[derive(Debug)]
pub struct GradientSyncRequest<T: DistributedScalar> {
pub id: u64,
pub param_name: String,
pub gradient: Tensor<T>,
pub reduce_op: ReduceOp,
pub priority: Priority,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Priority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug)]
pub struct GradientSyncCompletion {
pub request_id: u64,
pub success: bool,
pub error: Option<String>,
pub duration: Duration,
}
pub struct GradientBucketManager<T: DistributedScalar> {
buckets: HashMap<usize, GradientBucket<T>>,
param_to_bucket: HashMap<String, usize>,
next_bucket_id: usize,
config: AsyncConfig,
}
#[derive(Debug, Clone)]
pub struct GradientBucket<T: DistributedScalar> {
pub id: usize,
pub parameters: Vec<String>,
pub gradients: Vec<Tensor<T>>,
pub total_size: usize,
pub ready: bool,
pub last_update: Instant,
}
impl<T: DistributedScalar> AsyncGradientSynchronizer<T> {
pub fn new(config: AsyncConfig) -> RusTorchResult<Self> {
let (request_sender, request_receiver) = mpsc::channel();
let (completion_sender, completion_receiver) = mpsc::channel();
let bucket_manager = Arc::new(Mutex::new(GradientBucketManager::new(config.clone())));
let bucket_manager_worker = Arc::clone(&bucket_manager);
let worker_handle = thread::spawn(move || {
Self::worker_loop(request_receiver, completion_sender, bucket_manager_worker);
});
Ok(Self {
worker_handle: Some(worker_handle),
request_sender,
completion_receiver,
config,
bucket_manager,
})
}
pub fn submit_gradient(
&self,
param_name: String,
gradient: Tensor<T>,
priority: Priority,
) -> RusTorchResult<u64> {
static mut REQUEST_ID: u64 = 0;
let request_id = unsafe {
REQUEST_ID += 1;
REQUEST_ID
};
let request = GradientSyncRequest {
id: request_id,
param_name,
gradient,
reduce_op: ReduceOp::Average,
priority,
};
self.request_sender
.send(request)
.map_err(|_| RusTorchError::distributed("Failed to submit gradient sync request"))?;
Ok(request_id)
}
pub fn check_completions(&self) -> Vec<GradientSyncCompletion> {
let mut completions = Vec::new();
while let Ok(completion) = self.completion_receiver.try_recv() {
completions.push(completion);
}
completions
}
pub fn wait_for_completion(&self, request_id: u64, timeout: Duration) -> RusTorchResult<()> {
let start = Instant::now();
while start.elapsed() < timeout {
if let Ok(completion) = self
.completion_receiver
.recv_timeout(Duration::from_millis(100))
{
if completion.request_id == request_id {
if completion.success {
return Ok(());
} else {
return Err(RusTorchError::distributed(&format!(
"Gradient sync failed: {}",
completion
.error
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
}
}
}
Err(RusTorchError::distributed("Gradient sync timeout"))
}
pub fn sync_all(&self, timeout: Duration) -> RusTorchResult<()> {
let ready_buckets = {
let bucket_manager = self.bucket_manager.lock().unwrap();
bucket_manager.get_ready_buckets()
};
let start = Instant::now();
let mut pending_ids = Vec::new();
for bucket in ready_buckets {
for (i, grad) in bucket.gradients.iter().enumerate() {
if i < bucket.parameters.len() {
let id = self.submit_gradient(
bucket.parameters[i].clone(),
grad.clone(),
Priority::High,
)?;
pending_ids.push(id);
}
}
}
while !pending_ids.is_empty() && start.elapsed() < timeout {
let completions = self.check_completions();
for completion in completions {
if let Some(pos) = pending_ids
.iter()
.position(|&id| id == completion.request_id)
{
pending_ids.remove(pos);
if !completion.success {
return Err(RusTorchError::distributed(&format!(
"Gradient sync failed: {}",
completion
.error
.unwrap_or_else(|| "Unknown error".to_string())
)));
}
}
}
thread::sleep(Duration::from_millis(10));
}
if !pending_ids.is_empty() {
return Err(RusTorchError::distributed(
"Some gradient syncs did not complete",
));
}
Ok(())
}
fn worker_loop(
receiver: mpsc::Receiver<GradientSyncRequest<T>>,
completion_sender: mpsc::Sender<GradientSyncCompletion>,
bucket_manager: Arc<Mutex<GradientBucketManager<T>>>,
) {
let mut pending_requests: VecDeque<GradientSyncRequest<T>> = VecDeque::new();
loop {
while let Ok(request) = receiver.try_recv() {
pending_requests.push_back(request);
}
pending_requests
.make_contiguous()
.sort_by_key(|req| req.priority);
while let Some(request) = pending_requests.pop_front() {
let start_time = Instant::now();
let result = Self::process_gradient_sync(&request, &bucket_manager);
let duration = start_time.elapsed();
let completion = GradientSyncCompletion {
request_id: request.id,
success: result.is_ok(),
error: result.err().map(|e| e.to_string()),
duration,
};
if completion_sender.send(completion).is_err() {
break;
}
}
thread::sleep(Duration::from_millis(1));
}
}
fn process_gradient_sync(
request: &GradientSyncRequest<T>,
_bucket_manager: &Arc<Mutex<GradientBucketManager<T>>>,
) -> RusTorchResult<()> {
let mut grad_copy = request.gradient.clone();
api::all_reduce(&mut grad_copy, request.reduce_op, None, false)?;
Ok(())
}
}
impl<T: DistributedScalar> GradientBucketManager<T> {
pub fn new(config: AsyncConfig) -> Self {
Self {
buckets: HashMap::new(),
param_to_bucket: HashMap::new(),
next_bucket_id: 0,
config,
}
}
pub fn add_gradient(&mut self, param_name: String, gradient: Tensor<T>) -> RusTorchResult<()> {
let bucket_id = if let Some(&existing_id) = self.param_to_bucket.get(¶m_name) {
existing_id
} else {
self.find_or_create_bucket(¶m_name, &gradient)?
};
let gradient_size = self.estimate_tensor_size(&gradient);
if let Some(bucket) = self.buckets.get_mut(&bucket_id) {
bucket.parameters.push(param_name.clone());
bucket.gradients.push(gradient);
bucket.total_size += gradient_size;
bucket.last_update = Instant::now();
if bucket.total_size >= self.config.bucket_size_mb * 1024 * 1024 {
bucket.ready = true;
}
}
self.param_to_bucket.insert(param_name, bucket_id);
Ok(())
}
fn find_or_create_bucket(
&mut self,
param_name: &str,
gradient: &Tensor<T>,
) -> RusTorchResult<usize> {
let gradient_size = self.estimate_tensor_size(gradient);
let bucket_size_limit = self.config.bucket_size_mb * 1024 * 1024;
for (id, bucket) in &self.buckets {
if !bucket.ready && bucket.total_size + gradient_size <= bucket_size_limit {
return Ok(*id);
}
}
let bucket_id = self.next_bucket_id;
self.next_bucket_id += 1;
let bucket = GradientBucket {
id: bucket_id,
parameters: Vec::new(),
gradients: Vec::new(),
total_size: 0,
ready: false,
last_update: Instant::now(),
};
self.buckets.insert(bucket_id, bucket);
Ok(bucket_id)
}
pub fn get_ready_buckets(&self) -> Vec<GradientBucket<T>> {
self.buckets
.values()
.filter(|bucket| bucket.ready)
.cloned()
.collect()
}
pub fn mark_bucket_synced(&mut self, bucket_id: usize) -> RusTorchResult<()> {
if let Some(bucket) = self.buckets.get_mut(&bucket_id) {
bucket.ready = false;
bucket.gradients.clear();
bucket.parameters.clear();
bucket.total_size = 0;
}
Ok(())
}
fn estimate_tensor_size(&self, tensor: &Tensor<T>) -> usize {
tensor.numel() * std::mem::size_of::<T>()
}
}
impl<T: DistributedScalar> Drop for AsyncGradientSynchronizer<T> {
fn drop(&mut self) {
if let Some(handle) = self.worker_handle.take() {
let _ = handle.join();
}
}
}
pub mod compression {
use super::{DistributedScalar, *};
#[derive(Debug, Clone, Copy)]
pub enum CompressionAlgorithm {
None,
TopK { k: usize },
Quantization { bits: u8 },
ErrorFeedback,
}
pub fn compress_gradient<T: DistributedScalar>(
gradient: &Tensor<T>,
algorithm: CompressionAlgorithm,
) -> RusTorchResult<CompressedGradient<T>> {
match algorithm {
CompressionAlgorithm::None => Ok(CompressedGradient {
data: gradient.clone(),
metadata: CompressionMetadata::None,
original_shape: gradient.shape().to_vec(),
}),
CompressionAlgorithm::TopK { k } => compress_top_k(gradient, k),
_ => Err(RusTorchError::distributed(
"Compression algorithm not implemented",
)),
}
}
pub fn decompress_gradient<T: DistributedScalar>(
compressed: &CompressedGradient<T>,
) -> RusTorchResult<Tensor<T>> {
match &compressed.metadata {
CompressionMetadata::None => Ok(compressed.data.clone()),
CompressionMetadata::TopK { .. } => decompress_top_k(compressed),
_ => Err(RusTorchError::distributed("Decompression not implemented")),
}
}
#[derive(Debug, Clone)]
pub struct CompressedGradient<T: DistributedScalar> {
pub data: Tensor<T>,
pub metadata: CompressionMetadata,
pub original_shape: Vec<usize>,
}
#[derive(Debug, Clone)]
pub enum CompressionMetadata {
None,
TopK { k: usize, indices: Vec<usize> },
Quantization { scale: f32, zero_point: i8 },
}
fn compress_top_k<T: DistributedScalar>(
gradient: &Tensor<T>,
k: usize,
) -> RusTorchResult<CompressedGradient<T>> {
let total_elements = gradient.numel();
if k > total_elements {
return Err(RusTorchError::tensor_op("K larger than tensor size"));
}
let indices = (0..k).collect();
let compressed_data = gradient.clone();
Ok(CompressedGradient {
data: compressed_data,
metadata: CompressionMetadata::TopK { k, indices },
original_shape: gradient.shape().to_vec(),
})
}
fn decompress_top_k<T: DistributedScalar>(
compressed: &CompressedGradient<T>,
) -> RusTorchResult<Tensor<T>> {
Ok(compressed.data.clone())
}
}
pub struct AsyncGradContext<T: DistributedScalar> {
synchronizer: Arc<Mutex<AsyncGradientSynchronizer<T>>>,
pending_ops: Arc<Mutex<HashMap<u64, String>>>,
}
impl<T: DistributedScalar> AsyncGradContext<T> {
pub fn new(config: AsyncConfig) -> RusTorchResult<Self> {
let synchronizer = AsyncGradientSynchronizer::new(config)?;
Ok(Self {
synchronizer: Arc::new(Mutex::new(synchronizer)),
pending_ops: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn sync_parameter_async(&self, parameter: &Variable<T>) -> RusTorchResult<()> {
let grad_lock = parameter.grad();
let grad_guard = grad_lock.read().unwrap();
if let Some(ref gradient) = *grad_guard {
let param_name = format!("param_{}", parameter.id());
let gradient_clone = gradient.clone();
drop(grad_guard);
let synchronizer = self.synchronizer.lock().unwrap();
let request_id = synchronizer.submit_gradient(
param_name.clone(),
gradient_clone,
Priority::Normal,
)?;
let mut pending = self.pending_ops.lock().unwrap();
pending.insert(request_id, param_name);
}
Ok(())
}
pub fn synchronize(&self) -> RusTorchResult<()> {
let timeout = Duration::from_secs(30);
let synchronizer = self.synchronizer.lock().unwrap();
synchronizer.sync_all(timeout)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_async_config_default() {
let config = AsyncConfig::default();
assert_eq!(config.max_concurrent_ops, 4);
assert_eq!(config.bucket_size_mb, 25);
assert!(config.enable_bucketing);
}
#[test]
fn test_priority_ordering() {
assert!(Priority::Critical > Priority::High);
assert!(Priority::High > Priority::Normal);
assert!(Priority::Normal > Priority::Low);
}
#[test]
fn test_bucket_manager_creation() {
let config = AsyncConfig::default();
let manager = GradientBucketManager::<f32>::new(config);
assert_eq!(manager.buckets.len(), 0);
assert_eq!(manager.next_bucket_id, 0);
}
#[test]
fn test_compression_none() {
let tensor: Tensor<f32> = Tensor::ones(&[2, 2]);
let compressed =
compression::compress_gradient(&tensor, compression::CompressionAlgorithm::None)
.unwrap();
let decompressed = compression::decompress_gradient(&compressed).unwrap();
assert_eq!(tensor.shape(), decompressed.shape());
}
}