#[non_exhaustive]#[repr(u16)]pub enum EmbeddingKind {
Embedding = 0,
EmbeddingBackward = 1,
EmbeddingBagSum = 2,
EmbeddingBagMean = 3,
EmbeddingBagSumBackward = 4,
EmbeddingBagMeanBackward = 5,
EmbeddingBagMax = 6,
EmbeddingBagMaxBackward = 7,
}Expand description
Embedding-family op discriminant — Category M from the comprehensive plan.
Stored as u16 in crate::KernelSku::op when
category == OpCategory::Embedding. Phase 7 Milestone 7.5 wires:
Self::Embedding(FW + BW): row-lookupout[i, :] = weight[indices[i], :]with optionalpadding_idxthat emits an all-zero row at FW and skips accumulation at BW.Self::EmbeddingBagSum/Self::EmbeddingBagMean(FW + BW): bag-reduced row lookup —out[b, :] = reduce(weight[indices[k], :] for k in offsets[b]..offsets[b+1]). Mode determines the reducer (sum / divide-by-bag-size).EmbeddingBagMaxis deferred (needs argmax tracking for BW).
Index dtype is i32 only (i64 deferred). FW kernels emit
f32, f64, f16, bf16 (pure copy / reduce); BW kernels emit f32, f64 (atomicAdd).
Variants (Non-exhaustive)§
This enum is marked as non-exhaustive
Embedding = 0
embedding(weight, indices, padding_idx) —
out[i, :] = weight[indices[i], :]. PyTorch
torch.nn.functional.embedding.
EmbeddingBackward = 1
Gradient of Self::Embedding:
dweight[indices[i], :] += dout[i, :] (atomicAdd), skipping
rows where indices[i] == padding_idx.
EmbeddingBagSum = 2
embedding_bag(weight, indices, offsets, mode=Sum).
PyTorch torch.nn.functional.embedding_bag with mode='sum'.
EmbeddingBagMean = 3
embedding_bag(weight, indices, offsets, mode=Mean).
PyTorch torch.nn.functional.embedding_bag with mode='mean'.
EmbeddingBagSumBackward = 4
Gradient of embedding_bag (Sum-mode):
dweight[indices[k], :] += dout[b, :] for k in bag b (atomicAdd).
EmbeddingBagMeanBackward = 5
Gradient of embedding_bag (Mean-mode):
dweight[indices[k], :] += dout[b, :] / bag_size(b) (atomicAdd).
EmbeddingBagMax = 6
embedding_bag(weight, indices, offsets, mode=Max) — reserved.
Max-mode requires argmax tracking on FW (the per-feature index
of the contributing row) so the BW can scatter into just that
row — different plan shape; deferred.
EmbeddingBagMaxBackward = 7
Gradient of embedding_bag (Max-mode) — reserved.
Trait Implementations§
Source§impl Clone for EmbeddingKind
impl Clone for EmbeddingKind
Source§fn clone(&self) -> EmbeddingKind
fn clone(&self) -> EmbeddingKind
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreimpl Copy for EmbeddingKind
Source§impl Debug for EmbeddingKind
impl Debug for EmbeddingKind
impl Eq for EmbeddingKind
Source§impl Hash for EmbeddingKind
impl Hash for EmbeddingKind
Source§impl PartialEq for EmbeddingKind
impl PartialEq for EmbeddingKind
Source§fn eq(&self, other: &EmbeddingKind) -> bool
fn eq(&self, other: &EmbeddingKind) -> bool
self and other values to be equal, and is used by ==.