use crate::{Result, Tensor, TensorError};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct GradientCheckConfig {
pub epsilon: f64,
pub rtol: f64,
pub atol: f64,
pub check_elementwise: bool,
pub use_central_difference: bool,
pub max_samples: Option<usize>,
pub random_seed: Option<u64>,
}
impl Default for GradientCheckConfig {
fn default() -> Self {
Self {
epsilon: 1e-5,
rtol: 1e-3,
atol: 1e-5,
check_elementwise: false,
use_central_difference: true,
max_samples: None,
random_seed: None,
}
}
}
impl GradientCheckConfig {
pub fn strict() -> Self {
Self {
epsilon: 1e-6,
rtol: 1e-4,
atol: 1e-6,
check_elementwise: true,
use_central_difference: true,
max_samples: None,
random_seed: None,
}
}
pub fn relaxed() -> Self {
Self {
epsilon: 1e-4,
rtol: 1e-2,
atol: 1e-4,
check_elementwise: false,
use_central_difference: true,
max_samples: Some(100),
random_seed: Some(42),
}
}
pub fn fast() -> Self {
Self {
epsilon: 1e-5,
rtol: 1e-3,
atol: 1e-5,
check_elementwise: false,
use_central_difference: false,
max_samples: Some(50),
random_seed: Some(42),
}
}
}
#[derive(Debug, Clone)]
pub struct GradientCheckResult {
pub passed: bool,
pub max_relative_error: f64,
pub max_absolute_error: f64,
pub num_elements_checked: usize,
pub num_failures: usize,
pub failed_indices: Vec<usize>,
pub error_message: Option<String>,
}
impl GradientCheckResult {
pub fn is_ok(&self) -> bool {
self.passed
}
pub fn failure_rate(&self) -> f64 {
if self.num_elements_checked == 0 {
0.0
} else {
(self.num_failures as f64 / self.num_elements_checked as f64) * 100.0
}
}
pub fn summary(&self) -> String {
if self.passed {
format!(
"✓ Gradient check passed\n\
Elements checked: {}\n\
Max relative error: {:.2e}\n\
Max absolute error: {:.2e}",
self.num_elements_checked, self.max_relative_error, self.max_absolute_error
)
} else {
format!(
"✗ Gradient check FAILED\n\
Elements checked: {}\n\
Failures: {} ({:.2}%)\n\
Max relative error: {:.2e}\n\
Max absolute error: {:.2e}\n\
{}",
self.num_elements_checked,
self.num_failures,
self.failure_rate(),
self.max_relative_error,
self.max_absolute_error,
self.error_message.as_deref().unwrap_or("")
)
}
}
}
pub struct NumericalGradientChecker<T> {
config: GradientCheckConfig,
_phantom: PhantomData<T>,
}
impl<T> NumericalGradientChecker<T>
where
T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
{
pub fn new(config: GradientCheckConfig) -> Self {
Self {
config,
_phantom: PhantomData,
}
}
pub fn compute_numerical_gradient<F>(&self, input: &Tensor<T>, func: F) -> Result<Tensor<T>>
where
F: Fn(&Tensor<T>) -> Result<Tensor<T>>,
{
let input_data = input.data();
let input_shape = input.shape();
let mut gradient_data = Vec::with_capacity(input_data.len());
let epsilon = T::from_f64(self.config.epsilon).ok_or_else(|| {
TensorError::invalid_operation_simple("Failed to convert epsilon".to_string())
})?;
for i in 0..input_data.len() {
let grad = if self.config.use_central_difference {
let mut input_plus = input_data.to_vec();
let mut input_minus = input_data.to_vec();
input_plus[i] = input_plus[i] + epsilon;
input_minus[i] = input_minus[i] - epsilon;
let x_plus = Tensor::from_array(
scirs2_core::ndarray::Array::from_shape_vec(
input_shape.dims().to_vec(),
input_plus.to_vec(),
)
.map_err(|e| TensorError::invalid_argument(format!("Shape mismatch: {}", e)))?
.into_dyn(),
);
let x_minus = Tensor::from_array(
scirs2_core::ndarray::Array::from_shape_vec(
input_shape.dims().to_vec(),
input_minus.to_vec(),
)
.map_err(|e| TensorError::invalid_argument(format!("Shape mismatch: {}", e)))?
.into_dyn(),
);
let f_plus = func(&x_plus)?;
let f_minus = func(&x_minus)?;
let two_epsilon = epsilon + epsilon;
let diff = f_plus.data()[i] - f_minus.data()[i];
diff / two_epsilon
} else {
let mut input_plus = input_data.to_vec();
input_plus[i] = input_plus[i] + epsilon;
let x_plus = Tensor::from_array(
scirs2_core::ndarray::Array::from_shape_vec(
input_shape.dims().to_vec(),
input_plus.to_vec(),
)
.map_err(|e| TensorError::invalid_argument(format!("Shape mismatch: {}", e)))?
.into_dyn(),
);
let f_plus = func(&x_plus)?;
let f_x = func(input)?;
let diff = f_plus.data()[i] - f_x.data()[i];
diff / epsilon
};
gradient_data.push(grad);
}
let gradient_array =
scirs2_core::ndarray::Array::from_shape_vec(input_shape.dims().to_vec(), gradient_data)
.map_err(|e| TensorError::invalid_argument(format!("Shape mismatch: {}", e)))?
.into_dyn();
Ok(Tensor::from_array(gradient_array))
}
pub fn compare_gradients(
&self,
numerical: &Tensor<T>,
analytical: &Tensor<T>,
) -> Result<GradientCheckResult> {
if numerical.shape() != analytical.shape() {
return Err(TensorError::invalid_argument(format!(
"Shape mismatch: numerical {:?} vs analytical {:?}",
numerical.shape(),
analytical.shape()
)));
}
let num_data = numerical.data();
let ana_data = analytical.data();
let rtol = self.config.rtol;
let atol = self.config.atol;
let mut max_rel_error = 0.0;
let mut max_abs_error = 0.0;
let mut num_failures = 0;
let mut failed_indices = Vec::new();
for i in 0..num_data.len() {
let num_val = num_data[i].to_f64().unwrap_or(0.0);
let ana_val = ana_data[i].to_f64().unwrap_or(0.0);
let abs_error = (num_val - ana_val).abs();
let rel_error = if ana_val.abs() > 1e-10 {
abs_error / ana_val.abs()
} else {
abs_error
};
max_rel_error = max_rel_error.max(rel_error);
max_abs_error = max_abs_error.max(abs_error);
if rel_error > rtol && abs_error > atol {
num_failures += 1;
if self.config.check_elementwise {
failed_indices.push(i);
}
}
}
let passed = num_failures == 0;
let error_message = if !passed {
Some(format!(
"Gradient mismatch: {} of {} elements exceed tolerance (rtol={}, atol={})",
num_failures,
num_data.len(),
rtol,
atol
))
} else {
None
};
Ok(GradientCheckResult {
passed,
max_relative_error: max_rel_error,
max_absolute_error: max_abs_error,
num_elements_checked: num_data.len(),
num_failures,
failed_indices,
error_message,
})
}
pub fn check<F, G>(
&self,
input: &Tensor<T>,
forward: F,
gradient: G,
) -> Result<GradientCheckResult>
where
F: Fn(&Tensor<T>) -> Result<Tensor<T>>,
G: Fn(&Tensor<T>) -> Result<Tensor<T>>,
{
let numerical_grad = self.compute_numerical_gradient(input, forward)?;
let analytical_grad = gradient(input)?;
self.compare_gradients(&numerical_grad, &analytical_grad)
}
}
pub fn check_gradients<T, F, G>(
input: &Tensor<T>,
forward: F,
gradient: G,
config: &GradientCheckConfig,
) -> Result<GradientCheckResult>
where
T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
F: Fn(&Tensor<T>) -> Result<Tensor<T>>,
G: Fn(&Tensor<T>) -> Result<Tensor<T>>,
{
let checker = NumericalGradientChecker::new(config.clone());
checker.check(input, forward, gradient)
}
pub fn quick_check_gradients<T, F, G>(
input: &Tensor<T>,
forward: F,
gradient: G,
) -> Result<GradientCheckResult>
where
T: Float + FromPrimitive + Clone + Send + Sync + Default + 'static,
F: Fn(&Tensor<T>) -> Result<Tensor<T>>,
G: Fn(&Tensor<T>) -> Result<Tensor<T>>,
{
check_gradients(input, forward, gradient, &GradientCheckConfig::default())
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_gradient_check_linear() {
let input = Tensor::from_array(array![1.0, 2.0, 3.0].into_dyn());
let forward = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| v * 2.0).collect();
let result_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(result_array))
};
let gradient = |_x: &Tensor<f32>| {
let grad = array![2.0, 2.0, 2.0].into_dyn();
Ok(Tensor::from_array(grad))
};
let config = GradientCheckConfig::relaxed();
let result = check_gradients(&input, forward, gradient, &config)
.expect("test: check_gradients should succeed");
assert!(
result.passed,
"Gradient check should pass for linear function: {}",
result.summary()
);
}
#[test]
fn test_gradient_check_square() {
let input = Tensor::from_array(array![1.0, 2.0, 3.0].into_dyn());
let forward = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| v * v).collect();
let result_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(result_array))
};
let gradient = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| 2.0 * v).collect();
let grad_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(grad_array))
};
let config = GradientCheckConfig::relaxed();
let result = check_gradients(&input, forward, gradient, &config)
.expect("test: check_gradients should succeed");
assert!(
result.passed,
"Gradient check should pass for square function: {}",
result.summary()
);
}
#[test]
fn test_gradient_check_incorrect_gradient() {
let input = Tensor::from_array(array![1.0, 2.0, 3.0].into_dyn());
let forward = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| v * v).collect();
let result_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(result_array))
};
let wrong_gradient = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| 3.0 * v).collect();
let grad_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(grad_array))
};
let config = GradientCheckConfig::default();
let result = check_gradients(&input, forward, wrong_gradient, &config)
.expect("test: check_gradients should succeed");
assert!(
!result.passed,
"Gradient check should fail for incorrect gradient"
);
assert!(result.num_failures > 0);
}
#[test]
fn test_gradient_check_config_tolerances() {
let input = Tensor::from_array(array![1.0].into_dyn());
let forward = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| v * v).collect();
let result_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(result_array))
};
let slightly_off_gradient = |x: &Tensor<f32>| {
let data: Vec<f32> = x.data().iter().map(|&v| 2.0 * v * 1.01).collect();
let grad_array = scirs2_core::ndarray::Array::from_vec(data).into_dyn();
Ok(Tensor::from_array(grad_array))
};
let relaxed = GradientCheckConfig::relaxed();
let result = check_gradients(&input, forward, slightly_off_gradient, &relaxed)
.expect("test: check_gradients should succeed");
assert!(result.passed, "Should pass with relaxed tolerances");
let strict = GradientCheckConfig::strict();
let result = check_gradients(&input, forward, slightly_off_gradient, &strict)
.expect("test: check_gradients should succeed");
assert!(!result.passed, "Should fail with strict tolerances");
}
#[test]
fn test_gradient_check_result_summary() {
let result = GradientCheckResult {
passed: false,
max_relative_error: 0.05,
max_absolute_error: 0.01,
num_elements_checked: 100,
num_failures: 10,
failed_indices: vec![],
error_message: Some("Test error".to_string()),
};
let summary = result.summary();
assert!(summary.contains("FAILED"));
assert!(summary.contains("10.00%"));
assert_eq!(result.failure_rate(), 10.0);
}
}