1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//! Embedding op family — Category M.
//!
//! Phase 7 Milestone 7.5 of the baracuda-kernels comprehensive plan.
//! Plan-per-op because the FW / BW shapes share structure with
//! `index_select` but carry an optional `padding_idx` (and, for
//! `embedding_bag`, a bag-reduction mode + per-bag offset table) that
//! makes composing the indexing plans awkward.
//!
//! Ops shipped:
//! - [`EmbeddingPlan`] FW: `out[i, :] = weight[indices[i], :]` with
//! optional `padding_idx` zeroing matching rows.
//! - [`EmbeddingBackwardPlan`] BW: `dweight[indices[i], :] += dout[i, :]`
//! (atomicAdd), skipping the padding row.
//! - [`EmbeddingBagPlan`] FW: per-bag reduction over the index range
//! `offsets[b]..offsets[b+1]`. Modes: `Sum` / `Mean`.
//! - [`EmbeddingBagBackwardPlan`] BW: atomicAdd of `dout[b, :] / divisor`
//! into `dweight[indices[k], :]` for each k in the bag.
//! - [`EmbeddingBagMaxPlan`] FW (Phase 25): per-bag max-reduction with
//! per-feature argmax tracking. Writes both the max value and the
//! contributing `indices[k]` (i32) into `out_index`.
//! - [`EmbeddingBagMaxBackwardPlan`] BW (Phase 25): scatter-add `dout`
//! into `dweight[out_index[b, d], :]` via atomicAdd.
//!
//! Trailblazer dtype coverage:
//! - FW (`Embedding`, `EmbeddingBag`): `f32, f64, f16, bf16` (pure
//! copy / accumulator-typed sum).
//! - BW: `f32, f64` only — atomicAdd is native-FP.
//!
//! Index dtype is `i32` only. Negative or out-of-range indices are
//! treated as "skip" (no PyTorch-style wrap-around). The
//! padding-disabled sentinel is `i32::MIN` (mapped from `Option::None`).
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
/// Sentinel passed to the kernel when the caller does not supply a
/// `padding_idx`. Matches `kPaddingDisabled` in
/// `kernels/include/baracuda_embedding.cuh`.
pub const PADDING_DISABLED: i32 = i32MIN;