use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, Notify, RwLock};
use tokio::time::timeout;
use tracing::{debug, trace, warn};
use super::shard_manager::ModelShardManager;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum UpdateMode {
#[default]
Sync,
Async,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParameterServerConfig {
pub embedding_dim: usize,
pub num_entities: usize,
pub num_relations: usize,
pub num_shards: usize,
pub expected_workers: usize,
pub update_mode: UpdateMode,
pub learning_rate: f32,
pub max_staleness: u64,
pub barrier_timeout: Duration,
}
impl Default for ParameterServerConfig {
fn default() -> Self {
Self {
embedding_dim: 32,
num_entities: 64,
num_relations: 8,
num_shards: 4,
expected_workers: 4,
update_mode: UpdateMode::Sync,
learning_rate: 0.01,
max_staleness: 16,
barrier_timeout: Duration::from_secs(30),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardSnapshot {
pub shard_id: usize,
pub entities: Vec<Vec<f32>>,
pub entity_ids: Vec<String>,
pub relations: Vec<Vec<f32>>,
pub relation_ids: Vec<String>,
pub step: u64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ParameterServerStats {
pub total_pulls: u64,
pub total_pushes: u64,
pub barriers_completed: u64,
pub max_staleness_observed: u64,
pub last_grad_norm: f64,
}
#[derive(Debug)]
struct ShardState {
entities: Vec<Vec<f32>>,
entity_ids: Vec<String>,
step: u64,
pending: Vec<PendingGradient>,
pushed_workers: Vec<u32>,
staleness: u64,
barrier_done: Arc<Notify>,
}
#[derive(Debug, Clone)]
struct PendingGradient {
worker_id: u32,
rows: Vec<(usize, Vec<f32>)>, }
pub struct ParameterServer {
config: ParameterServerConfig,
shards: Vec<Arc<RwLock<ShardState>>>,
relations: Arc<RwLock<Vec<Vec<f32>>>>,
relation_ids: Vec<String>,
stats: Arc<Mutex<ParameterServerStats>>,
shard_manager: ModelShardManager,
}
impl ParameterServer {
pub fn new(
config: ParameterServerConfig,
entity_ids: Vec<String>,
relation_ids: Vec<String>,
shard_manager: ModelShardManager,
) -> Result<Self> {
if config.embedding_dim == 0 {
anyhow::bail!("embedding_dim must be > 0");
}
if config.num_shards == 0 {
anyhow::bail!("num_shards must be > 0");
}
if config.expected_workers == 0 {
anyhow::bail!("expected_workers must be > 0");
}
let num_shards = config.num_shards.min(shard_manager.num_shards());
let mut shards = Vec::with_capacity(num_shards);
let mut shard_buckets: Vec<(Vec<Vec<f32>>, Vec<String>)> =
(0..num_shards).map(|_| (Vec::new(), Vec::new())).collect();
for id in entity_ids.into_iter() {
let s = shard_manager.shard_for(&id);
let row = init_row(&id, config.embedding_dim);
shard_buckets[s].0.push(row);
shard_buckets[s].1.push(id);
}
for (entities, ids) in shard_buckets.into_iter() {
shards.push(Arc::new(RwLock::new(ShardState {
entities,
entity_ids: ids,
step: 0,
pending: Vec::new(),
pushed_workers: Vec::new(),
staleness: 0,
barrier_done: Arc::new(Notify::new()),
})));
}
let mut relations = Vec::with_capacity(relation_ids.len());
for id in &relation_ids {
relations.push(init_row(id, config.embedding_dim));
}
Ok(Self {
config,
shards,
relations: Arc::new(RwLock::new(relations)),
relation_ids,
stats: Arc::new(Mutex::new(ParameterServerStats::default())),
shard_manager,
})
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub fn config(&self) -> &ParameterServerConfig {
&self.config
}
pub fn shard_manager(&self) -> &ModelShardManager {
&self.shard_manager
}
pub async fn pull(&self, shard_id: usize) -> Result<ShardSnapshot> {
let shard = self
.shards
.get(shard_id)
.ok_or_else(|| anyhow::anyhow!("shard {shard_id} out of range"))?;
let g = shard.read().await;
let snap = ShardSnapshot {
shard_id,
entities: g.entities.clone(),
entity_ids: g.entity_ids.clone(),
relations: self.relations.read().await.clone(),
relation_ids: self.relation_ids.clone(),
step: g.step,
};
drop(g);
if matches!(self.config.update_mode, UpdateMode::Async) {
let mut w = shard.write().await;
w.staleness = 0;
}
let mut stats = self.stats.lock().await;
stats.total_pulls += 1;
Ok(snap)
}
pub async fn push(
&self,
shard_id: usize,
worker_id: u32,
rows: Vec<(usize, Vec<f32>)>,
) -> Result<()> {
let shard = self
.shards
.get(shard_id)
.ok_or_else(|| anyhow::anyhow!("shard {shard_id} out of range"))?
.clone();
for (idx, grad) in &rows {
if grad.len() != self.config.embedding_dim {
anyhow::bail!(
"gradient row {idx} has dim {} but server expects {}",
grad.len(),
self.config.embedding_dim
);
}
}
match self.config.update_mode {
UpdateMode::Sync => self.push_sync(shard, shard_id, worker_id, rows).await,
UpdateMode::Async => self.push_async(shard, worker_id, rows).await,
}
}
pub async fn push_relation(&self, worker_id: u32, rows: Vec<(usize, Vec<f32>)>) -> Result<()> {
for (idx, grad) in &rows {
if grad.len() != self.config.embedding_dim {
anyhow::bail!(
"relation gradient row {idx} has dim {} but server expects {}",
grad.len(),
self.config.embedding_dim
);
}
}
let mut rel = self.relations.write().await;
for (idx, grad) in rows {
if let Some(target) = rel.get_mut(idx) {
for (t, g) in target.iter_mut().zip(grad.iter()) {
*t -= self.config.learning_rate * *g;
}
}
}
trace!("worker {worker_id}: relation gradients applied");
Ok(())
}
pub async fn stats(&self) -> ParameterServerStats {
self.stats.lock().await.clone()
}
pub async fn shard_steps(&self) -> Vec<u64> {
let mut steps = Vec::with_capacity(self.shards.len());
for s in &self.shards {
steps.push(s.read().await.step);
}
steps
}
async fn push_sync(
&self,
shard: Arc<RwLock<ShardState>>,
shard_id: usize,
worker_id: u32,
rows: Vec<(usize, Vec<f32>)>,
) -> Result<()> {
let (apply_now, barrier) = {
let mut g = shard.write().await;
if g.pushed_workers.contains(&worker_id) {
anyhow::bail!("worker {worker_id} already pushed for shard {shard_id} this step");
}
g.pending.push(PendingGradient { worker_id, rows });
g.pushed_workers.push(worker_id);
let ready = g.pushed_workers.len() >= self.config.expected_workers;
(ready, g.barrier_done.clone())
};
if apply_now {
self.apply_sync_barrier(shard.clone(), shard_id).await?;
barrier.notify_waiters();
return Ok(());
}
let waited = timeout(self.config.barrier_timeout, barrier.notified()).await;
if waited.is_err() {
warn!(
"shard {shard_id} barrier timed out after {:?}; flushing partial step",
self.config.barrier_timeout
);
self.apply_sync_barrier(shard, shard_id).await?;
}
Ok(())
}
async fn apply_sync_barrier(
&self,
shard: Arc<RwLock<ShardState>>,
shard_id: usize,
) -> Result<()> {
let mut g = shard.write().await;
let lr = self.config.learning_rate;
let dim = self.config.embedding_dim;
let n = g.pending.len().max(1) as f32;
let mut acc: std::collections::HashMap<usize, Vec<f32>> = std::collections::HashMap::new();
for pending in &g.pending {
for (idx, grad) in &pending.rows {
let entry = acc.entry(*idx).or_insert_with(|| vec![0.0; dim]);
for (t, gval) in entry.iter_mut().zip(grad.iter()) {
*t += *gval / n;
}
}
}
let mut sq_sum = 0.0_f64;
for (idx, grad) in &acc {
if let Some(target) = g.entities.get_mut(*idx) {
for (t, gval) in target.iter_mut().zip(grad.iter()) {
*t -= lr * *gval;
sq_sum += (*gval as f64) * (*gval as f64);
}
}
}
g.pending.clear();
g.pushed_workers.clear();
g.step += 1;
let new_step = g.step;
drop(g);
let mut stats = self.stats.lock().await;
stats.total_pushes += 1;
stats.barriers_completed += 1;
if !acc.is_empty() {
stats.last_grad_norm = sq_sum / acc.len() as f64;
}
debug!("shard {shard_id} barrier applied (new step = {new_step})");
Ok(())
}
async fn push_async(
&self,
shard: Arc<RwLock<ShardState>>,
worker_id: u32,
rows: Vec<(usize, Vec<f32>)>,
) -> Result<()> {
let lr = self.config.learning_rate;
let mut g = shard.write().await;
let mut sq_sum = 0.0_f64;
let mut applied = 0usize;
for (idx, grad) in &rows {
if let Some(target) = g.entities.get_mut(*idx) {
for (t, gval) in target.iter_mut().zip(grad.iter()) {
*t -= lr * *gval;
sq_sum += (*gval as f64) * (*gval as f64);
}
applied += 1;
}
}
g.staleness = g.staleness.saturating_add(1);
g.step += 1;
let new_staleness = g.staleness;
drop(g);
let mut stats = self.stats.lock().await;
stats.total_pushes += 1;
stats.max_staleness_observed = stats.max_staleness_observed.max(new_staleness);
if applied > 0 {
stats.last_grad_norm = sq_sum / applied as f64;
}
if new_staleness > self.config.max_staleness {
warn!(
"worker {worker_id} async push: staleness {new_staleness} exceeds max {}",
self.config.max_staleness
);
}
Ok(())
}
}
fn init_row(seed_id: &str, dim: usize) -> Vec<f32> {
let mut h: u64 = 0xcbf2_9ce4_8422_2325;
for byte in seed_id.as_bytes() {
h ^= *byte as u64;
h = h.wrapping_mul(0x100_0000_01b3);
}
let mut state = h | 1;
let mut row = Vec::with_capacity(dim);
for _ in 0..dim {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let raw = (state >> 32) as u32;
let f = (raw as f32 / u32::MAX as f32) * 0.1 - 0.05;
row.push(f);
}
row
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed_training::shard_manager::{ModelShardManager, ShardingStrategy};
fn small_cfg(mode: UpdateMode, workers: usize) -> ParameterServerConfig {
ParameterServerConfig {
embedding_dim: 4,
num_entities: 8,
num_relations: 2,
num_shards: 2,
expected_workers: workers,
update_mode: mode,
learning_rate: 0.1,
max_staleness: 8,
barrier_timeout: Duration::from_millis(500),
}
}
fn small_server(mode: UpdateMode, workers: usize) -> ParameterServer {
let cfg = small_cfg(mode, workers);
let entity_ids: Vec<String> = (0..cfg.num_entities).map(|i| format!("e{i}")).collect();
let relation_ids: Vec<String> = (0..cfg.num_relations).map(|i| format!("r{i}")).collect();
let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
ParameterServer::new(cfg, entity_ids, relation_ids, mgr)
.expect("server construction failed")
}
#[tokio::test]
async fn server_constructs_and_reports_shards() {
let s = small_server(UpdateMode::Sync, 2);
assert_eq!(s.num_shards(), 2);
}
#[tokio::test]
async fn server_rejects_zero_dim() {
let mut cfg = small_cfg(UpdateMode::Sync, 2);
cfg.embedding_dim = 0;
let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
let res = ParameterServer::new(cfg, vec!["a".into()], vec!["r".into()], mgr);
assert!(res.is_err());
}
#[tokio::test]
async fn pull_returns_consistent_dim_rows() {
let s = small_server(UpdateMode::Sync, 2);
for shard in 0..s.num_shards() {
let snap = s.pull(shard).await.expect("pull");
assert_eq!(snap.shard_id, shard);
assert_eq!(snap.relations.len(), 2);
for row in &snap.entities {
assert_eq!(row.len(), 4);
}
}
}
#[tokio::test]
async fn push_async_applies_immediately() {
let s = small_server(UpdateMode::Async, 1);
let snap = s.pull(0).await.expect("pull");
let before = snap.entities.first().cloned().unwrap_or_default();
let grad: Vec<f32> = vec![1.0; 4];
if !snap.entities.is_empty() {
s.push(0, 0, vec![(0, grad.clone())])
.await
.expect("push async");
let snap2 = s.pull(0).await.expect("pull2");
let after = snap2.entities.first().cloned().unwrap_or_default();
for (b, a) in before.iter().zip(after.iter()) {
assert!(
(b - a - 0.1).abs() < 1e-5,
"expected b - a ≈ 0.1, got b={b}, a={a}"
);
}
}
}
#[tokio::test]
async fn push_sync_buffers_until_barrier() {
let s = Arc::new(small_server(UpdateMode::Sync, 2));
let snap = s.pull(0).await.expect("pull");
if snap.entities.is_empty() {
return;
}
let grad: Vec<f32> = vec![2.0; 4];
let s0 = Arc::clone(&s);
let g0 = grad.clone();
let h0 = tokio::spawn(async move {
s0.push(0, 0, vec![(0, g0)]).await.expect("worker 0 push");
});
let s1 = Arc::clone(&s);
let g1 = grad.clone();
let h1 = tokio::spawn(async move {
s1.push(0, 1, vec![(0, g1)]).await.expect("worker 1 push");
});
h0.await.expect("worker 0 join");
h1.await.expect("worker 1 join");
let stats = s.stats().await;
assert_eq!(
stats.barriers_completed, 1,
"exactly one barrier should have fired"
);
let steps = s.shard_steps().await;
assert_eq!(steps[0], 1, "shard 0 should have advanced one step");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn push_sync_rejects_double_push_from_same_worker() {
let s = Arc::new(small_server(UpdateMode::Sync, 2));
let snap = s.pull(0).await.expect("pull");
if snap.entities.is_empty() {
return;
}
let g = vec![0.0_f32; 4];
let s_first = Arc::clone(&s);
let g_first = g.clone();
let h = tokio::spawn(async move {
s_first.push(0, 7, vec![(0, g_first)]).await
});
tokio::task::yield_now().await;
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let err = s.push(0, 7, vec![(0, g)]).await;
assert!(err.is_err(), "second push by same worker must fail");
let _ = h.await.expect("join push task");
}
#[tokio::test]
async fn push_validates_gradient_dim() {
let s = small_server(UpdateMode::Async, 1);
let res = s.push(0, 0, vec![(0, vec![1.0; 3])]).await;
assert!(res.is_err());
}
#[tokio::test]
async fn relation_push_applies_with_learning_rate() {
let s = small_server(UpdateMode::Sync, 1);
let before = s.pull(0).await.expect("pull").relations[0].clone();
s.push_relation(0, vec![(0, vec![1.0_f32; 4])])
.await
.expect("rel push");
let after = s.pull(0).await.expect("pull2").relations[0].clone();
for (b, a) in before.iter().zip(after.iter()) {
assert!((b - a - 0.1).abs() < 1e-5);
}
}
#[tokio::test]
async fn async_pull_resets_staleness() {
let s = small_server(UpdateMode::Async, 1);
let snap = s.pull(0).await.expect("pull");
if snap.entities.is_empty() {
return;
}
for _ in 0..3 {
s.push(0, 0, vec![(0, vec![0.1_f32; 4])])
.await
.expect("push");
}
let stats_before = s.stats().await;
assert!(stats_before.max_staleness_observed >= 3);
let _ = s.pull(0).await.expect("pull");
let stats_after = s.stats().await;
assert_eq!(
stats_after.max_staleness_observed, stats_before.max_staleness_observed,
"max_staleness_observed is monotonic"
);
}
}