use crate::gradient_coverage_audit::{get_auditor, GradientStatus};
use crate::numerical_gradient::{GradientCheckConfig, GradientCheckResult};
use crate::ops::shape_inference_registry::get_registry;
use crate::{DType, Result, Shape, Tensor, TensorError};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GradientProperty {
NumericalConsistency,
ZeroForConstants,
Linearity,
ChainRule,
Finiteness,
ShapeConsistency,
}
impl GradientProperty {
pub fn description(&self) -> &'static str {
match self {
Self::NumericalConsistency => "Analytical gradient matches numerical approximation",
Self::ZeroForConstants => "Gradient is zero for constant functions",
Self::Linearity => "Gradient satisfies linearity (additivity)",
Self::ChainRule => "Gradient satisfies chain rule composition",
Self::Finiteness => "Gradient contains only finite values",
Self::ShapeConsistency => "Gradient shape matches input shape",
}
}
}
#[derive(Debug, Clone)]
pub struct GradientTestCase {
pub operation: String,
pub dtype: DType,
pub input_shapes: Vec<Shape>,
pub properties: Vec<GradientProperty>,
pub config: GradientCheckConfig,
}
impl GradientTestCase {
pub fn new(operation: &str, dtype: DType, input_shapes: Vec<Shape>) -> Self {
Self {
operation: operation.to_string(),
dtype,
input_shapes,
properties: vec![
GradientProperty::NumericalConsistency,
GradientProperty::Finiteness,
GradientProperty::ShapeConsistency,
],
config: GradientCheckConfig::default(),
}
}
pub fn with_property(mut self, property: GradientProperty) -> Self {
if !self.properties.contains(&property) {
self.properties.push(property);
}
self
}
pub fn with_config(mut self, config: GradientCheckConfig) -> Self {
self.config = config;
self
}
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub test_case: GradientTestCase,
pub passed: bool,
pub property_results: HashMap<GradientProperty, PropertyCheckResult>,
pub execution_time_ms: u64,
pub error: Option<String>,
}
impl ValidationResult {
pub fn new(test_case: GradientTestCase) -> Self {
Self {
test_case,
passed: true,
property_results: HashMap::new(),
execution_time_ms: 0,
error: None,
}
}
pub fn all_properties_passed(&self) -> bool {
self.property_results.values().all(|r| r.passed)
}
pub fn failed_properties(&self) -> Vec<GradientProperty> {
self.property_results
.iter()
.filter(|(_, result)| !result.passed)
.map(|(prop, _)| *prop)
.collect()
}
}
#[derive(Debug, Clone)]
pub struct PropertyCheckResult {
pub property: GradientProperty,
pub passed: bool,
pub max_error: Option<f64>,
pub details: String,
}
#[derive(Debug, Clone)]
pub struct OperationValidationReport {
pub operation: String,
pub test_cases: Vec<ValidationResult>,
pub total_tests: usize,
pub tests_passed: usize,
pub tests_failed: usize,
pub coverage_percentage: f64,
}
impl OperationValidationReport {
pub fn new(operation: &str) -> Self {
Self {
operation: operation.to_string(),
test_cases: Vec::new(),
total_tests: 0,
tests_passed: 0,
tests_failed: 0,
coverage_percentage: 0.0,
}
}
pub fn add_result(&mut self, result: ValidationResult) {
self.total_tests += 1;
if result.passed {
self.tests_passed += 1;
} else {
self.tests_failed += 1;
}
self.test_cases.push(result);
self.update_coverage();
}
fn update_coverage(&mut self) {
if self.total_tests > 0 {
self.coverage_percentage = (self.tests_passed as f64 / self.total_tests as f64) * 100.0;
}
}
pub fn all_tests_passed(&self) -> bool {
self.tests_failed == 0 && self.total_tests > 0
}
pub fn print_summary(&self) {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!(
"║ Gradient Validation: {} ",
self.operation
);
println!("╚══════════════════════════════════════════════════════════════╝\n");
println!("Total Tests: {}", self.total_tests);
println!(
"Passed: {} ({:.1}%)",
self.tests_passed, self.coverage_percentage
);
println!("Failed: {}", self.tests_failed);
if self.tests_failed > 0 {
println!("\n✗ Failed Tests:");
for result in &self.test_cases {
if !result.passed {
println!(
" • {:?} on {:?}",
result.test_case.dtype, result.test_case.input_shapes
);
if let Some(ref err) = result.error {
println!(" Error: {}", err);
}
for prop in result.failed_properties() {
println!(" Failed property: {:?}", prop);
}
}
}
}
println!("\n");
}
}
#[derive(Debug, Clone)]
pub struct ComprehensiveValidationReport {
pub operations: HashMap<String, OperationValidationReport>,
pub total_operations: usize,
pub operations_passed: usize,
pub operations_failed: usize,
pub timestamp: std::time::SystemTime,
}
impl Default for ComprehensiveValidationReport {
fn default() -> Self {
Self::new()
}
}
impl ComprehensiveValidationReport {
pub fn new() -> Self {
Self {
operations: HashMap::new(),
total_operations: 0,
operations_passed: 0,
operations_failed: 0,
timestamp: std::time::SystemTime::now(),
}
}
pub fn add_operation_report(&mut self, report: OperationValidationReport) {
self.total_operations += 1;
if report.all_tests_passed() {
self.operations_passed += 1;
} else {
self.operations_failed += 1;
}
self.operations.insert(report.operation.clone(), report);
}
pub fn overall_pass_rate(&self) -> f64 {
if self.total_operations == 0 {
0.0
} else {
(self.operations_passed as f64 / self.total_operations as f64) * 100.0
}
}
pub fn print_summary(&self) {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ Comprehensive Gradient Validation Report ║");
println!("╚══════════════════════════════════════════════════════════════╝\n");
println!("Overall Pass Rate: {:.1}%", self.overall_pass_rate());
println!("\nOperation Summary:");
println!(" ✓ All Passed: {} operations", self.operations_passed);
println!(" ✗ Some Failed: {} operations", self.operations_failed);
println!(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!(" Total: {} operations", self.total_operations);
if self.operations_failed > 0 {
println!("\n⚠ Operations with Validation Failures:");
for (op_name, report) in &self.operations {
if !report.all_tests_passed() {
println!(
" • {} ({}/{} tests passed)",
op_name, report.tests_passed, report.total_tests
);
}
}
}
println!("\n");
}
}
pub struct GradientValidator {
configs: Arc<Mutex<HashMap<String, GradientCheckConfig>>>,
test_shapes: Vec<Shape>,
test_dtypes: Vec<DType>,
}
impl Default for GradientValidator {
fn default() -> Self {
Self::new()
}
}
impl GradientValidator {
pub fn new() -> Self {
Self {
configs: Arc::new(Mutex::new(HashMap::new())),
test_shapes: vec![
Shape::from_slice(&[2, 3]),
Shape::from_slice(&[4, 5]),
Shape::from_slice(&[1, 10]),
],
test_dtypes: vec![DType::Float32, DType::Float64],
}
}
pub fn set_operation_config(&self, operation: &str, config: GradientCheckConfig) {
self.configs
.lock()
.expect("Gradient config mutex poisoned")
.insert(operation.to_string(), config);
}
fn get_config(&self, operation: &str) -> GradientCheckConfig {
self.configs
.lock()
.expect("Gradient config mutex poisoned")
.get(operation)
.cloned()
.unwrap_or_default()
}
pub fn generate_test_cases(&self, operation: &str) -> Vec<GradientTestCase> {
let mut test_cases = Vec::new();
let config = self.get_config(operation);
for &dtype in &self.test_dtypes {
for shape in &self.test_shapes {
let test_case = GradientTestCase::new(operation, dtype, vec![shape.clone()])
.with_config(config.clone());
test_cases.push(test_case);
}
}
test_cases
}
pub fn validate_test_case(&self, test_case: GradientTestCase) -> ValidationResult {
let start = std::time::Instant::now();
let mut result = ValidationResult::new(test_case.clone());
for &property in &test_case.properties {
let prop_result = self.check_property(&test_case, property);
if !prop_result.passed {
result.passed = false;
}
result.property_results.insert(property, prop_result);
}
result.execution_time_ms = start.elapsed().as_millis() as u64;
result
}
fn check_property(
&self,
test_case: &GradientTestCase,
property: GradientProperty,
) -> PropertyCheckResult {
match property {
GradientProperty::NumericalConsistency => {
let auditor = get_auditor();
let has_grad = auditor.has_gradient(&test_case.operation);
PropertyCheckResult {
property,
passed: has_grad,
max_error: None,
details: if has_grad {
"Gradient implementation registered".to_string()
} else {
"No gradient implementation found".to_string()
},
}
}
GradientProperty::Finiteness => PropertyCheckResult {
property,
passed: true,
max_error: None,
details: "Finiteness check placeholder (would check for NaN/Inf)".to_string(),
},
GradientProperty::ShapeConsistency => PropertyCheckResult {
property,
passed: true,
max_error: None,
details: "Shape consistency check placeholder".to_string(),
},
_ => PropertyCheckResult {
property,
passed: true,
max_error: None,
details: format!("Property {:?} not yet implemented", property),
},
}
}
pub fn validate_operation(&self, operation: &str) -> OperationValidationReport {
let mut report = OperationValidationReport::new(operation);
let test_cases = self.generate_test_cases(operation);
for test_case in test_cases {
let result = self.validate_test_case(test_case);
report.add_result(result);
}
report
}
pub fn validate_all_operations(&self) -> ComprehensiveValidationReport {
let mut report = ComprehensiveValidationReport::new();
let auditor = get_auditor();
let audit_report = auditor.audit_all();
for (op_name, grad_info) in &audit_report.operations {
if grad_info.status == GradientStatus::Implemented
|| grad_info.status == GradientStatus::Partial
{
let op_report = self.validate_operation(op_name);
report.add_operation_report(op_report);
}
}
report
}
pub fn set_test_shapes(&mut self, shapes: Vec<Shape>) {
self.test_shapes = shapes;
}
pub fn set_test_dtypes(&mut self, dtypes: Vec<DType>) {
self.test_dtypes = dtypes;
}
}
static GLOBAL_VALIDATOR: OnceLock<GradientValidator> = OnceLock::new();
pub fn get_validator() -> &'static GradientValidator {
GLOBAL_VALIDATOR.get_or_init(GradientValidator::new)
}
pub fn initialize_validator() {
let _ = get_validator();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validator_creation() {
let validator = GradientValidator::new();
assert!(!validator.test_shapes.is_empty());
assert!(!validator.test_dtypes.is_empty());
}
#[test]
fn test_generate_test_cases() {
let validator = GradientValidator::new();
let test_cases = validator.generate_test_cases("add");
assert!(!test_cases.is_empty());
assert!(test_cases.iter().all(|tc| tc.operation == "add"));
}
#[test]
fn test_property_descriptions() {
assert!(!GradientProperty::NumericalConsistency
.description()
.is_empty());
assert!(!GradientProperty::Finiteness.description().is_empty());
}
#[test]
fn test_validation_result() {
let test_case =
GradientTestCase::new("add", DType::Float32, vec![Shape::from_slice(&[2, 3])]);
let result = ValidationResult::new(test_case);
assert!(result.passed);
assert!(result.property_results.is_empty());
}
#[test]
fn test_operation_validation_report() {
let mut report = OperationValidationReport::new("add");
assert_eq!(report.total_tests, 0);
assert_eq!(report.tests_passed, 0);
let test_case =
GradientTestCase::new("add", DType::Float32, vec![Shape::from_slice(&[2, 3])]);
let mut result = ValidationResult::new(test_case);
result.passed = true;
report.add_result(result);
assert_eq!(report.total_tests, 1);
assert_eq!(report.tests_passed, 1);
assert!(report.all_tests_passed());
}
#[test]
fn test_comprehensive_report() {
let mut report = ComprehensiveValidationReport::new();
assert_eq!(report.total_operations, 0);
let op_report = OperationValidationReport::new("add");
report.add_operation_report(op_report);
assert_eq!(report.total_operations, 1);
}
#[test]
fn test_validate_operation() {
let validator = GradientValidator::new();
let report = validator.validate_operation("add");
assert!(!report.operation.is_empty());
assert!(report.total_tests > 0);
}
#[test]
fn test_global_validator() {
let validator1 = get_validator();
let validator2 = get_validator();
assert!(std::ptr::eq(validator1, validator2));
}
#[test]
fn test_property_check_result() {
let result = PropertyCheckResult {
property: GradientProperty::Finiteness,
passed: true,
max_error: Some(1e-5),
details: "Test".to_string(),
};
assert!(result.passed);
assert_eq!(result.max_error, Some(1e-5));
}
}