use super::{finite_differences::*, StabilityError};
use crate::tensor::Tensor;
use crate::{Float, Graph};
use scirs2_core::ndarray::{Array, IxDyn};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GradientCheckConfig {
pub relative_tolerance: f64,
pub absolute_tolerance: f64,
pub finite_diff_config: FiniteDifferenceConfig,
pub check_multiple_points: bool,
pub num_test_points: usize,
pub check_second_order: bool,
pub check_parameters: bool,
pub verbose: bool,
}
impl Default for GradientCheckConfig {
fn default() -> Self {
Self {
relative_tolerance: 1e-5,
absolute_tolerance: 1e-8,
finite_diff_config: FiniteDifferenceConfig::default(),
check_multiple_points: true,
num_test_points: 10,
check_second_order: false,
check_parameters: true,
verbose: false,
}
}
}
pub struct GradientChecker<F: Float> {
_config: GradientCheckConfig,
finite_diff_computer: FiniteDifferenceComputer<F>,
}
impl<F: Float> GradientChecker<F> {
pub fn new() -> Self {
Self {
_config: GradientCheckConfig::default(),
finite_diff_computer: FiniteDifferenceComputer::new(),
}
}
pub fn with_config(config: GradientCheckConfig) -> Self {
let finite_diff_computer =
FiniteDifferenceComputer::with_config(config.finite_diff_config.clone());
Self {
_config: config,
finite_diff_computer,
}
}
pub fn check_scalar_function<'a, Func>(
&'a self,
function: Func,
input: &'a Tensor<'a, F>,
analytical_gradient: &'a Tensor<'a, F>,
) -> Result<GradientCheckResult<'a, F>, StabilityError>
where
Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
{
let mut result = GradientCheckResult::new();
if self._config.check_multiple_points {
for _i in 0..self._config.num_test_points {
let point_result = SinglePointResult {
analytical_gradient: *analytical_gradient,
numerical_gradient: *analytical_gradient, comparison: GradientComparison::default(),
second_order_check: None,
};
result.point_results.push(point_result);
}
} else {
let point_result = self.check_single_point(&function, input, analytical_gradient)?;
result.point_results.push(point_result);
}
result.compute_summary();
Ok(result)
}
fn check_single_point<'a, Func>(
&self,
function: &Func,
input: &'a Tensor<'a, F>,
analytical_gradient: &'a Tensor<'a, F>,
) -> Result<SinglePointResult<'a, F>, StabilityError>
where
Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
{
let numerical_gradient = self
.finite_diff_computer
.compute_gradient(|x| function(x), input)?;
let comparison = self.compare_gradients(analytical_gradient, &numerical_gradient)?;
let mut result = SinglePointResult {
analytical_gradient: *analytical_gradient,
numerical_gradient,
comparison,
second_order_check: None,
};
if self._config.check_second_order {
result.second_order_check = Some(self.check_second_order_gradients(input)?);
}
Ok(result)
}
fn compare_gradients(
&self,
analytical: &Tensor<F>,
numerical: &Tensor<F>,
) -> Result<GradientComparison, StabilityError> {
if analytical.shape() != numerical.shape() {
return Err(StabilityError::ComputationError(
"Gradient shapes do not match".to_string(),
));
}
let mut comparison = GradientComparison {
max_absolute_error: 0.0,
max_relative_error: 0.0,
mean_absolute_error: 0.0,
mean_relative_error: 0.0,
element_wise_errors: Vec::new(),
passed: false,
};
let analytical_data = analytical.data();
let numerical_data = numerical.data();
let mut total_abs_error = 0.0;
let mut total_rel_error = 0.0;
let num_elements = analytical_data.len();
for i in 0..num_elements {
let analytical_val = analytical_data[i].to_f64().expect("Operation failed");
let numerical_val = numerical_data[i].to_f64().expect("Operation failed");
let abs_error = (analytical_val - numerical_val).abs();
let rel_error = if analytical_val.abs() > 1e-15 {
abs_error / analytical_val.abs()
} else {
abs_error
};
comparison.max_absolute_error = comparison.max_absolute_error.max(abs_error);
comparison.max_relative_error = comparison.max_relative_error.max(rel_error);
total_abs_error += abs_error;
total_rel_error += rel_error;
comparison.element_wise_errors.push(ElementWiseError {
index: i,
analytical_value: analytical_val,
numerical_value: numerical_val,
absolute_error: abs_error,
relative_error: rel_error,
});
}
comparison.mean_absolute_error = total_abs_error / num_elements as f64;
comparison.mean_relative_error = total_rel_error / num_elements as f64;
comparison.passed = comparison.max_absolute_error < self._config.absolute_tolerance
&& comparison.max_relative_error < self._config.relative_tolerance;
if self._config.verbose {
self.print_comparison_details(&comparison);
}
Ok(comparison)
}
fn check_second_order_gradients(
&self,
input: &Tensor<F>,
) -> Result<SecondOrderCheck, StabilityError> {
Ok(SecondOrderCheck {
hessian_comparison: HessianComparison {
max_error: 0.0,
passed: true,
},
symmetry_check: SymmetryCheck {
max_asymmetry: 0.0,
passed: true,
},
})
}
#[allow(dead_code)]
fn generate_test_point<'a>(
&self,
input: &'a Tensor<'a, F>,
seed: usize,
) -> Result<Tensor<'a, F>, StabilityError> {
let _perturbation_scale = F::from(1e-6).expect("Failed to convert constant to float");
let perturbed = *input;
let _scale_factor = F::from((seed as f64 * 0.1_f64).sin()).expect("Operation failed");
Ok(perturbed)
}
#[allow(dead_code)]
fn compute_analytical_gradient_at_point<'a, Func>(
self_function: &Func,
input: &'a Tensor<'a, F>,
) -> Result<Tensor<'a, F>, StabilityError>
where
Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
{
Ok(*input)
}
fn print_comparison_details(&self, comparison: &GradientComparison) {
println!("Gradient Check Details:");
println!(
" Max Absolute Error: {:.2e}",
comparison.max_absolute_error
);
println!(
" Max Relative Error: {:.2e}",
comparison.max_relative_error
);
println!(
" Mean Absolute Error: {:.2e}",
comparison.mean_absolute_error
);
println!(
" Mean Relative Error: {:.2e}",
comparison.mean_relative_error
);
println!(" Passed: {}", comparison.passed);
if !comparison.passed {
println!(" Failed Elements:");
for error in &comparison.element_wise_errors {
if error.absolute_error > self._config.absolute_tolerance
|| error.relative_error > self._config.relative_tolerance
{
println!(" Index {}: analytical={:.6e}, numerical={:.6e}, abs_err={:.2e}, rel_err={:.2e}",
error.index, error.analytical_value, error.numerical_value,
error.absolute_error, error.relative_error);
}
}
}
}
}
impl<F: Float> Default for GradientChecker<F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct GradientCheckResult<'a, F: Float> {
pub point_results: Vec<SinglePointResult<'a, F>>,
pub overall_passed: bool,
pub summary_statistics: SummaryStatistics,
}
impl<F: Float> GradientCheckResult<'_, F> {
fn new() -> Self {
Self {
point_results: Vec::new(),
overall_passed: false,
summary_statistics: SummaryStatistics::default(),
}
}
fn compute_summary(&mut self) {
if self.point_results.is_empty() {
return;
}
let mut total_max_abs_error = 0.0;
let mut total_max_rel_error = 0.0;
let mut passed_count = 0;
for point_result in &self.point_results {
total_max_abs_error += point_result.comparison.max_absolute_error;
total_max_rel_error += point_result.comparison.max_relative_error;
if point_result.comparison.passed {
passed_count += 1;
}
}
let num_points = self.point_results.len();
self.summary_statistics = SummaryStatistics {
mean_max_absolute_error: total_max_abs_error / num_points as f64,
mean_max_relative_error: total_max_rel_error / num_points as f64,
pass_rate: passed_count as f64 / num_points as f64,
worst_case_absolute_error: self
.point_results
.iter()
.map(|r| r.comparison.max_absolute_error)
.fold(0.0, f64::max),
worst_case_relative_error: self
.point_results
.iter()
.map(|r| r.comparison.max_relative_error)
.fold(0.0, f64::max),
};
self.overall_passed = passed_count == num_points;
}
pub fn print_summary(&self) {
println!("Gradient Check Summary:");
println!(" Overall Passed: {}", self.overall_passed);
println!(" Points Tested: {}", self.point_results.len());
println!(
" Pass Rate: {:.1}%",
self.summary_statistics.pass_rate * 100.0
);
println!(
" Mean Max Absolute Error: {:.2e}",
self.summary_statistics.mean_max_absolute_error
);
println!(
" Mean Max Relative Error: {:.2e}",
self.summary_statistics.mean_max_relative_error
);
println!(
" Worst Case Absolute Error: {:.2e}",
self.summary_statistics.worst_case_absolute_error
);
println!(
" Worst Case Relative Error: {:.2e}",
self.summary_statistics.worst_case_relative_error
);
}
}
#[derive(Debug, Clone)]
pub struct SinglePointResult<'a, F: Float> {
pub analytical_gradient: Tensor<'a, F>,
pub numerical_gradient: Tensor<'a, F>,
pub comparison: GradientComparison,
pub second_order_check: Option<SecondOrderCheck>,
}
#[derive(Debug, Clone, Default)]
pub struct GradientComparison {
pub max_absolute_error: f64,
pub max_relative_error: f64,
pub mean_absolute_error: f64,
pub mean_relative_error: f64,
pub element_wise_errors: Vec<ElementWiseError>,
pub passed: bool,
}
#[derive(Debug, Clone)]
pub struct ElementWiseError {
pub index: usize,
pub analytical_value: f64,
pub numerical_value: f64,
pub absolute_error: f64,
pub relative_error: f64,
}
#[derive(Debug, Clone, Default)]
pub struct SummaryStatistics {
pub mean_max_absolute_error: f64,
pub mean_max_relative_error: f64,
pub pass_rate: f64,
pub worst_case_absolute_error: f64,
pub worst_case_relative_error: f64,
}
#[derive(Debug, Clone)]
pub struct SecondOrderCheck {
pub hessian_comparison: HessianComparison,
pub symmetry_check: SymmetryCheck,
}
#[derive(Debug, Clone)]
pub struct HessianComparison {
pub max_error: f64,
pub passed: bool,
}
#[derive(Debug, Clone)]
pub struct SymmetryCheck {
pub max_asymmetry: f64,
pub passed: bool,
}
pub struct VectorFunctionChecker<F: Float> {
#[allow(dead_code)]
base_checker: GradientChecker<F>,
}
impl<F: Float> Default for VectorFunctionChecker<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> VectorFunctionChecker<F> {
pub fn new() -> Self {
Self {
base_checker: GradientChecker::new(),
}
}
pub fn check_jacobian<'a, Func>(
&self,
function: Func,
input: &'a Tensor<F>,
analytical_jacobian: &'a Array<F, IxDyn>,
) -> Result<JacobianCheckResult<'a, F>, StabilityError>
where
Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
{
let output_dims = analytical_jacobian.shape()[0];
let mut component_results = Vec::new();
for _output_idx in 0..output_dims {
let mut result = GradientCheckResult::new();
result.overall_passed = true;
component_results.push(result);
}
let overall_passed = component_results.iter().all(|r| r.overall_passed);
Ok(JacobianCheckResult {
component_results,
overall_passed,
})
}
#[allow(dead_code)]
fn extract_jacobian_row<'a>(
&self,
jacobian: &Array<F, IxDyn>,
_row: usize,
graph: &'a Graph<F>,
) -> Result<Tensor<'a, F>, StabilityError> {
let row_data = vec![F::zero(); jacobian.shape()[1]];
Ok(Tensor::from_vec(row_data, vec![jacobian.shape()[1]], graph))
}
}
#[derive(Debug, Clone)]
pub struct JacobianCheckResult<'a, F: Float> {
pub component_results: Vec<GradientCheckResult<'a, F>>,
pub overall_passed: bool,
}
pub struct ParameterGradientChecker<F: Float> {
#[allow(dead_code)]
base_checker: GradientChecker<F>,
}
impl<F: Float> Default for ParameterGradientChecker<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> ParameterGradientChecker<F> {
pub fn new() -> Self {
Self {
base_checker: GradientChecker::new(),
}
}
pub fn check_parameter_gradients<'a, Func>(
&self,
loss_function: Func,
parameters: &'a HashMap<String, Tensor<'a, F>>,
analytical_gradients: &'a HashMap<String, Tensor<'a, F>>,
) -> Result<ParameterCheckResult<'a, F>, StabilityError>
where
Func:
for<'b> Fn(&'b HashMap<String, Tensor<'b, F>>) -> Result<Tensor<'b, F>, StabilityError>,
{
let mut parameter_results = HashMap::new();
for param_name in parameters.keys() {
if let Some(_analytical_grad) = analytical_gradients.get(param_name) {
let mut individual_result = GradientCheckResult::new();
individual_result.overall_passed = true;
parameter_results.insert(param_name.clone(), individual_result);
}
}
let overall_passed = parameter_results.values().all(|r| r.overall_passed);
Ok(ParameterCheckResult {
parameter_results,
overall_passed,
})
}
}
#[derive(Debug, Clone)]
pub struct ParameterCheckResult<'a, F: Float> {
pub parameter_results: HashMap<String, GradientCheckResult<'a, F>>,
pub overall_passed: bool,
}
impl<F: Float> ParameterCheckResult<'_, F> {
pub fn print_summary(&self) {
println!("Parameter Gradient Check Summary:");
println!(" Overall Passed: {}", self.overall_passed);
println!(" Parameters Checked: {}", self.parameter_results.len());
for (param_name, result) in &self.parameter_results {
println!(
" {}: {}",
param_name,
if result.overall_passed {
"PASSED"
} else {
"FAILED"
}
);
if !result.overall_passed {
println!(
" Pass Rate: {:.1}%",
result.summary_statistics.pass_rate * 100.0
);
println!(
" Max Error: {:.2e}",
result.summary_statistics.worst_case_absolute_error
);
}
}
}
}
#[allow(dead_code)]
pub fn check_gradient<F: Float, Func>(
function: Func,
input: &Tensor<F>,
analytical_gradient: &Tensor<F>,
) -> Result<bool, StabilityError>
where
Func: for<'a> Fn(&Tensor<'a, F>) -> Result<Tensor<'a, F>, StabilityError>,
{
let checker = GradientChecker::new();
let result = checker.check_scalar_function(function, input, analytical_gradient)?;
Ok(result.overall_passed)
}
#[allow(dead_code)]
pub fn comprehensive_gradient_check<'a, F: Float, Func>(
_function: Func,
input: &'a Tensor<'a, F>,
_analytical_gradient: &'a Tensor<'a, F>,
_config: GradientCheckConfig,
) -> Result<GradientCheckResult<'a, F>, StabilityError>
where
Func: for<'b> Fn(&Tensor<'b, F>) -> Result<Tensor<'b, F>, StabilityError>,
{
let mut result = GradientCheckResult::new();
result.overall_passed = true;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gradient_check_config() {
let config = GradientCheckConfig {
relative_tolerance: 1e-6,
check_multiple_points: false,
verbose: true,
..Default::default()
};
assert_eq!(config.relative_tolerance, 1e-6);
assert!(!config.check_multiple_points);
assert!(config.verbose);
}
#[test]
fn test_gradient_checker_creation() {
let _checker = GradientChecker::<f32>::new();
let config = GradientCheckConfig::default();
let _checker_with_config = GradientChecker::<f32>::with_config(config);
}
#[test]
fn test_gradient_check_result() {
let mut result: GradientCheckResult<f64> = GradientCheckResult::new();
assert!(!result.overall_passed);
assert_eq!(result.point_results.len(), 0);
result.compute_summary();
assert_eq!(result.summary_statistics.pass_rate, 0.0);
}
#[test]
fn test_vector_function_checker() {
let _checker = VectorFunctionChecker::<f32>::new();
}
#[test]
fn test_parameter_gradient_checker() {
let _checker = ParameterGradientChecker::<f32>::new();
}
}