pub struct GradWithLossOptions {
pub zero_missing_wrt: bool,
}Expand description
Compute the reverse-mode gradient graph and the loss value.
Returns a graph whose outputs are
[loss, aux₁, aux₂, …, grad_wrt[0], grad_wrt[1], …].
The first forward output is treated as the scalar loss and is
differentiated. Any additional forward outputs (aux₁ …) are
mirrored from the forward graph and emitted unchanged — gradients are
not propagated through them. This is the canonical hook for emitting
training-side statistics (BatchNorm batch mean/variance, debug probes,
…) alongside the loss in a single forward+backward pass.
The returned graph contains a copy of the entire forward graph so
activations needed by gradient kernels are recomputed from inputs;
it also exposes a new Op::Input named "d_output" which the
caller seeds with the upstream gradient of the loss (typically a
scalar 1.0 for “differentiate the loss directly”). Auxiliary outputs
have no d_output-equivalent — by construction they don’t contribute
to the gradient path.
§Limitations
- Forward graph must have ≥ 1 output. The first is the loss.
- All ops in the forward graph must have an implemented VJP rule.
Hitting an op without one is a panic, not a silent miscompute.
Options for
grad_with_loss_opts.
Fields§
§zero_missing_wrt: boolWhen true, parameters in wrt with no gradient path receive an
explicit zero tensor instead of panicking (e.g. unused logit_bias).
Implementations§
Source§impl GradWithLossOptions
impl GradWithLossOptions
pub const STRICT: GradWithLossOptions
pub const TRAINING: GradWithLossOptions
Trait Implementations§
Source§impl Clone for GradWithLossOptions
impl Clone for GradWithLossOptions
Source§fn clone(&self) -> GradWithLossOptions
fn clone(&self) -> GradWithLossOptions
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreimpl Copy for GradWithLossOptions
Source§impl Debug for GradWithLossOptions
impl Debug for GradWithLossOptions
impl Eq for GradWithLossOptions
Source§impl PartialEq for GradWithLossOptions
impl PartialEq for GradWithLossOptions
Source§fn eq(&self, other: &GradWithLossOptions) -> bool
fn eq(&self, other: &GradWithLossOptions) -> bool
self and other values to be equal, and is used by ==.