use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GradCheckConfig {
pub epsilon: f64,
pub rel_tolerance: f64,
pub abs_tolerance: f64,
pub verbose: bool,
pub max_errors_to_report: usize,
}
impl Default for GradCheckConfig {
fn default() -> Self {
GradCheckConfig {
epsilon: 1e-5,
rel_tolerance: 1e-3,
abs_tolerance: 1e-5,
verbose: false,
max_errors_to_report: 10,
}
}
}
impl GradCheckConfig {
pub fn strict() -> Self {
GradCheckConfig {
epsilon: 1e-6,
rel_tolerance: 1e-4,
abs_tolerance: 1e-6,
verbose: true,
max_errors_to_report: 10,
}
}
pub fn relaxed() -> Self {
GradCheckConfig {
epsilon: 1e-4,
rel_tolerance: 1e-2,
abs_tolerance: 1e-4,
verbose: false,
max_errors_to_report: 10,
}
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn with_epsilon(mut self, epsilon: f64) -> Self {
self.epsilon = epsilon;
self
}
pub fn with_rel_tolerance(mut self, tolerance: f64) -> Self {
self.rel_tolerance = tolerance;
self
}
pub fn with_abs_tolerance(mut self, tolerance: f64) -> Self {
self.abs_tolerance = tolerance;
self
}
}
#[derive(Debug, Clone)]
pub struct GradCheckResult {
pub num_params: usize,
pub num_errors: usize,
pub max_error: f64,
pub max_rel_error: f64,
pub avg_error: f64,
pub passed: bool,
pub errors: Vec<GradientError>,
}
impl GradCheckResult {
pub fn new(num_params: usize) -> Self {
GradCheckResult {
num_params,
num_errors: 0,
max_error: 0.0,
max_rel_error: 0.0,
avg_error: 0.0,
passed: true,
errors: Vec::new(),
}
}
pub fn add_error(&mut self, error: GradientError) {
self.num_errors += 1;
self.max_error = self.max_error.max(error.abs_error);
self.max_rel_error = self.max_rel_error.max(error.rel_error);
self.passed = false;
self.errors.push(error);
}
pub fn finalize(mut self) -> Self {
if !self.errors.is_empty() {
let total_error: f64 = self.errors.iter().map(|e| e.abs_error).sum();
self.avg_error = total_error / self.errors.len() as f64;
}
self
}
pub fn summary(&self) -> String {
format!(
"Gradient Check: {} params, {} errors, max_error={:.2e}, max_rel_error={:.2e}, avg_error={:.2e}, {}",
self.num_params,
self.num_errors,
self.max_error,
self.max_rel_error,
self.avg_error,
if self.passed { "PASSED" } else { "FAILED" }
)
}
pub fn print_errors(&self, max_to_print: usize) {
if self.errors.is_empty() {
println!("✓ All gradients passed!");
return;
}
println!("\n✗ Gradient errors found:");
for (i, error) in self.errors.iter().take(max_to_print).enumerate() {
println!(
" [{}] Param {}: analytical={:.6e}, numerical={:.6e}, abs_err={:.2e}, rel_err={:.2e}",
i + 1,
error.param_id,
error.analytical_grad,
error.numerical_grad,
error.abs_error,
error.rel_error
);
}
if self.errors.len() > max_to_print {
println!(" ... and {} more errors", self.errors.len() - max_to_print);
}
}
}
#[derive(Debug, Clone)]
pub struct GradientError {
pub param_id: String,
pub index: usize,
pub analytical_grad: f64,
pub numerical_grad: f64,
pub abs_error: f64,
pub rel_error: f64,
}
impl GradientError {
pub fn new(param_id: String, index: usize, analytical: f64, numerical: f64) -> Self {
let abs_error = (analytical - numerical).abs();
let rel_error = if numerical.abs() > 1e-10 {
abs_error / numerical.abs()
} else {
abs_error
};
GradientError {
param_id,
index,
analytical_grad: analytical,
numerical_grad: numerical,
abs_error,
rel_error,
}
}
pub fn exceeds_tolerance(&self, config: &GradCheckConfig) -> bool {
self.abs_error > config.abs_tolerance && self.rel_error > config.rel_tolerance
}
}
pub fn numerical_gradient_central(
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
epsilon: f64,
) -> Vec<f64> {
let mut grad = vec![0.0; x.len()];
for i in 0..x.len() {
let mut x_plus = x.to_vec();
x_plus[i] += epsilon;
let f_plus = forward_fn(&x_plus);
let mut x_minus = x.to_vec();
x_minus[i] -= epsilon;
let f_minus = forward_fn(&x_minus);
grad[i] = (f_plus - f_minus) / (2.0 * epsilon);
}
grad
}
pub fn numerical_gradient_forward(
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
f_x: f64,
epsilon: f64,
) -> Vec<f64> {
let mut grad = vec![0.0; x.len()];
for i in 0..x.len() {
let mut x_plus = x.to_vec();
x_plus[i] += epsilon;
let f_plus = forward_fn(&x_plus);
grad[i] = (f_plus - f_x) / epsilon;
}
grad
}
pub fn numerical_gradient_fourth_order(
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
epsilon: f64,
) -> Vec<f64> {
let mut grad = vec![0.0; x.len()];
for i in 0..x.len() {
let mut x_plus2 = x.to_vec();
x_plus2[i] += 2.0 * epsilon;
let f_plus2 = forward_fn(&x_plus2);
let mut x_plus = x.to_vec();
x_plus[i] += epsilon;
let f_plus = forward_fn(&x_plus);
let mut x_minus = x.to_vec();
x_minus[i] -= epsilon;
let f_minus = forward_fn(&x_minus);
let mut x_minus2 = x.to_vec();
x_minus2[i] -= 2.0 * epsilon;
let f_minus2 = forward_fn(&x_minus2);
grad[i] = (-f_plus2 + 8.0 * f_plus - 8.0 * f_minus + f_minus2) / (12.0 * epsilon);
}
grad
}
pub fn numerical_gradient_richardson(
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
epsilon: f64,
) -> Vec<f64> {
let grad_h = numerical_gradient_central(&forward_fn, x, epsilon);
let grad_h_half = numerical_gradient_central(&forward_fn, x, epsilon / 2.0);
grad_h_half
.iter()
.zip(grad_h.iter())
.map(|(&g_half, &g_full)| (4.0 * g_half - g_full) / 3.0)
.collect()
}
pub fn numerical_gradient_complex_step(
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
epsilon: f64,
) -> Vec<f64> {
let mut grad = vec![0.0; x.len()];
for i in 0..x.len() {
let eps_tiny = epsilon * 1e-8;
let mut x_plus_small = x.to_vec();
x_plus_small[i] += eps_tiny;
let f_plus_small = forward_fn(&x_plus_small);
let mut x_minus_small = x.to_vec();
x_minus_small[i] -= eps_tiny;
let f_minus_small = forward_fn(&x_minus_small);
grad[i] = (f_plus_small - f_minus_small) / (2.0 * eps_tiny);
}
grad
}
pub fn numerical_gradient_adaptive(forward_fn: impl Fn(&[f64]) -> f64, x: &[f64]) -> Vec<f64> {
let epsilons = vec![1e-3, 1e-4, 1e-5, 1e-6, 1e-7];
let mut best_grad = Vec::new();
let mut min_variance = f64::MAX;
for &eps in &epsilons {
let grad = numerical_gradient_central(&forward_fn, x, eps);
if !grad.is_empty() {
let mean: f64 = grad.iter().sum::<f64>() / grad.len() as f64;
let variance: f64 =
grad.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / grad.len() as f64;
if variance < min_variance || best_grad.is_empty() {
min_variance = variance;
best_grad = grad;
}
}
}
best_grad
}
pub fn compare_gradients(
param_id: String,
analytical: &[f64],
numerical: &[f64],
config: &GradCheckConfig,
) -> Vec<GradientError> {
assert_eq!(analytical.len(), numerical.len());
let mut errors = Vec::new();
for (i, (&a, &n)) in analytical.iter().zip(numerical.iter()).enumerate() {
let error = GradientError::new(param_id.clone(), i, a, n);
if error.exceeds_tolerance(config) {
errors.push(error);
}
}
errors
}
pub struct GradientChecker {
config: GradCheckConfig,
results: HashMap<String, GradCheckResult>,
}
impl GradientChecker {
pub fn new(config: GradCheckConfig) -> Self {
GradientChecker {
config,
results: HashMap::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(GradCheckConfig::default())
}
pub fn check_parameter(
&mut self,
param_id: String,
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
analytical_grad: &[f64],
) -> GradCheckResult {
let numerical_grad = numerical_gradient_central(&forward_fn, x, self.config.epsilon);
let errors = compare_gradients(
param_id.clone(),
analytical_grad,
&numerical_grad,
&self.config,
);
let mut result = GradCheckResult::new(x.len());
for error in errors {
result.add_error(error);
}
let result = result.finalize();
if self.config.verbose {
println!("Checking parameter '{}':", param_id);
println!(" {}", result.summary());
if !result.passed {
result.print_errors(self.config.max_errors_to_report);
}
}
self.results.insert(param_id, result.clone());
result
}
pub fn results(&self) -> &HashMap<String, GradCheckResult> {
&self.results
}
pub fn all_passed(&self) -> bool {
self.results.values().all(|r| r.passed)
}
pub fn total_errors(&self) -> usize {
self.results.values().map(|r| r.num_errors).sum()
}
pub fn print_summary(&self) {
println!("\n=== Gradient Check Summary ===");
for (param_id, result) in &self.results {
println!("{}: {}", param_id, result.summary());
}
println!(
"\nTotal: {} parameters, {} errors",
self.results.len(),
self.total_errors()
);
if self.all_passed() {
println!("✓ All gradient checks PASSED");
} else {
println!("✗ Some gradient checks FAILED");
}
}
}
pub fn quick_check(
forward_fn: impl Fn(&[f64]) -> f64,
x: &[f64],
analytical_grad: &[f64],
) -> Result<(), String> {
let config = GradCheckConfig::default();
let numerical = numerical_gradient_central(&forward_fn, x, config.epsilon);
let errors = compare_gradients(
"quick_check".to_string(),
analytical_grad,
&numerical,
&config,
);
if errors.is_empty() {
Ok(())
} else {
let mut result = GradCheckResult::new(x.len());
for error in errors {
result.add_error(error);
}
Err(result.finalize().summary())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grad_check_config_default() {
let config = GradCheckConfig::default();
assert!(config.epsilon > 0.0);
assert!(config.rel_tolerance > 0.0);
assert!(config.abs_tolerance > 0.0);
}
#[test]
fn test_grad_check_config_strict() {
let strict = GradCheckConfig::strict();
let default = GradCheckConfig::default();
assert!(strict.epsilon <= default.epsilon);
assert!(strict.rel_tolerance <= default.rel_tolerance);
}
#[test]
fn test_grad_check_config_builder() {
let config = GradCheckConfig::default()
.with_epsilon(1e-4)
.with_verbose(true)
.with_rel_tolerance(1e-2);
assert_eq!(config.epsilon, 1e-4);
assert!(config.verbose);
assert_eq!(config.rel_tolerance, 1e-2);
}
#[test]
fn test_numerical_gradient_simple() {
let f = |x: &[f64]| x[0] * x[0];
let x = vec![3.0];
let grad = numerical_gradient_central(f, &x, 1e-5);
assert!((grad[0] - 6.0).abs() < 1e-4);
}
#[test]
fn test_numerical_gradient_multivariate() {
let f = |xy: &[f64]| xy[0] * xy[0] + xy[1] * xy[1];
let xy = vec![3.0, 4.0];
let grad = numerical_gradient_central(f, &xy, 1e-5);
assert!((grad[0] - 6.0).abs() < 1e-4);
assert!((grad[1] - 8.0).abs() < 1e-4);
}
#[test]
fn test_gradient_error_creation() {
let error = GradientError::new("param1".to_string(), 0, 1.0, 1.01);
assert_eq!(error.param_id, "param1");
assert_eq!(error.index, 0);
assert_eq!(error.analytical_grad, 1.0);
assert_eq!(error.numerical_grad, 1.01);
assert!(error.abs_error > 0.0);
assert!(error.rel_error > 0.0);
}
#[test]
fn test_gradient_error_exceeds_tolerance() {
let config = GradCheckConfig::default();
let error1 = GradientError::new("p1".to_string(), 0, 1.0, 2.0);
assert!(error1.exceeds_tolerance(&config));
let error2 = GradientError::new("p2".to_string(), 0, 1.0, 1.0000001);
assert!(!error2.exceeds_tolerance(&config));
}
#[test]
fn test_grad_check_result() {
let mut result = GradCheckResult::new(10);
assert!(result.passed);
assert_eq!(result.num_errors, 0);
result.add_error(GradientError::new("p1".to_string(), 0, 1.0, 2.0));
assert!(!result.passed);
assert_eq!(result.num_errors, 1);
let final_result = result.finalize();
assert!(final_result.avg_error > 0.0);
}
#[test]
fn test_compare_gradients() {
let config = GradCheckConfig::default();
let analytical = vec![1.0, 2.0, 3.0];
let numerical = vec![1.0, 2.0, 3.0];
let errors = compare_gradients("test".to_string(), &analytical, &numerical, &config);
assert_eq!(errors.len(), 0);
let numerical2 = vec![1.0, 2.5, 3.0];
let errors2 = compare_gradients("test".to_string(), &analytical, &numerical2, &config);
assert!(!errors2.is_empty());
}
#[test]
fn test_gradient_checker() {
let mut checker = GradientChecker::new(GradCheckConfig::default());
let f = |x: &[f64]| x[0] * x[0];
let x = vec![3.0];
let analytical = vec![6.0];
let result = checker.check_parameter("x".to_string(), f, &x, &analytical);
assert!(result.passed);
assert!(checker.all_passed());
}
#[test]
fn test_quick_check() {
let f = |x: &[f64]| x[0] * x[0];
let x = vec![3.0];
let grad = vec![6.0];
assert!(quick_check(f, &x, &grad).is_ok());
let bad_grad = vec![7.0];
assert!(quick_check(f, &x, &bad_grad).is_err());
}
#[test]
fn test_forward_gradient() {
let f = |x: &[f64]| x[0] * x[0];
let x = vec![3.0];
let f_x = f(&x);
let grad = numerical_gradient_forward(f, &x, f_x, 1e-5);
assert!((grad[0] - 6.0).abs() < 1e-3);
}
#[test]
fn test_fourth_order_gradient() {
let f = |x: &[f64]| x[0].powi(3);
let x = vec![2.0];
let grad = numerical_gradient_fourth_order(f, &x, 1e-3);
assert!((grad[0] - 12.0).abs() < 1e-5);
}
#[test]
fn test_fourth_order_multivariate() {
let f = |xy: &[f64]| xy[0].powi(3) + xy[1].powi(3);
let xy = vec![2.0, 3.0];
let grad = numerical_gradient_fourth_order(f, &xy, 1e-3);
assert!((grad[0] - 12.0).abs() < 1e-5); assert!((grad[1] - 27.0).abs() < 1e-5); }
#[test]
fn test_richardson_extrapolation() {
let f = |x: &[f64]| x[0].powi(4);
let x = vec![2.0];
let grad = numerical_gradient_richardson(f, &x, 1e-3);
assert!((grad[0] - 32.0).abs() < 1e-6);
}
#[test]
fn test_richardson_multivariate() {
let f = |xy: &[f64]| xy[0].powi(4) + xy[1].powi(4);
let xy = vec![2.0, 1.5];
let grad = numerical_gradient_richardson(f, &xy, 1e-3);
assert!((grad[0] - 32.0).abs() < 1e-6); assert!((grad[1] - 13.5).abs() < 1e-6); }
#[test]
fn test_complex_step_approximation() {
let f = |x: &[f64]| x[0] * x[0] + 2.0 * x[0] + 1.0;
let x = vec![3.0];
let grad = numerical_gradient_complex_step(f, &x, 1e-5);
assert!((grad[0] - 8.0).abs() < 0.1);
}
#[test]
fn test_adaptive_gradient() {
let f = |x: &[f64]| x[0] * x[0];
let x = vec![3.0];
let grad = numerical_gradient_adaptive(f, &x);
assert!((grad[0] - 6.0).abs() < 1e-4);
}
#[test]
fn test_adaptive_multivariate() {
let f = |xyz: &[f64]| xyz[0] * xyz[0] + xyz[1] * xyz[1] + xyz[2] * xyz[2];
let xyz = vec![1.0, 2.0, 3.0];
let grad = numerical_gradient_adaptive(f, &xyz);
assert!((grad[0] - 2.0).abs() < 1e-4);
assert!((grad[1] - 4.0).abs() < 1e-4);
assert!((grad[2] - 6.0).abs() < 1e-4);
}
#[test]
fn test_gradient_method_comparison() {
let f = |x: &[f64]| x[0].sin();
let x = vec![1.0_f64];
let expected = 1.0_f64.cos();
let grad_central = numerical_gradient_central(f, &x, 1e-5);
let grad_fourth = numerical_gradient_fourth_order(f, &x, 1e-3);
let grad_richardson = numerical_gradient_richardson(f, &x, 1e-3);
assert!((grad_central[0] - expected).abs() < 1e-5);
assert!((grad_fourth[0] - expected).abs() < 1e-6);
assert!((grad_richardson[0] - expected).abs() < 1e-7);
}
#[test]
fn test_gradient_stability_near_zero() {
let f = |x: &[f64]| x[0] * x[0] + 1e-10;
let x = vec![1e-8_f64];
let expected = 2.0 * 1e-8;
let grad = numerical_gradient_adaptive(f, &x);
assert!((grad[0] - expected).abs() < 1e-9);
}
#[test]
fn test_gradient_nonpolynomial() {
let f = |x: &[f64]| x[0].exp();
let x = vec![1.0_f64];
let expected = 1.0_f64.exp();
let grad_fourth = numerical_gradient_fourth_order(f, &x, 1e-4);
assert!((grad_fourth[0] - expected).abs() < 1e-6);
let grad_richardson = numerical_gradient_richardson(f, &x, 1e-4);
assert!((grad_richardson[0] - expected).abs() < 1e-7);
}
}