Skip to main content

baracuda_kernels/embedding/
embedding_bag_max.rs

1//! `embedding_bag` Max-mode FW plan — Category M. Phase 25.
2//!
3//! Per-bag max-reduction with per-feature argmax tracking. For each
4//! bag `b` and each feature `d`:
5//!
6//! - `out[b, d]      = max(weight[indices[k], d])` for `k ∈ bag b`,
7//!   excluding padding / OOB indices.
8//! - `out_index[b, d] = the (first-occurrence) indices[k]` that
9//!   contributed the max value, or `-1` if the bag was empty / all
10//!   padded.
11//!
12//! The `out_index` tensor is the saved-state input to
13//! [`crate::embedding::EmbeddingBagMaxBackwardPlan`].
14//!
15//! Trailblazer dtype coverage: `f32, f64, f16, bf16` (matches the
16//! Sum/Mean FWs). Indices: `i32` + `i64`. f16 / bf16 accumulate in f32.
17//!
18//! **Tie-break**: first occurrence (lowest `k` in the bag). PyTorch
19//! chooses the last occurrence; we document the divergence here.
20
21use core::ffi::c_void;
22use core::marker::PhantomData;
23
24use baracuda_cutlass::{Error, Result};
25use baracuda_driver::Stream;
26use baracuda_kernels_types::{
27    ArchSku, BackendKind, Element, ElementKind, EmbeddingKind, IndexElement, IndexElementKind,
28    KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
29    TensorRef, Workspace,
30};
31
32use crate::indexing::gather::map_status;
33
34use super::PADDING_DISABLED;
35
36/// Descriptor for an `embedding_bag` Max-mode op.
37#[derive(Copy, Clone, Debug)]
38pub struct EmbeddingBagMaxDescriptor {
39    /// Vocabulary size — extent of `weight` along axis 0.
40    pub num_embeddings: i32,
41    /// Embedding dimension — extent of `weight` along axis 1.
42    pub embedding_dim: i32,
43    /// Number of bags — extent of `offsets` and of `out`, `out_index`
44    /// along axis 0.
45    pub num_bags: i32,
46    /// Total flat-index length — extent of `indices`.
47    pub total_indices: i32,
48    /// Optional padding index. Indices matching `p` (or negative / OOB)
49    /// are dropped from the bag.
50    pub padding_idx: Option<i32>,
51    /// Value element type.
52    pub element: ElementKind,
53}
54
55/// Args bundle for an `embedding_bag` Max-mode launch.
56pub struct EmbeddingBagMaxArgs<'a, T: Element, I: IndexElement = i32> {
57    /// Weight matrix `[V, D]`.
58    pub weight: TensorRef<'a, T, 2>,
59    /// Flat index tensor `[total_indices]`.
60    pub indices: TensorRef<'a, I, 1>,
61    /// Per-bag start offset table `[num_bags]`, i32.
62    pub offsets: TensorRef<'a, i32, 1>,
63    /// Output max values `[num_bags, D]`.
64    pub output: TensorMut<'a, T, 2>,
65    /// Output per-(b, d) contributing-row index `[num_bags, D]`,
66    /// always i32. `-1` for empty / all-padded bags.
67    pub output_index: TensorMut<'a, i32, 2>,
68}
69
70/// `embedding_bag` Max-mode FW plan. Phase 25.
71///
72/// Per-bag max with per-feature argmax tracking. Pair with
73/// [`crate::EmbeddingBagMaxBackwardPlan`] for autograd — the BW pass
74/// scatters `dout[b, :]` into `dweight[output_index[b, :], :]`.
75///
76/// **Dtypes**: weight / output `{f32, f64, f16, bf16}`; index buffers
77/// `i32` / `i64`. `output_index` is always `i32`.
78///
79/// **Shape limits**: same as the Sum/Mean FW plus
80/// `output_index` `[num_bags, D]`.
81///
82/// **Workspace**: none.
83///
84/// **Precision guarantee**: deterministic, bit-stable. No atomics.
85///
86/// **Index policy**: padding / OOB indices skipped; empty / all-padded
87/// bag emits zero output and `-1` in every `output_index` cell.
88///
89/// **Tie-break**: first occurrence — diverges from PyTorch (last).
90pub struct EmbeddingBagMaxPlan<T: Element> {
91    desc: EmbeddingBagMaxDescriptor,
92    sku: KernelSku,
93    _marker: PhantomData<T>,
94}
95
96impl<T: Element> EmbeddingBagMaxPlan<T> {
97    /// Pick a kernel for `desc`.
98    pub fn select(
99        _stream: &Stream,
100        desc: &EmbeddingBagMaxDescriptor,
101        _pref: PlanPreference,
102    ) -> Result<Self> {
103        if desc.element != T::KIND {
104            return Err(Error::Unsupported(
105                "baracuda-kernels::EmbeddingBagMaxPlan: descriptor element != type parameter T",
106            ));
107        }
108        if desc.num_embeddings < 0
109            || desc.embedding_dim < 0
110            || desc.num_bags < 0
111            || desc.total_indices < 0
112        {
113            return Err(Error::InvalidProblem(
114                "baracuda-kernels::EmbeddingBagMaxPlan: num_embeddings / embedding_dim / \
115                 num_bags / total_indices must be non-negative",
116            ));
117        }
118        let supported = matches!(
119            T::KIND,
120            ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
121        );
122        if !supported {
123            return Err(Error::Unsupported(
124                "baracuda-kernels::EmbeddingBagMaxPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
125            ));
126        }
127        let precision_guarantee = PrecisionGuarantee {
128            math_precision: if T::KIND == ElementKind::F64 {
129                MathPrecision::F64
130            } else {
131                MathPrecision::F32
132            },
133            accumulator: if T::KIND == ElementKind::F64 {
134                ElementKind::F64
135            } else {
136                ElementKind::F32
137            },
138            bit_stable_on_same_hardware: true,
139            deterministic: true,
140        };
141        let sku = KernelSku {
142            category: OpCategory::Embedding,
143            op: EmbeddingKind::EmbeddingBagMax as u16,
144            element: T::KIND,
145            aux_element: Some(ElementKind::I32),
146            layout: None,
147            epilogue: None,
148            arch: ArchSku::Sm80,
149            backend: BackendKind::Bespoke,
150            precision_guarantee,
151        };
152        Ok(Self {
153            desc: *desc,
154            sku,
155            _marker: PhantomData,
156        })
157    }
158
159    /// Validate args.
160    pub fn can_implement<I: IndexElement>(
161        &self,
162        args: &EmbeddingBagMaxArgs<'_, T, I>,
163    ) -> Result<()> {
164        if args.weight.shape[0] != self.desc.num_embeddings
165            || args.weight.shape[1] != self.desc.embedding_dim
166        {
167            return Err(Error::InvalidProblem(
168                "baracuda-kernels::EmbeddingBagMaxPlan: weight shape mismatch with descriptor",
169            ));
170        }
171        if args.indices.shape[0] != self.desc.total_indices {
172            return Err(Error::InvalidProblem(
173                "baracuda-kernels::EmbeddingBagMaxPlan: indices.shape[0] != total_indices",
174            ));
175        }
176        if args.offsets.shape[0] != self.desc.num_bags {
177            return Err(Error::InvalidProblem(
178                "baracuda-kernels::EmbeddingBagMaxPlan: offsets.shape[0] != num_bags",
179            ));
180        }
181        if args.output.shape[0] != self.desc.num_bags
182            || args.output.shape[1] != self.desc.embedding_dim
183        {
184            return Err(Error::InvalidProblem(
185                "baracuda-kernels::EmbeddingBagMaxPlan: output shape must be [num_bags, embedding_dim]",
186            ));
187        }
188        if args.output_index.shape[0] != self.desc.num_bags
189            || args.output_index.shape[1] != self.desc.embedding_dim
190        {
191            return Err(Error::InvalidProblem(
192                "baracuda-kernels::EmbeddingBagMaxPlan: output_index shape must be [num_bags, embedding_dim]",
193            ));
194        }
195        Ok(())
196    }
197
198    /// Workspace size in bytes (zero).
199    #[inline]
200    pub fn workspace_size(&self) -> usize {
201        0
202    }
203
204    /// Identity of the kernel.
205    #[inline]
206    pub fn sku(&self) -> KernelSku {
207        self.sku
208    }
209
210    /// Numerical guarantees.
211    #[inline]
212    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
213        self.sku.precision_guarantee
214    }
215
216    /// Launch.
217    pub fn run<I: IndexElement>(
218        &self,
219        stream: &Stream,
220        _workspace: Workspace<'_>,
221        args: EmbeddingBagMaxArgs<'_, T, I>,
222    ) -> Result<()> {
223        self.can_implement(&args)?;
224        if self.desc.num_bags == 0 || self.desc.embedding_dim == 0 {
225            return Ok(());
226        }
227        let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
228        let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
229        let off_ptr = args.offsets.data.as_raw().0 as *const c_void;
230        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
231        let out_idx_ptr = args.output_index.data.as_raw().0 as *mut c_void;
232        let stream_ptr = stream.as_raw() as *mut c_void;
233        let padding_idx: i64 = self.desc.padding_idx.unwrap_or(PADDING_DISABLED) as i64;
234
235        let status = match (T::KIND, I::KIND) {
236            (ElementKind::F32, IndexElementKind::I32) => unsafe {
237                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_f32_run(
238                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
239                    self.desc.num_bags, padding_idx,
240                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
241                    core::ptr::null_mut(), 0, stream_ptr,
242                )
243            },
244            (ElementKind::F64, IndexElementKind::I32) => unsafe {
245                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_f64_run(
246                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
247                    self.desc.num_bags, padding_idx,
248                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
249                    core::ptr::null_mut(), 0, stream_ptr,
250                )
251            },
252            (ElementKind::F16, IndexElementKind::I32) => unsafe {
253                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_f16_run(
254                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
255                    self.desc.num_bags, padding_idx,
256                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
257                    core::ptr::null_mut(), 0, stream_ptr,
258                )
259            },
260            (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
261                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_bf16_run(
262                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
263                    self.desc.num_bags, padding_idx,
264                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
265                    core::ptr::null_mut(), 0, stream_ptr,
266                )
267            },
268            (ElementKind::F32, IndexElementKind::I64) => unsafe {
269                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_f32_run(
270                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
271                    self.desc.num_bags, padding_idx,
272                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
273                    core::ptr::null_mut(), 0, stream_ptr,
274                )
275            },
276            (ElementKind::F64, IndexElementKind::I64) => unsafe {
277                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_f64_run(
278                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
279                    self.desc.num_bags, padding_idx,
280                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
281                    core::ptr::null_mut(), 0, stream_ptr,
282                )
283            },
284            (ElementKind::F16, IndexElementKind::I64) => unsafe {
285                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_f16_run(
286                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
287                    self.desc.num_bags, padding_idx,
288                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
289                    core::ptr::null_mut(), 0, stream_ptr,
290                )
291            },
292            (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
293                baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_bf16_run(
294                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
295                    self.desc.num_bags, padding_idx,
296                    weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
297                    core::ptr::null_mut(), 0, stream_ptr,
298                )
299            },
300            _ => {
301                return Err(Error::Unsupported(
302                    "baracuda-kernels::EmbeddingBagMaxPlan::run reached an unimplemented dtype",
303                ));
304            }
305        };
306        map_status(status)
307    }
308}