Skip to main content

oxirs_embed/distributed_training/
worker.rs

1//! Distributed training worker that pulls latest params, computes local
2//! gradients on a TransE-shaped objective, and pushes them back to a
3//! [`super::parameter_server::ParameterServer`].
4//!
5//! The worker implements the classic parameter-server inner loop:
6//!
7//! ```text
8//! for step in 0..max_steps {
9//!     for shard in shards_this_worker_owns {
10//!         snap = ps.pull(shard)
11//!         (loss, grads) = local_step(snap, mini_batch)
12//!         ps.push(shard, grads)
13//!     }
14//! }
15//! ```
16//!
17//! The loss is a margin-ranking loss on TransE: `score(h,r,t) = ||h + r - t||`.
18//! Concretely, for each positive triple `(h, r, t)` we sample a negative tail
19//! `t'` from the same shard's entities and minimise
20//! `max(0, margin + score(h,r,t) - score(h,r,t'))`.  This is a small but
21//! genuinely non-trivial signal — sufficient to drive convergence on a toy
22//! graph and to validate that the parameter-server plumbing actually moves
23//! parameters in the right direction.
24//!
25//! Workers are intentionally **stateless** between iterations: every step pulls
26//! fresh parameters from the server.  This is wasteful but matches the
27//! prototype contract (and makes tests trivially reproducible).
28
29use anyhow::Result;
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::Instant;
34use tracing::{debug, trace};
35
36use super::parameter_server::{ParameterServer, ShardSnapshot, UpdateMode};
37
38/// `(row_index_in_shard, gradient_row)` — what the worker pushes per row.
39type GradRow = (usize, Vec<f32>);
40
41/// Output of [`Worker::local_step`]: mean loss, entity row gradients,
42/// relation row gradients, and sample count.
43type LocalStepOutput = (f64, Vec<GradRow>, Vec<GradRow>, usize);
44
45/// One TransE-shaped sample: `(head, relation, tail)`.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TripleSample {
48    /// Head entity ID.
49    pub head: String,
50    /// Relation IRI.
51    pub relation: String,
52    /// Tail entity ID.
53    pub tail: String,
54}
55
56impl TripleSample {
57    /// Convenience constructor.
58    pub fn new(
59        head: impl Into<String>,
60        relation: impl Into<String>,
61        tail: impl Into<String>,
62    ) -> Self {
63        Self {
64            head: head.into(),
65            relation: relation.into(),
66            tail: tail.into(),
67        }
68    }
69}
70
71/// Per-iteration loss reported by [`Worker::run`].
72#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct WorkerLoss {
74    /// Worker rank.
75    pub worker_id: u32,
76    /// Sequence of mean-batch losses, one entry per pulled shard.
77    pub history: Vec<f64>,
78    /// Sum of all losses across the run.
79    pub total_loss: f64,
80    /// Number of (h,r,t) triples that contributed to the loss.
81    pub samples: usize,
82}
83
84impl WorkerLoss {
85    /// Mean of the recorded losses.
86    pub fn mean(&self) -> f64 {
87        if self.history.is_empty() {
88            0.0
89        } else {
90            self.total_loss / self.history.len() as f64
91        }
92    }
93}
94
95/// Worker configuration.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct WorkerConfig {
98    /// Worker rank (must be unique within a parameter-server cohort).
99    pub worker_id: u32,
100    /// Maximum number of training iterations the worker will perform.
101    pub max_steps: usize,
102    /// TransE margin (γ).
103    pub margin: f32,
104    /// L2 regularisation coefficient.  `0.0` disables.
105    pub l2_reg: f32,
106    /// Random seed for deterministic negative sampling.
107    pub seed: u64,
108}
109
110impl Default for WorkerConfig {
111    fn default() -> Self {
112        Self {
113            worker_id: 0,
114            max_steps: 50,
115            margin: 1.0,
116            l2_reg: 0.0,
117            seed: 1,
118        }
119    }
120}
121
122/// In-process distributed-training worker.
123pub struct Worker {
124    config: WorkerConfig,
125    server: Arc<ParameterServer>,
126    /// Triples this worker is responsible for.  Each triple is routed to the
127    /// shard owning its head entity at training time.
128    samples: Vec<TripleSample>,
129    /// LCG state for deterministic negative sampling.
130    rng_state: u64,
131}
132
133impl Worker {
134    /// Build a new worker.
135    pub fn new(
136        config: WorkerConfig,
137        server: Arc<ParameterServer>,
138        samples: Vec<TripleSample>,
139    ) -> Self {
140        let rng_state = config.seed | 1;
141        Self {
142            config,
143            server,
144            samples,
145            rng_state,
146        }
147    }
148
149    /// Worker configuration view.
150    pub fn config(&self) -> &WorkerConfig {
151        &self.config
152    }
153
154    /// Run the worker's training loop.
155    ///
156    /// Returns the per-step loss history.  This is `async` and re-entrant:
157    /// callers typically `tokio::spawn` one task per worker and `join_all`
158    /// them.
159    pub async fn run(mut self) -> Result<WorkerLoss> {
160        let mut loss = WorkerLoss {
161            worker_id: self.config.worker_id,
162            ..Default::default()
163        };
164        let started = Instant::now();
165
166        // Group samples by shard ownership of the head entity, refreshed each step.
167        for step in 0..self.config.max_steps {
168            // Build per-shard groups using the current shard manager mapping.
169            // We materialise indices (not references) so the immutable borrow
170            // of `self.samples` is dropped before the per-shard loop, which
171            // re-borrows `self` mutably to mutate the RNG.
172            let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
173            for (i, s) in self.samples.iter().enumerate() {
174                let shard = self.server.shard_manager().shard_for(&s.head);
175                groups.entry(shard).or_default().push(i);
176            }
177
178            for (shard_id, indices) in groups {
179                let snap = self.server.pull(shard_id).await?;
180                // Clone the small per-shard sample slice into an owned vec so
181                // we can mutably borrow `self` (for RNG state) inside
182                // `local_step` without conflicting with the immutable borrow
183                // of `self.samples`.
184                let shard_samples: Vec<TripleSample> =
185                    indices.iter().map(|&i| self.samples[i].clone()).collect();
186                let sample_refs: Vec<&TripleSample> = shard_samples.iter().collect();
187                let (mean_loss, grads, rel_grads, n) = self.local_step(&snap, &sample_refs)?;
188                if !rel_grads.is_empty() {
189                    self.server
190                        .push_relation(self.config.worker_id, rel_grads)
191                        .await?;
192                }
193                if !grads.is_empty() {
194                    self.server
195                        .push(shard_id, self.config.worker_id, grads)
196                        .await?;
197                }
198                loss.history.push(mean_loss);
199                loss.total_loss += mean_loss;
200                loss.samples += n;
201                trace!(
202                    worker = self.config.worker_id,
203                    step,
204                    shard = shard_id,
205                    samples = n,
206                    loss = mean_loss,
207                    "worker step done"
208                );
209            }
210        }
211
212        debug!(
213            worker = self.config.worker_id,
214            elapsed_ms = started.elapsed().as_millis() as u64,
215            mean_loss = loss.mean(),
216            "worker finished"
217        );
218        Ok(loss)
219    }
220
221    /// Compute the local gradient batch for one shard's worth of samples.
222    ///
223    /// Returns `(mean_loss, entity_gradient_rows, relation_gradient_rows,
224    /// sample_count)`.  Gradient rows are `(row_index, gradient_vector)`
225    /// pairs ready to be pushed back to the server.  The entity rows are
226    /// indexed inside the shard; the relation rows use the global relation
227    /// table index because relations are fully replicated.
228    ///
229    /// Implementation note: we use a *closed-form* derivative of the margin
230    /// loss with respect to head/tail rows.  Both entity- and relation-row
231    /// gradients are returned to the caller, which is responsible for
232    /// applying them via [`ParameterServer::push`] /
233    /// [`ParameterServer::push_relation`] in the desired order.
234    fn local_step(
235        &mut self,
236        snap: &ShardSnapshot,
237        samples: &[&TripleSample],
238    ) -> Result<LocalStepOutput> {
239        if snap.entities.is_empty() || samples.is_empty() {
240            return Ok((0.0, Vec::new(), Vec::new(), 0));
241        }
242
243        let dim = snap.entities[0].len();
244        let entity_index: HashMap<&str, usize> = snap
245            .entity_ids
246            .iter()
247            .enumerate()
248            .map(|(i, s)| (s.as_str(), i))
249            .collect();
250        let relation_index: HashMap<&str, usize> = snap
251            .relation_ids
252            .iter()
253            .enumerate()
254            .map(|(i, s)| (s.as_str(), i))
255            .collect();
256
257        // Accumulate gradients per row in the shard.
258        let mut grad_acc: HashMap<usize, Vec<f32>> = HashMap::new();
259        // Accumulate relation gradients (replicated table → idx is global).
260        let mut rel_grad: HashMap<usize, Vec<f32>> = HashMap::new();
261        let mut total_loss = 0.0_f64;
262        let mut counted = 0usize;
263
264        for s in samples {
265            // We can only train on triples whose **head** lives on this shard.
266            let h_idx = match entity_index.get(s.head.as_str()) {
267                Some(&i) => i,
268                None => continue,
269            };
270            let r_idx = match relation_index.get(s.relation.as_str()) {
271                Some(&i) => i,
272                None => continue,
273            };
274
275            // Tail may live elsewhere; we still get a useful gradient on the
276            // head row by treating the tail as constant.  If the tail does
277            // happen to live on this shard we update both.
278            let t_idx_local = entity_index.get(s.tail.as_str()).copied();
279            let head = &snap.entities[h_idx];
280            let rel = &snap.relations[r_idx];
281
282            // For tail, prefer the shard's own copy if available, otherwise
283            // we fabricate a vector by looking up the relation row → that's
284            // a pragmatic toy choice; the prototype is intentionally not a
285            // full distributed embedding lookup.
286            let tail_vec: Vec<f32> = match t_idx_local {
287                Some(i) => snap.entities[i].clone(),
288                None => snap.relations[r_idx].clone(),
289            };
290
291            // Sample a negative tail t' from this shard's entities.
292            let neg_idx = self.next_index(snap.entities.len());
293            let neg = &snap.entities[neg_idx];
294
295            // Score(h, r, t) = ||h + r - t||₂  (using f32 throughout).
296            let pos_diff: Vec<f32> = head
297                .iter()
298                .zip(rel.iter())
299                .zip(tail_vec.iter())
300                .map(|((h, r), t)| h + r - t)
301                .collect();
302            let neg_diff: Vec<f32> = head
303                .iter()
304                .zip(rel.iter())
305                .zip(neg.iter())
306                .map(|((h, r), n)| h + r - n)
307                .collect();
308
309            let pos_score = l2_norm(&pos_diff);
310            let neg_score = l2_norm(&neg_diff);
311            let margin = self.config.margin;
312            let raw_loss = (margin + pos_score - neg_score).max(0.0);
313            total_loss += raw_loss as f64;
314            counted += 1;
315
316            // Subgradient when raw_loss > 0:
317            //   ∂L/∂h = (pos_diff/||pos_diff||) - (neg_diff/||neg_diff||)
318            //   ∂L/∂r = same
319            //   ∂L/∂t = -pos_diff/||pos_diff||
320            //   ∂L/∂t' = neg_diff/||neg_diff||
321            if raw_loss > 0.0 {
322                let pos_norm = pos_score.max(1e-6);
323                let neg_norm = neg_score.max(1e-6);
324
325                let grad_h: Vec<f32> = pos_diff
326                    .iter()
327                    .zip(neg_diff.iter())
328                    .map(|(p, n)| p / pos_norm - n / neg_norm)
329                    .collect();
330                let grad_r = grad_h.clone();
331                let grad_t: Vec<f32> = pos_diff.iter().map(|p| -p / pos_norm).collect();
332                let grad_neg: Vec<f32> = neg_diff.iter().map(|n| n / neg_norm).collect();
333
334                accumulate_grad(&mut grad_acc, h_idx, &grad_h, dim);
335                if let Some(ti) = t_idx_local {
336                    accumulate_grad(&mut grad_acc, ti, &grad_t, dim);
337                }
338                accumulate_grad(&mut grad_acc, neg_idx, &grad_neg, dim);
339                accumulate_grad(&mut rel_grad, r_idx, &grad_r, dim);
340            }
341
342            // Optional L2 regularisation on the head row.
343            if self.config.l2_reg > 0.0 {
344                let entry = grad_acc.entry(h_idx).or_insert_with(|| vec![0.0; dim]);
345                for (e, h) in entry.iter_mut().zip(head.iter()) {
346                    *e += self.config.l2_reg * *h;
347                }
348            }
349        }
350
351        let mean_loss = if counted == 0 {
352            0.0
353        } else {
354            total_loss / counted as f64
355        };
356        let grads: Vec<(usize, Vec<f32>)> = grad_acc.into_iter().collect();
357        let rel_grads: Vec<(usize, Vec<f32>)> = rel_grad.into_iter().collect();
358        Ok((mean_loss, grads, rel_grads, counted))
359    }
360
361    fn next_index(&mut self, n: usize) -> usize {
362        // LCG step (Numerical Recipes).
363        self.rng_state = self
364            .rng_state
365            .wrapping_mul(6364136223846793005)
366            .wrapping_add(1442695040888963407);
367        ((self.rng_state >> 32) as usize) % n.max(1)
368    }
369}
370
371/// Run multiple workers concurrently and gather their losses.
372///
373/// `workers` is consumed; each worker is spawned on its own tokio task.
374pub async fn run_workers(workers: Vec<Worker>) -> Result<Vec<WorkerLoss>> {
375    let mut handles = Vec::with_capacity(workers.len());
376    for w in workers {
377        handles.push(tokio::spawn(async move { w.run().await }));
378    }
379    let mut out = Vec::with_capacity(handles.len());
380    for h in handles {
381        match h.await {
382            Ok(Ok(loss)) => out.push(loss),
383            Ok(Err(e)) => return Err(e),
384            Err(join_err) => return Err(anyhow::anyhow!("worker join failed: {join_err}")),
385        }
386    }
387    Ok(out)
388}
389
390/// Pretty-print server state for debugging (used by examples).
391pub async fn describe_server(server: &ParameterServer) -> String {
392    let stats = server.stats().await;
393    let steps = server.shard_steps().await;
394    let mode = match server.config().update_mode {
395        UpdateMode::Sync => "sync",
396        UpdateMode::Async => "async",
397    };
398    format!(
399        "ParameterServer[mode={mode}, shards={}, total_pulls={}, total_pushes={}, barriers={}, steps={steps:?}]",
400        server.num_shards(),
401        stats.total_pulls,
402        stats.total_pushes,
403        stats.barriers_completed,
404    )
405}
406
407// ── helpers ─────────────────────────────────────────────────────────────────
408
409fn l2_norm(v: &[f32]) -> f32 {
410    v.iter().map(|x| x * x).sum::<f32>().sqrt()
411}
412
413fn accumulate_grad(target: &mut HashMap<usize, Vec<f32>>, idx: usize, grad: &[f32], dim: usize) {
414    let entry = target.entry(idx).or_insert_with(|| vec![0.0; dim]);
415    for (e, g) in entry.iter_mut().zip(grad.iter()) {
416        *e += *g;
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423    use crate::distributed_training::parameter_server::{
424        ParameterServer, ParameterServerConfig, UpdateMode,
425    };
426    use crate::distributed_training::shard_manager::{ModelShardManager, ShardingStrategy};
427
428    fn build_server(workers: usize, mode: UpdateMode) -> Arc<ParameterServer> {
429        let cfg = ParameterServerConfig {
430            embedding_dim: 8,
431            num_entities: 8,
432            num_relations: 2,
433            num_shards: 2,
434            expected_workers: workers,
435            update_mode: mode,
436            learning_rate: 0.05,
437            max_staleness: 16,
438            barrier_timeout: std::time::Duration::from_millis(500),
439        };
440        let entity_ids: Vec<String> = (0..cfg.num_entities).map(|i| format!("e{i}")).collect();
441        let relation_ids: Vec<String> = (0..cfg.num_relations).map(|i| format!("r{i}")).collect();
442        let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
443        Arc::new(
444            ParameterServer::new(cfg, entity_ids, relation_ids, mgr)
445                .expect("server construction failed"),
446        )
447    }
448
449    fn small_triples() -> Vec<TripleSample> {
450        vec![
451            TripleSample::new("e0", "r0", "e1"),
452            TripleSample::new("e2", "r0", "e3"),
453            TripleSample::new("e4", "r1", "e5"),
454            TripleSample::new("e6", "r1", "e7"),
455        ]
456    }
457
458    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
459    async fn worker_runs_async_and_records_loss() {
460        let server = build_server(1, UpdateMode::Async);
461        let cfg = WorkerConfig {
462            worker_id: 0,
463            max_steps: 5,
464            margin: 1.0,
465            l2_reg: 0.0,
466            seed: 7,
467        };
468        let w = Worker::new(cfg, Arc::clone(&server), small_triples());
469        let loss = w.run().await.expect("worker run failed");
470        assert_eq!(loss.worker_id, 0);
471        assert!(
472            !loss.history.is_empty(),
473            "worker should record at least one loss entry"
474        );
475    }
476
477    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
478    async fn four_workers_async_complete() {
479        let server = build_server(1, UpdateMode::Async);
480        let mut ws = Vec::new();
481        for i in 0..4 {
482            ws.push(Worker::new(
483                WorkerConfig {
484                    worker_id: i,
485                    max_steps: 3,
486                    margin: 1.0,
487                    l2_reg: 1e-4,
488                    seed: 1 + i as u64,
489                },
490                Arc::clone(&server),
491                small_triples(),
492            ));
493        }
494        let losses = run_workers(ws).await.expect("workers failed");
495        assert_eq!(losses.len(), 4);
496        for l in &losses {
497            assert!(l.history.iter().all(|x| x.is_finite()));
498        }
499    }
500
501    #[tokio::test]
502    async fn describe_server_renders() {
503        let s = build_server(1, UpdateMode::Async);
504        let desc = describe_server(&s).await;
505        assert!(desc.contains("mode=async"));
506    }
507}