use super::collective::{allreduce, ReduceOp};
use super::communication::{AsyncCommunicator, CommunicationError, MessagePriority, TensorMessage};
use super::process::{Communicator, ProcessError};
use crate::error::NumRs2Error;
use oxicode::{Decode, Encode};
use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
#[derive(Error, Debug)]
pub enum CoordinatorError {
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Communication error: {0}")]
Communication(#[from] CommunicationError),
#[error("Invalid parameter key: {0}")]
InvalidKey(String),
#[error("Parameter not found: {0}")]
ParameterNotFound(String),
#[error("Checkpoint error: {0}")]
Checkpoint(String),
#[error("Recovery error: {0}")]
Recovery(String),
#[error("Synchronization error: {0}")]
Synchronization(String),
#[error("Configuration error: {0}")]
Configuration(String),
}
impl From<CoordinatorError> for NumRs2Error {
fn from(err: CoordinatorError) -> Self {
NumRs2Error::DistributedComputing(err.to_string())
}
}
pub struct ParameterServer {
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
num_ps: usize,
parameters: Arc<RwLock<HashMap<String, Vec<f32>>>>,
gradient_buffer: Arc<Mutex<HashMap<String, Vec<f32>>>>,
versions: Arc<RwLock<HashMap<String, u64>>>,
}
impl ParameterServer {
pub fn new(communicator: Arc<Communicator>, num_ps: usize) -> Result<Self, CoordinatorError> {
let async_comm = AsyncCommunicator::new(communicator.clone())?;
Ok(Self {
communicator,
async_comm,
num_ps,
parameters: Arc::new(RwLock::new(HashMap::new())),
gradient_buffer: Arc::new(Mutex::new(HashMap::new())),
versions: Arc::new(RwLock::new(HashMap::new())),
})
}
pub async fn init_parameter(
&self,
key: &str,
initial_values: Vec<f32>,
) -> Result<(), CoordinatorError> {
let mut params = self.parameters.write().await;
params.insert(key.to_string(), initial_values);
let mut versions = self.versions.write().await;
versions.insert(key.to_string(), 0);
Ok(())
}
pub async fn push_gradients(
&self,
key: &str,
gradients: &[f32],
) -> Result<(), CoordinatorError> {
let mut buffer = self.gradient_buffer.lock().await;
let entry = buffer
.entry(key.to_string())
.or_insert_with(|| vec![0.0; gradients.len()]);
for (acc, &grad) in entry.iter_mut().zip(gradients.iter()) {
*acc += grad;
}
Ok(())
}
pub async fn pull_parameters(&self, key: &str) -> Result<Vec<f32>, CoordinatorError> {
let params = self.parameters.read().await;
params
.get(key)
.cloned()
.ok_or_else(|| CoordinatorError::ParameterNotFound(key.to_string()))
}
pub async fn apply_gradients(
&self,
key: &str,
learning_rate: f32,
) -> Result<(), CoordinatorError> {
let mut buffer = self.gradient_buffer.lock().await;
let gradients = buffer
.get_mut(key)
.ok_or_else(|| CoordinatorError::ParameterNotFound(key.to_string()))?;
let mut params = self.parameters.write().await;
let parameters = params
.get_mut(key)
.ok_or_else(|| CoordinatorError::ParameterNotFound(key.to_string()))?;
for (param, grad) in parameters.iter_mut().zip(gradients.iter_mut()) {
*param -= learning_rate * *grad;
*grad = 0.0; }
let mut versions = self.versions.write().await;
if let Some(version) = versions.get_mut(key) {
*version += 1;
}
Ok(())
}
pub async fn get_version(&self, key: &str) -> Result<u64, CoordinatorError> {
let versions = self.versions.read().await;
versions
.get(key)
.copied()
.ok_or_else(|| CoordinatorError::ParameterNotFound(key.to_string()))
}
pub fn num_servers(&self) -> usize {
self.num_ps
}
pub fn get_server_for_key(&self, key: &str) -> usize {
let hash = key
.bytes()
.fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
(hash as usize) % self.num_ps
}
}
pub struct RingAllReduce {
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
ring: Vec<usize>,
}
impl RingAllReduce {
pub fn new(communicator: Arc<Communicator>) -> Result<Self, CoordinatorError> {
let async_comm = AsyncCommunicator::new(communicator.clone())?;
let size = communicator.size();
let ring: Vec<usize> = (0..size).map(|i| (i + 1) % size).collect();
Ok(Self {
communicator,
async_comm,
ring,
})
}
pub async fn allreduce(&self, data: &[f32]) -> Result<Vec<f32>, CoordinatorError> {
let size = self.communicator.size();
let rank = self.communicator.rank();
if size == 1 {
return Ok(data.to_vec());
}
let chunk_size = data.len().div_ceil(size);
let result = data.to_vec();
for step in 0..size - 1 {
let send_chunk = rank;
let recv_chunk = (rank + size - 1) % size;
let send_start = send_chunk * chunk_size;
let send_end = (send_start + chunk_size).min(data.len());
let next_rank = self.ring[rank];
let prev_rank = (rank + size - 1) % size;
if send_start < data.len() {
let chunk = &result[send_start..send_end];
let msg = TensorMessage::new(
chunk.to_vec(),
super::communication::CompressionStrategy::None,
MessagePriority::High,
);
self.async_comm.isend(msg, next_rank).await?;
}
let _ = (prev_rank, recv_chunk);
}
for step in 0..size - 1 {
let send_chunk = (rank + 1 - step + size) % size;
let next_rank = self.ring[rank];
let send_start = send_chunk * chunk_size;
let send_end = (send_start + chunk_size).min(data.len());
if send_start < data.len() {
let chunk = &result[send_start..send_end];
let msg = TensorMessage::new(
chunk.to_vec(),
super::communication::CompressionStrategy::None,
MessagePriority::High,
);
self.async_comm.isend(msg, next_rank).await?;
}
let _ = step;
}
Ok(result)
}
pub fn topology(&self) -> &[usize] {
&self.ring
}
}
pub struct TreeAllReduce {
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
branching_factor: usize,
parent: Option<usize>,
children: Vec<usize>,
}
impl TreeAllReduce {
pub fn new(
communicator: Arc<Communicator>,
branching_factor: usize,
) -> Result<Self, CoordinatorError> {
let async_comm = AsyncCommunicator::new(communicator.clone())?;
let rank = communicator.rank();
let size = communicator.size();
let parent = if rank == 0 {
None
} else {
Some((rank - 1) / branching_factor)
};
let children: Vec<usize> = (1..=branching_factor)
.map(|i| rank * branching_factor + i)
.filter(|&c| c < size)
.collect();
Ok(Self {
communicator,
async_comm,
branching_factor,
parent,
children,
})
}
pub async fn allreduce(&self, data: &[f32]) -> Result<Vec<f32>, CoordinatorError> {
let result = data.to_vec();
if !self.children.is_empty() {
for &child in &self.children {
let _ = child; }
}
if let Some(parent_rank) = self.parent {
let msg = TensorMessage::new(
result.clone(),
super::communication::CompressionStrategy::None,
MessagePriority::High,
);
self.async_comm.isend(msg, parent_rank).await?;
}
if let Some(parent_rank) = self.parent {
let _ = parent_rank;
}
for &child in &self.children {
let msg = TensorMessage::new(
result.clone(),
super::communication::CompressionStrategy::None,
MessagePriority::High,
);
self.async_comm.isend(msg, child).await?;
}
Ok(result)
}
pub fn branching_factor(&self) -> usize {
self.branching_factor
}
pub fn parent(&self) -> Option<usize> {
self.parent
}
pub fn children(&self) -> &[usize] {
&self.children
}
}
pub struct DistributedBarrier {
communicator: Arc<Communicator>,
generation: Arc<Mutex<u64>>,
arrived: Arc<Mutex<usize>>,
}
impl DistributedBarrier {
pub fn new(communicator: Arc<Communicator>) -> Result<Self, CoordinatorError> {
Ok(Self {
communicator,
generation: Arc::new(Mutex::new(0)),
arrived: Arc::new(Mutex::new(0)),
})
}
pub async fn wait(&self) -> Result<(), CoordinatorError> {
let size = self.communicator.size();
let mut arrived = self.arrived.lock().await;
*arrived += 1;
if *arrived == size {
*arrived = 0;
let mut gen = self.generation.lock().await;
*gen += 1;
} else {
drop(arrived);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
Ok(())
}
pub async fn generation(&self) -> u64 {
*self.generation.lock().await
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct Checkpoint {
pub id: String,
pub generation: u64,
pub parameters: HashMap<String, Vec<f32>>,
pub metadata: HashMap<String, String>,
}
impl Checkpoint {
pub fn new(id: String, generation: u64) -> Self {
Self {
id,
generation,
parameters: HashMap::new(),
metadata: HashMap::new(),
}
}
pub fn add_parameter(&mut self, key: String, values: Vec<f32>) {
self.parameters.insert(key, values);
}
pub fn add_metadata(&mut self, key: String, value: String) {
self.metadata.insert(key, value);
}
pub fn save(&self, path: &PathBuf) -> Result<(), CoordinatorError> {
let data = oxicode::encode_to_vec(self)
.map_err(|e| CoordinatorError::Checkpoint(format!("Failed to serialize: {}", e)))?;
std::fs::write(path, data)
.map_err(|e| CoordinatorError::Checkpoint(format!("Failed to write file: {}", e)))?;
Ok(())
}
pub fn load(path: &PathBuf) -> Result<Self, CoordinatorError> {
let data = std::fs::read(path)
.map_err(|e| CoordinatorError::Checkpoint(format!("Failed to read file: {}", e)))?;
let (checkpoint, _) = oxicode::decode_from_slice(&data)
.map_err(|e| CoordinatorError::Checkpoint(format!("Failed to deserialize: {}", e)))?;
Ok(checkpoint)
}
pub fn get_parameter(&self, key: &str) -> Option<&Vec<f32>> {
self.parameters.get(key)
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let checkpoint = Checkpoint::new("test".to_string(), 100);
assert_eq!(checkpoint.id, "test");
assert_eq!(checkpoint.generation, 100);
assert!(checkpoint.parameters.is_empty());
assert!(checkpoint.metadata.is_empty());
}
#[test]
fn test_checkpoint_add_parameter() {
let mut checkpoint = Checkpoint::new("test".to_string(), 100);
let params = vec![1.0, 2.0, 3.0];
checkpoint.add_parameter("weights".to_string(), params.clone());
assert_eq!(checkpoint.parameters.len(), 1);
assert_eq!(checkpoint.get_parameter("weights"), Some(¶ms));
}
#[test]
fn test_checkpoint_add_metadata() {
let mut checkpoint = Checkpoint::new("test".to_string(), 100);
checkpoint.add_metadata("model".to_string(), "resnet50".to_string());
assert_eq!(checkpoint.metadata.len(), 1);
assert_eq!(
checkpoint.get_metadata("model"),
Some(&"resnet50".to_string())
);
}
#[test]
fn test_checkpoint_serialization() {
let mut checkpoint = Checkpoint::new("test".to_string(), 100);
checkpoint.add_parameter("weights".to_string(), vec![1.0, 2.0, 3.0]);
checkpoint.add_metadata("model".to_string(), "test_model".to_string());
let serialized = oxicode::encode_to_vec(&checkpoint);
assert!(serialized.is_ok());
let bytes = serialized.expect("serialization failed");
let deserialized: Result<(Checkpoint, usize), _> = oxicode::decode_from_slice(&bytes);
assert!(deserialized.is_ok());
let (restored, _) = deserialized.expect("deserialization failed");
assert_eq!(restored.id, checkpoint.id);
assert_eq!(restored.generation, checkpoint.generation);
}
#[test]
fn test_checkpoint_save_load() {
let mut checkpoint = Checkpoint::new("test".to_string(), 100);
checkpoint.add_parameter("weights".to_string(), vec![1.0, 2.0, 3.0]);
let temp_dir = std::env::temp_dir();
let path = temp_dir.join("test_checkpoint.bin");
let save_result = checkpoint.save(&path);
assert!(save_result.is_ok());
let load_result = Checkpoint::load(&path);
assert!(load_result.is_ok());
let loaded = load_result.expect("load failed");
assert_eq!(loaded.id, checkpoint.id);
assert_eq!(loaded.generation, checkpoint.generation);
let _ = std::fs::remove_file(path);
}
#[test]
fn test_checkpoint_get_missing_parameter() {
let checkpoint = Checkpoint::new("test".to_string(), 100);
assert_eq!(checkpoint.get_parameter("missing"), None);
}
#[test]
fn test_checkpoint_get_missing_metadata() {
let checkpoint = Checkpoint::new("test".to_string(), 100);
assert_eq!(checkpoint.get_metadata("missing"), None);
}
#[test]
fn test_checkpoint_multiple_parameters() {
let mut checkpoint = Checkpoint::new("test".to_string(), 100);
checkpoint.add_parameter("layer1".to_string(), vec![1.0, 2.0]);
checkpoint.add_parameter("layer2".to_string(), vec![3.0, 4.0, 5.0]);
assert_eq!(checkpoint.parameters.len(), 2);
assert_eq!(checkpoint.get_parameter("layer1"), Some(&vec![1.0, 2.0]));
assert_eq!(
checkpoint.get_parameter("layer2"),
Some(&vec![3.0, 4.0, 5.0])
);
}
#[test]
fn test_checkpoint_multiple_metadata() {
let mut checkpoint = Checkpoint::new("test".to_string(), 100);
checkpoint.add_metadata("model".to_string(), "resnet".to_string());
checkpoint.add_metadata("dataset".to_string(), "imagenet".to_string());
assert_eq!(checkpoint.metadata.len(), 2);
assert_eq!(
checkpoint.get_metadata("model"),
Some(&"resnet".to_string())
);
assert_eq!(
checkpoint.get_metadata("dataset"),
Some(&"imagenet".to_string())
);
}
}