use crate::error::{GnnError, Result};
use crate::search::cosine_similarity;
use ndarray::Array2;
#[derive(Debug, Clone)]
pub enum OptimizerType {
Sgd {
learning_rate: f32,
momentum: f32,
},
Adam {
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
},
}
#[derive(Debug)]
enum OptimizerState {
Sgd {
velocity: Option<Array2<f32>>,
},
Adam {
m: Option<Array2<f32>>,
v: Option<Array2<f32>>,
t: usize,
},
}
pub struct Optimizer {
optimizer_type: OptimizerType,
state: OptimizerState,
}
impl Optimizer {
pub fn new(optimizer_type: OptimizerType) -> Self {
let state = match &optimizer_type {
OptimizerType::Sgd { .. } => OptimizerState::Sgd { velocity: None },
OptimizerType::Adam { .. } => OptimizerState::Adam {
m: None,
v: None,
t: 0,
},
};
Self {
optimizer_type,
state,
}
}
pub fn step(&mut self, params: &mut Array2<f32>, grads: &Array2<f32>) -> Result<()> {
if params.shape() != grads.shape() {
return Err(GnnError::dimension_mismatch(
format!("{:?}", params.shape()),
format!("{:?}", grads.shape()),
));
}
match (&self.optimizer_type, &mut self.state) {
(
OptimizerType::Sgd {
learning_rate,
momentum,
},
OptimizerState::Sgd { velocity },
) => Self::sgd_step_with_momentum(params, grads, *learning_rate, *momentum, velocity),
(
OptimizerType::Adam {
learning_rate,
beta1,
beta2,
epsilon,
},
OptimizerState::Adam { m, v, t },
) => Self::adam_step(
params,
grads,
*learning_rate,
*beta1,
*beta2,
*epsilon,
m,
v,
t,
),
_ => return Err(GnnError::invalid_input("Optimizer type and state mismatch")),
}
}
fn sgd_step_with_momentum(
params: &mut Array2<f32>,
grads: &Array2<f32>,
learning_rate: f32,
momentum: f32,
velocity: &mut Option<Array2<f32>>,
) -> Result<()> {
if momentum == 0.0 {
*params -= &(grads * learning_rate);
} else {
if velocity.is_none() {
*velocity = Some(Array2::zeros(params.dim()));
}
if let Some(v) = velocity {
let new_velocity = v.mapv(|x| x * momentum) + grads * learning_rate;
*v = new_velocity;
*params -= &*v;
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn adam_step(
params: &mut Array2<f32>,
grads: &Array2<f32>,
learning_rate: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
m: &mut Option<Array2<f32>>,
v: &mut Option<Array2<f32>>,
t: &mut usize,
) -> Result<()> {
if m.is_none() {
*m = Some(Array2::zeros(params.dim()));
}
if v.is_none() {
*v = Some(Array2::zeros(params.dim()));
}
*t += 1;
let timestep = *t as f32;
if let (Some(m_buf), Some(v_buf)) = (m, v) {
let new_m = m_buf.mapv(|x| x * beta1) + grads * (1.0 - beta1);
*m_buf = new_m;
let grads_squared = grads.mapv(|x| x * x);
let new_v = v_buf.mapv(|x| x * beta2) + grads_squared * (1.0 - beta2);
*v_buf = new_v;
let bias_correction1 = 1.0 - beta1.powi(*t as i32);
let m_hat = m_buf.mapv(|x| x / bias_correction1);
let bias_correction2 = 1.0 - beta2.powi(*t as i32);
let v_hat = v_buf.mapv(|x| x / bias_correction2);
let update = m_hat
.iter()
.zip(v_hat.iter())
.map(|(&m_val, &v_val)| learning_rate * m_val / (v_val.sqrt() + epsilon));
for (param, upd) in params.iter_mut().zip(update) {
*param -= upd;
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub enum LossType {
Mse,
CrossEntropy,
BinaryCrossEntropy,
}
pub struct Loss;
impl Loss {
const EPS: f32 = 1e-7;
const MAX_GRAD: f32 = 1e6;
pub fn compute(
loss_type: LossType,
predictions: &Array2<f32>,
targets: &Array2<f32>,
) -> Result<f32> {
if predictions.shape() != targets.shape() {
return Err(GnnError::dimension_mismatch(
format!("{:?}", predictions.shape()),
format!("{:?}", targets.shape()),
));
}
if predictions.is_empty() {
return Err(GnnError::invalid_input(
"Cannot compute loss on empty arrays",
));
}
match loss_type {
LossType::Mse => Self::mse_forward(predictions, targets),
LossType::CrossEntropy => Self::cross_entropy_forward(predictions, targets),
LossType::BinaryCrossEntropy => Self::bce_forward(predictions, targets),
}
}
pub fn gradient(
loss_type: LossType,
predictions: &Array2<f32>,
targets: &Array2<f32>,
) -> Result<Array2<f32>> {
if predictions.shape() != targets.shape() {
return Err(GnnError::dimension_mismatch(
format!("{:?}", predictions.shape()),
format!("{:?}", targets.shape()),
));
}
if predictions.is_empty() {
return Err(GnnError::invalid_input(
"Cannot compute gradient on empty arrays",
));
}
match loss_type {
LossType::Mse => Self::mse_backward(predictions, targets),
LossType::CrossEntropy => Self::cross_entropy_backward(predictions, targets),
LossType::BinaryCrossEntropy => Self::bce_backward(predictions, targets),
}
}
fn mse_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
let diff = predictions - targets;
let squared = diff.mapv(|x| x * x);
Ok(squared.mean().unwrap_or(0.0))
}
fn mse_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
let n = predictions.len() as f32;
let diff = predictions - targets;
Ok(diff.mapv(|x| 2.0 * x / n))
}
fn cross_entropy_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
let log_pred = predictions.mapv(|x| (x.max(Self::EPS)).ln());
let elementwise = targets * &log_pred;
let loss = -elementwise.sum() / predictions.nrows() as f32;
Ok(loss)
}
fn cross_entropy_backward(
predictions: &Array2<f32>,
targets: &Array2<f32>,
) -> Result<Array2<f32>> {
let n = predictions.nrows() as f32;
let safe_pred = predictions.mapv(|x| x.max(Self::EPS));
let grad = targets / &safe_pred;
Ok(grad.mapv(|x| (-x / n).clamp(-Self::MAX_GRAD, Self::MAX_GRAD)))
}
fn bce_forward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<f32> {
let n = predictions.len() as f32;
let loss: f32 = predictions
.iter()
.zip(targets.iter())
.map(|(&p, &t)| {
let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
-(t * p_safe.ln() + (1.0 - t) * (1.0 - p_safe).ln())
})
.sum();
Ok(loss / n)
}
fn bce_backward(predictions: &Array2<f32>, targets: &Array2<f32>) -> Result<Array2<f32>> {
let n = predictions.len() as f32;
let grad_vec: Vec<f32> = predictions
.iter()
.zip(targets.iter())
.map(|(&p, &t)| {
let p_safe = p.clamp(Self::EPS, 1.0 - Self::EPS);
let grad = (-t / p_safe + (1.0 - t) / (1.0 - p_safe)) / n;
grad.clamp(-Self::MAX_GRAD, Self::MAX_GRAD)
})
.collect();
Array2::from_shape_vec(predictions.dim(), grad_vec)
.map_err(|e| GnnError::training(format!("Failed to reshape gradient: {}", e)))
}
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub learning_rate: f32,
pub loss_type: LossType,
pub optimizer_type: OptimizerType,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: 100,
batch_size: 32,
learning_rate: 0.001,
loss_type: LossType::Mse,
optimizer_type: OptimizerType::Adam {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
}
}
}
#[derive(Debug, Clone)]
pub struct TrainConfig {
pub batch_size: usize,
pub n_negatives: usize,
pub temperature: f32,
pub learning_rate: f32,
pub flush_threshold: usize,
}
impl Default for TrainConfig {
fn default() -> Self {
Self {
batch_size: 256,
n_negatives: 64,
temperature: 0.07,
learning_rate: 0.001,
flush_threshold: 1000,
}
}
}
#[derive(Debug, Clone)]
pub struct OnlineConfig {
pub local_steps: usize,
pub propagate_updates: bool,
}
impl Default for OnlineConfig {
fn default() -> Self {
Self {
local_steps: 5,
propagate_updates: true,
}
}
}
pub fn info_nce_loss(
anchor: &[f32],
positives: &[&[f32]],
negatives: &[&[f32]],
temperature: f32,
) -> f32 {
if positives.is_empty() {
return 0.0;
}
let pos_sims: Vec<f32> = positives
.iter()
.map(|pos| cosine_similarity(anchor, pos) / temperature)
.collect();
let neg_sims: Vec<f32> = negatives
.iter()
.map(|neg| cosine_similarity(anchor, neg) / temperature)
.collect();
let mut total_loss = 0.0;
for &pos_sim in &pos_sims {
let mut all_logits = vec![pos_sim];
all_logits.extend(&neg_sims);
let max_logit = all_logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let log_sum_exp = max_logit
+ all_logits
.iter()
.map(|&x| (x - max_logit).exp())
.sum::<f32>()
.ln();
total_loss -= pos_sim - log_sum_exp;
}
total_loss / positives.len() as f32
}
pub fn local_contrastive_loss(
node_embedding: &[f32],
neighbor_embeddings: &[Vec<f32>],
non_neighbor_embeddings: &[Vec<f32>],
temperature: f32,
) -> f32 {
if neighbor_embeddings.is_empty() {
return 0.0;
}
let positives: Vec<&[f32]> = neighbor_embeddings.iter().map(|v| v.as_slice()).collect();
let negatives: Vec<&[f32]> = non_neighbor_embeddings
.iter()
.map(|v| v.as_slice())
.collect();
info_nce_loss(node_embedding, &positives, &negatives, temperature)
}
pub fn sgd_step(embedding: &mut [f32], grad: &[f32], learning_rate: f32) {
assert_eq!(
embedding.len(),
grad.len(),
"Embedding and gradient must have the same length"
);
for (emb, &g) in embedding.iter_mut().zip(grad.iter()) {
*emb -= learning_rate * g;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_train_config_default() {
let config = TrainConfig::default();
assert_eq!(config.batch_size, 256);
assert_eq!(config.n_negatives, 64);
assert_eq!(config.temperature, 0.07);
assert_eq!(config.learning_rate, 0.001);
assert_eq!(config.flush_threshold, 1000);
}
#[test]
fn test_online_config_default() {
let config = OnlineConfig::default();
assert_eq!(config.local_steps, 5);
assert!(config.propagate_updates);
}
#[test]
fn test_info_nce_loss_basic() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negative1 = vec![0.0, 1.0, 0.0];
let negative2 = vec![0.0, 0.0, 1.0];
let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07);
assert!(loss > 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_info_nce_loss_perfect_match() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![1.0, 0.0, 0.0];
let negative1 = vec![0.0, 1.0, 0.0];
let negative2 = vec![0.0, 0.0, 1.0];
let loss = info_nce_loss(&anchor, &[&positive], &[&negative1, &negative2], 0.07);
assert!(loss < 1.0);
assert!(loss.is_finite());
}
#[test]
fn test_info_nce_loss_no_positives() {
let anchor = vec![1.0, 0.0, 0.0];
let negative1 = vec![0.0, 1.0, 0.0];
let loss = info_nce_loss(&anchor, &[], &[&negative1], 0.07);
assert_eq!(loss, 0.0);
}
#[test]
fn test_info_nce_loss_temperature_effect() {
let anchor = vec![1.0, 0.0, 0.0];
let positive = vec![0.9, 0.1, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss_low_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 0.07);
let loss_high_temp = info_nce_loss(&anchor, &[&positive], &[&negative], 1.0);
assert!(
loss_low_temp > 0.0 && loss_low_temp.is_finite(),
"Low temp loss should be positive and finite, got: {}",
loss_low_temp
);
assert!(
loss_high_temp > 0.0 && loss_high_temp.is_finite(),
"High temp loss should be positive and finite, got: {}",
loss_high_temp
);
assert!(loss_low_temp < 10.0, "Loss should not be too large");
assert!(loss_high_temp < 10.0, "Loss should not be too large");
}
#[test]
fn test_local_contrastive_loss_basic() {
let node = vec![1.0, 0.0, 0.0];
let neighbor = vec![0.9, 0.1, 0.0];
let non_neighbor1 = vec![0.0, 1.0, 0.0];
let non_neighbor2 = vec![0.0, 0.0, 1.0];
let loss =
local_contrastive_loss(&node, &[neighbor], &[non_neighbor1, non_neighbor2], 0.07);
assert!(loss > 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_local_contrastive_loss_multiple_neighbors() {
let node = vec![1.0, 0.0, 0.0];
let neighbor1 = vec![0.9, 0.1, 0.0];
let neighbor2 = vec![0.95, 0.05, 0.0];
let non_neighbor = vec![0.0, 1.0, 0.0];
let loss = local_contrastive_loss(&node, &[neighbor1, neighbor2], &[non_neighbor], 0.07);
assert!(loss > 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_local_contrastive_loss_no_neighbors() {
let node = vec![1.0, 0.0, 0.0];
let non_neighbor = vec![0.0, 1.0, 0.0];
let loss = local_contrastive_loss(&node, &[], &[non_neighbor], 0.07);
assert_eq!(loss, 0.0);
}
#[test]
fn test_sgd_step_basic() {
let mut embedding = vec![1.0, 2.0, 3.0];
let gradient = vec![0.1, -0.2, 0.3];
let learning_rate = 0.01;
sgd_step(&mut embedding, &gradient, learning_rate);
assert!((embedding[0] - 0.999).abs() < 1e-6); assert!((embedding[1] - 2.002).abs() < 1e-6); assert!((embedding[2] - 2.997).abs() < 1e-6); }
#[test]
fn test_sgd_step_zero_gradient() {
let mut embedding = vec![1.0, 2.0, 3.0];
let original = embedding.clone();
let gradient = vec![0.0, 0.0, 0.0];
let learning_rate = 0.01;
sgd_step(&mut embedding, &gradient, learning_rate);
assert_eq!(embedding, original);
}
#[test]
fn test_sgd_step_zero_learning_rate() {
let mut embedding = vec![1.0, 2.0, 3.0];
let original = embedding.clone();
let gradient = vec![0.1, 0.2, 0.3];
let learning_rate = 0.0;
sgd_step(&mut embedding, &gradient, learning_rate);
assert_eq!(embedding, original);
}
#[test]
fn test_sgd_step_large_learning_rate() {
let mut embedding = vec![10.0, 20.0, 30.0];
let gradient = vec![1.0, 2.0, 3.0];
let learning_rate = 5.0;
sgd_step(&mut embedding, &gradient, learning_rate);
assert!((embedding[0] - 5.0).abs() < 1e-5); assert!((embedding[1] - 10.0).abs() < 1e-5); assert!((embedding[2] - 15.0).abs() < 1e-5); }
#[test]
#[should_panic(expected = "Embedding and gradient must have the same length")]
fn test_sgd_step_mismatched_lengths() {
let mut embedding = vec![1.0, 2.0, 3.0];
let gradient = vec![0.1, 0.2];
sgd_step(&mut embedding, &gradient, 0.01);
}
#[test]
fn test_info_nce_loss_multiple_positives() {
let anchor = vec![1.0, 0.0, 0.0];
let positive1 = vec![0.9, 0.1, 0.0];
let positive2 = vec![0.95, 0.05, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss = info_nce_loss(&anchor, &[&positive1, &positive2], &[&negative], 0.07);
assert!(loss > 0.0);
assert!(loss.is_finite());
}
#[test]
fn test_contrastive_loss_gradient_property() {
let anchor = vec![1.0, 0.0, 0.0];
let positive_far = vec![0.5, 0.5, 0.0];
let positive_close = vec![0.9, 0.1, 0.0];
let negative = vec![0.0, 1.0, 0.0];
let loss_far = info_nce_loss(&anchor, &[&positive_far], &[&negative], 0.07);
let loss_close = info_nce_loss(&anchor, &[&positive_close], &[&negative], 0.07);
assert!(loss_close < loss_far);
}
#[test]
fn test_sgd_optimizer_basic() {
let optimizer_type = OptimizerType::Sgd {
learning_rate: 0.1,
momentum: 0.0,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
let result = optimizer.step(&mut params, &grads);
assert!(result.is_ok());
assert!((params[[0, 0]] - 0.99).abs() < 1e-6); assert!((params[[0, 1]] - 1.98).abs() < 1e-6); assert!((params[[1, 0]] - 2.97).abs() < 1e-6); assert!((params[[1, 1]] - 3.96).abs() < 1e-6); }
#[test]
fn test_sgd_optimizer_with_momentum() {
let optimizer_type = OptimizerType::Sgd {
learning_rate: 0.1,
momentum: 0.9,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
let result = optimizer.step(&mut params, &grads);
assert!(result.is_ok());
assert!((params[[0, 0]] - 0.99).abs() < 1e-6);
let result = optimizer.step(&mut params, &grads);
assert!(result.is_ok());
assert!(params[[0, 0]] < 0.99);
}
#[test]
fn test_adam_optimizer_basic() {
let optimizer_type = OptimizerType::Adam {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
let original_params = params.clone();
let result = optimizer.step(&mut params, &grads);
assert!(result.is_ok());
assert!(params[[0, 0]] < original_params[[0, 0]]);
assert!(params[[0, 1]] < original_params[[0, 1]]);
assert!(params[[1, 0]] < original_params[[1, 0]]);
assert!(params[[1, 1]] < original_params[[1, 1]]);
assert!(params.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_adam_optimizer_multiple_steps() {
let optimizer_type = OptimizerType::Adam {
learning_rate: 0.01,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let grads = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
let initial_params = params.clone();
for _ in 0..10 {
let result = optimizer.step(&mut params, &grads);
assert!(result.is_ok());
assert!(params.iter().all(|&x| x.is_finite()));
}
assert!(params[[0, 0]] < initial_params[[0, 0]]);
assert!(params[[1, 1]] < initial_params[[1, 1]]);
for i in 0..2 {
for j in 0..2 {
assert!(params[[i, j]] < initial_params[[i, j]]);
}
}
}
#[test]
fn test_adam_bias_correction() {
let optimizer_type = OptimizerType::Adam {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
};
let mut optimizer = Optimizer::new(optimizer_type.clone());
let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
let grads = Array2::from_shape_vec((1, 1), vec![0.1]).unwrap();
let result = optimizer.step(&mut params, &grads);
assert!(result.is_ok());
let first_update = 1.0 - params[[0, 0]];
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
for _ in 0..100 {
let _ = optimizer.step(&mut params, &grads);
}
assert!(first_update > 0.0);
}
#[test]
fn test_optimizer_shape_mismatch() {
let optimizer_type = OptimizerType::Adam {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let grads = Array2::from_shape_vec((3, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
let result = optimizer.step(&mut params, &grads);
assert!(result.is_err());
if let Err(GnnError::DimensionMismatch { expected, actual }) = result {
assert!(expected.contains("2, 2"));
assert!(actual.contains("3, 2"));
} else {
panic!("Expected DimensionMismatch error");
}
}
#[test]
fn test_adam_convergence() {
let optimizer_type = OptimizerType::Adam {
learning_rate: 0.5,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
for _ in 0..200 {
let grads =
Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]])
.unwrap();
let _ = optimizer.step(&mut params, &grads);
}
assert!(params[[0, 0]].abs() < 0.5);
assert!(params[[0, 1]].abs() < 0.5);
}
#[test]
fn test_sgd_momentum_convergence() {
let optimizer_type = OptimizerType::Sgd {
learning_rate: 0.01,
momentum: 0.9,
};
let mut optimizer = Optimizer::new(optimizer_type);
let mut params = Array2::from_shape_vec((1, 2), vec![5.0, 5.0]).unwrap();
for _ in 0..200 {
let grads =
Array2::from_shape_vec((1, 2), vec![2.0 * params[[0, 0]], 2.0 * params[[0, 1]]])
.unwrap();
let _ = optimizer.step(&mut params, &grads);
}
assert!(params[[0, 0]].abs() < 0.5);
assert!(params[[0, 1]].abs() < 0.5);
}
#[test]
fn test_mse_loss_zero_when_equal() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let target = pred.clone();
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
assert!(
(loss - 0.0).abs() < 1e-6,
"MSE should be 0 when pred == target"
);
}
#[test]
fn test_mse_loss_positive() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let target = Array2::from_shape_vec((2, 2), vec![2.0, 3.0, 4.0, 5.0]).unwrap();
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
assert!((loss - 1.0).abs() < 1e-6, "MSE should be 1.0, got {}", loss);
}
#[test]
fn test_mse_loss_varying_diffs() {
let pred = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 0.0]).unwrap();
let target = Array2::from_shape_vec((1, 4), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
assert!((loss - 7.5).abs() < 1e-6, "MSE should be 7.5, got {}", loss);
}
#[test]
fn test_mse_gradient_shape() {
let pred = Array2::from_shape_vec((2, 3), vec![0.0; 6]).unwrap();
let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_mse_gradient_direction() {
let pred = Array2::from_shape_vec((1, 2), vec![0.0, 2.0]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 1.0]).unwrap();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
assert!(
grad[[0, 0]] < 0.0,
"Gradient should be negative when pred < target"
);
assert!(
grad[[0, 1]] > 0.0,
"Gradient should be positive when pred > target"
);
}
#[test]
fn test_mse_gradient_zero_when_equal() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let target = pred.clone();
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
assert!(
grad.iter().all(|&x| x.abs() < 1e-6),
"Gradient should be zero when pred == target"
);
}
#[test]
fn test_bce_loss_perfect_predictions() {
let pred = Array2::from_shape_vec((1, 2), vec![0.999, 0.001]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert!(
loss < 0.1,
"BCE should be low for good predictions, got {}",
loss
);
}
#[test]
fn test_bce_loss_bad_predictions() {
let pred = Array2::from_shape_vec((1, 2), vec![0.001, 0.999]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert!(
loss > 1.0,
"BCE should be high for bad predictions, got {}",
loss
);
}
#[test]
fn test_bce_loss_numerical_stability() {
let pred = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![0.0, 1.0]).unwrap();
let loss = Loss::compute(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert!(
loss.is_finite(),
"BCE should be finite even with extreme values"
);
}
#[test]
fn test_bce_gradient_shape() {
let pred = Array2::from_shape_vec((3, 2), vec![0.5; 6]).unwrap();
let target = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap();
let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_bce_gradient_direction() {
let pred = Array2::from_shape_vec((1, 2), vec![0.3, 0.7]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let grad = Loss::gradient(LossType::BinaryCrossEntropy, &pred, &target).unwrap();
assert!(
grad[[0, 0]] < 0.0,
"Gradient should be negative to increase pred towards 1"
);
assert!(
grad[[0, 1]] > 0.0,
"Gradient should be positive to decrease pred towards 0"
);
}
#[test]
fn test_cross_entropy_one_hot() {
let pred = Array2::from_shape_vec((2, 3), vec![0.7, 0.2, 0.1, 0.1, 0.8, 0.1]).unwrap();
let target = Array2::from_shape_vec((2, 3), vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
assert!(
loss > 0.0 && loss < 1.0,
"CE should be reasonable for good predictions, got {}",
loss
);
}
#[test]
fn test_cross_entropy_wrong_class() {
let pred = Array2::from_shape_vec((1, 3), vec![0.1, 0.1, 0.8]).unwrap();
let target = Array2::from_shape_vec((1, 3), vec![1.0, 0.0, 0.0]).unwrap();
let loss = Loss::compute(LossType::CrossEntropy, &pred, &target).unwrap();
assert!(
loss > 1.0,
"CE should be high for wrong predictions, got {}",
loss
);
}
#[test]
fn test_cross_entropy_gradient_shape() {
let pred = Array2::from_shape_vec((2, 4), vec![0.25; 8]).unwrap();
let target =
Array2::from_shape_vec((2, 4), vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
let grad = Loss::gradient(LossType::CrossEntropy, &pred, &target).unwrap();
assert_eq!(grad.shape(), pred.shape());
}
#[test]
fn test_loss_dimension_mismatch_error() {
let pred = Array2::from_shape_vec((2, 2), vec![1.0; 4]).unwrap();
let target = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let result = Loss::compute(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Should error on dimension mismatch");
let result = Loss::gradient(LossType::Mse, &pred, &target);
assert!(
result.is_err(),
"Gradient should error on dimension mismatch"
);
}
#[test]
fn test_loss_empty_array_error() {
let pred = Array2::from_shape_vec((0, 2), vec![]).unwrap();
let target = Array2::from_shape_vec((0, 2), vec![]).unwrap();
let result = Loss::compute(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Should error on empty arrays");
let result = Loss::gradient(LossType::Mse, &pred, &target);
assert!(result.is_err(), "Gradient should error on empty arrays");
}
#[test]
fn test_loss_gradient_numerical_check() {
let pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.8]).unwrap();
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let analytical_grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
let eps = 1e-5;
for i in 0..2 {
let mut pred_plus = pred.clone();
let mut pred_minus = pred.clone();
pred_plus[[0, i]] += eps;
pred_minus[[0, i]] -= eps;
let loss_plus = Loss::compute(LossType::Mse, &pred_plus, &target).unwrap();
let loss_minus = Loss::compute(LossType::Mse, &pred_minus, &target).unwrap();
let numerical_grad = (loss_plus - loss_minus) / (2.0 * eps);
let error = (analytical_grad[[0, i]] - numerical_grad).abs();
assert!(
error < 1e-3,
"Numerical gradient check failed: analytical={}, numerical={}",
analytical_grad[[0, i]],
numerical_grad
);
}
}
#[test]
fn test_training_loop_integration() {
let mut optimizer = Optimizer::new(OptimizerType::Sgd {
learning_rate: 0.1,
momentum: 0.0,
});
let target = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).unwrap();
let mut pred = Array2::from_shape_vec((1, 2), vec![0.5, 0.5]).unwrap();
let initial_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
for _ in 0..10 {
let grad = Loss::gradient(LossType::Mse, &pred, &target).unwrap();
optimizer.step(&mut pred, &grad).unwrap();
}
let final_loss = Loss::compute(LossType::Mse, &pred, &target).unwrap();
assert!(
final_loss < initial_loss,
"Loss should decrease during training"
);
}
}