1#[derive(Debug, Clone)]
13pub struct OnlineUpdateConfig {
14 pub learning_rate: f64,
16 pub decay: f64,
18 pub regularization: f64,
20 pub batch_size: usize,
22 pub max_grad_norm: f64,
24}
25
26impl Default for OnlineUpdateConfig {
27 fn default() -> Self {
28 Self {
29 learning_rate: 0.001,
30 decay: 0.9999,
31 regularization: 1e-4,
32 batch_size: 32,
33 max_grad_norm: 1.0,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
44pub struct AdamOptimizer {
45 pub m: Vec<f64>,
47 pub v: Vec<f64>,
49 pub t: u64,
51 pub lr: f64,
53 pub beta1: f64,
55 pub beta2: f64,
57 pub epsilon: f64,
59}
60
61impl AdamOptimizer {
62 pub fn new(param_count: usize, lr: f64) -> Self {
64 Self {
65 m: vec![0.0; param_count],
66 v: vec![0.0; param_count],
67 t: 0,
68 lr,
69 beta1: 0.9,
70 beta2: 0.999,
71 epsilon: 1e-8,
72 }
73 }
74
75 pub fn step(&mut self, params: &mut [f64], gradients: &[f64]) {
79 self.t += 1;
80 let t = self.t as f64;
81 let bias_corr1 = 1.0 - self.beta1.powf(t);
82 let bias_corr2 = 1.0 - self.beta2.powf(t);
83
84 for i in 0..params.len().min(gradients.len()).min(self.m.len()) {
85 let g = gradients[i];
86 self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
88 self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
89 let m_hat = self.m[i] / bias_corr1;
91 let v_hat = self.v[i] / bias_corr2;
92 params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
94 }
95 }
96
97 pub fn reset(&mut self) {
99 self.m.iter_mut().for_each(|x| *x = 0.0);
100 self.v.iter_mut().for_each(|x| *x = 0.0);
101 self.t = 0;
102 }
103
104 pub fn step_count(&self) -> u64 {
106 self.t
107 }
108}
109
110pub struct OnlineEmbeddingTrainer {
123 pub config: OnlineUpdateConfig,
124 pub optimizer: AdamOptimizer,
125 pub step: u64,
126 pub loss_history: Vec<f64>,
127}
128
129impl OnlineEmbeddingTrainer {
130 pub fn new(config: OnlineUpdateConfig, param_count: usize) -> Self {
136 let lr = config.learning_rate;
137 Self {
138 config,
139 optimizer: AdamOptimizer::new(param_count, lr),
140 step: 0,
141 loss_history: Vec::new(),
142 }
143 }
144
145 pub fn update_step(
152 &mut self,
153 embeddings: &mut [Vec<f64>],
154 triple: (usize, usize, usize),
155 label: f64,
156 ) {
157 let (head, relation, tail) = triple;
158 if embeddings.is_empty() {
159 return;
160 }
161
162 let n_emb = embeddings.len();
163 let dim = embeddings[0].len();
164
165 if head >= n_emb || relation >= n_emb || tail >= n_emb || dim == 0 {
166 return;
167 }
168
169 let effective_lr = self.config.learning_rate * self.config.decay.powf(self.step as f64);
171
172 let h = embeddings[head].clone();
174 let r = embeddings[relation].clone();
175 let t = embeddings[tail].clone();
176
177 let diff: Vec<f64> = (0..dim).map(|i| h[i] + r[i] - t[i]).collect();
178 let norm: f64 = diff.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
179 let loss = (label * (-norm)).max(0.0) + norm * 1e-4; let base_grad_sign = if label > 0.0 { 1.0 } else { -1.0 };
184 let mut grads: Vec<f64> = diff.iter().map(|&d| base_grad_sign * d / norm).collect();
185
186 let grad_norm: f64 = grads.iter().map(|g| g * g).sum::<f64>().sqrt();
188 if grad_norm > self.config.max_grad_norm {
189 let scale = self.config.max_grad_norm / grad_norm;
190 grads.iter_mut().for_each(|g| *g *= scale);
191 }
192
193 let reg = self.config.regularization;
195
196 let optimizer_lr = effective_lr;
200 self.optimizer.lr = optimizer_lr;
201
202 let mut h_params = embeddings[head].clone();
204 let h_grads: Vec<f64> = (0..dim).map(|i| grads[i] + reg * h[i]).collect();
205 {
206 let off = dim.min(self.optimizer.m.len());
207 let (m_sl, v_sl, t_ref, b1, b2, eps) = (
208 &mut self.optimizer.m[0..off],
209 &mut self.optimizer.v[0..off],
210 &mut self.optimizer.t,
211 self.optimizer.beta1,
212 self.optimizer.beta2,
213 self.optimizer.epsilon,
214 );
215 adam_step_slice(
216 m_sl,
217 v_sl,
218 t_ref,
219 &mut h_params,
220 &h_grads,
221 optimizer_lr,
222 b1,
223 b2,
224 eps,
225 );
226 }
227 embeddings[head] = h_params;
228
229 let mut r_params = embeddings[relation].clone();
231 let r_grads: Vec<f64> = (0..dim).map(|i| grads[i] + reg * r[i]).collect();
232 {
233 let off = dim.min(self.optimizer.m.len());
234 let (m_sl, v_sl, t_ref, b1, b2, eps) = (
235 &mut self.optimizer.m[0..off],
236 &mut self.optimizer.v[0..off],
237 &mut self.optimizer.t,
238 self.optimizer.beta1,
239 self.optimizer.beta2,
240 self.optimizer.epsilon,
241 );
242 adam_step_slice(
243 m_sl,
244 v_sl,
245 t_ref,
246 &mut r_params,
247 &r_grads,
248 optimizer_lr,
249 b1,
250 b2,
251 eps,
252 );
253 }
254 embeddings[relation] = r_params;
255
256 let mut t_params = embeddings[tail].clone();
258 let t_grads: Vec<f64> = (0..dim).map(|i| -grads[i] + reg * t[i]).collect();
259 {
260 let off = dim.min(self.optimizer.m.len());
261 let (m_sl, v_sl, t_ref, b1, b2, eps) = (
262 &mut self.optimizer.m[0..off],
263 &mut self.optimizer.v[0..off],
264 &mut self.optimizer.t,
265 self.optimizer.beta1,
266 self.optimizer.beta2,
267 self.optimizer.epsilon,
268 );
269 adam_step_slice(
270 m_sl,
271 v_sl,
272 t_ref,
273 &mut t_params,
274 &t_grads,
275 optimizer_lr,
276 b1,
277 b2,
278 eps,
279 );
280 }
281 embeddings[tail] = t_params;
282
283 self.loss_history.push(loss);
284 self.step += 1;
285 }
286
287 pub fn avg_loss(&self) -> f64 {
289 if self.loss_history.is_empty() {
290 return 0.0;
291 }
292 self.loss_history.iter().sum::<f64>() / self.loss_history.len() as f64
293 }
294
295 pub fn recent_loss(&self, n: usize) -> f64 {
297 if self.loss_history.is_empty() {
298 return 0.0;
299 }
300 let start = self.loss_history.len().saturating_sub(n);
301 let slice = &self.loss_history[start..];
302 slice.iter().sum::<f64>() / slice.len() as f64
303 }
304
305 pub fn step_count(&self) -> u64 {
307 self.step
308 }
309}
310
311#[allow(clippy::too_many_arguments)]
318fn adam_step_slice(
319 m: &mut [f64],
320 v: &mut [f64],
321 t: &mut u64,
322 params: &mut [f64],
323 grads: &[f64],
324 lr: f64,
325 beta1: f64,
326 beta2: f64,
327 epsilon: f64,
328) {
329 *t += 1;
330 let tc = *t as f64;
331 let bc1 = 1.0 - beta1.powf(tc);
332 let bc2 = 1.0 - beta2.powf(tc);
333
334 let len = params.len().min(grads.len()).min(m.len()).min(v.len());
335 for i in 0..len {
336 let g = grads[i];
337 m[i] = beta1 * m[i] + (1.0 - beta1) * g;
338 v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
339 let m_hat = m[i] / bc1;
340 let v_hat = v[i] / bc2;
341 params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
342 }
343}
344
345#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
356 fn test_default_config_values() {
357 let cfg = OnlineUpdateConfig::default();
358 assert!((cfg.learning_rate - 0.001).abs() < 1e-12);
359 assert!((cfg.decay - 0.9999).abs() < 1e-12);
360 assert!((cfg.regularization - 1e-4).abs() < 1e-12);
361 assert_eq!(cfg.batch_size, 32);
362 assert!((cfg.max_grad_norm - 1.0).abs() < 1e-12);
363 }
364
365 #[test]
366 fn test_config_clone() {
367 let cfg = OnlineUpdateConfig::default();
368 let cloned = cfg.clone();
369 assert!((cloned.learning_rate - cfg.learning_rate).abs() < 1e-12);
370 }
371
372 #[test]
375 fn test_adam_creation() {
376 let opt = AdamOptimizer::new(10, 0.001);
377 assert_eq!(opt.m.len(), 10);
378 assert_eq!(opt.v.len(), 10);
379 assert_eq!(opt.t, 0);
380 assert!((opt.lr - 0.001).abs() < 1e-12);
381 }
382
383 #[test]
384 fn test_adam_step_changes_params() {
385 let mut opt = AdamOptimizer::new(4, 0.01);
386 let mut params = vec![1.0_f64; 4];
387 let grads = vec![0.1, 0.2, 0.3, 0.4];
388 opt.step(&mut params, &grads);
389 for &p in ¶ms {
391 assert!(p < 1.0, "params should decrease with positive gradient");
392 }
393 }
394
395 #[test]
396 fn test_adam_step_count() {
397 let mut opt = AdamOptimizer::new(4, 0.01);
398 let mut params = vec![0.0_f64; 4];
399 let grads = vec![0.1; 4];
400 opt.step(&mut params, &grads);
401 opt.step(&mut params, &grads);
402 assert_eq!(opt.step_count(), 2);
403 }
404
405 #[test]
406 fn test_adam_reset() {
407 let mut opt = AdamOptimizer::new(4, 0.01);
408 let mut params = vec![0.0_f64; 4];
409 let grads = vec![0.1; 4];
410 opt.step(&mut params, &grads);
411 opt.reset();
412 assert_eq!(opt.step_count(), 0);
413 assert!(opt.m.iter().all(|&x| x == 0.0));
414 assert!(opt.v.iter().all(|&x| x == 0.0));
415 }
416
417 #[test]
418 fn test_adam_converges_simple_quadratic() {
419 let mut opt = AdamOptimizer::new(1, 0.1);
421 let mut params = vec![0.0_f64];
422 for _ in 0..500 {
423 let g = 2.0 * (params[0] - 3.0);
424 opt.step(&mut params, &[g]);
425 }
426 assert!(
427 (params[0] - 3.0).abs() < 0.1,
428 "Adam should converge to x=3, got {}",
429 params[0]
430 );
431 }
432
433 #[test]
434 fn test_adam_zero_gradient_no_change() {
435 let mut opt = AdamOptimizer::new(4, 0.01);
436 let params_before = vec![1.0_f64, 2.0, 3.0, 4.0];
437 let mut params = params_before.clone();
438 let grads = vec![1e-15_f64; 4];
440 opt.step(&mut params, &grads);
441 for (a, b) in params.iter().zip(params_before.iter()) {
443 assert!(
444 (a - b).abs() < 1e-3,
445 "near-zero gradient should barely change params"
446 );
447 }
448 }
449
450 #[test]
453 fn test_trainer_creation() {
454 let cfg = OnlineUpdateConfig::default();
455 let trainer = OnlineEmbeddingTrainer::new(cfg, 100);
456 assert_eq!(trainer.step_count(), 0);
457 assert_eq!(trainer.avg_loss(), 0.0);
458 }
459
460 #[test]
461 fn test_trainer_update_increments_step() {
462 let cfg = OnlineUpdateConfig::default();
463 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
464 let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
465 trainer.update_step(&mut embs, (0, 1, 2), 1.0);
466 assert_eq!(trainer.step_count(), 1);
467 }
468
469 #[test]
470 fn test_trainer_records_loss() {
471 let cfg = OnlineUpdateConfig::default();
472 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
473 let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
474 trainer.update_step(&mut embs, (0, 1, 2), 1.0);
475 assert!(!trainer.loss_history.is_empty());
476 assert!(trainer.avg_loss().is_finite());
477 }
478
479 #[test]
480 fn test_trainer_recent_loss_empty() {
481 let cfg = OnlineUpdateConfig::default();
482 let trainer = OnlineEmbeddingTrainer::new(cfg, 64);
483 assert_eq!(trainer.recent_loss(5), 0.0);
484 }
485
486 #[test]
487 fn test_trainer_recent_loss_fewer_than_n() {
488 let cfg = OnlineUpdateConfig::default();
489 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
490 let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
491 trainer.update_step(&mut embs, (0, 1, 2), 1.0);
492 let rl = trainer.recent_loss(5);
494 assert!(rl.is_finite());
495 }
496
497 #[test]
498 fn test_trainer_modifies_embeddings() {
499 let cfg = OnlineUpdateConfig::default();
500 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
501 let initial = vec![vec![1.0_f64; 8]; 10];
502 let mut embs = initial.clone();
503 trainer.update_step(&mut embs, (0, 1, 2), 1.0);
504 let changed = embs
505 .iter()
506 .zip(initial.iter())
507 .any(|(a, b)| a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-12));
508 assert!(changed, "update_step should modify at least one embedding");
509 }
510
511 #[test]
512 fn test_trainer_out_of_bounds_indices_ignored() {
513 let cfg = OnlineUpdateConfig::default();
514 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
515 let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 5];
516 trainer.update_step(&mut embs, (10, 20, 30), 1.0);
518 assert_eq!(trainer.step_count(), 0); }
520
521 #[test]
522 fn test_trainer_multiple_steps() {
523 let cfg = OnlineUpdateConfig::default();
524 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
525 let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
526 for i in 0..20 {
527 let h = i % 5;
528 let r = (i + 1) % 5;
529 let t = (i + 2) % 5;
530 trainer.update_step(&mut embs, (h, r, t), 1.0);
531 }
532 assert_eq!(trainer.step_count(), 20);
533 assert!(trainer.avg_loss().is_finite());
534 }
535
536 #[test]
537 fn test_trainer_positive_vs_negative_label() {
538 let cfg = OnlineUpdateConfig::default();
540 let mut t_pos = OnlineEmbeddingTrainer::new(cfg.clone(), 64);
541 let mut t_neg = OnlineEmbeddingTrainer::new(cfg, 64);
542 let mut embs_pos: Vec<Vec<f64>> = vec![vec![0.5; 8]; 10];
543 let mut embs_neg = embs_pos.clone();
544
545 for _ in 0..10 {
546 t_pos.update_step(&mut embs_pos, (0, 1, 2), 1.0);
547 t_neg.update_step(&mut embs_neg, (0, 1, 2), -1.0);
548 }
549 let diff_exists = embs_pos[0]
551 .iter()
552 .zip(embs_neg[0].iter())
553 .any(|(a, b)| (a - b).abs() > 1e-9);
554 assert!(
555 diff_exists,
556 "positive and negative training should produce different embeddings"
557 );
558 }
559
560 #[test]
561 fn test_adam_optimizer_lr_decay() {
562 let cfg = OnlineUpdateConfig {
564 decay: 0.5,
565 learning_rate: 0.01,
566 ..Default::default()
567 };
568 let mut trainer = OnlineEmbeddingTrainer::new(cfg, 32);
569 let mut embs: Vec<Vec<f64>> = vec![vec![1.0; 8]; 10];
571 for _ in 0..100 {
572 trainer.update_step(&mut embs, (0, 1, 2), 1.0);
573 }
574 assert_eq!(trainer.step_count(), 100);
576 }
577}