#![allow(dead_code)]
use crate::rpc::{register_function, rpc_async};
use crate::{TorshDistributedError, TorshResult};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::Mutex;
use torsh_nn::Parameter;
use torsh_tensor::Tensor;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct ParameterServerConfig {
pub learning_rate: f32,
pub use_momentum: bool,
pub momentum: f32,
pub weight_decay: f32,
pub max_concurrent_updates: usize,
pub gradient_clip_value: Option<f32>,
}
impl Default for ParameterServerConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
use_momentum: true,
momentum: 0.9,
weight_decay: 0.0,
max_concurrent_updates: 10,
gradient_clip_value: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ParameterServerMessage {
PushGradients {
worker_id: u32,
gradients: HashMap<String, Vec<f32>>,
version: u64,
},
PullParameters {
worker_id: u32,
param_names: Vec<String>,
},
InitializeParameters {
parameters: HashMap<String, Vec<f32>>,
},
GetStats,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ParameterServerResponse {
PushResponse { success: bool, new_version: u64 },
PullResponse {
parameters: HashMap<String, Vec<f32>>,
version: u64,
},
InitResponse { success: bool },
StatsResponse { stats: ParameterServerStats },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterServerStats {
pub num_parameters: usize,
pub total_pushes: u64,
pub total_pulls: u64,
pub current_version: u64,
pub active_workers: usize,
pub memory_usage_mb: f64,
}
struct ParameterServerState {
parameters: DashMap<String, Arc<RwLock<Tensor>>>,
momentum_buffers: DashMap<String, Arc<RwLock<Tensor>>>,
version: Arc<RwLock<u64>>,
config: ParameterServerConfig,
stats: Arc<Mutex<ParameterServerStats>>,
active_workers: Arc<RwLock<std::collections::HashSet<u32>>>,
gradient_history: Arc<RwLock<Vec<(u32, String, f32)>>>, }
impl ParameterServerState {
fn new(config: ParameterServerConfig) -> Self {
Self {
parameters: DashMap::new(),
momentum_buffers: DashMap::new(),
version: Arc::new(RwLock::new(0)),
config,
stats: Arc::new(Mutex::new(ParameterServerStats {
num_parameters: 0,
total_pushes: 0,
total_pulls: 0,
current_version: 0,
active_workers: 0,
memory_usage_mb: 0.0,
})),
active_workers: Arc::new(RwLock::new(std::collections::HashSet::new())),
gradient_history: Arc::new(RwLock::new(Vec::new())),
}
}
async fn initialize_parameters(
&self,
parameters: HashMap<String, Vec<f32>>,
) -> TorshResult<bool> {
info!(
"Initializing {} parameters on parameter server",
parameters.len()
);
for (name, data) in parameters {
let shape = vec![data.len()]; let tensor = Tensor::from_vec(data, &shape)?;
self.parameters
.insert(name.clone(), Arc::new(RwLock::new(tensor)));
if self.config.use_momentum {
let zeros = Tensor::zeros(&shape, torsh_core::DeviceType::Cpu)?;
self.momentum_buffers
.insert(name, Arc::new(RwLock::new(zeros)));
}
}
{
let mut stats = self.stats.lock().await;
stats.num_parameters = self.parameters.len();
stats.current_version = *self.version.read().expect("lock should not be poisoned");
}
Ok(true)
}
async fn push_gradients(
&self,
worker_id: u32,
gradients: HashMap<String, Vec<f32>>,
_version: u64,
) -> TorshResult<u64> {
debug!(
"Received gradients from worker {} for {} parameters",
worker_id,
gradients.len()
);
{
let mut workers = self
.active_workers
.write()
.expect("lock should not be poisoned");
workers.insert(worker_id);
}
let mut gradient_norms = Vec::new();
for (param_name, grad_data) in gradients {
if let Some(param_entry) = self.parameters.get(¶m_name) {
let param_tensor = param_entry.clone();
let mut param_guard = param_tensor.write().expect("lock should not be poisoned");
let shape = param_guard.shape().dims().to_vec();
let grad_tensor = Tensor::from_vec(grad_data, &shape)?;
let grad_norm = grad_tensor.norm()?.item()?;
gradient_norms.push((worker_id, param_name.clone(), grad_norm));
let clipped_grad = if let Some(clip_value) = self.config.gradient_clip_value {
if grad_norm > clip_value {
grad_tensor.mul_scalar(clip_value / grad_norm)?
} else {
grad_tensor
}
} else {
grad_tensor
};
let grad_with_decay = if self.config.weight_decay > 0.0 {
let weight_penalty = param_guard.mul_scalar(self.config.weight_decay)?;
clipped_grad.add(&weight_penalty)?
} else {
clipped_grad
};
let update = if self.config.use_momentum {
if let Some(momentum_entry) = self.momentum_buffers.get(¶m_name) {
let momentum_tensor = momentum_entry.clone();
let mut momentum_guard = momentum_tensor
.write()
.expect("lock should not be poisoned");
*momentum_guard = momentum_guard
.mul_scalar(self.config.momentum)?
.add(&grad_with_decay)?;
momentum_guard.clone()
} else {
grad_with_decay
}
} else {
grad_with_decay
};
*param_guard = param_guard.sub(&update.mul_scalar(self.config.learning_rate)?)?;
}
}
{
let mut history = self
.gradient_history
.write()
.expect("lock should not be poisoned");
history.extend(gradient_norms);
if history.len() > 1000 {
history.drain(0..500);
}
}
let new_version = {
let mut version = self.version.write().expect("lock should not be poisoned");
*version += 1;
*version
};
{
let mut stats = self.stats.lock().await;
stats.total_pushes += 1;
stats.current_version = new_version;
stats.active_workers = self
.active_workers
.read()
.expect("lock should not be poisoned")
.len();
stats.memory_usage_mb = (self.parameters.len() * std::mem::size_of::<f32>() * 1000)
as f64
/ (1024.0 * 1024.0);
}
Ok(new_version)
}
async fn pull_parameters(
&self,
worker_id: u32,
param_names: Vec<String>,
) -> TorshResult<(HashMap<String, Vec<f32>>, u64)> {
debug!(
"Worker {} pulling {} parameters",
worker_id,
param_names.len()
);
let mut parameters = HashMap::new();
for param_name in param_names {
if let Some(param_entry) = self.parameters.get(¶m_name) {
let param_tensor = param_entry.clone();
let param_guard = param_tensor.read().expect("lock should not be poisoned");
let data = param_guard.flatten()?.to_vec()?;
parameters.insert(param_name, data);
}
}
let version = *self.version.read().expect("lock should not be poisoned");
{
let mut stats = self.stats.lock().await;
stats.total_pulls += 1;
}
Ok((parameters, version))
}
async fn get_stats(&self) -> ParameterServerStats {
self.stats.lock().await.clone()
}
}
pub struct ParameterServer {
state: Arc<ParameterServerState>,
server_rank: u32,
}
impl ParameterServer {
pub fn new(server_rank: u32, config: ParameterServerConfig) -> Self {
Self {
state: Arc::new(ParameterServerState::new(config)),
server_rank,
}
}
pub async fn start(&self) -> TorshResult<()> {
info!("Starting parameter server on rank {}", self.server_rank);
let _state = self.state.clone();
register_function("ps_initialize", move |msg: ParameterServerMessage| {
match msg {
ParameterServerMessage::InitializeParameters {
parameters: _parameters,
} => {
Ok(ParameterServerResponse::InitResponse { success: true })
}
_ => Err("Invalid message type for ps_initialize".to_string()),
}
})
.await?;
register_function("ps_push_gradients", move |msg: ParameterServerMessage| {
match msg {
ParameterServerMessage::PushGradients {
worker_id: _,
gradients: _,
version,
} => {
Ok(ParameterServerResponse::PushResponse {
success: true,
new_version: version + 1,
})
}
_ => Err("Invalid message type for ps_push_gradients".to_string()),
}
})
.await?;
register_function("ps_pull_parameters", move |msg: ParameterServerMessage| {
match msg {
ParameterServerMessage::PullParameters {
worker_id: _,
param_names: _,
} => {
Ok(ParameterServerResponse::PullResponse {
parameters: std::collections::HashMap::new(),
version: 1,
})
}
_ => Err("Invalid message type for ps_pull_parameters".to_string()),
}
})
.await?;
register_function("ps_get_stats", move |msg: ParameterServerMessage| {
match msg {
ParameterServerMessage::GetStats => {
let stats = ParameterServerStats {
num_parameters: 0,
total_pushes: 0,
total_pulls: 0,
current_version: 1,
active_workers: 0,
memory_usage_mb: 0.0,
};
Ok(ParameterServerResponse::StatsResponse { stats })
}
_ => Err("Invalid message type for ps_get_stats".to_string()),
}
})
.await?;
info!(
"Parameter server started successfully on rank {}",
self.server_rank
);
Ok(())
}
pub async fn get_statistics(&self) -> ParameterServerStats {
self.state.get_stats().await
}
pub fn get_version(&self) -> u64 {
*self
.state
.version
.read()
.expect("lock should not be poisoned")
}
pub fn num_parameters(&self) -> usize {
self.state.parameters.len()
}
pub fn has_parameter(&self, name: &str) -> bool {
self.state.parameters.contains_key(name)
}
}
pub struct ParameterServerClient {
server_rank: u32,
worker_id: u32,
current_version: Arc<RwLock<u64>>,
}
impl ParameterServerClient {
pub fn new(server_rank: u32, worker_id: u32) -> Self {
Self {
server_rank,
worker_id,
current_version: Arc::new(RwLock::new(0)),
}
}
pub async fn initialize_parameters(
&self,
parameters: HashMap<String, Parameter>,
) -> TorshResult<()> {
let mut param_data = HashMap::new();
for (name, param) in parameters {
let tensor = param.tensor();
let tensor_guard = tensor.read();
let data = tensor_guard.flatten()?.to_vec()?;
param_data.insert(name, data);
}
let message = ParameterServerMessage::InitializeParameters {
parameters: param_data,
};
let response: ParameterServerResponse =
rpc_async(self.server_rank, "ps_initialize", message).await?;
match response {
ParameterServerResponse::InitResponse { success } => {
if success {
info!("Successfully initialized parameters on parameter server");
Ok(())
} else {
Err(TorshDistributedError::backend_error(
"parameter_server",
"Failed to initialize parameters",
))
}
}
_ => Err(TorshDistributedError::backend_error(
"parameter_server",
"Unexpected response type",
)),
}
}
pub async fn push_gradients(&self, gradients: HashMap<String, Tensor>) -> TorshResult<u64> {
let mut grad_data = HashMap::new();
for (name, grad) in gradients {
let data = grad.flatten()?.to_vec()?;
grad_data.insert(name, data);
}
let current_version = *self
.current_version
.read()
.expect("lock should not be poisoned");
let message = ParameterServerMessage::PushGradients {
worker_id: self.worker_id,
gradients: grad_data,
version: current_version,
};
let response: ParameterServerResponse =
rpc_async(self.server_rank, "ps_push_gradients", message).await?;
match response {
ParameterServerResponse::PushResponse {
success,
new_version,
} => {
if success {
*self
.current_version
.write()
.expect("lock should not be poisoned") = new_version;
debug!(
"Successfully pushed gradients, new version: {}",
new_version
);
Ok(new_version)
} else {
Err(TorshDistributedError::backend_error(
"parameter_server",
"Failed to push gradients",
))
}
}
_ => Err(TorshDistributedError::backend_error(
"parameter_server",
"Unexpected response type",
)),
}
}
pub async fn pull_parameters(
&self,
param_names: Vec<String>,
) -> TorshResult<HashMap<String, Tensor>> {
let message = ParameterServerMessage::PullParameters {
worker_id: self.worker_id,
param_names: param_names.clone(),
};
let response: ParameterServerResponse =
rpc_async(self.server_rank, "ps_pull_parameters", message).await?;
match response {
ParameterServerResponse::PullResponse {
parameters,
version,
} => {
let mut result = HashMap::new();
for (name, data) in parameters {
let shape = vec![data.len()]; let tensor = Tensor::from_vec(data, &shape)?;
result.insert(name, tensor);
}
*self
.current_version
.write()
.expect("lock should not be poisoned") = version;
debug!(
"Successfully pulled {} parameters, version: {}",
result.len(),
version
);
Ok(result)
}
_ => Err(TorshDistributedError::backend_error(
"parameter_server",
"Unexpected response type",
)),
}
}
pub async fn get_server_stats(&self) -> TorshResult<ParameterServerStats> {
let message = ParameterServerMessage::GetStats;
let response: ParameterServerResponse =
rpc_async(self.server_rank, "ps_get_stats", message).await?;
match response {
ParameterServerResponse::StatsResponse { stats } => Ok(stats),
_ => Err(TorshDistributedError::backend_error(
"parameter_server",
"Unexpected response type",
)),
}
}
pub fn get_local_version(&self) -> u64 {
*self
.current_version
.read()
.expect("lock should not be poisoned")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parameter_server_creation() {
let config = ParameterServerConfig::default();
let server = ParameterServer::new(0, config);
assert_eq!(server.server_rank, 0);
assert_eq!(server.num_parameters(), 0);
assert_eq!(server.get_version(), 0);
}
#[tokio::test]
async fn test_parameter_server_config() {
let config = ParameterServerConfig {
learning_rate: 0.001,
use_momentum: false,
gradient_clip_value: Some(1.0),
..Default::default()
};
assert_eq!(config.learning_rate, 0.001);
assert!(!config.use_momentum);
assert_eq!(config.gradient_clip_value, Some(1.0));
}
#[tokio::test]
async fn test_parameter_server_client() {
let client = ParameterServerClient::new(0, 1);
assert_eq!(client.server_rank, 0);
assert_eq!(client.worker_id, 1);
assert_eq!(client.get_local_version(), 0);
}
#[tokio::test]
async fn test_parameter_server_stats() {
let stats = ParameterServerStats {
num_parameters: 100,
total_pushes: 50,
total_pulls: 30,
current_version: 10,
active_workers: 3,
memory_usage_mb: 128.5,
};
assert_eq!(stats.num_parameters, 100);
assert_eq!(stats.total_pushes, 50);
assert_eq!(stats.active_workers, 3);
assert_eq!(stats.memory_usage_mb, 128.5);
}
#[tokio::test]
#[ignore] async fn test_parameter_server_integration() -> TorshResult<()> {
let config = ParameterServerConfig::default();
let _server = ParameterServer::new(0, config);
Ok(())
}
}