1use anyhow::Result;
31use serde::{Deserialize, Serialize};
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{Mutex, Notify, RwLock};
35use tokio::time::timeout;
36use tracing::{debug, trace, warn};
37
38use super::shard_manager::ModelShardManager;
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
42pub enum UpdateMode {
43 #[default]
45 Sync,
46 Async,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ParameterServerConfig {
53 pub embedding_dim: usize,
55 pub num_entities: usize,
57 pub num_relations: usize,
59 pub num_shards: usize,
61 pub expected_workers: usize,
63 pub update_mode: UpdateMode,
65 pub learning_rate: f32,
67 pub max_staleness: u64,
69 pub barrier_timeout: Duration,
72}
73
74impl Default for ParameterServerConfig {
75 fn default() -> Self {
76 Self {
77 embedding_dim: 32,
78 num_entities: 64,
79 num_relations: 8,
80 num_shards: 4,
81 expected_workers: 4,
82 update_mode: UpdateMode::Sync,
83 learning_rate: 0.01,
84 max_staleness: 16,
85 barrier_timeout: Duration::from_secs(30),
86 }
87 }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ShardSnapshot {
96 pub shard_id: usize,
98 pub entities: Vec<Vec<f32>>,
100 pub entity_ids: Vec<String>,
102 pub relations: Vec<Vec<f32>>,
104 pub relation_ids: Vec<String>,
106 pub step: u64,
108}
109
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct ParameterServerStats {
113 pub total_pulls: u64,
115 pub total_pushes: u64,
117 pub barriers_completed: u64,
119 pub max_staleness_observed: u64,
121 pub last_grad_norm: f64,
123}
124
125#[derive(Debug)]
128struct ShardState {
129 entities: Vec<Vec<f32>>,
131 entity_ids: Vec<String>,
133 step: u64,
135 pending: Vec<PendingGradient>,
137 pushed_workers: Vec<u32>,
139 staleness: u64,
141 barrier_done: Arc<Notify>,
143}
144
145#[derive(Debug, Clone)]
146struct PendingGradient {
147 worker_id: u32,
148 rows: Vec<(usize, Vec<f32>)>, }
150
151pub struct ParameterServer {
157 config: ParameterServerConfig,
158 shards: Vec<Arc<RwLock<ShardState>>>,
161 relations: Arc<RwLock<Vec<Vec<f32>>>>,
163 relation_ids: Vec<String>,
165 stats: Arc<Mutex<ParameterServerStats>>,
167 shard_manager: ModelShardManager,
170}
171
172impl ParameterServer {
173 pub fn new(
180 config: ParameterServerConfig,
181 entity_ids: Vec<String>,
182 relation_ids: Vec<String>,
183 shard_manager: ModelShardManager,
184 ) -> Result<Self> {
185 if config.embedding_dim == 0 {
186 anyhow::bail!("embedding_dim must be > 0");
187 }
188 if config.num_shards == 0 {
189 anyhow::bail!("num_shards must be > 0");
190 }
191 if config.expected_workers == 0 {
192 anyhow::bail!("expected_workers must be > 0");
193 }
194 let num_shards = config.num_shards.min(shard_manager.num_shards());
195
196 let mut shards = Vec::with_capacity(num_shards);
198 let mut shard_buckets: Vec<(Vec<Vec<f32>>, Vec<String>)> =
199 (0..num_shards).map(|_| (Vec::new(), Vec::new())).collect();
200
201 for id in entity_ids.into_iter() {
202 let s = shard_manager.shard_for(&id);
203 let row = init_row(&id, config.embedding_dim);
204 shard_buckets[s].0.push(row);
205 shard_buckets[s].1.push(id);
206 }
207
208 for (entities, ids) in shard_buckets.into_iter() {
209 shards.push(Arc::new(RwLock::new(ShardState {
210 entities,
211 entity_ids: ids,
212 step: 0,
213 pending: Vec::new(),
214 pushed_workers: Vec::new(),
215 staleness: 0,
216 barrier_done: Arc::new(Notify::new()),
217 })));
218 }
219
220 let mut relations = Vec::with_capacity(relation_ids.len());
222 for id in &relation_ids {
223 relations.push(init_row(id, config.embedding_dim));
224 }
225
226 Ok(Self {
227 config,
228 shards,
229 relations: Arc::new(RwLock::new(relations)),
230 relation_ids,
231 stats: Arc::new(Mutex::new(ParameterServerStats::default())),
232 shard_manager,
233 })
234 }
235
236 pub fn num_shards(&self) -> usize {
238 self.shards.len()
239 }
240
241 pub fn config(&self) -> &ParameterServerConfig {
243 &self.config
244 }
245
246 pub fn shard_manager(&self) -> &ModelShardManager {
248 &self.shard_manager
249 }
250
251 pub async fn pull(&self, shard_id: usize) -> Result<ShardSnapshot> {
256 let shard = self
257 .shards
258 .get(shard_id)
259 .ok_or_else(|| anyhow::anyhow!("shard {shard_id} out of range"))?;
260 let g = shard.read().await;
261 let snap = ShardSnapshot {
262 shard_id,
263 entities: g.entities.clone(),
264 entity_ids: g.entity_ids.clone(),
265 relations: self.relations.read().await.clone(),
266 relation_ids: self.relation_ids.clone(),
267 step: g.step,
268 };
269 drop(g);
270
271 if matches!(self.config.update_mode, UpdateMode::Async) {
273 let mut w = shard.write().await;
274 w.staleness = 0;
275 }
276
277 let mut stats = self.stats.lock().await;
278 stats.total_pulls += 1;
279 Ok(snap)
280 }
281
282 pub async fn push(
290 &self,
291 shard_id: usize,
292 worker_id: u32,
293 rows: Vec<(usize, Vec<f32>)>,
294 ) -> Result<()> {
295 let shard = self
296 .shards
297 .get(shard_id)
298 .ok_or_else(|| anyhow::anyhow!("shard {shard_id} out of range"))?
299 .clone();
300
301 for (idx, grad) in &rows {
303 if grad.len() != self.config.embedding_dim {
304 anyhow::bail!(
305 "gradient row {idx} has dim {} but server expects {}",
306 grad.len(),
307 self.config.embedding_dim
308 );
309 }
310 }
311
312 match self.config.update_mode {
313 UpdateMode::Sync => self.push_sync(shard, shard_id, worker_id, rows).await,
314 UpdateMode::Async => self.push_async(shard, worker_id, rows).await,
315 }
316 }
317
318 pub async fn push_relation(&self, worker_id: u32, rows: Vec<(usize, Vec<f32>)>) -> Result<()> {
323 for (idx, grad) in &rows {
324 if grad.len() != self.config.embedding_dim {
325 anyhow::bail!(
326 "relation gradient row {idx} has dim {} but server expects {}",
327 grad.len(),
328 self.config.embedding_dim
329 );
330 }
331 }
332
333 let mut rel = self.relations.write().await;
334 for (idx, grad) in rows {
335 if let Some(target) = rel.get_mut(idx) {
336 for (t, g) in target.iter_mut().zip(grad.iter()) {
337 *t -= self.config.learning_rate * *g;
338 }
339 }
340 }
341 trace!("worker {worker_id}: relation gradients applied");
342 Ok(())
343 }
344
345 pub async fn stats(&self) -> ParameterServerStats {
347 self.stats.lock().await.clone()
348 }
349
350 pub async fn shard_steps(&self) -> Vec<u64> {
352 let mut steps = Vec::with_capacity(self.shards.len());
353 for s in &self.shards {
354 steps.push(s.read().await.step);
355 }
356 steps
357 }
358
359 async fn push_sync(
362 &self,
363 shard: Arc<RwLock<ShardState>>,
364 shard_id: usize,
365 worker_id: u32,
366 rows: Vec<(usize, Vec<f32>)>,
367 ) -> Result<()> {
368 let (apply_now, barrier) = {
370 let mut g = shard.write().await;
371 if g.pushed_workers.contains(&worker_id) {
372 anyhow::bail!("worker {worker_id} already pushed for shard {shard_id} this step");
373 }
374 g.pending.push(PendingGradient { worker_id, rows });
375 g.pushed_workers.push(worker_id);
376 let ready = g.pushed_workers.len() >= self.config.expected_workers;
377 (ready, g.barrier_done.clone())
378 };
379
380 if apply_now {
381 self.apply_sync_barrier(shard.clone(), shard_id).await?;
382 barrier.notify_waiters();
383 return Ok(());
384 }
385
386 let waited = timeout(self.config.barrier_timeout, barrier.notified()).await;
389 if waited.is_err() {
390 warn!(
391 "shard {shard_id} barrier timed out after {:?}; flushing partial step",
392 self.config.barrier_timeout
393 );
394 self.apply_sync_barrier(shard, shard_id).await?;
395 }
396 Ok(())
397 }
398
399 async fn apply_sync_barrier(
400 &self,
401 shard: Arc<RwLock<ShardState>>,
402 shard_id: usize,
403 ) -> Result<()> {
404 let mut g = shard.write().await;
405 let lr = self.config.learning_rate;
406 let dim = self.config.embedding_dim;
407 let n = g.pending.len().max(1) as f32;
408
409 let mut acc: std::collections::HashMap<usize, Vec<f32>> = std::collections::HashMap::new();
411 for pending in &g.pending {
412 for (idx, grad) in &pending.rows {
413 let entry = acc.entry(*idx).or_insert_with(|| vec![0.0; dim]);
414 for (t, gval) in entry.iter_mut().zip(grad.iter()) {
415 *t += *gval / n;
416 }
417 }
418 }
419
420 let mut sq_sum = 0.0_f64;
422 for (idx, grad) in &acc {
423 if let Some(target) = g.entities.get_mut(*idx) {
424 for (t, gval) in target.iter_mut().zip(grad.iter()) {
425 *t -= lr * *gval;
426 sq_sum += (*gval as f64) * (*gval as f64);
427 }
428 }
429 }
430
431 g.pending.clear();
432 g.pushed_workers.clear();
433 g.step += 1;
434 let new_step = g.step;
435 drop(g);
436
437 let mut stats = self.stats.lock().await;
438 stats.total_pushes += 1;
439 stats.barriers_completed += 1;
440 if !acc.is_empty() {
441 stats.last_grad_norm = sq_sum / acc.len() as f64;
442 }
443 debug!("shard {shard_id} barrier applied (new step = {new_step})");
444 Ok(())
445 }
446
447 async fn push_async(
448 &self,
449 shard: Arc<RwLock<ShardState>>,
450 worker_id: u32,
451 rows: Vec<(usize, Vec<f32>)>,
452 ) -> Result<()> {
453 let lr = self.config.learning_rate;
454 let mut g = shard.write().await;
455 let mut sq_sum = 0.0_f64;
456 let mut applied = 0usize;
457 for (idx, grad) in &rows {
458 if let Some(target) = g.entities.get_mut(*idx) {
459 for (t, gval) in target.iter_mut().zip(grad.iter()) {
460 *t -= lr * *gval;
461 sq_sum += (*gval as f64) * (*gval as f64);
462 }
463 applied += 1;
464 }
465 }
466 g.staleness = g.staleness.saturating_add(1);
467 g.step += 1;
468 let new_staleness = g.staleness;
469 drop(g);
470
471 let mut stats = self.stats.lock().await;
472 stats.total_pushes += 1;
473 stats.max_staleness_observed = stats.max_staleness_observed.max(new_staleness);
474 if applied > 0 {
475 stats.last_grad_norm = sq_sum / applied as f64;
476 }
477
478 if new_staleness > self.config.max_staleness {
479 warn!(
480 "worker {worker_id} async push: staleness {new_staleness} exceeds max {}",
481 self.config.max_staleness
482 );
483 }
484 Ok(())
485 }
486}
487
488fn init_row(seed_id: &str, dim: usize) -> Vec<f32> {
489 let mut h: u64 = 0xcbf2_9ce4_8422_2325;
492 for byte in seed_id.as_bytes() {
493 h ^= *byte as u64;
494 h = h.wrapping_mul(0x100_0000_01b3);
495 }
496 let mut state = h | 1;
497 let mut row = Vec::with_capacity(dim);
498 for _ in 0..dim {
499 state = state
501 .wrapping_mul(6364136223846793005)
502 .wrapping_add(1442695040888963407);
503 let raw = (state >> 32) as u32;
505 let f = (raw as f32 / u32::MAX as f32) * 0.1 - 0.05;
506 row.push(f);
507 }
508 row
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use crate::distributed_training::shard_manager::{ModelShardManager, ShardingStrategy};
515
516 fn small_cfg(mode: UpdateMode, workers: usize) -> ParameterServerConfig {
517 ParameterServerConfig {
518 embedding_dim: 4,
519 num_entities: 8,
520 num_relations: 2,
521 num_shards: 2,
522 expected_workers: workers,
523 update_mode: mode,
524 learning_rate: 0.1,
525 max_staleness: 8,
526 barrier_timeout: Duration::from_millis(500),
527 }
528 }
529
530 fn small_server(mode: UpdateMode, workers: usize) -> ParameterServer {
531 let cfg = small_cfg(mode, workers);
532 let entity_ids: Vec<String> = (0..cfg.num_entities).map(|i| format!("e{i}")).collect();
533 let relation_ids: Vec<String> = (0..cfg.num_relations).map(|i| format!("r{i}")).collect();
534 let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
535 ParameterServer::new(cfg, entity_ids, relation_ids, mgr)
536 .expect("server construction failed")
537 }
538
539 #[tokio::test]
540 async fn server_constructs_and_reports_shards() {
541 let s = small_server(UpdateMode::Sync, 2);
542 assert_eq!(s.num_shards(), 2);
543 }
544
545 #[tokio::test]
546 async fn server_rejects_zero_dim() {
547 let mut cfg = small_cfg(UpdateMode::Sync, 2);
548 cfg.embedding_dim = 0;
549 let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
550 let res = ParameterServer::new(cfg, vec!["a".into()], vec!["r".into()], mgr);
551 assert!(res.is_err());
552 }
553
554 #[tokio::test]
555 async fn pull_returns_consistent_dim_rows() {
556 let s = small_server(UpdateMode::Sync, 2);
557 for shard in 0..s.num_shards() {
558 let snap = s.pull(shard).await.expect("pull");
559 assert_eq!(snap.shard_id, shard);
560 assert_eq!(snap.relations.len(), 2);
561 for row in &snap.entities {
562 assert_eq!(row.len(), 4);
563 }
564 }
565 }
566
567 #[tokio::test]
568 async fn push_async_applies_immediately() {
569 let s = small_server(UpdateMode::Async, 1);
570 let snap = s.pull(0).await.expect("pull");
571 let before = snap.entities.first().cloned().unwrap_or_default();
572
573 let grad: Vec<f32> = vec![1.0; 4];
574 if !snap.entities.is_empty() {
575 s.push(0, 0, vec![(0, grad.clone())])
576 .await
577 .expect("push async");
578 let snap2 = s.pull(0).await.expect("pull2");
579 let after = snap2.entities.first().cloned().unwrap_or_default();
580 for (b, a) in before.iter().zip(after.iter()) {
582 assert!(
583 (b - a - 0.1).abs() < 1e-5,
584 "expected b - a ≈ 0.1, got b={b}, a={a}"
585 );
586 }
587 }
588 }
589
590 #[tokio::test]
591 async fn push_sync_buffers_until_barrier() {
592 let s = Arc::new(small_server(UpdateMode::Sync, 2));
593 let snap = s.pull(0).await.expect("pull");
594 if snap.entities.is_empty() {
595 return;
597 }
598
599 let grad: Vec<f32> = vec![2.0; 4];
600
601 let s0 = Arc::clone(&s);
603 let g0 = grad.clone();
604 let h0 = tokio::spawn(async move {
605 s0.push(0, 0, vec![(0, g0)]).await.expect("worker 0 push");
606 });
607
608 let s1 = Arc::clone(&s);
610 let g1 = grad.clone();
611 let h1 = tokio::spawn(async move {
612 s1.push(0, 1, vec![(0, g1)]).await.expect("worker 1 push");
613 });
614
615 h0.await.expect("worker 0 join");
616 h1.await.expect("worker 1 join");
617
618 let stats = s.stats().await;
619 assert_eq!(
620 stats.barriers_completed, 1,
621 "exactly one barrier should have fired"
622 );
623 let steps = s.shard_steps().await;
624 assert_eq!(steps[0], 1, "shard 0 should have advanced one step");
625 }
626
627 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
628 async fn push_sync_rejects_double_push_from_same_worker() {
629 let s = Arc::new(small_server(UpdateMode::Sync, 2));
633 let snap = s.pull(0).await.expect("pull");
634 if snap.entities.is_empty() {
635 return;
636 }
637
638 let g = vec![0.0_f32; 4];
639 let s_first = Arc::clone(&s);
640 let g_first = g.clone();
641 let h = tokio::spawn(async move {
642 s_first.push(0, 7, vec![(0, g_first)]).await
647 });
648
649 tokio::task::yield_now().await;
651 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
652
653 let err = s.push(0, 7, vec![(0, g)]).await;
654 assert!(err.is_err(), "second push by same worker must fail");
655
656 let _ = h.await.expect("join push task");
658 }
659
660 #[tokio::test]
661 async fn push_validates_gradient_dim() {
662 let s = small_server(UpdateMode::Async, 1);
663 let res = s.push(0, 0, vec![(0, vec![1.0; 3])]).await;
665 assert!(res.is_err());
666 }
667
668 #[tokio::test]
669 async fn relation_push_applies_with_learning_rate() {
670 let s = small_server(UpdateMode::Sync, 1);
671 let before = s.pull(0).await.expect("pull").relations[0].clone();
672 s.push_relation(0, vec![(0, vec![1.0_f32; 4])])
673 .await
674 .expect("rel push");
675 let after = s.pull(0).await.expect("pull2").relations[0].clone();
676 for (b, a) in before.iter().zip(after.iter()) {
677 assert!((b - a - 0.1).abs() < 1e-5);
678 }
679 }
680
681 #[tokio::test]
682 async fn async_pull_resets_staleness() {
683 let s = small_server(UpdateMode::Async, 1);
684 let snap = s.pull(0).await.expect("pull");
686 if snap.entities.is_empty() {
687 return;
688 }
689 for _ in 0..3 {
690 s.push(0, 0, vec![(0, vec![0.1_f32; 4])])
691 .await
692 .expect("push");
693 }
694 let stats_before = s.stats().await;
695 assert!(stats_before.max_staleness_observed >= 3);
696
697 let _ = s.pull(0).await.expect("pull");
699 let stats_after = s.stats().await;
700 assert_eq!(
701 stats_after.max_staleness_observed, stats_before.max_staleness_observed,
702 "max_staleness_observed is monotonic"
703 );
704 }
705}