pub trait GradFn: Send + Sync {
// Required methods
fn backward(&self, grad_output: &Tensor) -> Vec<Tensor>;
fn name(&self) -> &'static str;
}Expand description
Trait for functions that compute gradients during backward pass.
Each differentiable operation creates a GradFn implementation
that captures the necessary context for gradient computation.
§Example Implementation
For element-wise addition z = x + y:
- ∂z/∂x = 1
- ∂z/∂y = 1
So backward(grad_output) returns [grad_output, grad_output].