Skip to main content

baracuda_kernels/embedding/
embedding_bag.rs

1//! `embedding_bag` FW plan — Category M.
2//!
3//! Per-bag reduction over a flat `indices` array partitioned by the
4//! `offsets` table. For each bag `b`:
5//! - `start = offsets[b]`
6//! - `end   = offsets[b + 1] if b + 1 < num_bags else total_indices`
7//! - `out[b, :] = reduce(weight[indices[k], :] for k in start..end)`
8//!
9//! Reducer is selected by [`EmbeddingBagMode`]:
10//! - `Sum`: pure addition.
11//! - `Mean`: sum, then divide by the post-skip bag size. If every
12//!   entry in the bag was padding / OOB the row is emitted as zero
13//!   (no divide by zero).
14//!
15//! `padding_idx` rows are dropped from the reduction (also excluded
16//! from the Mean divisor). Empty bags (`start == end`) emit zero rows.
17//!
18//! `Max`-mode is deferred — it needs per-feature argmax tracking on FW
19//! so the BW can scatter into the contributing rows.
20//!
21//! Trailblazer dtype coverage: `f32, f64, f16, bf16`. f16 / bf16
22//! accumulate in f32 internally before casting back to T at write.
23
24use core::ffi::c_void;
25use core::marker::PhantomData;
26
27use baracuda_cutlass::{Error, Result};
28use baracuda_driver::Stream;
29use baracuda_kernels_types::{
30    ArchSku, BackendKind, Element, ElementKind, EmbeddingKind, IndexElement, IndexElementKind,
31    KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
32    TensorRef, Workspace,
33};
34
35use crate::indexing::gather::map_status;
36
37use super::PADDING_DISABLED;
38
39/// Reduction mode for `embedding_bag`.
40///
41/// **Intentionally NOT `#[non_exhaustive]`** — Sum / Mean is the
42/// closed set for the EmbeddingBag op family today; Max mode lives on
43/// its own [`super::EmbeddingBagMaxPlan`] plan (separate FFI surface)
44/// rather than as a third variant here. New variants would be a
45/// deliberate breaking-change event.
46#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
47pub enum EmbeddingBagMode {
48    /// `out[b, :] = Σ weight[indices[k], :]` for k in bag b.
49    Sum,
50    /// `out[b, :] = (Σ weight[indices[k], :]) / bag_size(b)` where
51    /// `bag_size` counts only non-padded / in-bounds indices.
52    Mean,
53}
54
55impl EmbeddingBagMode {
56    /// FFI tag matching `kModeSum` / `kModeMean` in the .cuh header.
57    #[inline]
58    pub(crate) fn ffi_tag(self) -> i32 {
59        match self {
60            EmbeddingBagMode::Sum => 0,
61            EmbeddingBagMode::Mean => 1,
62        }
63    }
64
65    /// Maps to the corresponding [`EmbeddingKind`] discriminant for SKU
66    /// tagging.
67    #[inline]
68    fn kind(self) -> EmbeddingKind {
69        match self {
70            EmbeddingBagMode::Sum => EmbeddingKind::EmbeddingBagSum,
71            EmbeddingBagMode::Mean => EmbeddingKind::EmbeddingBagMean,
72        }
73    }
74}
75
76/// Descriptor for an `embedding_bag` op.
77#[derive(Copy, Clone, Debug)]
78pub struct EmbeddingBagDescriptor {
79    /// Vocabulary size — extent of `weight` along axis 0.
80    pub num_embeddings: i32,
81    /// Embedding dimension — extent of `weight` along axis 1.
82    pub embedding_dim: i32,
83    /// Number of bags — extent of `offsets` and of `output` along axis 0.
84    pub num_bags: i32,
85    /// Total flat-index length — extent of `indices` along axis 0.
86    pub total_indices: i32,
87    /// Reduction mode.
88    pub mode: EmbeddingBagMode,
89    /// Optional padding index. When `Some(p)`, indices equal to `p` (or
90    /// negative / OOB) are dropped from the bag's reduction. Excluded
91    /// from the Mean divisor.
92    pub padding_idx: Option<i32>,
93    /// Value element type.
94    pub element: ElementKind,
95}
96
97/// Args bundle for an `embedding_bag` launch.
98///
99/// Phase 11.5: `I: IndexElement` generic (`i32` or `i64`) for the
100/// index tensor. `offsets` stays i32 — bag boundaries fit comfortably.
101pub struct EmbeddingBagArgs<'a, T: Element, I: IndexElement = i32> {
102    /// Weight matrix `[V, D]`. Row-major contiguous.
103    pub weight: TensorRef<'a, T, 2>,
104    /// Flat index tensor `[total_indices]`. `i32` (legacy) or `i64`
105    /// (PyTorch default).
106    pub indices: TensorRef<'a, I, 1>,
107    /// Per-bag start offset table `[num_bags]`, i32. `offsets[0]` should
108    /// be 0; `offsets[b+1] - offsets[b]` is bag `b`'s size; the last
109    /// bag's implicit end is `total_indices`.
110    pub offsets: TensorRef<'a, i32, 1>,
111    /// Output `[num_bags, D]`. Row-major contiguous.
112    pub output: TensorMut<'a, T, 2>,
113}
114
115/// `embedding_bag` plan.
116///
117/// Per-bag reduction over a flat `indices` array partitioned by the
118/// `offsets` table (PyTorch `torch.nn.functional.embedding_bag`). For
119/// each bag `b`: `out[b, :] = reduce(weight[indices[k], :])` for
120/// `k ∈ offsets[b]..offsets[b+1]` (last bag's end is `total_indices`).
121///
122/// **When to use**: forward pooled embedding lookup (e.g. continuous
123/// bag-of-words). Pair with
124/// [`EmbeddingBagBackwardPlan`](crate::EmbeddingBagBackwardPlan) for
125/// autograd. For non-pooled lookup, use
126/// [`EmbeddingPlan`](crate::EmbeddingPlan).
127///
128/// **Dtypes**: weight / output `{f32, f64, f16, bf16}`; indices and
129/// offsets always `i32`. f16 / bf16 accumulate in f32 internally
130/// before the cast back to T at write.
131///
132/// **Shape limits**: `weight` is `[V, D]`, `indices` is
133/// `[total_indices]`, `offsets` is `[num_bags]`, `output` is
134/// `[num_bags, D]`.
135///
136/// **Workspace**: none.
137///
138/// **Precision guarantee**: deterministic, bit-stable on same
139/// hardware. No atomics on FW.
140///
141/// **Index policy**: `padding_idx` (or negative / OOB) indices are
142/// dropped from the bag's reduction; excluded from the Mean divisor.
143/// Empty bags (`start == end`) emit zero rows; an all-padding bag in
144/// Mean mode also emits zero (no divide-by-zero).
145///
146/// **Known limitations**: `Max` mode is deferred — it needs
147/// per-feature argmax tracking on FW so the BW can scatter into the
148/// contributing rows.
149pub struct EmbeddingBagPlan<T: Element> {
150    desc: EmbeddingBagDescriptor,
151    sku: KernelSku,
152    _marker: PhantomData<T>,
153}
154
155impl<T: Element> EmbeddingBagPlan<T> {
156    /// Pick a kernel for `desc`.
157    pub fn select(
158        _stream: &Stream,
159        desc: &EmbeddingBagDescriptor,
160        _pref: PlanPreference,
161    ) -> Result<Self> {
162        if desc.element != T::KIND {
163            return Err(Error::Unsupported(
164                "baracuda-kernels::EmbeddingBagPlan: descriptor element != type parameter T",
165            ));
166        }
167        if desc.num_embeddings < 0
168            || desc.embedding_dim < 0
169            || desc.num_bags < 0
170            || desc.total_indices < 0
171        {
172            return Err(Error::InvalidProblem(
173                "baracuda-kernels::EmbeddingBagPlan: num_embeddings / embedding_dim / num_bags / \
174                 total_indices must be non-negative",
175            ));
176        }
177        let supported = matches!(
178            T::KIND,
179            ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
180        );
181        if !supported {
182            return Err(Error::Unsupported(
183                "baracuda-kernels::EmbeddingBagPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
184            ));
185        }
186        let precision_guarantee = PrecisionGuarantee {
187            math_precision: if T::KIND == ElementKind::F64 {
188                MathPrecision::F64
189            } else {
190                MathPrecision::F32
191            },
192            accumulator: if T::KIND == ElementKind::F64 {
193                ElementKind::F64
194            } else {
195                ElementKind::F32
196            },
197            // No atomics on FW — same input → same output bit pattern.
198            bit_stable_on_same_hardware: true,
199            deterministic: true,
200        };
201        let sku = KernelSku {
202            category: OpCategory::Embedding,
203            op: desc.mode.kind() as u16,
204            element: T::KIND,
205            aux_element: Some(ElementKind::I32),
206            layout: None,
207            epilogue: None,
208            arch: ArchSku::Sm80,
209            backend: BackendKind::Bespoke,
210            precision_guarantee,
211        };
212        Ok(Self {
213            desc: *desc,
214            sku,
215            _marker: PhantomData,
216        })
217    }
218
219    /// Validate args.
220    pub fn can_implement<I: IndexElement>(&self, args: &EmbeddingBagArgs<'_, T, I>) -> Result<()> {
221        if args.weight.shape[0] != self.desc.num_embeddings
222            || args.weight.shape[1] != self.desc.embedding_dim
223        {
224            return Err(Error::InvalidProblem(
225                "baracuda-kernels::EmbeddingBagPlan: weight shape mismatch with descriptor",
226            ));
227        }
228        if args.indices.shape[0] != self.desc.total_indices {
229            return Err(Error::InvalidProblem(
230                "baracuda-kernels::EmbeddingBagPlan: indices.shape[0] != total_indices",
231            ));
232        }
233        if args.offsets.shape[0] != self.desc.num_bags {
234            return Err(Error::InvalidProblem(
235                "baracuda-kernels::EmbeddingBagPlan: offsets.shape[0] != num_bags",
236            ));
237        }
238        if args.output.shape[0] != self.desc.num_bags
239            || args.output.shape[1] != self.desc.embedding_dim
240        {
241            return Err(Error::InvalidProblem(
242                "baracuda-kernels::EmbeddingBagPlan: output shape must be [num_bags, embedding_dim]",
243            ));
244        }
245        let weight_len = args.weight.data.len() as i64;
246        let idx_len = args.indices.data.len() as i64;
247        let off_len = args.offsets.data.len() as i64;
248        let out_len = args.output.data.len() as i64;
249        let weight_numel = args.weight.numel();
250        let idx_numel = args.indices.numel();
251        let off_numel = args.offsets.numel();
252        let out_numel = args.output.numel();
253        if weight_len < weight_numel {
254            return Err(Error::BufferTooSmall {
255                needed: weight_numel as usize,
256                got: weight_len as usize,
257            });
258        }
259        if idx_len < idx_numel {
260            return Err(Error::BufferTooSmall {
261                needed: idx_numel as usize,
262                got: idx_len as usize,
263            });
264        }
265        if off_len < off_numel {
266            return Err(Error::BufferTooSmall {
267                needed: off_numel as usize,
268                got: off_len as usize,
269            });
270        }
271        if out_len < out_numel {
272            return Err(Error::BufferTooSmall {
273                needed: out_numel as usize,
274                got: out_len as usize,
275            });
276        }
277        Ok(())
278    }
279
280    /// Workspace size in bytes (zero).
281    #[inline]
282    pub fn workspace_size(&self) -> usize {
283        0
284    }
285
286    /// Identity of the kernel this plan picked.
287    #[inline]
288    pub fn sku(&self) -> KernelSku {
289        self.sku
290    }
291
292    /// Numerical guarantees for this plan's kernel.
293    #[inline]
294    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
295        self.sku.precision_guarantee
296    }
297
298    /// Launch.
299    ///
300    /// Phase 11.5: generic over `I: IndexElement`.
301    pub fn run<I: IndexElement>(
302        &self,
303        stream: &Stream,
304        _workspace: Workspace<'_>,
305        args: EmbeddingBagArgs<'_, T, I>,
306    ) -> Result<()> {
307        self.can_implement(&args)?;
308        if self.desc.num_bags == 0 || self.desc.embedding_dim == 0 {
309            return Ok(());
310        }
311        let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
312        let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
313        let off_ptr = args.offsets.data.as_raw().0 as *const c_void;
314        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
315        let stream_ptr = stream.as_raw() as *mut c_void;
316        // Phase 11.5: padding_idx widens to i64 across FFI.
317        let padding_idx: i64 = self.desc.padding_idx.unwrap_or(PADDING_DISABLED) as i64;
318        let mode = self.desc.mode.ffi_tag();
319
320        let status = match (T::KIND, I::KIND) {
321            (ElementKind::F32, IndexElementKind::I32) => unsafe {
322                baracuda_kernels_sys::baracuda_kernels_embedding_bag_f32_run(
323                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
324                    self.desc.num_bags, mode, padding_idx,
325                    weight_ptr, idx_ptr, off_ptr, out_ptr,
326                    core::ptr::null_mut(), 0, stream_ptr,
327                )
328            },
329            (ElementKind::F64, IndexElementKind::I32) => unsafe {
330                baracuda_kernels_sys::baracuda_kernels_embedding_bag_f64_run(
331                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
332                    self.desc.num_bags, mode, padding_idx,
333                    weight_ptr, idx_ptr, off_ptr, out_ptr,
334                    core::ptr::null_mut(), 0, stream_ptr,
335                )
336            },
337            (ElementKind::F16, IndexElementKind::I32) => unsafe {
338                baracuda_kernels_sys::baracuda_kernels_embedding_bag_f16_run(
339                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
340                    self.desc.num_bags, mode, padding_idx,
341                    weight_ptr, idx_ptr, off_ptr, out_ptr,
342                    core::ptr::null_mut(), 0, stream_ptr,
343                )
344            },
345            (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
346                baracuda_kernels_sys::baracuda_kernels_embedding_bag_bf16_run(
347                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
348                    self.desc.num_bags, mode, padding_idx,
349                    weight_ptr, idx_ptr, off_ptr, out_ptr,
350                    core::ptr::null_mut(), 0, stream_ptr,
351                )
352            },
353            (ElementKind::F32, IndexElementKind::I64) => unsafe {
354                baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_f32_run(
355                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
356                    self.desc.num_bags, mode, padding_idx,
357                    weight_ptr, idx_ptr, off_ptr, out_ptr,
358                    core::ptr::null_mut(), 0, stream_ptr,
359                )
360            },
361            (ElementKind::F64, IndexElementKind::I64) => unsafe {
362                baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_f64_run(
363                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
364                    self.desc.num_bags, mode, padding_idx,
365                    weight_ptr, idx_ptr, off_ptr, out_ptr,
366                    core::ptr::null_mut(), 0, stream_ptr,
367                )
368            },
369            (ElementKind::F16, IndexElementKind::I64) => unsafe {
370                baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_f16_run(
371                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
372                    self.desc.num_bags, mode, padding_idx,
373                    weight_ptr, idx_ptr, off_ptr, out_ptr,
374                    core::ptr::null_mut(), 0, stream_ptr,
375                )
376            },
377            (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
378                baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_bf16_run(
379                    self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
380                    self.desc.num_bags, mode, padding_idx,
381                    weight_ptr, idx_ptr, off_ptr, out_ptr,
382                    core::ptr::null_mut(), 0, stream_ptr,
383                )
384            },
385            _ => {
386                return Err(Error::Unsupported(
387                    "baracuda-kernels::EmbeddingBagPlan::run reached an unimplemented dtype \
388                     — select() should have caught this",
389                ));
390            }
391        };
392        map_status(status)
393    }
394}