use crate::{Scirs2Exec, Scirs2Tensor};
use scirs2_core::ndarray::ArrayD;
use tensorlogic_infer::{ExecutorError, TlAutodiff};
use tensorlogic_ir::EinsumGraph;
#[derive(Clone, Copy, Debug)]
pub struct GradientCheckConfig {
pub epsilon: f64,
pub rtol: f64,
pub atol: f64,
}
impl Default for GradientCheckConfig {
fn default() -> Self {
GradientCheckConfig {
epsilon: 1e-5,
rtol: 1e-3,
atol: 1e-5,
}
}
}
#[derive(Debug)]
pub struct GradientCheckResult {
pub tensor_name: String,
pub max_abs_diff: f64,
pub max_rel_diff: f64,
pub passed: bool,
pub num_elements: usize,
}
impl GradientCheckResult {
pub fn is_close(&self, config: &GradientCheckConfig) -> bool {
self.max_abs_diff < config.atol || self.max_rel_diff < config.rtol
}
}
pub fn compute_numeric_gradient(
graph: &EinsumGraph,
executor: &mut Scirs2Exec,
tensor_name: &str,
config: &GradientCheckConfig,
) -> Result<Scirs2Tensor, ExecutorError> {
let input_tensor = executor
.tensors
.get(tensor_name)
.ok_or_else(|| ExecutorError::TensorNotFound(tensor_name.to_string()))?
.clone();
let shape = input_tensor.shape();
let mut numeric_grad = ArrayD::zeros(shape);
for idx in 0..input_tensor.len() {
let mut tensor_plus = input_tensor.clone();
let mut tensor_minus = input_tensor.clone();
let flat_plus = tensor_plus
.as_slice_mut()
.expect("tensor has standard contiguous layout");
let flat_minus = tensor_minus
.as_slice_mut()
.expect("tensor has standard contiguous layout");
flat_plus[idx] += config.epsilon;
flat_minus[idx] -= config.epsilon;
executor.add_tensor(tensor_name, tensor_plus);
let output_plus = executor.forward(graph)?;
executor.add_tensor(tensor_name, tensor_minus);
let output_minus = executor.forward(graph)?;
let grad_value = (output_plus.sum() - output_minus.sum()) / (2.0 * config.epsilon);
let flat_grad = numeric_grad
.as_slice_mut()
.expect("numeric_grad has standard contiguous layout");
flat_grad[idx] = grad_value;
}
executor.add_tensor(tensor_name, input_tensor);
Ok(numeric_grad)
}
pub fn compare_gradients(
analytical: &Scirs2Tensor,
numeric: &Scirs2Tensor,
tensor_name: &str,
config: &GradientCheckConfig,
) -> GradientCheckResult {
assert_eq!(
analytical.shape(),
numeric.shape(),
"Gradient shapes must match"
);
let mut max_abs_diff: f64 = 0.0;
let mut max_rel_diff: f64 = 0.0;
let num_elements = analytical.len();
for (a, n) in analytical.iter().zip(numeric.iter()) {
let abs_diff = (a - n).abs();
let rel_diff = if n.abs() > 1e-10 {
abs_diff / n.abs()
} else {
abs_diff
};
max_abs_diff = max_abs_diff.max(abs_diff);
max_rel_diff = max_rel_diff.max(rel_diff);
}
let passed = max_abs_diff < config.atol || max_rel_diff < config.rtol;
GradientCheckResult {
tensor_name: tensor_name.to_string(),
max_abs_diff,
max_rel_diff,
passed,
num_elements,
}
}
pub fn check_gradients(
graph: &EinsumGraph,
executor: &mut Scirs2Exec,
config: Option<GradientCheckConfig>,
) -> Result<Vec<GradientCheckResult>, ExecutorError> {
let config = config.unwrap_or_default();
let mut results = Vec::new();
let output = executor.forward(graph)?;
let loss_grad = Scirs2Tensor::ones(output.raw_dim());
let analytical_tape = executor.backward(graph, &loss_grad)?;
for (idx, tensor_name) in graph.tensors.iter().enumerate() {
if executor.tensors.contains_key(tensor_name) {
let numeric_grad = compute_numeric_gradient(graph, executor, tensor_name, &config)?;
if let Some(Some(analytical_grad)) = analytical_tape.tensors.get(idx) {
let result =
compare_gradients(analytical_grad, &numeric_grad, tensor_name, &config);
results.push(result);
}
}
}
Ok(results)
}
#[cfg(all(test, feature = "integration-tests"))]
mod tests {
use super::*;
use tensorlogic_compiler::compile_to_einsum;
use tensorlogic_ir::{TLExpr, Term};
#[test]
fn test_gradient_check_add() {
let x = TLExpr::pred("x", vec![Term::var("i"), Term::var("j")]);
let y = TLExpr::pred("y", vec![Term::var("i"), Term::var("j")]);
let expr = TLExpr::add(x, y);
let graph = compile_to_einsum(&expr).expect("unwrap");
let mut executor = Scirs2Exec::new();
let x_tensor = Scirs2Exec::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("unwrap");
let y_tensor = Scirs2Exec::from_vec(vec![0.5, 0.5, 1.0, 1.0], vec![2, 2]).expect("unwrap");
executor.add_tensor(graph.tensors[0].clone(), x_tensor);
executor.add_tensor(graph.tensors[1].clone(), y_tensor);
let results = check_gradients(&graph, &mut executor, None).expect("unwrap");
for result in results {
println!(
"Tensor: {}, Max abs diff: {:.6e}, Max rel diff: {:.6e}, Passed: {}",
result.tensor_name, result.max_abs_diff, result.max_rel_diff, result.passed
);
assert!(
result.passed,
"Gradient check failed for {}",
result.tensor_name
);
}
}
#[test]
fn test_gradient_check_multiply() {
let x = TLExpr::pred("x", vec![Term::var("i"), Term::var("j")]);
let y = TLExpr::pred("y", vec![Term::var("i"), Term::var("j")]);
let expr = TLExpr::mul(x, y);
let graph = compile_to_einsum(&expr).expect("unwrap");
let mut executor = Scirs2Exec::new();
let x_tensor = Scirs2Exec::from_vec(vec![2.0, 3.0, 4.0, 5.0], vec![2, 2]).expect("unwrap");
let y_tensor = Scirs2Exec::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]).expect("unwrap");
executor.add_tensor(graph.tensors[0].clone(), x_tensor);
executor.add_tensor(graph.tensors[1].clone(), y_tensor);
let results = check_gradients(&graph, &mut executor, None).expect("unwrap");
for result in results {
println!(
"Tensor: {}, Max abs diff: {:.6e}, Max rel diff: {:.6e}",
result.tensor_name, result.max_abs_diff, result.max_rel_diff
);
assert!(
result.passed,
"Gradient check failed for {}",
result.tensor_name
);
}
}
#[test]
fn test_gradient_check_divide() {
let x = TLExpr::pred("x", vec![Term::var("i"), Term::var("j")]);
let y = TLExpr::pred("y", vec![Term::var("i"), Term::var("j")]);
let expr = TLExpr::div(x, y);
let graph = compile_to_einsum(&expr).expect("unwrap");
let mut executor = Scirs2Exec::new();
let x_tensor =
Scirs2Exec::from_vec(vec![6.0, 8.0, 10.0, 12.0], vec![2, 2]).expect("unwrap");
let y_tensor = Scirs2Exec::from_vec(vec![2.0, 2.0, 2.0, 3.0], vec![2, 2]).expect("unwrap");
executor.add_tensor(graph.tensors[0].clone(), x_tensor);
executor.add_tensor(graph.tensors[1].clone(), y_tensor);
let results = check_gradients(&graph, &mut executor, None).expect("unwrap");
for result in results {
println!(
"Tensor: {}, Max abs diff: {:.6e}, Max rel diff: {:.6e}",
result.tensor_name, result.max_abs_diff, result.max_rel_diff
);
assert!(
result.passed,
"Gradient check failed for {}",
result.tensor_name
);
}
}
}