#![allow(non_snake_case)]
use super::super::*;
#[test]
fn test_safe_cholesky_solve_positive_definite() {
let A = Matrix::from_vec(2, 2, vec![4.0, 2.0, 2.0, 3.0]).expect("valid dimensions");
let b = Vector::from_slice(&[6.0, 5.0]);
let x = safe_cholesky_solve(&A, &b, 1e-8, 10).expect("should solve");
assert_eq!(x.len(), 2);
let Ax = Vector::from_slice(&[
A.get(0, 0) * x[0] + A.get(0, 1) * x[1],
A.get(1, 0) * x[0] + A.get(1, 1) * x[1],
]);
assert!((Ax[0] - b[0]).abs() < 1e-5);
assert!((Ax[1] - b[1]).abs() < 1e-5);
}
#[test]
fn test_safe_cholesky_solve_identity() {
let A = Matrix::eye(3);
let b = Vector::from_slice(&[1.0, 2.0, 3.0]);
let x = safe_cholesky_solve(&A, &b, 1e-8, 10).expect("should solve");
assert!((x[0] - 1.0).abs() < 1e-6);
assert!((x[1] - 2.0).abs() < 1e-6);
assert!((x[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_safe_cholesky_solve_ill_conditioned() {
let A = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1e-10]).expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 1.0]);
let x = safe_cholesky_solve(&A, &b, 1e-8, 10).expect("should solve with regularization");
assert_eq!(x.len(), 2);
assert!((x[0] - 1.0).abs() < 1e-3);
}
#[test]
fn test_safe_cholesky_solve_not_positive_definite() {
let A = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, -0.5]).expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 1.0]);
let result = safe_cholesky_solve(&A, &b, 1e-4, 10);
if let Ok(x) = result {
assert_eq!(x.len(), 2);
} else {
}
}
#[test]
fn test_safe_cholesky_solve_zero_matrix() {
let A = Matrix::from_vec(2, 2, vec![0.0, 0.0, 0.0, 0.0]).expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 1.0]);
let result = safe_cholesky_solve(&A, &b, 1e-4, 10);
assert!(result.is_ok()); }
#[test]
fn test_safe_cholesky_solve_small_initial_lambda() {
let A = Matrix::eye(2);
let b = Vector::from_slice(&[1.0, 1.0]);
let x = safe_cholesky_solve(&A, &b, 1e-12, 10).expect("should solve");
assert!((x[0] - 1.0).abs() < 1e-6);
assert!((x[1] - 1.0).abs() < 1e-6);
}
#[test]
fn test_safe_cholesky_solve_max_attempts() {
let A = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]).expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 1.0]);
let x = safe_cholesky_solve(&A, &b, 1e-8, 1).expect("should solve");
assert_eq!(x.len(), 2);
}
#[test]
fn test_safe_cholesky_solve_large_system() {
let n = 5;
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 2.0; if i > 0 {
data[i * n + (i - 1)] = 1.0; data[(i - 1) * n + i] = 1.0; }
}
let A = Matrix::from_vec(n, n, data).expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 1.0, 1.0, 1.0, 1.0]);
let x = safe_cholesky_solve(&A, &b, 1e-8, 10).expect("should solve");
assert_eq!(x.len(), 5);
}
#[test]
fn test_safe_cholesky_solve_symmetric() {
let A = Matrix::from_vec(3, 3, vec![2.0, 1.0, 0.0, 1.0, 2.0, 1.0, 0.0, 1.0, 2.0])
.expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 2.0, 1.0]);
let x = safe_cholesky_solve(&A, &b, 1e-8, 10).expect("should solve");
assert_eq!(x.len(), 3);
}
#[test]
fn test_safe_cholesky_solve_lambda_escalation() {
let A = Matrix::from_vec(2, 2, vec![1.0, 0.999, 0.999, 1.0]).expect("valid dimensions");
let b = Vector::from_slice(&[1.0, 1.0]);
let x = safe_cholesky_solve(&A, &b, 1e-10, 15).expect("should solve");
assert_eq!(x.len(), 2);
assert!(x[0].is_finite());
assert!(x[1].is_finite());
}
#[test]
fn test_sgd_new() {
let optimizer = SGD::new(0.01);
assert!((optimizer.learning_rate() - 0.01).abs() < 1e-6);
assert!((optimizer.momentum() - 0.0).abs() < 1e-6);
assert!(!optimizer.has_momentum());
}
#[test]
fn test_sgd_with_momentum() {
let optimizer = SGD::new(0.01).with_momentum(0.9);
assert!((optimizer.learning_rate() - 0.01).abs() < 1e-6);
assert!((optimizer.momentum() - 0.9).abs() < 1e-6);
assert!(optimizer.has_momentum());
}
#[test]
fn test_sgd_step_basic() {
let mut optimizer = SGD::new(0.1);
let mut params = Vector::from_slice(&[1.0, 2.0, 3.0]);
let gradients = Vector::from_slice(&[1.0, 2.0, 3.0]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 0.9).abs() < 1e-6);
assert!((params[1] - 1.8).abs() < 1e-6);
assert!((params[2] - 2.7).abs() < 1e-6);
}
#[test]
fn test_sgd_step_with_momentum() {
let mut optimizer = SGD::new(0.1).with_momentum(0.9);
let mut params = Vector::from_slice(&[1.0, 1.0]);
let gradients = Vector::from_slice(&[1.0, 1.0]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 0.9).abs() < 1e-6);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 0.71).abs() < 1e-6);
}
#[test]
fn test_sgd_momentum_accumulation() {
let mut optimizer = SGD::new(0.1).with_momentum(0.9);
let mut params = Vector::from_slice(&[0.0]);
let gradients = Vector::from_slice(&[1.0]);
let mut prev_step = 0.0;
for _ in 0..10 {
let before = params[0];
optimizer.step(&mut params, &gradients);
let step = before - params[0];
assert!(step >= prev_step);
prev_step = step;
}
}
#[test]
fn test_sgd_reset() {
let mut optimizer = SGD::new(0.1).with_momentum(0.9);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[1.0]);
optimizer.step(&mut params, &gradients);
optimizer.reset();
let mut params2 = Vector::from_slice(&[1.0]);
optimizer.step(&mut params2, &gradients);
assert!((params2[0] - 0.9).abs() < 1e-6);
}
#[test]
fn test_sgd_zero_gradient() {
let mut optimizer = SGD::new(0.1);
let mut params = Vector::from_slice(&[1.0, 2.0]);
let gradients = Vector::from_slice(&[0.0, 0.0]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 1.0).abs() < 1e-6);
assert!((params[1] - 2.0).abs() < 1e-6);
}
#[test]
fn test_sgd_negative_gradients() {
let mut optimizer = SGD::new(0.1);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[-1.0]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 1.1).abs() < 1e-6);
}
#[test]
#[should_panic(expected = "same length")]
fn test_sgd_mismatched_lengths() {
let mut optimizer = SGD::new(0.1);
let mut params = Vector::from_slice(&[1.0, 2.0]);
let gradients = Vector::from_slice(&[1.0]);
optimizer.step(&mut params, &gradients);
}
#[test]
fn test_sgd_large_learning_rate() {
let mut optimizer = SGD::new(10.0);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[0.1]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 0.0).abs() < 1e-6);
}
#[test]
fn test_sgd_small_learning_rate() {
let mut optimizer = SGD::new(0.001);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[1.0]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 0.999).abs() < 1e-6);
}
#[test]
fn test_sgd_clone() {
let optimizer = SGD::new(0.01).with_momentum(0.9);
let cloned = optimizer.clone();
assert!((cloned.learning_rate() - optimizer.learning_rate()).abs() < 1e-6);
assert!((cloned.momentum() - optimizer.momentum()).abs() < 1e-6);
}
#[test]
fn test_sgd_multiple_steps() {
let mut optimizer = SGD::new(0.1);
let mut params = Vector::from_slice(&[10.0]);
let gradients = Vector::from_slice(&[1.0]);
for _ in 0..10 {
optimizer.step(&mut params, &gradients);
}
assert!((params[0] - 9.0).abs() < 1e-4);
}
#[test]
fn test_sgd_velocity_reinitialization() {
let mut optimizer = SGD::new(0.1).with_momentum(0.9);
let mut params = Vector::from_slice(&[1.0, 2.0]);
let gradients = Vector::from_slice(&[1.0, 1.0]);
optimizer.step(&mut params, &gradients);
let mut params3 = Vector::from_slice(&[1.0, 2.0, 3.0]);
let gradients3 = Vector::from_slice(&[1.0, 1.0, 1.0]);
optimizer.step(&mut params3, &gradients3);
assert!((params3[0] - 0.9).abs() < 1e-6);
}
#[test]
fn test_adam_new() {
let optimizer = Adam::new(0.001);
assert!((optimizer.learning_rate() - 0.001).abs() < 1e-9);
assert!((optimizer.beta1() - 0.9).abs() < 1e-9);
assert!((optimizer.beta2() - 0.999).abs() < 1e-9);
assert!((optimizer.epsilon() - 1e-8).abs() < 1e-15);
assert_eq!(optimizer.steps(), 0);
}
#[test]
fn test_adam_with_beta1() {
let optimizer = Adam::new(0.001).with_beta1(0.95);
assert!((optimizer.beta1() - 0.95).abs() < 1e-9);
}
#[test]
fn test_adam_with_beta2() {
let optimizer = Adam::new(0.001).with_beta2(0.9999);
assert!((optimizer.beta2() - 0.9999).abs() < 1e-9);
}
#[test]
fn test_adam_with_epsilon() {
let optimizer = Adam::new(0.001).with_epsilon(1e-7);
assert!((optimizer.epsilon() - 1e-7).abs() < 1e-15);
}
#[test]
fn test_adam_step_basic() {
let mut optimizer = Adam::new(0.001);
let mut params = Vector::from_slice(&[1.0, 2.0]);
let gradients = Vector::from_slice(&[0.1, 0.2]);
optimizer.step(&mut params, &gradients);
assert!(params[0] < 1.0); assert!(params[1] < 2.0); assert_eq!(optimizer.steps(), 1);
}
#[test]
fn test_adam_multiple_steps() {
let mut optimizer = Adam::new(0.001);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[1.0]);
let initial = params[0];
for _ in 0..5 {
optimizer.step(&mut params, &gradients);
}
assert!(params[0] < initial);
assert_eq!(optimizer.steps(), 5);
}
#[test]
fn test_adam_bias_correction() {
let mut optimizer = Adam::new(0.01);
let mut params = Vector::from_slice(&[10.0]);
let gradients = Vector::from_slice(&[1.0]);
optimizer.step(&mut params, &gradients);
let first_step_size = 10.0 - params[0];
let mut optimizer2 = Adam::new(0.01);
let mut params2 = Vector::from_slice(&[10.0]);
optimizer2.step(&mut params2, &gradients);
optimizer2.step(&mut params2, &gradients);
let second_step_size = params[0] - params2[0];
assert!(first_step_size > second_step_size * 0.5);
}
#[test]
fn test_adam_reset() {
let mut optimizer = Adam::new(0.001);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[1.0]);
optimizer.step(&mut params, &gradients);
assert_eq!(optimizer.steps(), 1);
optimizer.reset();
assert_eq!(optimizer.steps(), 0);
}
#[test]
fn test_adam_zero_gradient() {
let mut optimizer = Adam::new(0.001);
let mut params = Vector::from_slice(&[1.0, 2.0]);
let gradients = Vector::from_slice(&[0.0, 0.0]);
optimizer.step(&mut params, &gradients);
assert!((params[0] - 1.0).abs() < 0.01);
assert!((params[1] - 2.0).abs() < 0.01);
}
#[test]
fn test_adam_negative_gradients() {
let mut optimizer = Adam::new(0.001);
let mut params = Vector::from_slice(&[1.0]);
let gradients = Vector::from_slice(&[-1.0]);
optimizer.step(&mut params, &gradients);
assert!(params[0] > 1.0);
}
#[path = "core_adam_errors.rs"]
mod core_adam_errors;
#[path = "core_lbfgs.rs"]
mod core_lbfgs;
#[path = "core_conjugate_gradient.rs"]
mod core_conjugate_gradient;