Skip to main content

GradFn

Trait GradFn 

Source
pub trait GradFn<R: Runtime>: Send + Sync {
    // Required methods
    fn backward(
        &self,
        grad_output: &Tensor<R>,
    ) -> Result<Vec<Option<Tensor<R>>>>;
    fn inputs(&self) -> &[TensorId];
    fn name(&self) -> &'static str;

    // Provided methods
    fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> { ... }
    fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> { ... }
    fn saved_tensors(&self) -> &[Tensor<R>] { ... }
}
Expand description

Trait for computing gradients during backward pass

Each operation that participates in autograd has an associated GradFn that knows how to compute gradients for its inputs.

§Implementation Guide

When implementing this trait, you must implement both backward() and backward_var() if you want proper second-order differentiation support.

The backward() method is used for first-order gradients (standard backprop). The backward_var() method is used for second-order gradients (Hessians, HVPs).

§Example

struct MyOpBackward<R: Runtime> {
    saved_tensor: Tensor<R>,
}

impl<R: Runtime> GradFn<R> for MyOpBackward<R> {
    fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
        // Compute gradients using tensor ops
        // let grad = client.mul(grad_output, &self.saved_tensor)?;
        Ok(vec![Some(grad_output.clone())])
    }

    fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
        // Compute gradients using var_ops to maintain computation graph
        Ok(vec![Some(grad_output.clone())])
    }
}

Required Methods§

Source

fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>>

Compute gradients for input tensors given the gradient of the output

Returns a vector of optional gradients - one per input. None indicates that input doesn’t need a gradient.

Source

fn inputs(&self) -> &[TensorId]

Get the IDs of input tensors

Used for topological sorting during backward pass.

Source

fn name(&self) -> &'static str

Human-readable name for debugging

Provided Methods§

Source

fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>>

Compute gradients as Vars for second-order differentiation

This method enables higher-order derivatives by returning Vars that retain their computation history.

§Important: Override for Second-Order Derivatives

The default implementation creates detached Vars with no gradient history. This means second-order derivatives will NOT flow through operations that rely on the default implementation.

If your operation needs to support second-order differentiation (Hessians, Hessian-vector products), you must override this method to:

  1. Use var_ops (var_mul, var_add, etc.) instead of raw tensor operations
  2. Use Var::with_id_and_grad_fn() for saved tensors to preserve the chain
  3. Return Vars that maintain the computation graph
§Default Behavior

The default implementation calls backward() and wraps results in Vars with requires_grad=true but grad_fn=None. This is suitable for:

  • Operations that don’t need second-order derivatives
  • Leaf operations where the gradient chain naturally terminates
  • Initial development before adding full second-order support
§Arguments
  • grad_output - The gradient of the loss with respect to this op’s output
§Returns

A vector of optional Vars - one per input. Each Var can be differentiated again to compute second-order derivatives.

Source

fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>>

Get the grad_fns of input tensors for graph traversal

Returns a vector of optional grad_fns - one per input. None indicates a leaf tensor.

Source

fn saved_tensors(&self) -> &[Tensor<R>]

Get tensors saved during forward pass

Some operations (like softmax) need forward outputs for backward.

Implementors§

Source§

impl<R: Runtime> GradFn<R> for AbsBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for AddBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for AddScalarBackward<R>

Source§

impl<R: Runtime> GradFn<R> for CatBackward<R>

Source§

impl<R: Runtime> GradFn<R> for CosBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for CumprodBackward<R>
where R::Client: CumulativeOps<R> + BinaryOps<R>,

Source§

impl<R: Runtime> GradFn<R> for CumsumBackward<R>
where R::Client: CumulativeOps<R>,

Source§

impl<R: Runtime> GradFn<R> for DetBackward<R>

Source§

impl<R: Runtime> GradFn<R> for DivBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for DivScalarBackward<R>
where R::Client: ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for ExpBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for ExpandBackward<R>
where R::Client: RuntimeClient<R> + TensorOps<R> + ReduceOps<R>,

Source§

impl<R: Runtime> GradFn<R> for FusedAddLayerNormBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R> + ReduceOps<R> + UnaryOps<R>,

Source§

impl<R: Runtime> GradFn<R> for FusedAddRmsNormBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R> + ReduceOps<R> + UnaryOps<R>,

Source§

impl<R: Runtime> GradFn<R> for GroupNormBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R> + BinaryOps<R> + UnaryOps<R>,

Source§

impl<R: Runtime> GradFn<R> for InverseBackward<R>
where R::Client: MatmulOps<R> + TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for LayerNormBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R> + ReduceOps<R> + UnaryOps<R>,

Source§

impl<R: Runtime> GradFn<R> for LogBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MatmulBackward<R>
where R::Client: MatmulOps<R> + TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MatmulBiasActivationBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R> + ReduceOps<R> + UnaryOps<R> + MatmulOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MaxBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MeanBackward<R>
where R::Client: ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MinBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R> + ReduceOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MulBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for MulScalarBackward<R>
where R::Client: ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for NegBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for PermuteBackward<R>

Source§

impl<R: Runtime> GradFn<R> for PowBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for PowScalarBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for RecipBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for ReshapeBackward<R>

Source§

impl<R: Runtime> GradFn<R> for RmsNormBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + BinaryOps<R> + ReduceOps<R> + UnaryOps<R>,

Source§

impl<R: Runtime> GradFn<R> for SinBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for SoftmaxBackward<R>
where R::Client: TensorOps<R> + ReduceOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for SolveBackward<R>

Source§

impl<R: Runtime> GradFn<R> for SqrtBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for SquareBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for StdBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R>,

Source§

impl<R: Runtime> GradFn<R> for SubBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime> GradFn<R> for SubScalarBackward<R>

Source§

impl<R: Runtime> GradFn<R> for SumBackward<R>

Source§

impl<R: Runtime> GradFn<R> for TanBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime> GradFn<R> for TransposeBackward<R>

Source§

impl<R: Runtime> GradFn<R> for VarBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + ReduceOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for CastBackward<R>

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for CholeskyBackward<R>

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for ClampBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R> + CompareOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for GatherBackward<R>
where R::Client: IndexingOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for LogSoftmaxBackward<R>
where R::Client: TensorOps<R> + UnaryOps<R> + ReduceOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for NarrowBackward<R>
where R::Client: RuntimeClient<R> + TensorOps<R> + ShapeOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for ReluBackward<R>
where R::Client: TensorOps<R> + CompareOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for SigmoidBackward<R>
where R::Client: TensorOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for SiluBackward<R>
where R::Client: TensorOps<R> + ActivationOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for SoftplusBackward<R>
where R::Client: TensorOps<R> + ActivationOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for TanhBackward<R>
where R::Client: TensorOps<R> + ScalarOps<R>,

Source§

impl<R: Runtime<DType = DType>> GradFn<R> for TraceBackward<R>