Skip to main content

GradFn

Trait GradFn 

Source
pub trait GradFn: Send + Sync {
    // Required methods
    fn backward(&self, grad_output: &Tensor) -> Vec<Tensor>;
    fn name(&self) -> &'static str;
}
Expand description

Trait for functions that compute gradients during backward pass.

Each differentiable operation creates a GradFn implementation that captures the necessary context for gradient computation.

§Example Implementation

For element-wise addition z = x + y:

  • ∂z/∂x = 1
  • ∂z/∂y = 1

So backward(grad_output) returns [grad_output, grad_output].

Required Methods§

Source

fn backward(&self, grad_output: &Tensor) -> Vec<Tensor>

Compute gradients with respect to inputs.

§Arguments
  • grad_output - Gradient flowing back from downstream operations
§Returns

Vector of gradients, one for each input tensor. The order must match the input order used during forward pass.

Source

fn name(&self) -> &'static str

Human-readable name for debugging.

Implementors§