1use anyhow::Result;
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32use std::sync::Arc;
33use std::time::Instant;
34use tracing::{debug, trace};
35
36use super::parameter_server::{ParameterServer, ShardSnapshot, UpdateMode};
37
38type GradRow = (usize, Vec<f32>);
40
41type LocalStepOutput = (f64, Vec<GradRow>, Vec<GradRow>, usize);
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TripleSample {
48 pub head: String,
50 pub relation: String,
52 pub tail: String,
54}
55
56impl TripleSample {
57 pub fn new(
59 head: impl Into<String>,
60 relation: impl Into<String>,
61 tail: impl Into<String>,
62 ) -> Self {
63 Self {
64 head: head.into(),
65 relation: relation.into(),
66 tail: tail.into(),
67 }
68 }
69}
70
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
73pub struct WorkerLoss {
74 pub worker_id: u32,
76 pub history: Vec<f64>,
78 pub total_loss: f64,
80 pub samples: usize,
82}
83
84impl WorkerLoss {
85 pub fn mean(&self) -> f64 {
87 if self.history.is_empty() {
88 0.0
89 } else {
90 self.total_loss / self.history.len() as f64
91 }
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct WorkerConfig {
98 pub worker_id: u32,
100 pub max_steps: usize,
102 pub margin: f32,
104 pub l2_reg: f32,
106 pub seed: u64,
108}
109
110impl Default for WorkerConfig {
111 fn default() -> Self {
112 Self {
113 worker_id: 0,
114 max_steps: 50,
115 margin: 1.0,
116 l2_reg: 0.0,
117 seed: 1,
118 }
119 }
120}
121
122pub struct Worker {
124 config: WorkerConfig,
125 server: Arc<ParameterServer>,
126 samples: Vec<TripleSample>,
129 rng_state: u64,
131}
132
133impl Worker {
134 pub fn new(
136 config: WorkerConfig,
137 server: Arc<ParameterServer>,
138 samples: Vec<TripleSample>,
139 ) -> Self {
140 let rng_state = config.seed | 1;
141 Self {
142 config,
143 server,
144 samples,
145 rng_state,
146 }
147 }
148
149 pub fn config(&self) -> &WorkerConfig {
151 &self.config
152 }
153
154 pub async fn run(mut self) -> Result<WorkerLoss> {
160 let mut loss = WorkerLoss {
161 worker_id: self.config.worker_id,
162 ..Default::default()
163 };
164 let started = Instant::now();
165
166 for step in 0..self.config.max_steps {
168 let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
173 for (i, s) in self.samples.iter().enumerate() {
174 let shard = self.server.shard_manager().shard_for(&s.head);
175 groups.entry(shard).or_default().push(i);
176 }
177
178 for (shard_id, indices) in groups {
179 let snap = self.server.pull(shard_id).await?;
180 let shard_samples: Vec<TripleSample> =
185 indices.iter().map(|&i| self.samples[i].clone()).collect();
186 let sample_refs: Vec<&TripleSample> = shard_samples.iter().collect();
187 let (mean_loss, grads, rel_grads, n) = self.local_step(&snap, &sample_refs)?;
188 if !rel_grads.is_empty() {
189 self.server
190 .push_relation(self.config.worker_id, rel_grads)
191 .await?;
192 }
193 if !grads.is_empty() {
194 self.server
195 .push(shard_id, self.config.worker_id, grads)
196 .await?;
197 }
198 loss.history.push(mean_loss);
199 loss.total_loss += mean_loss;
200 loss.samples += n;
201 trace!(
202 worker = self.config.worker_id,
203 step,
204 shard = shard_id,
205 samples = n,
206 loss = mean_loss,
207 "worker step done"
208 );
209 }
210 }
211
212 debug!(
213 worker = self.config.worker_id,
214 elapsed_ms = started.elapsed().as_millis() as u64,
215 mean_loss = loss.mean(),
216 "worker finished"
217 );
218 Ok(loss)
219 }
220
221 fn local_step(
235 &mut self,
236 snap: &ShardSnapshot,
237 samples: &[&TripleSample],
238 ) -> Result<LocalStepOutput> {
239 if snap.entities.is_empty() || samples.is_empty() {
240 return Ok((0.0, Vec::new(), Vec::new(), 0));
241 }
242
243 let dim = snap.entities[0].len();
244 let entity_index: HashMap<&str, usize> = snap
245 .entity_ids
246 .iter()
247 .enumerate()
248 .map(|(i, s)| (s.as_str(), i))
249 .collect();
250 let relation_index: HashMap<&str, usize> = snap
251 .relation_ids
252 .iter()
253 .enumerate()
254 .map(|(i, s)| (s.as_str(), i))
255 .collect();
256
257 let mut grad_acc: HashMap<usize, Vec<f32>> = HashMap::new();
259 let mut rel_grad: HashMap<usize, Vec<f32>> = HashMap::new();
261 let mut total_loss = 0.0_f64;
262 let mut counted = 0usize;
263
264 for s in samples {
265 let h_idx = match entity_index.get(s.head.as_str()) {
267 Some(&i) => i,
268 None => continue,
269 };
270 let r_idx = match relation_index.get(s.relation.as_str()) {
271 Some(&i) => i,
272 None => continue,
273 };
274
275 let t_idx_local = entity_index.get(s.tail.as_str()).copied();
279 let head = &snap.entities[h_idx];
280 let rel = &snap.relations[r_idx];
281
282 let tail_vec: Vec<f32> = match t_idx_local {
287 Some(i) => snap.entities[i].clone(),
288 None => snap.relations[r_idx].clone(),
289 };
290
291 let neg_idx = self.next_index(snap.entities.len());
293 let neg = &snap.entities[neg_idx];
294
295 let pos_diff: Vec<f32> = head
297 .iter()
298 .zip(rel.iter())
299 .zip(tail_vec.iter())
300 .map(|((h, r), t)| h + r - t)
301 .collect();
302 let neg_diff: Vec<f32> = head
303 .iter()
304 .zip(rel.iter())
305 .zip(neg.iter())
306 .map(|((h, r), n)| h + r - n)
307 .collect();
308
309 let pos_score = l2_norm(&pos_diff);
310 let neg_score = l2_norm(&neg_diff);
311 let margin = self.config.margin;
312 let raw_loss = (margin + pos_score - neg_score).max(0.0);
313 total_loss += raw_loss as f64;
314 counted += 1;
315
316 if raw_loss > 0.0 {
322 let pos_norm = pos_score.max(1e-6);
323 let neg_norm = neg_score.max(1e-6);
324
325 let grad_h: Vec<f32> = pos_diff
326 .iter()
327 .zip(neg_diff.iter())
328 .map(|(p, n)| p / pos_norm - n / neg_norm)
329 .collect();
330 let grad_r = grad_h.clone();
331 let grad_t: Vec<f32> = pos_diff.iter().map(|p| -p / pos_norm).collect();
332 let grad_neg: Vec<f32> = neg_diff.iter().map(|n| n / neg_norm).collect();
333
334 accumulate_grad(&mut grad_acc, h_idx, &grad_h, dim);
335 if let Some(ti) = t_idx_local {
336 accumulate_grad(&mut grad_acc, ti, &grad_t, dim);
337 }
338 accumulate_grad(&mut grad_acc, neg_idx, &grad_neg, dim);
339 accumulate_grad(&mut rel_grad, r_idx, &grad_r, dim);
340 }
341
342 if self.config.l2_reg > 0.0 {
344 let entry = grad_acc.entry(h_idx).or_insert_with(|| vec![0.0; dim]);
345 for (e, h) in entry.iter_mut().zip(head.iter()) {
346 *e += self.config.l2_reg * *h;
347 }
348 }
349 }
350
351 let mean_loss = if counted == 0 {
352 0.0
353 } else {
354 total_loss / counted as f64
355 };
356 let grads: Vec<(usize, Vec<f32>)> = grad_acc.into_iter().collect();
357 let rel_grads: Vec<(usize, Vec<f32>)> = rel_grad.into_iter().collect();
358 Ok((mean_loss, grads, rel_grads, counted))
359 }
360
361 fn next_index(&mut self, n: usize) -> usize {
362 self.rng_state = self
364 .rng_state
365 .wrapping_mul(6364136223846793005)
366 .wrapping_add(1442695040888963407);
367 ((self.rng_state >> 32) as usize) % n.max(1)
368 }
369}
370
371pub async fn run_workers(workers: Vec<Worker>) -> Result<Vec<WorkerLoss>> {
375 let mut handles = Vec::with_capacity(workers.len());
376 for w in workers {
377 handles.push(tokio::spawn(async move { w.run().await }));
378 }
379 let mut out = Vec::with_capacity(handles.len());
380 for h in handles {
381 match h.await {
382 Ok(Ok(loss)) => out.push(loss),
383 Ok(Err(e)) => return Err(e),
384 Err(join_err) => return Err(anyhow::anyhow!("worker join failed: {join_err}")),
385 }
386 }
387 Ok(out)
388}
389
390pub async fn describe_server(server: &ParameterServer) -> String {
392 let stats = server.stats().await;
393 let steps = server.shard_steps().await;
394 let mode = match server.config().update_mode {
395 UpdateMode::Sync => "sync",
396 UpdateMode::Async => "async",
397 };
398 format!(
399 "ParameterServer[mode={mode}, shards={}, total_pulls={}, total_pushes={}, barriers={}, steps={steps:?}]",
400 server.num_shards(),
401 stats.total_pulls,
402 stats.total_pushes,
403 stats.barriers_completed,
404 )
405}
406
407fn l2_norm(v: &[f32]) -> f32 {
410 v.iter().map(|x| x * x).sum::<f32>().sqrt()
411}
412
413fn accumulate_grad(target: &mut HashMap<usize, Vec<f32>>, idx: usize, grad: &[f32], dim: usize) {
414 let entry = target.entry(idx).or_insert_with(|| vec![0.0; dim]);
415 for (e, g) in entry.iter_mut().zip(grad.iter()) {
416 *e += *g;
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::distributed_training::parameter_server::{
424 ParameterServer, ParameterServerConfig, UpdateMode,
425 };
426 use crate::distributed_training::shard_manager::{ModelShardManager, ShardingStrategy};
427
428 fn build_server(workers: usize, mode: UpdateMode) -> Arc<ParameterServer> {
429 let cfg = ParameterServerConfig {
430 embedding_dim: 8,
431 num_entities: 8,
432 num_relations: 2,
433 num_shards: 2,
434 expected_workers: workers,
435 update_mode: mode,
436 learning_rate: 0.05,
437 max_staleness: 16,
438 barrier_timeout: std::time::Duration::from_millis(500),
439 };
440 let entity_ids: Vec<String> = (0..cfg.num_entities).map(|i| format!("e{i}")).collect();
441 let relation_ids: Vec<String> = (0..cfg.num_relations).map(|i| format!("r{i}")).collect();
442 let mgr = ModelShardManager::new(cfg.num_shards, ShardingStrategy::EntityHash);
443 Arc::new(
444 ParameterServer::new(cfg, entity_ids, relation_ids, mgr)
445 .expect("server construction failed"),
446 )
447 }
448
449 fn small_triples() -> Vec<TripleSample> {
450 vec![
451 TripleSample::new("e0", "r0", "e1"),
452 TripleSample::new("e2", "r0", "e3"),
453 TripleSample::new("e4", "r1", "e5"),
454 TripleSample::new("e6", "r1", "e7"),
455 ]
456 }
457
458 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
459 async fn worker_runs_async_and_records_loss() {
460 let server = build_server(1, UpdateMode::Async);
461 let cfg = WorkerConfig {
462 worker_id: 0,
463 max_steps: 5,
464 margin: 1.0,
465 l2_reg: 0.0,
466 seed: 7,
467 };
468 let w = Worker::new(cfg, Arc::clone(&server), small_triples());
469 let loss = w.run().await.expect("worker run failed");
470 assert_eq!(loss.worker_id, 0);
471 assert!(
472 !loss.history.is_empty(),
473 "worker should record at least one loss entry"
474 );
475 }
476
477 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
478 async fn four_workers_async_complete() {
479 let server = build_server(1, UpdateMode::Async);
480 let mut ws = Vec::new();
481 for i in 0..4 {
482 ws.push(Worker::new(
483 WorkerConfig {
484 worker_id: i,
485 max_steps: 3,
486 margin: 1.0,
487 l2_reg: 1e-4,
488 seed: 1 + i as u64,
489 },
490 Arc::clone(&server),
491 small_triples(),
492 ));
493 }
494 let losses = run_workers(ws).await.expect("workers failed");
495 assert_eq!(losses.len(), 4);
496 for l in &losses {
497 assert!(l.history.iter().all(|x| x.is_finite()));
498 }
499 }
500
501 #[tokio::test]
502 async fn describe_server_renders() {
503 let s = build_server(1, UpdateMode::Async);
504 let desc = describe_server(&s).await;
505 assert!(desc.contains("mode=async"));
506 }
507}