Skip to main content

LossKind

Enum LossKind 

Source
#[non_exhaustive]
#[repr(u16)]
pub enum LossKind {
Show 20 variants Mse = 0, Nll = 1, CrossEntropy = 2, Bce = 3, KlDiv = 4, L1 = 5, SmoothL1 = 6, HingeEmbedding = 7, MarginRanking = 8, TripletMargin = 9, Ctc = 10, PoissonNll = 11, Huber = 12, BceWithLogits = 13, GaussianNll = 14, CosineEmbedding = 15, MultiMargin = 16, MultilabelMargin = 17, MultilabelSoftMargin = 18, FusedLinearCrossEntropy = 19,
}
Expand description

Loss op discriminant — category R from the comprehensive plan.

Stored as u16 in crate::KernelSku::op when category == OpCategory::Loss. Each variant has its own Plan type today (different argument shapes — MSE / BCE / KLDiv take two same-dtype tensor inputs, NLL / CrossEntropy take a T input plus an i64 target index tensor) but they share the LossReduction enum for selecting per-cell / mean / sum output shape.

Today wired: {Mse, Nll, CrossEntropy, Bce, KlDiv} × {f32, f16, bf16, f64} — FW + BW. HingeEmbedding, L1, SmoothL1, MarginRanking, TripletMargin, CtcLoss, and PoissonNllLoss are reserved discriminants for future fanout.

Variants (Non-exhaustive)§

This enum is marked as non-exhaustive
Non-exhaustive enums could have additional variants added in future. Therefore, when matching against variants of non-exhaustive enums, an extra wildcard arm must be added to account for any future variants.
§

Mse = 0

y = mean((pred - target)²) (or sum / per-cell). PyTorch torch.nn.functional.mse_loss.

§

Nll = 1

y = -mean(input[target_idx[i]]) along the feature axis. PyTorch torch.nn.functional.nll_loss. Heterogeneous-dtype: input T, target i64.

§

CrossEntropy = 2

y = NLLLoss(LogSoftmax(input), target) — fused for numerical stability. PyTorch torch.nn.functional.cross_entropy. Today wired for class-index target only (i64); soft-target CE is reserved.

§

Bce = 3

y = -mean(target·log(pred) + (1-target)·log(1-pred)). PyTorch torch.nn.functional.binary_cross_entropy. Caller ensures pred ∈ (0, 1).

§

KlDiv = 4

y = mean(target·(log(target) - input)). PyTorch torch.nn.functional.kl_div with the “input is log-prob” convention.

§

L1 = 5

y = mean(|pred - target|) (or sum / per-cell). PyTorch torch.nn.functional.l1_loss.

§

SmoothL1 = 6

Smooth L1 / “Huber-with-β” loss. PyTorch torch.nn.functional.smooth_l1_loss.

§

HingeEmbedding = 7

y = mean(input if t==1 else max(0, margin - input)). PyTorch torch.nn.functional.hinge_embedding_loss. Heterogeneous-dtype: input is T, target is i64 (±1).

§

MarginRanking = 8

y = mean(max(0, -t · (x1 - x2) + margin)). PyTorch torch.nn.functional.margin_ranking_loss. Target t is T (±1).

§

TripletMargin = 9

y = mean(max(0, ||a-p||_p - ||a-n||_p + margin)). PyTorch torch.nn.functional.triplet_margin_loss. 2-D input [N, D].

§

Ctc = 10

Reserved — torch.nn.functional.ctc_loss.

§

PoissonNll = 11

y = mean(exp(input) - target · input) (default log_input=true). PyTorch torch.nn.functional.poisson_nll_loss.

§

Huber = 12

Huber loss (separate from SmoothL1 — PyTorch torch.nn.functional.huber_loss).

§

BceWithLogits = 13

Numerically stable BCE for raw logits. PyTorch torch.nn.functional.binary_cross_entropy_with_logits.

§

GaussianNll = 14

Gaussian NLL. PyTorch torch.nn.GaussianNLLLoss.

§

CosineEmbedding = 15

y = (1 - cos(x1, x2)) if t==1 else max(0, cos(x1, x2) - margin), then mean. PyTorch torch.nn.functional.cosine_embedding_loss. 2-D input [N, D]. Target is T (±1.0).

§

MultiMargin = 16

y = mean_i Σ_{j != t_i} max(0, margin - input[i, t_i] + input[i, j])^p / C. PyTorch torch.nn.functional.multi_margin_loss. Input [N, C], target [N] i64 class indices.

§

MultilabelMargin = 17

Multi-label margin loss. PyTorch torch.nn.functional.multilabel_margin_loss. Input [N, C], target [N, C] i64 (positive class indices followed by -1 padding sentinel).

§

MultilabelSoftMargin = 18

y = mean(-mean_c(target·log(sigmoid(x)) + (1-target)·log(1-sigmoid(x)))). PyTorch torch.nn.functional.multilabel_soft_margin_loss. Input [N, C], target [N, C] T.

§

FusedLinearCrossEntropy = 19

Fused Linear Cross-Entropy. loss = CE(input @ weight^T, target) without materializing the [BT, V] logits tensor — the projection GEMM and the cross-entropy reduction run together in a chunked outer loop. Backward produces grad_input and grad_weight directly during the forward pass; backward call just multiplies them by the upstream dy scalar. Algorithm: LinkedIn Liger-Kernel (liger_kernel/ops/fused_linear_cross_entropy.py). Saves ~5-10 GiB at vocab=128K, BT=16K (Llama-3-class) by streaming logits in chunk_size-row tiles.

Trait Implementations§

Source§

impl Clone for LossKind

Source§

fn clone(&self) -> LossKind

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Copy for LossKind

Source§

impl Debug for LossKind

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl Eq for LossKind

Source§

impl Hash for LossKind

Source§

fn hash<__H>(&self, state: &mut __H)
where __H: Hasher,

Feeds this value into the given Hasher. Read more
1.3.0 · Source§

fn hash_slice<H>(data: &[Self], state: &mut H)
where H: Hasher, Self: Sized,

Feeds a slice of this type into the given Hasher. Read more
Source§

impl PartialEq for LossKind

Source§

fn eq(&self, other: &LossKind) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0 (const: unstable) · Source§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl StructuralPartialEq for LossKind

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.