Skip to main content

oxirs_embed/distributed_training/
parameter_server.rs

1//! In-process toy parameter server with sharded embeddings.
2//!
3//! [`ParameterServer`] is a **prototype**: it lives entirely inside one Rust
4//! process, owns sharded copies of the entity / relation embedding tables, and
5//! lets workers (typically [`super::worker::Worker`]) **pull** the latest
6//! parameters and **push** locally-computed gradients back.  Two update modes
7//! are supported:
8//!
9//! 1. [`UpdateMode::Sync`] — pushes are **buffered** per-step until every
10//!    expected worker has pushed; only then is the average applied to the
11//!    parameters and a new step starts.  This is the standard mini-batch SGD
12//!    contract and gives reproducible convergence.  The barrier is per-shard:
13//!    a worker that pushed for shard A is free to push for shard B without
14//!    waiting on shard A's barrier to clear.
15//!
16//! 2. [`UpdateMode::Async`] — pushes are applied **immediately** with no
17//!    barrier.  This trades a small amount of staleness (workers may be working
18//!    against a slightly outdated copy of the parameters) for higher
19//!    throughput.  We track a per-shard *staleness counter* (the number of
20//!    pushes between the last `pull` and the next `pull`) so callers can
21//!    bound the expected divergence.
22//!
23//! The server is bounded by design to **4–8 workers / 4–8 shards** for the
24//! prototype.  Larger setups should use a real RPC-based parameter server.
25//!
26//! All public methods are `async` and use `tokio::sync::RwLock` so shards can
27//! be pulled concurrently from many workers without contention; pushes acquire
28//! a write lock only on the affected shard.
29
30use anyhow::Result;
31use serde::{Deserialize, Serialize};
32use std::sync::Arc;
33use std::time::Duration;
34use tokio::sync::{Mutex, Notify, RwLock};
35use tokio::time::timeout;
36use tracing::{debug, trace, warn};
37
38use super::shard_manager::ModelShardManager;
39
40/// How [`ParameterServer::push`] applies gradients.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
42pub enum UpdateMode {
43    /// Synchronous: gradients buffered until `expected_workers` push, then averaged.
44    #[default]
45    Sync,
46    /// Asynchronous: gradients applied immediately (eventual consistency).
47    Async,
48}
49
50/// Configuration for [`ParameterServer`].
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ParameterServerConfig {
53    /// Embedding dimensionality (length of each row).
54    pub embedding_dim: usize,
55    /// Total number of entity rows the server will own (across all shards).
56    pub num_entities: usize,
57    /// Total number of relation rows the server will own.
58    pub num_relations: usize,
59    /// Number of shards to split the entity table into.
60    pub num_shards: usize,
61    /// Number of workers expected to push per step in [`UpdateMode::Sync`].
62    pub expected_workers: usize,
63    /// Sync/async update mode.
64    pub update_mode: UpdateMode,
65    /// Optimizer learning rate applied during `push`.
66    pub learning_rate: f32,
67    /// Maximum staleness tolerated in async mode before logging a warning.
68    pub max_staleness: u64,
69    /// Per-step barrier timeout in [`UpdateMode::Sync`]; if exceeded, the
70    /// barrier proceeds with whatever pushes have arrived.
71    pub barrier_timeout: Duration,
72}
73
74impl Default for ParameterServerConfig {
75    fn default() -> Self {
76        Self {
77            embedding_dim: 32,
78            num_entities: 64,
79            num_relations: 8,
80            num_shards: 4,
81            expected_workers: 4,
82            update_mode: UpdateMode::Sync,
83            learning_rate: 0.01,
84            max_staleness: 16,
85            barrier_timeout: Duration::from_secs(30),
86        }
87    }
88}
89
90/// Public-facing snapshot of one shard's contents.
91///
92/// Returned by [`ParameterServer::pull`].  Workers operate on this owned copy
93/// locally, then submit gradients via `push`.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ShardSnapshot {
96    /// Shard index.
97    pub shard_id: usize,
98    /// One row of `embedding_dim` weights per entity owned by the shard.
99    pub entities: Vec<Vec<f32>>,
100    /// Mapping from row index inside `entities` back to the global entity ID.
101    pub entity_ids: Vec<String>,
102    /// Relation table — small, replicated to every shard for convenience.
103    pub relations: Vec<Vec<f32>>,
104    /// Mapping from row index inside `relations` back to the global relation ID.
105    pub relation_ids: Vec<String>,
106    /// Server step counter at the moment the snapshot was taken.
107    pub step: u64,
108}
109
110/// Aggregate stats reported by [`ParameterServer::stats`].
111#[derive(Debug, Clone, Default, Serialize, Deserialize)]
112pub struct ParameterServerStats {
113    /// Total `pull` operations served.
114    pub total_pulls: u64,
115    /// Total `push` operations applied (sync averages count as one).
116    pub total_pushes: u64,
117    /// Number of completed sync barriers.
118    pub barriers_completed: u64,
119    /// Maximum observed staleness in async mode (pushes since last pull on a shard).
120    pub max_staleness_observed: u64,
121    /// Mean squared L2 norm of the most recent applied gradient (per shard).
122    pub last_grad_norm: f64,
123}
124
125// ── Internal shard state (private) ───────────────────────────────────────────
126
127#[derive(Debug)]
128struct ShardState {
129    /// Owned entity rows for this shard.
130    entities: Vec<Vec<f32>>,
131    /// Global entity IDs, aligned with `entities`.
132    entity_ids: Vec<String>,
133    /// Per-shard step counter.
134    step: u64,
135    /// Pending sync-mode pushes for the **current** step.
136    pending: Vec<PendingGradient>,
137    /// Tracks pushes seen in this step (set of `worker_id`).
138    pushed_workers: Vec<u32>,
139    /// Async-mode staleness: number of pushes since the last pull.
140    staleness: u64,
141    /// Notified when a barrier completes (sync mode).
142    barrier_done: Arc<Notify>,
143}
144
145#[derive(Debug, Clone)]
146struct PendingGradient {
147    worker_id: u32,
148    rows: Vec<(usize, Vec<f32>)>, // (row index inside shard, gradient row)
149}
150
151// ── ParameterServer ──────────────────────────────────────────────────────────
152
153/// Sharded parameter server with sync/async update modes.
154///
155/// See module-level docs for semantics.
156pub struct ParameterServer {
157    config: ParameterServerConfig,
158    /// Per-shard state.  Each shard has its own RwLock so independent shards
159    /// can be served concurrently.
160    shards: Vec<Arc<RwLock<ShardState>>>,
161    /// Replicated relation table.  Tiny enough to keep behind a single lock.
162    relations: Arc<RwLock<Vec<Vec<f32>>>>,
163    /// Stable list of relation IDs (declared at construction).
164    relation_ids: Vec<String>,
165    /// Stats (single mutex; cold path).
166    stats: Arc<Mutex<ParameterServerStats>>,
167    /// Shard manager — owned because the server may need to reshard if elastic
168    /// scaling is added later.  Today it is read-only.
169    shard_manager: ModelShardManager,
170}
171
172impl ParameterServer {
173    /// Build a new parameter server.
174    ///
175    /// `entity_ids` and `relation_ids` must list every entity / relation the
176    /// server should own; the server hashes each entity ID into a shard via
177    /// the supplied [`ModelShardManager`].  Embedding rows are initialised to
178    /// small uniform values in `[-0.05, 0.05]` for reproducibility.
179    pub fn new(
180        config: ParameterServerConfig,
181        entity_ids: Vec<String>,
182        relation_ids: Vec<String>,
183        shard_manager: ModelShardManager,
184    ) -> Result<Self> {
185        if config.embedding_dim == 0 {
186            anyhow::bail!("embedding_dim must be > 0");
187        }
188        if config.num_shards == 0 {
189            anyhow::bail!("num_shards must be > 0");
190        }
191        if config.expected_workers == 0 {
192            anyhow::bail!("expected_workers must be > 0");
193        }
194        let num_shards = config.num_shards.min(shard_manager.num_shards());
195
196        // Initial weights — deterministic small values keyed off the entity ID.
197        let mut shards = Vec::with_capacity(num_shards);
198        let mut shard_buckets: Vec<(Vec<Vec<f32>>, Vec<String>)> =
199            (0..num_shards).map(|_| (Vec::new(), Vec::new())).collect();
200
201        for id in entity_ids.into_iter() {
202            let s = shard_manager.shard_for(&id);
203            let row = init_row(&id, config.embedding_dim);
204            shard_buckets[s].0.push(row);
205            shard_buckets[s].1.push(id);
206        }
207
208        for (entities, ids) in shard_buckets.into_iter() {
209            shards.push(Arc::new(RwLock::new(ShardState {
210                entities,
211                entity_ids: ids,
212                step: 0,
213                pending: Vec::new(),
214                pushed_workers: Vec::new(),
215                staleness: 0,
216                barrier_done: Arc::new(Notify::new()),
217            })));
218        }
219
220        // Build relation table — tiny, fully replicated.
221        let mut relations = Vec::with_capacity(relation_ids.len());
222        for id in &relation_ids {
223            relations.push(init_row(id, config.embedding_dim));
224        }
225
226        Ok(Self {
227            config,
228            shards,
229            relations: Arc::new(RwLock::new(relations)),
230            relation_ids,
231            stats: Arc::new(Mutex::new(ParameterServerStats::default())),
232            shard_manager,
233        })
234    }
235
236    /// Number of shards.
237    pub fn num_shards(&self) -> usize {
238        self.shards.len()
239    }
240
241    /// Configuration snapshot.
242    pub fn config(&self) -> &ParameterServerConfig {
243        &self.config
244    }
245
246    /// Shard manager view (read-only).
247    pub fn shard_manager(&self) -> &ModelShardManager {
248        &self.shard_manager
249    }
250
251    /// Pull a snapshot of `shard_id`.
252    ///
253    /// Returns owned data so workers can compute gradients without holding any
254    /// lock.  Panics-free: an out-of-range `shard_id` returns an `Err`.
255    pub async fn pull(&self, shard_id: usize) -> Result<ShardSnapshot> {
256        let shard = self
257            .shards
258            .get(shard_id)
259            .ok_or_else(|| anyhow::anyhow!("shard {shard_id} out of range"))?;
260        let g = shard.read().await;
261        let snap = ShardSnapshot {
262            shard_id,
263            entities: g.entities.clone(),
264            entity_ids: g.entity_ids.clone(),
265            relations: self.relations.read().await.clone(),
266            relation_ids: self.relation_ids.clone(),
267            step: g.step,
268        };
269        drop(g);
270
271        // Reset async staleness window on pull.
272        if matches!(self.config.update_mode, UpdateMode::Async) {
273            let mut w = shard.write().await;
274            w.staleness = 0;
275        }
276
277        let mut stats = self.stats.lock().await;
278        stats.total_pulls += 1;
279        Ok(snap)
280    }
281
282    /// Push entity-row gradients to `shard_id`.
283    ///
284    /// `rows` is `(row_index_inside_shard, gradient_row)` pairs.  In
285    /// [`UpdateMode::Sync`] the push is buffered until `expected_workers`
286    /// pushes have accumulated for this shard *or* the per-step barrier
287    /// timeout fires.  In [`UpdateMode::Async`] the gradient is applied
288    /// immediately and the per-shard staleness counter is incremented.
289    pub async fn push(
290        &self,
291        shard_id: usize,
292        worker_id: u32,
293        rows: Vec<(usize, Vec<f32>)>,
294    ) -> Result<()> {
295        let shard = self
296            .shards
297            .get(shard_id)
298            .ok_or_else(|| anyhow::anyhow!("shard {shard_id} out of range"))?
299            .clone();
300
301        // Validate gradient shapes up-front, before we touch any state.
302        for (idx, grad) in &rows {
303            if grad.len() != self.config.embedding_dim {
304                anyhow::bail!(
305                    "gradient row {idx} has dim {} but server expects {}",
306                    grad.len(),
307                    self.config.embedding_dim
308                );
309            }
310        }
311
312        match self.config.update_mode {
313            UpdateMode::Sync => self.push_sync(shard, shard_id, worker_id, rows).await,
314            UpdateMode::Async => self.push_async(shard, worker_id, rows).await,
315        }
316    }
317
318    /// Apply gradients to relation rows.
319    ///
320    /// Relation gradients are always averaged across the most-recent push
321    /// regardless of `update_mode` — they're a tiny fully-replicated table.
322    pub async fn push_relation(&self, worker_id: u32, rows: Vec<(usize, Vec<f32>)>) -> Result<()> {
323        for (idx, grad) in &rows {
324            if grad.len() != self.config.embedding_dim {
325                anyhow::bail!(
326                    "relation gradient row {idx} has dim {} but server expects {}",
327                    grad.len(),
328                    self.config.embedding_dim
329                );
330            }
331        }
332
333        let mut rel = self.relations.write().await;
334        for (idx, grad) in rows {
335            if let Some(target) = rel.get_mut(idx) {
336                for (t, g) in target.iter_mut().zip(grad.iter()) {
337                    *t -= self.config.learning_rate * *g;
338                }
339            }
340        }
341        trace!("worker {worker_id}: relation gradients applied");
342        Ok(())
343    }
344
345    /// Snapshot of current stats.
346    pub async fn stats(&self) -> ParameterServerStats {
347        self.stats.lock().await.clone()
348    }
349
350    /// Total per-shard step counts (for tests).
351    pub async fn shard_steps(&self) -> Vec<u64> {
352        let mut steps = Vec::with_capacity(self.shards.len());
353        for s in &self.shards {
354            steps.push(s.read().await.step);
355        }
356        steps
357    }
358
359    // ── Internals ───────────────────────────────────────────────────────────
360
361    async fn push_sync(
362        &self,
363        shard: Arc<RwLock<ShardState>>,
364        shard_id: usize,
365        worker_id: u32,
366        rows: Vec<(usize, Vec<f32>)>,
367    ) -> Result<()> {
368        // Append the push and check the barrier.
369        let (apply_now, barrier) = {
370            let mut g = shard.write().await;
371            if g.pushed_workers.contains(&worker_id) {
372                anyhow::bail!("worker {worker_id} already pushed for shard {shard_id} this step");
373            }
374            g.pending.push(PendingGradient { worker_id, rows });
375            g.pushed_workers.push(worker_id);
376            let ready = g.pushed_workers.len() >= self.config.expected_workers;
377            (ready, g.barrier_done.clone())
378        };
379
380        if apply_now {
381            self.apply_sync_barrier(shard.clone(), shard_id).await?;
382            barrier.notify_waiters();
383            return Ok(());
384        }
385
386        // Wait for someone else (or our own future call) to fill the barrier,
387        // up to `barrier_timeout`.
388        let waited = timeout(self.config.barrier_timeout, barrier.notified()).await;
389        if waited.is_err() {
390            warn!(
391                "shard {shard_id} barrier timed out after {:?}; flushing partial step",
392                self.config.barrier_timeout
393            );
394            self.apply_sync_barrier(shard, shard_id).await?;
395        }
396        Ok(())
397    }
398
399    async fn apply_sync_barrier(
400        &self,
401        shard: Arc<RwLock<ShardState>>,
402        shard_id: usize,
403    ) -> Result<()> {
404        let mut g = shard.write().await;
405        let lr = self.config.learning_rate;
406        let dim = self.config.embedding_dim;
407        let n = g.pending.len().max(1) as f32;
408
409        // Average gradients per row.
410        let mut acc: std::collections::HashMap<usize, Vec<f32>> = std::collections::HashMap::new();
411        for pending in &g.pending {
412            for (idx, grad) in &pending.rows {
413                let entry = acc.entry(*idx).or_insert_with(|| vec![0.0; dim]);
414                for (t, gval) in entry.iter_mut().zip(grad.iter()) {
415                    *t += *gval / n;
416                }
417            }
418        }
419
420        // Apply averaged gradients.
421        let mut sq_sum = 0.0_f64;
422        for (idx, grad) in &acc {
423            if let Some(target) = g.entities.get_mut(*idx) {
424                for (t, gval) in target.iter_mut().zip(grad.iter()) {
425                    *t -= lr * *gval;
426                    sq_sum += (*gval as f64) * (*gval as f64);
427                }
428            }
429        }
430
431        g.pending.clear();
432        g.pushed_workers.clear();
433        g.step += 1;
434        let new_step = g.step;
435        drop(g);
436
437        let mut stats = self.stats.lock().await;
438        stats.total_pushes += 1;
439        stats.barriers_completed += 1;
440        if !acc.is_empty() {
441            stats.last_grad_norm = sq_sum / acc.len() as f64;
442        }
443        debug!("shard {shard_id} barrier applied (new step = {new_step})");
444        Ok(())
445    }
446
447    async fn push_async(
448        &self,
449        shard: Arc<RwLock<ShardState>>,
450        worker_id: u32,
451        rows: Vec<(usize, Vec<f32>)>,
452    ) -> Result<()> {
453        let lr = self.config.learning_rate;
454        let mut g = shard.write().await;
455        let mut sq_sum = 0.0_f64;
456        let mut applied = 0usize;
457        for (idx, grad) in &rows {
458            if let Some(target) = g.entities.get_mut(*idx) {
459                for (t, gval) in target.iter_mut().zip(grad.iter()) {
460                    *t -= lr * *gval;
461                    sq_sum += (*gval as f64) * (*gval as f64);
462                }
463                applied += 1;
464            }
465        }
466        g.staleness = g.staleness.saturating_add(1);
467        g.step += 1;
468        let new_staleness = g.staleness;
469        drop(g);
470
471        let mut stats = self.stats.lock().await;
472        stats.total_pushes += 1;
473        stats.max_staleness_observed = stats.max_staleness_observed.max(new_staleness);
474        if applied > 0 {
475            stats.last_grad_norm = sq_sum / applied as f64;
476        }
477
478        if new_staleness > self.config.max_staleness {
479            warn!(
480                "worker {worker_id} async push: staleness {new_staleness} exceeds max {}",
481                self.config.max_staleness
482            );
483        }
484        Ok(())
485    }
486}
487
488fn init_row(seed_id: &str, dim: usize) -> Vec<f32> {
489    // Linear congruential generator seeded by FNV-1a hash of `seed_id`.
490    // Pure-Rust, no rand dependency, fully deterministic across runs.
491    let mut h: u64 = 0xcbf2_9ce4_8422_2325;
492    for byte in seed_id.as_bytes() {
493        h ^= *byte as u64;
494        h = h.wrapping_mul(0x100_0000_01b3);
495    }
496    let mut state = h | 1;
497    let mut row = Vec::with_capacity(dim);
498    for _ in 0..dim {
499        // Numerical Recipes LCG.
500        state = state
501            .wrapping_mul(6364136223846793005)
502            .wrapping_add(1442695040888963407);
503        // Map to [-0.05, 0.05).
504        let raw = (state >> 32) as u32;
505        let f = (raw as f32 / u32::MAX as f32) * 0.1 - 0.05;
506        row.push(f);
507    }
508    row
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use crate::distributed_training::shard_manager::{ModelShardManager, ShardingStrategy};
515
516    fn small_cfg(mode: UpdateMode, workers: usize) -> ParameterServerConfig {
517        ParameterServerConfig {
518            embedding_dim: 4,
519            num_entities: 8,
520            num_relations: 2,
521            num_shards: 2,
522            expected_workers: workers,
523            update_mode: mode,
524            learning_rate: 0.1,
525            max_staleness: 8,
526            barrier_timeout: Duration::from_millis(500),
527        }
528    }
529
530    fn small_server(mode: UpdateMode, workers: usize) -> ParameterServer {
531        let cfg = small_cfg(mode, workers);
532        let entity_ids: Vec<String> = (0..cfg.num_entities).map(|i| format!("e{i}")).collect();
533        let relation_ids: Vec<String> = (0..cfg.num_relations).map(|i| format!("r{i}")).collect();
534        let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
535        ParameterServer::new(cfg, entity_ids, relation_ids, mgr)
536            .expect("server construction failed")
537    }
538
539    #[tokio::test]
540    async fn server_constructs_and_reports_shards() {
541        let s = small_server(UpdateMode::Sync, 2);
542        assert_eq!(s.num_shards(), 2);
543    }
544
545    #[tokio::test]
546    async fn server_rejects_zero_dim() {
547        let mut cfg = small_cfg(UpdateMode::Sync, 2);
548        cfg.embedding_dim = 0;
549        let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
550        let res = ParameterServer::new(cfg, vec!["a".into()], vec!["r".into()], mgr);
551        assert!(res.is_err());
552    }
553
554    #[tokio::test]
555    async fn pull_returns_consistent_dim_rows() {
556        let s = small_server(UpdateMode::Sync, 2);
557        for shard in 0..s.num_shards() {
558            let snap = s.pull(shard).await.expect("pull");
559            assert_eq!(snap.shard_id, shard);
560            assert_eq!(snap.relations.len(), 2);
561            for row in &snap.entities {
562                assert_eq!(row.len(), 4);
563            }
564        }
565    }
566
567    #[tokio::test]
568    async fn push_async_applies_immediately() {
569        let s = small_server(UpdateMode::Async, 1);
570        let snap = s.pull(0).await.expect("pull");
571        let before = snap.entities.first().cloned().unwrap_or_default();
572
573        let grad: Vec<f32> = vec![1.0; 4];
574        if !snap.entities.is_empty() {
575            s.push(0, 0, vec![(0, grad.clone())])
576                .await
577                .expect("push async");
578            let snap2 = s.pull(0).await.expect("pull2");
579            let after = snap2.entities.first().cloned().unwrap_or_default();
580            // After applying grad=1.0 with lr=0.1, weights drop by 0.1.
581            for (b, a) in before.iter().zip(after.iter()) {
582                assert!(
583                    (b - a - 0.1).abs() < 1e-5,
584                    "expected b - a ≈ 0.1, got b={b}, a={a}"
585                );
586            }
587        }
588    }
589
590    #[tokio::test]
591    async fn push_sync_buffers_until_barrier() {
592        let s = Arc::new(small_server(UpdateMode::Sync, 2));
593        let snap = s.pull(0).await.expect("pull");
594        if snap.entities.is_empty() {
595            // Hash put no entities on shard 0; pick another shard.
596            return;
597        }
598
599        let grad: Vec<f32> = vec![2.0; 4];
600
601        // Push from worker 0; should *not* increment step yet.
602        let s0 = Arc::clone(&s);
603        let g0 = grad.clone();
604        let h0 = tokio::spawn(async move {
605            s0.push(0, 0, vec![(0, g0)]).await.expect("worker 0 push");
606        });
607
608        // Push from worker 1; this completes the barrier.
609        let s1 = Arc::clone(&s);
610        let g1 = grad.clone();
611        let h1 = tokio::spawn(async move {
612            s1.push(0, 1, vec![(0, g1)]).await.expect("worker 1 push");
613        });
614
615        h0.await.expect("worker 0 join");
616        h1.await.expect("worker 1 join");
617
618        let stats = s.stats().await;
619        assert_eq!(
620            stats.barriers_completed, 1,
621            "exactly one barrier should have fired"
622        );
623        let steps = s.shard_steps().await;
624        assert_eq!(steps[0], 1, "shard 0 should have advanced one step");
625    }
626
627    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
628    async fn push_sync_rejects_double_push_from_same_worker() {
629        // Force a 2-worker barrier and use the same worker_id twice.  The
630        // first push will block on the barrier; the second concurrent push
631        // from the same worker must be rejected.
632        let s = Arc::new(small_server(UpdateMode::Sync, 2));
633        let snap = s.pull(0).await.expect("pull");
634        if snap.entities.is_empty() {
635            return;
636        }
637
638        let g = vec![0.0_f32; 4];
639        let s_first = Arc::clone(&s);
640        let g_first = g.clone();
641        let h = tokio::spawn(async move {
642            // This will block on the barrier (only one worker pushed).
643            // The barrier_timeout in `small_cfg` is 500ms — it will eventually
644            // unblock and return Ok, but until then we have time to fire the
645            // second push from the same worker.
646            s_first.push(0, 7, vec![(0, g_first)]).await
647        });
648
649        // Yield so the spawned push registers its worker_id before we re-push.
650        tokio::task::yield_now().await;
651        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
652
653        let err = s.push(0, 7, vec![(0, g)]).await;
654        assert!(err.is_err(), "second push by same worker must fail");
655
656        // Let the original push complete (either via timeout flush or barrier).
657        let _ = h.await.expect("join push task");
658    }
659
660    #[tokio::test]
661    async fn push_validates_gradient_dim() {
662        let s = small_server(UpdateMode::Async, 1);
663        // Wrong-length gradient.
664        let res = s.push(0, 0, vec![(0, vec![1.0; 3])]).await;
665        assert!(res.is_err());
666    }
667
668    #[tokio::test]
669    async fn relation_push_applies_with_learning_rate() {
670        let s = small_server(UpdateMode::Sync, 1);
671        let before = s.pull(0).await.expect("pull").relations[0].clone();
672        s.push_relation(0, vec![(0, vec![1.0_f32; 4])])
673            .await
674            .expect("rel push");
675        let after = s.pull(0).await.expect("pull2").relations[0].clone();
676        for (b, a) in before.iter().zip(after.iter()) {
677            assert!((b - a - 0.1).abs() < 1e-5);
678        }
679    }
680
681    #[tokio::test]
682    async fn async_pull_resets_staleness() {
683        let s = small_server(UpdateMode::Async, 1);
684        // Find a shard with at least one entity to push to.
685        let snap = s.pull(0).await.expect("pull");
686        if snap.entities.is_empty() {
687            return;
688        }
689        for _ in 0..3 {
690            s.push(0, 0, vec![(0, vec![0.1_f32; 4])])
691                .await
692                .expect("push");
693        }
694        let stats_before = s.stats().await;
695        assert!(stats_before.max_staleness_observed >= 3);
696
697        // Pulling again should reset the per-shard counter (max stays).
698        let _ = s.pull(0).await.expect("pull");
699        let stats_after = s.stats().await;
700        assert_eq!(
701            stats_after.max_staleness_observed, stats_before.max_staleness_observed,
702            "max_staleness_observed is monotonic"
703        );
704    }
705}