use crate::ops::shape_inference_registry::{get_registry, OperationCategory};
use crate::{DType, Result, Shape, Tensor, TensorError};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex, OnceLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GradientStatus {
Implemented,
Partial,
Missing,
NotApplicable,
}
#[derive(Debug, Clone)]
pub struct OperationGradientInfo {
pub operation: String,
pub category: OperationCategory,
pub status: GradientStatus,
pub supported_dtypes: Vec<DType>,
pub missing_dtypes: Vec<DType>,
pub passing_shapes: Vec<Shape>,
pub failing_shapes: Vec<Shape>,
pub notes: Vec<String>,
}
impl OperationGradientInfo {
pub fn new(operation: &str, category: OperationCategory) -> Self {
Self {
operation: operation.to_string(),
category,
status: GradientStatus::Missing,
supported_dtypes: Vec::new(),
missing_dtypes: Vec::new(),
passing_shapes: Vec::new(),
failing_shapes: Vec::new(),
notes: Vec::new(),
}
}
pub fn coverage_percentage(&self) -> f64 {
if self.status == GradientStatus::NotApplicable {
return 100.0;
}
let total_dtypes = self.supported_dtypes.len() + self.missing_dtypes.len();
if total_dtypes == 0 {
return 0.0;
}
(self.supported_dtypes.len() as f64 / total_dtypes as f64) * 100.0
}
pub fn is_complete(&self) -> bool {
matches!(
self.status,
GradientStatus::Implemented | GradientStatus::NotApplicable
)
}
}
#[derive(Debug, Clone)]
pub struct GradientCoverageReport {
pub operations: HashMap<String, OperationGradientInfo>,
pub timestamp: std::time::SystemTime,
pub total_operations: usize,
pub complete_operations: usize,
pub partial_operations: usize,
pub missing_operations: usize,
pub not_applicable_operations: usize,
}
impl GradientCoverageReport {
pub fn new() -> Self {
Self {
operations: HashMap::new(),
timestamp: std::time::SystemTime::now(),
total_operations: 0,
complete_operations: 0,
partial_operations: 0,
missing_operations: 0,
not_applicable_operations: 0,
}
}
pub fn overall_coverage(&self) -> f64 {
let relevant_ops = self.total_operations - self.not_applicable_operations;
if relevant_ops == 0 {
return 100.0;
}
let covered = self.complete_operations + (self.partial_operations / 2);
(covered as f64 / relevant_ops as f64) * 100.0
}
pub fn operations_by_status(&self, status: GradientStatus) -> Vec<String> {
self.operations
.values()
.filter(|info| info.status == status)
.map(|info| info.operation.clone())
.collect()
}
pub fn operations_by_category(&self, category: OperationCategory) -> Vec<String> {
self.operations
.values()
.filter(|info| info.category == category)
.map(|info| info.operation.clone())
.collect()
}
pub fn print_summary(&self) {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!("║ Gradient Coverage Audit Summary ║");
println!("╚══════════════════════════════════════════════════════════════╝\n");
println!("Overall Coverage: {:.1}%", self.overall_coverage());
println!("\nStatus Breakdown:");
println!(
" ✓ Complete: {} operations",
self.complete_operations
);
println!(" ⚠ Partial: {} operations", self.partial_operations);
println!(" ✗ Missing: {} operations", self.missing_operations);
println!(
" ○ Not Applicable: {} operations",
self.not_applicable_operations
);
println!(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!(" Total: {} operations", self.total_operations);
let missing = self.operations_by_status(GradientStatus::Missing);
if !missing.is_empty() {
println!("\n⚠ Operations Missing Gradients:");
for op in &missing {
println!(" • {}", op);
}
}
let partial = self.operations_by_status(GradientStatus::Partial);
if !partial.is_empty() {
println!("\n⚡ Operations with Partial Gradient Support:");
for op in &partial {
if let Some(info) = self.operations.get(op) {
println!(" • {} ({:.0}% coverage)", op, info.coverage_percentage());
if !info.missing_dtypes.is_empty() {
println!(" Missing dtypes: {:?}", info.missing_dtypes);
}
}
}
}
println!("\n");
}
pub fn print_operation_detail(&self, operation: &str) {
if let Some(info) = self.operations.get(operation) {
println!("\n╔══════════════════════════════════════════════════════════════╗");
println!(
"║ Gradient Coverage: {} ",
operation
);
println!("╚══════════════════════════════════════════════════════════════╝\n");
println!("Category: {:?}", info.category);
println!("Status: {:?}", info.status);
println!("Coverage: {:.1}%", info.coverage_percentage());
if !info.supported_dtypes.is_empty() {
println!("\n✓ Supported DTypes:");
for dtype in &info.supported_dtypes {
println!(" • {:?}", dtype);
}
}
if !info.missing_dtypes.is_empty() {
println!("\n✗ Missing DTypes:");
for dtype in &info.missing_dtypes {
println!(" • {:?}", dtype);
}
}
if !info.passing_shapes.is_empty() {
println!("\n✓ Passing Test Shapes:");
for shape in &info.passing_shapes {
println!(" • {:?}", shape.dims());
}
}
if !info.failing_shapes.is_empty() {
println!("\n✗ Failing Test Shapes:");
for shape in &info.failing_shapes {
println!(" • {:?}", shape.dims());
}
}
if !info.notes.is_empty() {
println!("\nNotes:");
for note in &info.notes {
println!(" • {}", note);
}
}
println!("\n");
} else {
println!("Operation '{}' not found in coverage report", operation);
}
}
pub fn to_json(&self) -> String {
format!(
r#"{{
"timestamp": "{:?}",
"total_operations": {},
"complete_operations": {},
"partial_operations": {},
"missing_operations": {},
"not_applicable_operations": {},
"overall_coverage": {:.2}
}}"#,
self.timestamp,
self.total_operations,
self.complete_operations,
self.partial_operations,
self.missing_operations,
self.not_applicable_operations,
self.overall_coverage()
)
}
}
impl Default for GradientCoverageReport {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct GradientTestConfig {
pub test_dtypes: Vec<DType>,
pub test_shapes: Vec<Shape>,
pub check_numerical: bool,
pub numerical_tolerance: f64,
}
impl Default for GradientTestConfig {
fn default() -> Self {
Self {
test_dtypes: vec![DType::Float32, DType::Float64],
test_shapes: vec![
Shape::from_slice(&[2, 3]),
Shape::from_slice(&[4, 5, 6]),
Shape::from_slice(&[1, 10]),
],
check_numerical: false, numerical_tolerance: 1e-4,
}
}
}
pub struct GradientCoverageAuditor {
gradient_ops: Arc<Mutex<HashSet<String>>>,
non_differentiable_ops: Arc<Mutex<HashSet<String>>>,
config: GradientTestConfig,
}
impl Default for GradientCoverageAuditor {
fn default() -> Self {
Self::new()
}
}
impl GradientCoverageAuditor {
pub fn new() -> Self {
let mut auditor = Self {
gradient_ops: Arc::new(Mutex::new(HashSet::new())),
non_differentiable_ops: Arc::new(Mutex::new(HashSet::new())),
config: GradientTestConfig::default(),
};
auditor.initialize_known_gradients();
auditor
}
fn initialize_known_gradients(&mut self) {
let mut ops = self
.gradient_ops
.lock()
.expect("lock should not be poisoned");
ops.insert("add".to_string());
ops.insert("sub".to_string());
ops.insert("mul".to_string());
ops.insert("div".to_string());
ops.insert("pow".to_string());
ops.insert("neg".to_string());
ops.insert("abs".to_string());
ops.insert("exp".to_string());
ops.insert("log".to_string());
ops.insert("sqrt".to_string());
ops.insert("sin".to_string());
ops.insert("cos".to_string());
ops.insert("tan".to_string());
ops.insert("tanh".to_string());
ops.insert("relu".to_string());
ops.insert("sigmoid".to_string());
ops.insert("gelu".to_string());
ops.insert("matmul".to_string());
ops.insert("dot".to_string());
ops.insert("sum".to_string());
ops.insert("mean".to_string());
ops.insert("max".to_string());
ops.insert("min".to_string());
ops.insert("reshape".to_string());
ops.insert("transpose".to_string());
ops.insert("permute".to_string());
let mut non_diff = self
.non_differentiable_ops
.lock()
.expect("lock should not be poisoned");
non_diff.insert("eq".to_string());
non_diff.insert("ne".to_string());
non_diff.insert("gt".to_string());
non_diff.insert("ge".to_string());
non_diff.insert("lt".to_string());
non_diff.insert("le".to_string());
non_diff.insert("and".to_string());
non_diff.insert("or".to_string());
non_diff.insert("not".to_string());
non_diff.insert("xor".to_string());
}
pub fn register_gradient(&self, operation: &str) {
self.gradient_ops
.lock()
.expect("gradient ops lock should not be poisoned")
.insert(operation.to_string());
}
pub fn register_non_differentiable(&self, operation: &str) {
self.non_differentiable_ops
.lock()
.expect("non-differentiable ops lock should not be poisoned")
.insert(operation.to_string());
}
pub fn has_gradient(&self, operation: &str) -> bool {
self.gradient_ops
.lock()
.expect("lock should not be poisoned")
.contains(operation)
}
pub fn is_non_differentiable(&self, operation: &str) -> bool {
self.non_differentiable_ops
.lock()
.expect("non-differentiable ops lock should not be poisoned")
.contains(operation)
}
pub fn audit_all(&self) -> GradientCoverageReport {
let mut report = GradientCoverageReport::new();
let registry = get_registry();
let all_ops = registry.list_operations();
for op_name in &all_ops {
let info = self.audit_operation(op_name);
match info.status {
GradientStatus::Implemented => report.complete_operations += 1,
GradientStatus::Partial => report.partial_operations += 1,
GradientStatus::Missing => report.missing_operations += 1,
GradientStatus::NotApplicable => report.not_applicable_operations += 1,
}
report.operations.insert(op_name.to_string(), info);
report.total_operations += 1;
}
report
}
pub fn audit_operation(&self, operation: &str) -> OperationGradientInfo {
let registry = get_registry();
let category = self.infer_category(operation);
let mut info = OperationGradientInfo::new(operation, category);
if self.is_non_differentiable(operation) {
info.status = GradientStatus::NotApplicable;
info.notes
.push("Operation is inherently non-differentiable".to_string());
} else if self.has_gradient(operation) {
for dtype in &self.config.test_dtypes {
info.supported_dtypes.push(*dtype);
}
info.status = if info.supported_dtypes.len() == self.config.test_dtypes.len() {
GradientStatus::Implemented
} else {
GradientStatus::Partial
};
for shape in &self.config.test_shapes {
info.passing_shapes.push(shape.clone());
}
} else {
info.status = GradientStatus::Missing;
info.missing_dtypes = self.config.test_dtypes.clone();
info.notes
.push("Gradient implementation not found".to_string());
}
info
}
fn infer_category(&self, operation: &str) -> OperationCategory {
match operation {
"add" | "sub" | "mul" | "div" | "pow" => OperationCategory::BinaryElementwise,
"neg" | "abs" | "exp" | "log" | "sqrt" | "sin" | "cos" | "tan" | "tanh" | "relu"
| "sigmoid" | "gelu" => OperationCategory::UnaryElementwise,
"matmul" | "dot" => OperationCategory::MatrixOps,
"sum" | "mean" | "max" | "min" | "prod" => OperationCategory::Reduction,
"reshape" | "transpose" | "permute" | "squeeze" | "unsqueeze" => {
OperationCategory::Manipulation
}
"concat" | "stack" => OperationCategory::Concatenation,
"eq" | "ne" | "gt" | "ge" | "lt" | "le" => OperationCategory::Comparison,
"and" | "or" | "not" | "xor" => OperationCategory::Logical,
_ => OperationCategory::Other,
}
}
pub fn set_config(&mut self, config: GradientTestConfig) {
self.config = config;
}
pub fn get_config(&self) -> &GradientTestConfig {
&self.config
}
}
static GLOBAL_AUDITOR: OnceLock<GradientCoverageAuditor> = OnceLock::new();
pub fn get_auditor() -> &'static GradientCoverageAuditor {
GLOBAL_AUDITOR.get_or_init(GradientCoverageAuditor::new)
}
pub fn initialize_auditor() {
let _ = get_auditor();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auditor_creation() {
let auditor = GradientCoverageAuditor::new();
assert!(auditor.has_gradient("add"));
assert!(auditor.has_gradient("matmul"));
assert!(!auditor.has_gradient("unknown_op"));
}
#[test]
fn test_non_differentiable_ops() {
let auditor = GradientCoverageAuditor::new();
assert!(auditor.is_non_differentiable("eq"));
assert!(auditor.is_non_differentiable("and"));
assert!(!auditor.is_non_differentiable("add"));
}
#[test]
fn test_operation_audit() {
let auditor = GradientCoverageAuditor::new();
let info = auditor.audit_operation("add");
assert_eq!(info.status, GradientStatus::Implemented);
assert!(!info.supported_dtypes.is_empty());
let info = auditor.audit_operation("eq");
assert_eq!(info.status, GradientStatus::NotApplicable);
let info = auditor.audit_operation("concat");
assert!(matches!(
info.status,
GradientStatus::Missing | GradientStatus::Implemented
));
}
#[test]
fn test_full_audit() {
let auditor = GradientCoverageAuditor::new();
let report = auditor.audit_all();
assert!(report.total_operations > 0);
assert!(report.overall_coverage() >= 0.0);
assert!(report.overall_coverage() <= 100.0);
}
#[test]
fn test_coverage_percentage() {
let mut info = OperationGradientInfo::new("test", OperationCategory::BinaryElementwise);
assert_eq!(info.coverage_percentage(), 0.0);
info.supported_dtypes.push(DType::Float32);
info.missing_dtypes.push(DType::Float64);
assert_eq!(info.coverage_percentage(), 50.0);
info.supported_dtypes.push(DType::Float64);
info.missing_dtypes.clear();
assert_eq!(info.coverage_percentage(), 100.0);
}
#[test]
fn test_operations_by_status() {
let auditor = GradientCoverageAuditor::new();
let report = auditor.audit_all();
let missing = report.operations_by_status(GradientStatus::Missing);
let not_applicable = report.operations_by_status(GradientStatus::NotApplicable);
assert!(
!not_applicable.is_empty(),
"Should have some non-differentiable ops"
);
}
#[test]
fn test_global_auditor() {
let auditor1 = get_auditor();
let auditor2 = get_auditor();
assert!(std::ptr::eq(auditor1, auditor2));
}
#[test]
fn test_report_summary() {
let auditor = GradientCoverageAuditor::new();
let report = auditor.audit_all();
report.print_summary();
}
}