pub struct GradGraph {
pub nodes: Vec<Rc<RefCell<GradNode>>>,
}Expand description
The reverse-mode AD tape/graph.
Fields§
§nodes: Vec<Rc<RefCell<GradNode>>>Implementations§
Source§impl GradGraph
impl GradGraph
pub fn new() -> Self
Sourcepub fn parameter(&mut self, tensor: Tensor) -> usize
pub fn parameter(&mut self, tensor: Tensor) -> usize
Create a parameter node (trainable, accumulates gradients).
Sourcepub fn softmax(&mut self, a: usize) -> usize
pub fn softmax(&mut self, a: usize) -> usize
Softmax along the last axis (treats tensor as a flat vector for 1-D). Uses numerically stable log-sum-exp: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
Sourcepub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize
pub fn cross_entropy(&mut self, logits: usize, targets: usize) -> usize
Cross-entropy loss: -sum(targets * log(softmax(logits))) Uses numerically stable log-sum-exp internally. Returns a scalar [1] tensor.
Sourcepub fn layer_norm(&mut self, a: usize) -> usize
pub fn layer_norm(&mut self, a: usize) -> usize
Layer normalization: normalize input to zero mean and unit variance. y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5.
Sourcepub fn batch_norm(&mut self, a: usize) -> usize
pub fn batch_norm(&mut self, a: usize) -> usize
Batch normalization: normalize along the first axis (batch dimension). For a tensor of shape [batch, features], normalizes each feature across the batch. y = (x - mean(x)) / sqrt(var(x) + eps), where eps = 1e-5. For 1-D inputs, behaves identically to layer_norm.
Sourcepub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize
pub fn clamp(&mut self, a: usize, min: f64, max: f64) -> usize
Element-wise clamp to [min, max].
Sourcepub fn where_cond(
&mut self,
cond: usize,
on_true: usize,
on_false: usize,
) -> usize
pub fn where_cond( &mut self, cond: usize, on_true: usize, on_false: usize, ) -> usize
Conditional select: where(cond, on_true, on_false). cond is a tensor of 0.0/1.0 values acting as a mask.
Sourcepub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize
pub fn reshape(&mut self, a: usize, new_shape: &[usize]) -> usize
Reshape a tensor. Stores the original shape for backward.
Sourcepub fn transpose_op(&mut self, a: usize) -> usize
pub fn transpose_op(&mut self, a: usize) -> usize
Transpose a 2-D tensor.
Sourcepub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize
pub fn cat(&mut self, inputs: &[usize], axis: usize) -> usize
Concatenate tensors along an axis. All input tensors must have the same shape except along the concatenation axis.
Sourcepub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize
pub fn gather(&mut self, a: usize, indices: &[usize], axis: usize) -> usize
Gather elements along an axis using indices. For a 1-D tensor, returns tensor[indices].
Sourcepub fn div(&mut self, a: usize, b: usize) -> usize
pub fn div(&mut self, a: usize, b: usize) -> usize
Element-wise division: a / b. GradOp::Div(a, b) already has backward implementation.
Sourcepub fn neg(&mut self, a: usize) -> usize
pub fn neg(&mut self, a: usize) -> usize
Element-wise negation: -a. GradOp::Neg(a) already has backward implementation.
Sourcepub fn scalar_mul(&mut self, a: usize, s: f64) -> usize
pub fn scalar_mul(&mut self, a: usize, s: f64) -> usize
Scalar multiply: a * s (where s is an f64 constant). GradOp::ScalarMul(a, s) already has backward implementation.
Sourcepub fn exp(&mut self, a: usize) -> usize
pub fn exp(&mut self, a: usize) -> usize
Element-wise exponential: exp(a). GradOp::Exp(a) already has backward implementation.
Sourcepub fn ln(&mut self, a: usize) -> usize
pub fn ln(&mut self, a: usize) -> usize
Element-wise natural logarithm: ln(a). GradOp::Ln(a) already has backward implementation.
Sourcepub fn set_tensor(&self, idx: usize, tensor: Tensor)
pub fn set_tensor(&self, idx: usize, tensor: Tensor)
Set the tensor at a node (for parameter updates).
Sourcepub fn clip_grad(&self, max_norm: f64)
pub fn clip_grad(&self, max_norm: f64)
Clip all gradients to [-max_norm, max_norm] (element-wise).
This prevents gradient explosion during backpropagation.
Sourcepub fn clip_grad_norm(&self, max_norm: f64) -> f64
pub fn clip_grad_norm(&self, max_norm: f64) -> f64
Clip gradients by global norm: if ||grads||_2 > max_norm, scale all gradients so the global norm equals max_norm. Deterministic via sequential accumulation.
Sourcepub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor
pub fn jacobian(&mut self, output_idx: usize, param_idx: usize) -> Tensor
Compute the Jacobian of a vector-valued output node with respect to a parameter node. Returns a 2D tensor of shape [output_dim, param_dim].
Strategy: run backward once per output element with a one-hot seed.
Sourcepub fn hessian_diag(
&mut self,
loss_idx: usize,
param_idx: usize,
eps: f64,
) -> Tensor
pub fn hessian_diag( &mut self, loss_idx: usize, param_idx: usize, eps: f64, ) -> Tensor
Compute the diagonal of the Hessian of a scalar loss with respect to a parameter node. Uses finite differences on the gradient (compute grad, perturb, re-compute grad).
Returns a tensor of the same shape as the parameter.
Sourcepub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor
pub fn hessian(&mut self, loss_idx: usize, param_idx: usize) -> Tensor
Compute the full Hessian matrix of a scalar loss with respect to a parameter node.
Returns a 2D tensor of shape [param_dim, param_dim] where H[i, j] = d²loss / (dp_i dp_j).
Strategy: For each parameter element i, perturb param[i] by +eps and -eps, re-run the forward pass to update intermediate node values, then run backward() to get the gradient vector. The i-th row of the Hessian is (grad_plus - grad_minus) / (2 * eps). Uses eps = 1e-5 for accurate central differences.
Sourcepub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor
pub fn double_backward(&mut self, loss_idx: usize, param_idx: usize) -> Tensor
Compute the second derivative of a scalar loss with respect to a parameter node.
Implements double_backward via finite differences on the backward pass: perturbs the parameter by +eps/-eps, re-runs the forward and backward pass, and computes d(grad)/d(param) numerically. For a scalar param this gives the exact second derivative d²loss/dparam².
Returns a tensor of the same shape as the parameter containing second derivatives.
Sourcepub fn vmap_forward(
&mut self,
input_idx: usize,
batch_data: &[Tensor],
) -> Vec<usize>
pub fn vmap_forward( &mut self, input_idx: usize, batch_data: &[Tensor], ) -> Vec<usize>
Vectorized map (batched evaluation) over a batch dimension.
For each tensor in batch_data, sets the input node input_idx to that tensor,
re-evaluates all downstream nodes by re-running the forward pass (recomputing
tensor values from the graph structure), and records the output node index
after each evaluation.
Returns a Vec<usize> of output node indices (one per batch element). After
calling this, g.value(results[k]) returns the output for batch element k.
Note: This is a simple batched evaluation helper. It mutates node tensors
in-place. After calling vmap_forward, the graph holds the values for the
LAST batch element. Use g.value(results[k]) to read individual results
(stored in snapshot tensors inside each returned node).
Implementation: For each batch element, set the input tensor, re-forward the subgraph from input_idx..=loss_idx by replaying each op, and record the final node’s value in a fresh parameter node.
Sourcepub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor)
pub fn backward_with_seed(&mut self, loss_idx: usize, seed: &Tensor)
Backward pass with a custom gradient seed tensor (for Jacobian computation).
Trait Implementations§
Auto Trait Implementations§
impl Freeze for GradGraph
impl !RefUnwindSafe for GradGraph
impl !Send for GradGraph
impl !Sync for GradGraph
impl Unpin for GradGraph
impl UnsafeUnpin for GradGraph
impl !UnwindSafe for GradGraph
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more