#[derive(Debug, Clone)]
pub struct OnlineUpdateConfig {
pub learning_rate: f64,
pub decay: f64,
pub regularization: f64,
pub batch_size: usize,
pub max_grad_norm: f64,
}
impl Default for OnlineUpdateConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
decay: 0.9999,
regularization: 1e-4,
batch_size: 32,
max_grad_norm: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct AdamOptimizer {
pub m: Vec<f64>,
pub v: Vec<f64>,
pub t: u64,
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub epsilon: f64,
}
impl AdamOptimizer {
pub fn new(param_count: usize, lr: f64) -> Self {
Self {
m: vec![0.0; param_count],
v: vec![0.0; param_count],
t: 0,
lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
}
}
pub fn step(&mut self, params: &mut [f64], gradients: &[f64]) {
self.t += 1;
let t = self.t as f64;
let bias_corr1 = 1.0 - self.beta1.powf(t);
let bias_corr2 = 1.0 - self.beta2.powf(t);
for i in 0..params.len().min(gradients.len()).min(self.m.len()) {
let g = gradients[i];
self.m[i] = self.beta1 * self.m[i] + (1.0 - self.beta1) * g;
self.v[i] = self.beta2 * self.v[i] + (1.0 - self.beta2) * g * g;
let m_hat = self.m[i] / bias_corr1;
let v_hat = self.v[i] / bias_corr2;
params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
}
}
pub fn reset(&mut self) {
self.m.iter_mut().for_each(|x| *x = 0.0);
self.v.iter_mut().for_each(|x| *x = 0.0);
self.t = 0;
}
pub fn step_count(&self) -> u64 {
self.t
}
}
pub struct OnlineEmbeddingTrainer {
pub config: OnlineUpdateConfig,
pub optimizer: AdamOptimizer,
pub step: u64,
pub loss_history: Vec<f64>,
}
impl OnlineEmbeddingTrainer {
pub fn new(config: OnlineUpdateConfig, param_count: usize) -> Self {
let lr = config.learning_rate;
Self {
config,
optimizer: AdamOptimizer::new(param_count, lr),
step: 0,
loss_history: Vec::new(),
}
}
pub fn update_step(
&mut self,
embeddings: &mut [Vec<f64>],
triple: (usize, usize, usize),
label: f64,
) {
let (head, relation, tail) = triple;
if embeddings.is_empty() {
return;
}
let n_emb = embeddings.len();
let dim = embeddings[0].len();
if head >= n_emb || relation >= n_emb || tail >= n_emb || dim == 0 {
return;
}
let effective_lr = self.config.learning_rate * self.config.decay.powf(self.step as f64);
let h = embeddings[head].clone();
let r = embeddings[relation].clone();
let t = embeddings[tail].clone();
let diff: Vec<f64> = (0..dim).map(|i| h[i] + r[i] - t[i]).collect();
let norm: f64 = diff.iter().map(|x| x * x).sum::<f64>().sqrt().max(1e-10);
let loss = (label * (-norm)).max(0.0) + norm * 1e-4;
let base_grad_sign = if label > 0.0 { 1.0 } else { -1.0 };
let mut grads: Vec<f64> = diff.iter().map(|&d| base_grad_sign * d / norm).collect();
let grad_norm: f64 = grads.iter().map(|g| g * g).sum::<f64>().sqrt();
if grad_norm > self.config.max_grad_norm {
let scale = self.config.max_grad_norm / grad_norm;
grads.iter_mut().for_each(|g| *g *= scale);
}
let reg = self.config.regularization;
let optimizer_lr = effective_lr;
self.optimizer.lr = optimizer_lr;
let mut h_params = embeddings[head].clone();
let h_grads: Vec<f64> = (0..dim).map(|i| grads[i] + reg * h[i]).collect();
{
let off = dim.min(self.optimizer.m.len());
let (m_sl, v_sl, t_ref, b1, b2, eps) = (
&mut self.optimizer.m[0..off],
&mut self.optimizer.v[0..off],
&mut self.optimizer.t,
self.optimizer.beta1,
self.optimizer.beta2,
self.optimizer.epsilon,
);
adam_step_slice(
m_sl,
v_sl,
t_ref,
&mut h_params,
&h_grads,
optimizer_lr,
b1,
b2,
eps,
);
}
embeddings[head] = h_params;
let mut r_params = embeddings[relation].clone();
let r_grads: Vec<f64> = (0..dim).map(|i| grads[i] + reg * r[i]).collect();
{
let off = dim.min(self.optimizer.m.len());
let (m_sl, v_sl, t_ref, b1, b2, eps) = (
&mut self.optimizer.m[0..off],
&mut self.optimizer.v[0..off],
&mut self.optimizer.t,
self.optimizer.beta1,
self.optimizer.beta2,
self.optimizer.epsilon,
);
adam_step_slice(
m_sl,
v_sl,
t_ref,
&mut r_params,
&r_grads,
optimizer_lr,
b1,
b2,
eps,
);
}
embeddings[relation] = r_params;
let mut t_params = embeddings[tail].clone();
let t_grads: Vec<f64> = (0..dim).map(|i| -grads[i] + reg * t[i]).collect();
{
let off = dim.min(self.optimizer.m.len());
let (m_sl, v_sl, t_ref, b1, b2, eps) = (
&mut self.optimizer.m[0..off],
&mut self.optimizer.v[0..off],
&mut self.optimizer.t,
self.optimizer.beta1,
self.optimizer.beta2,
self.optimizer.epsilon,
);
adam_step_slice(
m_sl,
v_sl,
t_ref,
&mut t_params,
&t_grads,
optimizer_lr,
b1,
b2,
eps,
);
}
embeddings[tail] = t_params;
self.loss_history.push(loss);
self.step += 1;
}
pub fn avg_loss(&self) -> f64 {
if self.loss_history.is_empty() {
return 0.0;
}
self.loss_history.iter().sum::<f64>() / self.loss_history.len() as f64
}
pub fn recent_loss(&self, n: usize) -> f64 {
if self.loss_history.is_empty() {
return 0.0;
}
let start = self.loss_history.len().saturating_sub(n);
let slice = &self.loss_history[start..];
slice.iter().sum::<f64>() / slice.len() as f64
}
pub fn step_count(&self) -> u64 {
self.step
}
}
#[allow(clippy::too_many_arguments)]
fn adam_step_slice(
m: &mut [f64],
v: &mut [f64],
t: &mut u64,
params: &mut [f64],
grads: &[f64],
lr: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
) {
*t += 1;
let tc = *t as f64;
let bc1 = 1.0 - beta1.powf(tc);
let bc2 = 1.0 - beta2.powf(tc);
let len = params.len().min(grads.len()).min(m.len()).min(v.len());
for i in 0..len {
let g = grads[i];
m[i] = beta1 * m[i] + (1.0 - beta1) * g;
v[i] = beta2 * v[i] + (1.0 - beta2) * g * g;
let m_hat = m[i] / bc1;
let v_hat = v[i] / bc2;
params[i] -= lr * m_hat / (v_hat.sqrt() + epsilon);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config_values() {
let cfg = OnlineUpdateConfig::default();
assert!((cfg.learning_rate - 0.001).abs() < 1e-12);
assert!((cfg.decay - 0.9999).abs() < 1e-12);
assert!((cfg.regularization - 1e-4).abs() < 1e-12);
assert_eq!(cfg.batch_size, 32);
assert!((cfg.max_grad_norm - 1.0).abs() < 1e-12);
}
#[test]
fn test_config_clone() {
let cfg = OnlineUpdateConfig::default();
let cloned = cfg.clone();
assert!((cloned.learning_rate - cfg.learning_rate).abs() < 1e-12);
}
#[test]
fn test_adam_creation() {
let opt = AdamOptimizer::new(10, 0.001);
assert_eq!(opt.m.len(), 10);
assert_eq!(opt.v.len(), 10);
assert_eq!(opt.t, 0);
assert!((opt.lr - 0.001).abs() < 1e-12);
}
#[test]
fn test_adam_step_changes_params() {
let mut opt = AdamOptimizer::new(4, 0.01);
let mut params = vec![1.0_f64; 4];
let grads = vec![0.1, 0.2, 0.3, 0.4];
opt.step(&mut params, &grads);
for &p in ¶ms {
assert!(p < 1.0, "params should decrease with positive gradient");
}
}
#[test]
fn test_adam_step_count() {
let mut opt = AdamOptimizer::new(4, 0.01);
let mut params = vec![0.0_f64; 4];
let grads = vec![0.1; 4];
opt.step(&mut params, &grads);
opt.step(&mut params, &grads);
assert_eq!(opt.step_count(), 2);
}
#[test]
fn test_adam_reset() {
let mut opt = AdamOptimizer::new(4, 0.01);
let mut params = vec![0.0_f64; 4];
let grads = vec![0.1; 4];
opt.step(&mut params, &grads);
opt.reset();
assert_eq!(opt.step_count(), 0);
assert!(opt.m.iter().all(|&x| x == 0.0));
assert!(opt.v.iter().all(|&x| x == 0.0));
}
#[test]
fn test_adam_converges_simple_quadratic() {
let mut opt = AdamOptimizer::new(1, 0.1);
let mut params = vec![0.0_f64];
for _ in 0..500 {
let g = 2.0 * (params[0] - 3.0);
opt.step(&mut params, &[g]);
}
assert!(
(params[0] - 3.0).abs() < 0.1,
"Adam should converge to x=3, got {}",
params[0]
);
}
#[test]
fn test_adam_zero_gradient_no_change() {
let mut opt = AdamOptimizer::new(4, 0.01);
let params_before = vec![1.0_f64, 2.0, 3.0, 4.0];
let mut params = params_before.clone();
let grads = vec![1e-15_f64; 4];
opt.step(&mut params, &grads);
for (a, b) in params.iter().zip(params_before.iter()) {
assert!(
(a - b).abs() < 1e-3,
"near-zero gradient should barely change params"
);
}
}
#[test]
fn test_trainer_creation() {
let cfg = OnlineUpdateConfig::default();
let trainer = OnlineEmbeddingTrainer::new(cfg, 100);
assert_eq!(trainer.step_count(), 0);
assert_eq!(trainer.avg_loss(), 0.0);
}
#[test]
fn test_trainer_update_increments_step() {
let cfg = OnlineUpdateConfig::default();
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
trainer.update_step(&mut embs, (0, 1, 2), 1.0);
assert_eq!(trainer.step_count(), 1);
}
#[test]
fn test_trainer_records_loss() {
let cfg = OnlineUpdateConfig::default();
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
trainer.update_step(&mut embs, (0, 1, 2), 1.0);
assert!(!trainer.loss_history.is_empty());
assert!(trainer.avg_loss().is_finite());
}
#[test]
fn test_trainer_recent_loss_empty() {
let cfg = OnlineUpdateConfig::default();
let trainer = OnlineEmbeddingTrainer::new(cfg, 64);
assert_eq!(trainer.recent_loss(5), 0.0);
}
#[test]
fn test_trainer_recent_loss_fewer_than_n() {
let cfg = OnlineUpdateConfig::default();
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
trainer.update_step(&mut embs, (0, 1, 2), 1.0);
let rl = trainer.recent_loss(5);
assert!(rl.is_finite());
}
#[test]
fn test_trainer_modifies_embeddings() {
let cfg = OnlineUpdateConfig::default();
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
let initial = vec![vec![1.0_f64; 8]; 10];
let mut embs = initial.clone();
trainer.update_step(&mut embs, (0, 1, 2), 1.0);
let changed = embs
.iter()
.zip(initial.iter())
.any(|(a, b)| a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-12));
assert!(changed, "update_step should modify at least one embedding");
}
#[test]
fn test_trainer_out_of_bounds_indices_ignored() {
let cfg = OnlineUpdateConfig::default();
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 5];
trainer.update_step(&mut embs, (10, 20, 30), 1.0);
assert_eq!(trainer.step_count(), 0); }
#[test]
fn test_trainer_multiple_steps() {
let cfg = OnlineUpdateConfig::default();
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 64);
let mut embs: Vec<Vec<f64>> = vec![vec![0.1; 8]; 10];
for i in 0..20 {
let h = i % 5;
let r = (i + 1) % 5;
let t = (i + 2) % 5;
trainer.update_step(&mut embs, (h, r, t), 1.0);
}
assert_eq!(trainer.step_count(), 20);
assert!(trainer.avg_loss().is_finite());
}
#[test]
fn test_trainer_positive_vs_negative_label() {
let cfg = OnlineUpdateConfig::default();
let mut t_pos = OnlineEmbeddingTrainer::new(cfg.clone(), 64);
let mut t_neg = OnlineEmbeddingTrainer::new(cfg, 64);
let mut embs_pos: Vec<Vec<f64>> = vec![vec![0.5; 8]; 10];
let mut embs_neg = embs_pos.clone();
for _ in 0..10 {
t_pos.update_step(&mut embs_pos, (0, 1, 2), 1.0);
t_neg.update_step(&mut embs_neg, (0, 1, 2), -1.0);
}
let diff_exists = embs_pos[0]
.iter()
.zip(embs_neg[0].iter())
.any(|(a, b)| (a - b).abs() > 1e-9);
assert!(
diff_exists,
"positive and negative training should produce different embeddings"
);
}
#[test]
fn test_adam_optimizer_lr_decay() {
let cfg = OnlineUpdateConfig {
decay: 0.5,
learning_rate: 0.01,
..Default::default()
};
let mut trainer = OnlineEmbeddingTrainer::new(cfg, 32);
let mut embs: Vec<Vec<f64>> = vec![vec![1.0; 8]; 10];
for _ in 0..100 {
trainer.update_step(&mut embs, (0, 1, 2), 1.0);
}
assert_eq!(trainer.step_count(), 100);
}
}