pub trait GradFn<R: Runtime>: Send + Sync {
// Required methods
fn backward(
&self,
grad_output: &Tensor<R>,
) -> Result<Vec<Option<Tensor<R>>>>;
fn inputs(&self) -> &[TensorId];
fn name(&self) -> &'static str;
// Provided methods
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> { ... }
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> { ... }
fn saved_tensors(&self) -> &[Tensor<R>] { ... }
}Expand description
Trait for computing gradients during backward pass
Each operation that participates in autograd has an associated
GradFn that knows how to compute gradients for its inputs.
§Implementation Guide
When implementing this trait, you must implement both backward() and
backward_var() if you want proper second-order differentiation support.
The backward() method is used for first-order gradients (standard backprop).
The backward_var() method is used for second-order gradients (Hessians, HVPs).
§Example
struct MyOpBackward<R: Runtime> {
saved_tensor: Tensor<R>,
}
impl<R: Runtime> GradFn<R> for MyOpBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
// Compute gradients using tensor ops
// let grad = client.mul(grad_output, &self.saved_tensor)?;
Ok(vec![Some(grad_output.clone())])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
// Compute gradients using var_ops to maintain computation graph
Ok(vec![Some(grad_output.clone())])
}
}Required Methods§
Sourcefn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>>
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>>
Compute gradients for input tensors given the gradient of the output
Returns a vector of optional gradients - one per input.
None indicates that input doesn’t need a gradient.
Provided Methods§
Sourcefn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>
Compute gradients as Vars for second-order differentiation
This method enables higher-order derivatives by returning Vars
that retain their computation history.
§Important: Override for Second-Order Derivatives
The default implementation creates detached Vars with no gradient history. This means second-order derivatives will NOT flow through operations that rely on the default implementation.
If your operation needs to support second-order differentiation (Hessians, Hessian-vector products), you must override this method to:
- Use
var_ops(var_mul, var_add, etc.) instead of raw tensor operations - Use
Var::with_id_and_grad_fn()for saved tensors to preserve the chain - Return Vars that maintain the computation graph
§Default Behavior
The default implementation calls backward() and wraps results in Vars
with requires_grad=true but grad_fn=None. This is suitable for:
- Operations that don’t need second-order derivatives
- Leaf operations where the gradient chain naturally terminates
- Initial development before adding full second-order support
§Arguments
grad_output- The gradient of the loss with respect to this op’s output
§Returns
A vector of optional Vars - one per input. Each Var can be
differentiated again to compute second-order derivatives.
Sourcefn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>>
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>>
Get the grad_fns of input tensors for graph traversal
Returns a vector of optional grad_fns - one per input.
None indicates a leaf tensor.
Sourcefn saved_tensors(&self) -> &[Tensor<R>]
fn saved_tensors(&self) -> &[Tensor<R>]
Get tensors saved during forward pass
Some operations (like softmax) need forward outputs for backward.