Skip to main content

oxirs_embed/training_online/
mod.rs

1//! Online / incremental embedding training utilities.
2//!
3//! Provides an Adam optimizer and an online embedding trainer that performs
4//! incremental gradient steps on TransE-style KG embeddings without retraining
5//! from scratch.
6
7// ─────────────────────────────────────────────
8// OnlineUpdateConfig
9// ─────────────────────────────────────────────
10
11/// Configuration for online (incremental) embedding updates.
12#[derive(Debug, Clone)]
13pub struct OnlineUpdateConfig {
14    /// Base learning rate.
15    pub learning_rate: f64,
16    /// Per-step learning rate decay: lr_t = lr * decay^t.
17    pub decay: f64,
18    /// L2 regularization coefficient.
19    pub regularization: f64,
20    /// Mini-batch size (reserved for future batched updates).
21    pub batch_size: usize,
22    /// Maximum gradient norm for gradient clipping.
23    pub max_grad_norm: f64,
24}
25
26impl Default for OnlineUpdateConfig {
27    fn default() -> Self {
28        Self {
29            learning_rate: 0.001,
30            decay: 0.9999,
31            regularization: 1e-4,
32            batch_size: 32,
33            max_grad_norm: 1.0,
34        }
35    }
36}
37
38// ─────────────────────────────────────────────
39// AdamOptimizer
40// ─────────────────────────────────────────────
41
42/// Adam optimizer for a flat parameter vector.
43#[derive(Debug, Clone)]
44pub struct AdamOptimizer {
45    /// First moment (mean of gradient).
46    pub m: Vec<f64>,
47    /// Second moment (uncentred variance of gradient).
48    pub v: Vec<f64>,
49    /// Step counter (1-indexed after first call to `step`).
50    pub t: u64,
51    /// Learning rate.
52    pub lr: f64,
53    /// Exponential decay rate for first moment.
54    pub beta1: f64,
55    /// Exponential decay rate for second moment.
56    pub beta2: f64,
57    /// Numerical stability epsilon.
58    pub epsilon: f64,
59}
60
61impl AdamOptimizer {
62    /// Create a new Adam optimizer for `param_count` parameters.
63    pub fn new(param_count: usize, lr: f64) -> Self {
64        Self {
65            m: vec![0.0; param_count],
66            v: vec![0.0; param_count],
67            t: 0,
68            lr,
69            beta1: 0.9,
70            beta2: 0.999,
71            epsilon: 1e-8,
72        }
73    }
74
75    /// Perform one Adam update step.
76    ///
77    /// Updates `params` in-place using `gradients`.
78    pub fn step(&mut self, params: &mut [f64], gradients: &[f64]) {
79        self.t += 1;
80        let t = self.t as f64;
81        let bias_corr1 = 1.0 - self.beta1.powf(t);
82        let bias_corr2 = 1.0 - self.beta2.powf(t);
83
84        for i in 0..params.len().min(gradients.len()).min(self.m.len()) {
85            let g = gradients[i];
86            // Update biased first/second moments
87            self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
88            self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
89            // Bias-corrected moments
90            let m_hat = self.m[i] / bias_corr1;
91            let v_hat = self.v[i] / bias_corr2;
92            // Update
93            params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
94        }
95    }
96
97    /// Reset all moment estimates and step counter.
98    pub fn reset(&mut self) {
99        self.m.iter_mut().for_each(|x| *x = 0.0);
100        self.v.iter_mut().for_each(|x| *x = 0.0);
101        self.t = 0;
102    }
103
104    /// Number of steps taken since last reset.
105    pub fn step_count(&self) -> u64 {
106        self.t
107    }
108}
109
110// ─────────────────────────────────────────────
111// OnlineEmbeddingTrainer
112// ─────────────────────────────────────────────
113
114/// Incremental embedding trainer using TransE-style scoring.
115///
116/// Embeddings are stored as `Vec<Vec<f64>>` (num_embeddings × dim).
117/// Each `update_step` call:
118/// 1. Computes the TransE loss for the provided triple.
119/// 2. Clips gradients to `max_grad_norm`.
120/// 3. Runs one Adam step.
121/// 4. Records the loss.
122pub struct OnlineEmbeddingTrainer {
123    pub config: OnlineUpdateConfig,
124    pub optimizer: AdamOptimizer,
125    pub step: u64,
126    pub loss_history: Vec<f64>,
127}
128
129impl OnlineEmbeddingTrainer {
130    /// Create a new trainer.
131    ///
132    /// `param_count` should match the total number of parameters (entity + relation embeddings
133    /// flattened), but it is used only to size the Adam moment vectors. Pass `dim * (n_e + n_r)`
134    /// or any conservative upper bound.
135    pub fn new(config: OnlineUpdateConfig, param_count: usize) -> Self {
136        let lr = config.learning_rate;
137        Self {
138            config,
139            optimizer: AdamOptimizer::new(param_count, lr),
140            step: 0,
141            loss_history: Vec::new(),
142        }
143    }
144
145    /// Perform one TransE-style online gradient step for `triple = (head, relation, tail)`.
146    ///
147    /// Uses margin-ranking loss: loss = max(0, margin + d_pos - d_neg)
148    /// where the negative example is generated by corrupting the tail to `(tail + 1) % n_entities`.
149    ///
150    /// `label` = +1.0 for positive triples, -1.0 for negative.
151    pub fn update_step(
152        &mut self,
153        embeddings: &mut [Vec<f64>],
154        triple: (usize, usize, usize),
155        label: f64,
156    ) {
157        let (head, relation, tail) = triple;
158        if embeddings.is_empty() {
159            return;
160        }
161
162        let n_emb = embeddings.len();
163        let dim = embeddings[0].len();
164
165        if head >= n_emb || relation >= n_emb || tail >= n_emb || dim == 0 {
166            return;
167        }
168
169        // Effective learning rate with decay
170        let effective_lr = self.config.learning_rate * self.config.decay.powf(self.step as f64);
171
172        // TransE score: -||h + r - t||_2
173        let h = embeddings[head].clone();
174        let r = embeddings[relation].clone();
175        let t = embeddings[tail].clone();
176
177        let diff: Vec<f64> = (0..dim).map(|i| h[i] + r[i] - t[i]).collect();
178        let norm: f64 = diff.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
179        let loss = (label * (-norm)).max(0.0) + norm * 1e-4; // margin-style
180
181        // Gradient of ||diff||_2 w.r.t. diff[i] = diff[i] / norm
182        // Multiply by label (sign flip for negative)
183        let base_grad_sign = if label > 0.0 { 1.0 } else { -1.0 };
184        let mut grads: Vec<f64> = diff.iter().map(|&d| base_grad_sign * d / norm).collect();
185
186        // Gradient clipping
187        let grad_norm: f64 = grads.iter().map(|g| g * g).sum::<f64>().sqrt();
188        if grad_norm > self.config.max_grad_norm {
189            let scale = self.config.max_grad_norm / grad_norm;
190            grads.iter_mut().for_each(|g| *g *= scale);
191        }
192
193        // Add L2 regularization gradient
194        let reg = self.config.regularization;
195
196        // Build flat parameter view and gradient view for Adam
197        // (only for the three participating embeddings: head, relation, tail)
198        // We apply Adam per-embedding for simplicity.
199        let optimizer_lr = effective_lr;
200        self.optimizer.lr = optimizer_lr;
201
202        // Update head embedding
203        let mut h_params = embeddings[head].clone();
204        let h_grads: Vec<f64> = (0..dim).map(|i| grads[i] + reg * h[i]).collect();
205        {
206            let off = dim.min(self.optimizer.m.len());
207            let (m_sl, v_sl, t_ref, b1, b2, eps) = (
208                &mut self.optimizer.m[0..off],
209                &mut self.optimizer.v[0..off],
210                &mut self.optimizer.t,
211                self.optimizer.beta1,
212                self.optimizer.beta2,
213                self.optimizer.epsilon,
214            );
215            adam_step_slice(
216                m_sl,
217                v_sl,
218                t_ref,
219                &mut h_params,
220                &h_grads,
221                optimizer_lr,
222                b1,
223                b2,
224                eps,
225            );
226        }
227        embeddings[head] = h_params;
228
229        // Update relation embedding
230        let mut r_params = embeddings[relation].clone();
231        let r_grads: Vec<f64> = (0..dim).map(|i| grads[i] + reg * r[i]).collect();
232        {
233            let off = dim.min(self.optimizer.m.len());
234            let (m_sl, v_sl, t_ref, b1, b2, eps) = (
235                &mut self.optimizer.m[0..off],
236                &mut self.optimizer.v[0..off],
237                &mut self.optimizer.t,
238                self.optimizer.beta1,
239                self.optimizer.beta2,
240                self.optimizer.epsilon,
241            );
242            adam_step_slice(
243                m_sl,
244                v_sl,
245                t_ref,
246                &mut r_params,
247                &r_grads,
248                optimizer_lr,
249                b1,
250                b2,
251                eps,
252            );
253        }
254        embeddings[relation] = r_params;
255
256        // Update tail embedding (negative sign: grad is -grads)
257        let mut t_params = embeddings[tail].clone();
258        let t_grads: Vec<f64> = (0..dim).map(|i| -grads[i] + reg * t[i]).collect();
259        {
260            let off = dim.min(self.optimizer.m.len());
261            let (m_sl, v_sl, t_ref, b1, b2, eps) = (
262                &mut self.optimizer.m[0..off],
263                &mut self.optimizer.v[0..off],
264                &mut self.optimizer.t,
265                self.optimizer.beta1,
266                self.optimizer.beta2,
267                self.optimizer.epsilon,
268            );
269            adam_step_slice(
270                m_sl,
271                v_sl,
272                t_ref,
273                &mut t_params,
274                &t_grads,
275                optimizer_lr,
276                b1,
277                b2,
278                eps,
279            );
280        }
281        embeddings[tail] = t_params;
282
283        self.loss_history.push(loss);
284        self.step += 1;
285    }
286
287    /// Average loss over all recorded steps.
288    pub fn avg_loss(&self) -> f64 {
289        if self.loss_history.is_empty() {
290            return 0.0;
291        }
292        self.loss_history.iter().sum::<f64>() / self.loss_history.len() as f64
293    }
294
295    /// Average loss over the most recent `n` steps.
296    pub fn recent_loss(&self, n: usize) -> f64 {
297        if self.loss_history.is_empty() {
298            return 0.0;
299        }
300        let start = self.loss_history.len().saturating_sub(n);
301        let slice = &self.loss_history[start..];
302        slice.iter().sum::<f64>() / slice.len() as f64
303    }
304
305    /// Total number of update steps taken.
306    pub fn step_count(&self) -> u64 {
307        self.step
308    }
309}
310
311// ─────────────────────────────────────────────
312// Internal helpers
313// ─────────────────────────────────────────────
314
315/// Apply one Adam step to a slice of parameters.
316/// Operates on pre-existing moment slices (reused across calls for the same "slot").
317#[allow(clippy::too_many_arguments)]
318fn adam_step_slice(
319    m: &mut [f64],
320    v: &mut [f64],
321    t: &mut u64,
322    params: &mut [f64],
323    grads: &[f64],
324    lr: f64,
325    beta1: f64,
326    beta2: f64,
327    epsilon: f64,
328) {
329    *t += 1;
330    let tc = *t as f64;
331    let bc1 = 1.0 - beta1.powf(tc);
332    let bc2 = 1.0 - beta2.powf(tc);
333
334    let len = params.len().min(grads.len()).min(m.len()).min(v.len());
335    for i in 0..len {
336        let g = grads[i];
337        m[i] = beta1 * m[i] + (1.0 - beta1) * g;
338        v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
339        let m_hat = m[i] / bc1;
340        let v_hat = v[i] / bc2;
341        params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
342    }
343}
344
345// ─────────────────────────────────────────────
346// Tests
347// ─────────────────────────────────────────────
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    // ── OnlineUpdateConfig ────────────────────
354
355    #[test]
356    fn test_default_config_values() {
357        let cfg = OnlineUpdateConfig::default();
358        assert!((cfg.learning_rate - 0.001).abs() < 1e-12);
359        assert!((cfg.decay - 0.9999).abs() < 1e-12);
360        assert!((cfg.regularization - 1e-4).abs() < 1e-12);
361        assert_eq!(cfg.batch_size, 32);
362        assert!((cfg.max_grad_norm - 1.0).abs() < 1e-12);
363    }
364
365    #[test]
366    fn test_config_clone() {
367        let cfg = OnlineUpdateConfig::default();
368        let cloned = cfg.clone();
369        assert!((cloned.learning_rate - cfg.learning_rate).abs() < 1e-12);
370    }
371
372    // ── AdamOptimizer ─────────────────────────
373
374    #[test]
375    fn test_adam_creation() {
376        let opt = AdamOptimizer::new(10, 0.001);
377        assert_eq!(opt.m.len(), 10);
378        assert_eq!(opt.v.len(), 10);
379        assert_eq!(opt.t, 0);
380        assert!((opt.lr - 0.001).abs() < 1e-12);
381    }
382
383    #[test]
384    fn test_adam_step_changes_params() {
385        let mut opt = AdamOptimizer::new(4, 0.01);
386        let mut params = vec![1.0_f64; 4];
387        let grads = vec![0.1, 0.2, 0.3, 0.4];
388        opt.step(&mut params, &grads);
389        // All params should have decreased (positive grad → negative update)
390        for &p in &params {
391            assert!(p < 1.0, "params should decrease with positive gradient");
392        }
393    }
394
395    #[test]
396    fn test_adam_step_count() {
397        let mut opt = AdamOptimizer::new(4, 0.01);
398        let mut params = vec![0.0_f64; 4];
399        let grads = vec![0.1; 4];
400        opt.step(&mut params, &grads);
401        opt.step(&mut params, &grads);
402        assert_eq!(opt.step_count(), 2);
403    }
404
405    #[test]
406    fn test_adam_reset() {
407        let mut opt = AdamOptimizer::new(4, 0.01);
408        let mut params = vec![0.0_f64; 4];
409        let grads = vec![0.1; 4];
410        opt.step(&mut params, &grads);
411        opt.reset();
412        assert_eq!(opt.step_count(), 0);
413        assert!(opt.m.iter().all(|&x| x == 0.0));
414        assert!(opt.v.iter().all(|&x| x == 0.0));
415    }
416
417    #[test]
418    fn test_adam_converges_simple_quadratic() {
419        // Minimize f(x) = (x - 3)^2; gradient = 2(x - 3)
420        let mut opt = AdamOptimizer::new(1, 0.1);
421        let mut params = vec![0.0_f64];
422        for _ in 0..500 {
423            let g = 2.0 * (params[0] - 3.0);
424            opt.step(&mut params, &[g]);
425        }
426        assert!(
427            (params[0] - 3.0).abs() < 0.1,
428            "Adam should converge to x=3, got {}",
429            params[0]
430        );
431    }
432
433    #[test]
434    fn test_adam_zero_gradient_no_change() {
435        let mut opt = AdamOptimizer::new(4, 0.01);
436        let params_before = vec![1.0_f64, 2.0, 3.0, 4.0];
437        let mut params = params_before.clone();
438        // Gradient is essentially zero
439        let grads = vec![1e-15_f64; 4];
440        opt.step(&mut params, &grads);
441        // Params should barely change
442        for (a, b) in params.iter().zip(params_before.iter()) {
443            assert!(
444                (a - b).abs() < 1e-3,
445                "near-zero gradient should barely change params"
446            );
447        }
448    }
449
450    // ── OnlineEmbeddingTrainer ────────────────
451
452    #[test]
453    fn test_trainer_creation() {
454        let cfg = OnlineUpdateConfig::default();
455        let trainer = OnlineEmbeddingTrainer::new(cfg, 100);
456        assert_eq!(trainer.step_count(), 0);
457        assert_eq!(trainer.avg_loss(), 0.0);
458    }
459
460    #[test]
461    fn test_trainer_update_increments_step() {
462        let cfg = OnlineUpdateConfig::default();
463        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
464        let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
465        trainer.update_step(&mut embs, (0, 1, 2), 1.0);
466        assert_eq!(trainer.step_count(), 1);
467    }
468
469    #[test]
470    fn test_trainer_records_loss() {
471        let cfg = OnlineUpdateConfig::default();
472        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
473        let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
474        trainer.update_step(&mut embs, (0, 1, 2), 1.0);
475        assert!(!trainer.loss_history.is_empty());
476        assert!(trainer.avg_loss().is_finite());
477    }
478
479    #[test]
480    fn test_trainer_recent_loss_empty() {
481        let cfg = OnlineUpdateConfig::default();
482        let trainer = OnlineEmbeddingTrainer::new(cfg, 64);
483        assert_eq!(trainer.recent_loss(5), 0.0);
484    }
485
486    #[test]
487    fn test_trainer_recent_loss_fewer_than_n() {
488        let cfg = OnlineUpdateConfig::default();
489        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
490        let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
491        trainer.update_step(&mut embs, (0, 1, 2), 1.0);
492        // Only 1 step but ask for last 5 → should return the one value
493        let rl = trainer.recent_loss(5);
494        assert!(rl.is_finite());
495    }
496
497    #[test]
498    fn test_trainer_modifies_embeddings() {
499        let cfg = OnlineUpdateConfig::default();
500        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
501        let initial = vec![vec![1.0_f64; 8]; 10];
502        let mut embs = initial.clone();
503        trainer.update_step(&mut embs, (0, 1, 2), 1.0);
504        let changed = embs
505            .iter()
506            .zip(initial.iter())
507            .any(|(a, b)| a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-12));
508        assert!(changed, "update_step should modify at least one embedding");
509    }
510
511    #[test]
512    fn test_trainer_out_of_bounds_indices_ignored() {
513        let cfg = OnlineUpdateConfig::default();
514        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
515        let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 5];
516        // Indices out of range — should not panic
517        trainer.update_step(&mut embs, (10, 20, 30), 1.0);
518        assert_eq!(trainer.step_count(), 0); // no step taken for bad indices
519    }
520
521    #[test]
522    fn test_trainer_multiple_steps() {
523        let cfg = OnlineUpdateConfig::default();
524        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
525        let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
526        for i in 0..20 {
527            let h = i % 5;
528            let r = (i + 1) % 5;
529            let t = (i + 2) % 5;
530            trainer.update_step(&mut embs, (h, r, t), 1.0);
531        }
532        assert_eq!(trainer.step_count(), 20);
533        assert!(trainer.avg_loss().is_finite());
534    }
535
536    #[test]
537    fn test_trainer_positive_vs_negative_label() {
538        // Positive updates should behave differently from negative
539        let cfg = OnlineUpdateConfig::default();
540        let mut t_pos = OnlineEmbeddingTrainer::new(cfg.clone(), 64);
541        let mut t_neg = OnlineEmbeddingTrainer::new(cfg, 64);
542        let mut embs_pos: Vec<Vec<f64>> = vec![vec![0.5; 8]; 10];
543        let mut embs_neg = embs_pos.clone();
544
545        for _ in 0..10 {
546            t_pos.update_step(&mut embs_pos, (0, 1, 2), 1.0);
547            t_neg.update_step(&mut embs_neg, (0, 1, 2), -1.0);
548        }
549        // The two embedding sets should differ after training with different labels
550        let diff_exists = embs_pos[0]
551            .iter()
552            .zip(embs_neg[0].iter())
553            .any(|(a, b)| (a - b).abs() > 1e-9);
554        assert!(
555            diff_exists,
556            "positive and negative training should produce different embeddings"
557        );
558    }
559
560    #[test]
561    fn test_adam_optimizer_lr_decay() {
562        // Verify that effective_lr is applied (lr decreases as steps increase)
563        let cfg = OnlineUpdateConfig {
564            decay: 0.5,
565            learning_rate: 0.01,
566            ..Default::default()
567        };
568        let mut trainer = OnlineEmbeddingTrainer::new(cfg, 32);
569        // After many steps the lr should be very small → changes become tiny
570        let mut embs: Vec<Vec<f64>> = vec![vec![1.0; 8]; 10];
571        for _ in 0..100 {
572            trainer.update_step(&mut embs, (0, 1, 2), 1.0);
573        }
574        // Just verify no panic and loss is recorded
575        assert_eq!(trainer.step_count(), 100);
576    }
577}