#[non_exhaustive]#[repr(u16)]pub enum LossKind {
Show 20 variants
Mse = 0,
Nll = 1,
CrossEntropy = 2,
Bce = 3,
KlDiv = 4,
L1 = 5,
SmoothL1 = 6,
HingeEmbedding = 7,
MarginRanking = 8,
TripletMargin = 9,
Ctc = 10,
PoissonNll = 11,
Huber = 12,
BceWithLogits = 13,
GaussianNll = 14,
CosineEmbedding = 15,
MultiMargin = 16,
MultilabelMargin = 17,
MultilabelSoftMargin = 18,
FusedLinearCrossEntropy = 19,
}Expand description
Loss op discriminant — category R from the comprehensive plan.
Stored as u16 in crate::KernelSku::op when
category == OpCategory::Loss. Each variant has its own Plan type
today (different argument shapes — MSE / BCE / KLDiv take two
same-dtype tensor inputs, NLL / CrossEntropy take a T input plus an
i64 target index tensor) but they share the LossReduction
enum for selecting per-cell / mean / sum output shape.
Today wired: {Mse, Nll, CrossEntropy, Bce, KlDiv} × {f32, f16, bf16, f64} — FW + BW. HingeEmbedding, L1, SmoothL1, MarginRanking,
TripletMargin, CtcLoss, and PoissonNllLoss are reserved
discriminants for future fanout.
Variants (Non-exhaustive)§
This enum is marked as non-exhaustive
Mse = 0
y = mean((pred - target)²) (or sum / per-cell). PyTorch
torch.nn.functional.mse_loss.
Nll = 1
y = -mean(input[target_idx[i]]) along the feature axis. PyTorch
torch.nn.functional.nll_loss. Heterogeneous-dtype: input T,
target i64.
CrossEntropy = 2
y = NLLLoss(LogSoftmax(input), target) — fused for numerical
stability. PyTorch torch.nn.functional.cross_entropy. Today wired
for class-index target only (i64); soft-target CE is reserved.
Bce = 3
y = -mean(target·log(pred) + (1-target)·log(1-pred)). PyTorch
torch.nn.functional.binary_cross_entropy. Caller ensures
pred ∈ (0, 1).
KlDiv = 4
y = mean(target·(log(target) - input)). PyTorch
torch.nn.functional.kl_div with the “input is log-prob”
convention.
L1 = 5
y = mean(|pred - target|) (or sum / per-cell). PyTorch
torch.nn.functional.l1_loss.
SmoothL1 = 6
Smooth L1 / “Huber-with-β” loss. PyTorch
torch.nn.functional.smooth_l1_loss.
HingeEmbedding = 7
y = mean(input if t==1 else max(0, margin - input)). PyTorch
torch.nn.functional.hinge_embedding_loss. Heterogeneous-dtype:
input is T, target is i64 (±1).
MarginRanking = 8
y = mean(max(0, -t · (x1 - x2) + margin)). PyTorch
torch.nn.functional.margin_ranking_loss. Target t is T (±1).
TripletMargin = 9
y = mean(max(0, ||a-p||_p - ||a-n||_p + margin)). PyTorch
torch.nn.functional.triplet_margin_loss. 2-D input [N, D].
Ctc = 10
Reserved — torch.nn.functional.ctc_loss.
PoissonNll = 11
y = mean(exp(input) - target · input) (default log_input=true).
PyTorch torch.nn.functional.poisson_nll_loss.
Huber = 12
Huber loss (separate from SmoothL1 — PyTorch
torch.nn.functional.huber_loss).
BceWithLogits = 13
Numerically stable BCE for raw logits. PyTorch
torch.nn.functional.binary_cross_entropy_with_logits.
GaussianNll = 14
Gaussian NLL. PyTorch torch.nn.GaussianNLLLoss.
CosineEmbedding = 15
y = (1 - cos(x1, x2)) if t==1 else max(0, cos(x1, x2) - margin),
then mean. PyTorch torch.nn.functional.cosine_embedding_loss.
2-D input [N, D]. Target is T (±1.0).
MultiMargin = 16
y = mean_i Σ_{j != t_i} max(0, margin - input[i, t_i] + input[i, j])^p / C.
PyTorch torch.nn.functional.multi_margin_loss. Input [N, C],
target [N] i64 class indices.
MultilabelMargin = 17
Multi-label margin loss. PyTorch
torch.nn.functional.multilabel_margin_loss. Input [N, C],
target [N, C] i64 (positive class indices followed by -1
padding sentinel).
MultilabelSoftMargin = 18
y = mean(-mean_c(target·log(sigmoid(x)) + (1-target)·log(1-sigmoid(x)))).
PyTorch torch.nn.functional.multilabel_soft_margin_loss.
Input [N, C], target [N, C] T.
FusedLinearCrossEntropy = 19
Fused Linear Cross-Entropy. loss = CE(input @ weight^T, target)
without materializing the [BT, V] logits tensor — the
projection GEMM and the cross-entropy reduction run together
in a chunked outer loop. Backward produces grad_input and
grad_weight directly during the forward pass; backward call
just multiplies them by the upstream dy scalar. Algorithm:
LinkedIn Liger-Kernel
(liger_kernel/ops/fused_linear_cross_entropy.py).
Saves ~5-10 GiB at vocab=128K, BT=16K (Llama-3-class) by
streaming logits in chunk_size-row tiles.