Skip to main content

GradGraph

Struct GradGraph 

Source
pub struct GradGraph {
    pub nodes: Vec<Rc<RefCell<GradNode>>>,
}
Expand description

The reverse-mode AD tape/graph.

Fields§

§nodes: Vec<Rc<RefCell<GradNode>>>

Implementations§

Source§

impl GradGraph

Source

pub fn new() -> Self

Source

pub fn input(&mut self, tensor: Tensor) -> usize

Create an input node (data, no gradient).

Source

pub fn parameter(&mut self, tensor: Tensor) -> usize

Create a parameter node (trainable, accumulates gradients).

Source

pub fn add(&mut self, a: usize, b: usize) -> usize

Element-wise addition.

Source

pub fn sub(&mut self, a: usize, b: usize) -> usize

Element-wise subtraction.

Source

pub fn mul(&mut self, a: usize, b: usize) -> usize

Element-wise multiplication.

Source

pub fn matmul(&mut self, a: usize, b: usize) -> usize

Matrix multiplication.

Source

pub fn sum(&mut self, a: usize) -> usize

Sum all elements.

Source

pub fn mean(&mut self, a: usize) -> usize

Mean of all elements.

Source

pub fn sin(&mut self, a: usize) -> usize

Element-wise sine.

Source

pub fn cos(&mut self, a: usize) -> usize

Element-wise cosine.

Source

pub fn sqrt(&mut self, a: usize) -> usize

Element-wise square root.

Source

pub fn pow(&mut self, a: usize, n: f64) -> usize

Element-wise power with constant exponent.

Source

pub fn sigmoid(&mut self, a: usize) -> usize

Sigmoid activation: 1 / (1 + exp(-x)).

Source

pub fn relu(&mut self, a: usize) -> usize

ReLU activation: max(0, x).

Source

pub fn tanh_act(&mut self, a: usize) -> usize

Tanh activation.

Source

pub fn abs(&mut self, a: usize) -> usize

Element-wise absolute value.

Source

pub fn log2(&mut self, a: usize) -> usize

Element-wise log base 2.

Source

pub fn softmax(&mut self, a: usize) -> usize

Softmax along the last axis (treats tensor as a flat vector for 1-D). Uses numerically stable log-sum-exp: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))

Source

pub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize

Cross-entropy loss: -sum(targets * log(softmax(logits))) Uses numerically stable log-sum-exp internally. Returns a scalar [1] tensor.

Source

pub fn layer_norm(&mut self, a: usize) -> usize

Layer normalization: normalize input to zero mean and unit variance. y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5.

Source

pub fn batch_norm(&mut self, a: usize) -> usize

Batch normalization: normalize along the first axis (batch dimension). For a tensor of shape [batch, features], normalizes each feature across the batch. y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5. For 1-D inputs, behaves identically to layer_norm.

Source

pub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize

Element-wise clamp to [min, max].

Source

pub fn where_cond( &mut self, cond: usize, on_true: usize, on_false: usize, ) -> usize

Conditional select: where(cond, on_true, on_false). cond is a tensor of 0.0/1.0 values acting as a mask.

Source

pub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize

Reshape a tensor. Stores the original shape for backward.

Source

pub fn transpose_op(&mut self, a: usize) -> usize

Transpose a 2-D tensor.

Source

pub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize

Concatenate tensors along an axis. All input tensors must have the same shape except along the concatenation axis.

Source

pub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize

Gather elements along an axis using indices. For a 1-D tensor, returns tensor[indices].

Source

pub fn div(&mut self, a: usize, b: usize) -> usize

Element-wise division: a / b. GradOp::Div(a, b) already has backward implementation.

Source

pub fn neg(&mut self, a: usize) -> usize

Element-wise negation: -a. GradOp::Neg(a) already has backward implementation.

Source

pub fn scalar_mul(&mut self, a: usize, s: f64) -> usize

Scalar multiply: a * s (where s is an f64 constant). GradOp::ScalarMul(a, s) already has backward implementation.

Source

pub fn exp(&mut self, a: usize) -> usize

Element-wise exponential: exp(a). GradOp::Exp(a) already has backward implementation.

Source

pub fn ln(&mut self, a: usize) -> usize

Element-wise natural logarithm: ln(a). GradOp::Ln(a) already has backward implementation.

Source

pub fn value(&self, idx: usize) -> f64

Get the scalar value from a 1-element tensor node.

Source

pub fn tensor(&self, idx: usize) -> Tensor

Get the tensor at a node.

Source

pub fn set_tensor(&self, idx: usize, tensor: Tensor)

Set the tensor at a node (for parameter updates).

Source

pub fn grad(&self, idx: usize) -> Option<Tensor>

Get the gradient at a node.

Source

pub fn zero_grad(&self)

Zero out all gradients.

Source

pub fn clip_grad(&self, max_norm: f64)

Clip all gradients to [-max_norm, max_norm] (element-wise). This prevents gradient explosion during backpropagation.

Source

pub fn clip_grad_norm(&self, max_norm: f64) -> f64

Clip gradients by global norm: if ||grads||_2 > max_norm, scale all gradients so the global norm equals max_norm. Deterministic via sequential accumulation.

Source

pub fn backward(&self, loss_idx: usize)

Run backward pass from a loss node.

Source

pub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor

Compute the Jacobian of a vector-valued output node with respect to a parameter node. Returns a 2D tensor of shape [output_dim, param_dim].

Strategy: run backward once per output element with a one-hot seed.

Source

pub fn hessian_diag( &mut self, loss_idx: usize, param_idx: usize, eps: f64, ) -> Tensor

Compute the diagonal of the Hessian of a scalar loss with respect to a parameter node. Uses finite differences on the gradient (compute grad, perturb, re-compute grad).

Returns a tensor of the same shape as the parameter.

Source

pub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor

Compute the full Hessian matrix of a scalar loss with respect to a parameter node.

Returns a 2D tensor of shape [param_dim, param_dim] where H[i, j] = d²loss / (dp_i dp_j).

Strategy: For each parameter element i, perturb param[i] by +eps and -eps, re-run the forward pass to update intermediate node values, then run backward() to get the gradient vector. The i-th row of the Hessian is (grad_plus - grad_minus) / (2 * eps). Uses eps = 1e-5 for accurate central differences.

Source

pub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor

Compute the second derivative of a scalar loss with respect to a parameter node.

Implements double_backward via finite differences on the backward pass: perturbs the parameter by +eps/-eps, re-runs the forward and backward pass, and computes d(grad)/d(param) numerically. For a scalar param this gives the exact second derivative d²loss/dparam².

Returns a tensor of the same shape as the parameter containing second derivatives.

Source

pub fn vmap_forward( &mut self, input_idx: usize, batch_data: &[Tensor], ) -> Vec<usize>

Vectorized map (batched evaluation) over a batch dimension.

For each tensor in batch_data, sets the input node input_idx to that tensor, re-evaluates all downstream nodes by re-running the forward pass (recomputing tensor values from the graph structure), and records the output node index after each evaluation.

Returns a Vec<usize> of output node indices (one per batch element). After calling this, g.value(results[k]) returns the output for batch element k.

Note: This is a simple batched evaluation helper. It mutates node tensors in-place. After calling vmap_forward, the graph holds the values for the LAST batch element. Use g.value(results[k]) to read individual results (stored in snapshot tensors inside each returned node).

Implementation: For each batch element, set the input tensor, re-forward the subgraph from input_idx..=loss_idx by replaying each op, and record the final node’s value in a fresh parameter node.

Source

pub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor)

Backward pass with a custom gradient seed tensor (for Jacobian computation).

Trait Implementations§

Source§

impl Default for GradGraph

Source§

fn default() -> Self

Returns the “default value” for a type. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.