1use core::ffi::c_void;
25use core::marker::PhantomData;
26
27use baracuda_cutlass::{Error, Result};
28use baracuda_driver::Stream;
29use baracuda_kernels_types::{
30 ArchSku, BackendKind, Element, ElementKind, EmbeddingKind, IndexElement, IndexElementKind,
31 KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
32 TensorRef, Workspace,
33};
34
35use crate::indexing::gather::map_status;
36
37use super::PADDING_DISABLED;
38
39#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
47pub enum EmbeddingBagMode {
48 Sum,
50 Mean,
53}
54
55impl EmbeddingBagMode {
56 #[inline]
58 pub(crate) fn ffi_tag(self) -> i32 {
59 match self {
60 EmbeddingBagMode::Sum => 0,
61 EmbeddingBagMode::Mean => 1,
62 }
63 }
64
65 #[inline]
68 fn kind(self) -> EmbeddingKind {
69 match self {
70 EmbeddingBagMode::Sum => EmbeddingKind::EmbeddingBagSum,
71 EmbeddingBagMode::Mean => EmbeddingKind::EmbeddingBagMean,
72 }
73 }
74}
75
76#[derive(Copy, Clone, Debug)]
78pub struct EmbeddingBagDescriptor {
79 pub num_embeddings: i32,
81 pub embedding_dim: i32,
83 pub num_bags: i32,
85 pub total_indices: i32,
87 pub mode: EmbeddingBagMode,
89 pub padding_idx: Option<i32>,
93 pub element: ElementKind,
95}
96
97pub struct EmbeddingBagArgs<'a, T: Element, I: IndexElement = i32> {
102 pub weight: TensorRef<'a, T, 2>,
104 pub indices: TensorRef<'a, I, 1>,
107 pub offsets: TensorRef<'a, i32, 1>,
111 pub output: TensorMut<'a, T, 2>,
113}
114
115pub struct EmbeddingBagPlan<T: Element> {
150 desc: EmbeddingBagDescriptor,
151 sku: KernelSku,
152 _marker: PhantomData<T>,
153}
154
155impl<T: Element> EmbeddingBagPlan<T> {
156 pub fn select(
158 _stream: &Stream,
159 desc: &EmbeddingBagDescriptor,
160 _pref: PlanPreference,
161 ) -> Result<Self> {
162 if desc.element != T::KIND {
163 return Err(Error::Unsupported(
164 "baracuda-kernels::EmbeddingBagPlan: descriptor element != type parameter T",
165 ));
166 }
167 if desc.num_embeddings < 0
168 || desc.embedding_dim < 0
169 || desc.num_bags < 0
170 || desc.total_indices < 0
171 {
172 return Err(Error::InvalidProblem(
173 "baracuda-kernels::EmbeddingBagPlan: num_embeddings / embedding_dim / num_bags / \
174 total_indices must be non-negative",
175 ));
176 }
177 let supported = matches!(
178 T::KIND,
179 ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
180 );
181 if !supported {
182 return Err(Error::Unsupported(
183 "baracuda-kernels::EmbeddingBagPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
184 ));
185 }
186 let precision_guarantee = PrecisionGuarantee {
187 math_precision: if T::KIND == ElementKind::F64 {
188 MathPrecision::F64
189 } else {
190 MathPrecision::F32
191 },
192 accumulator: if T::KIND == ElementKind::F64 {
193 ElementKind::F64
194 } else {
195 ElementKind::F32
196 },
197 bit_stable_on_same_hardware: true,
199 deterministic: true,
200 };
201 let sku = KernelSku {
202 category: OpCategory::Embedding,
203 op: desc.mode.kind() as u16,
204 element: T::KIND,
205 aux_element: Some(ElementKind::I32),
206 layout: None,
207 epilogue: None,
208 arch: ArchSku::Sm80,
209 backend: BackendKind::Bespoke,
210 precision_guarantee,
211 };
212 Ok(Self {
213 desc: *desc,
214 sku,
215 _marker: PhantomData,
216 })
217 }
218
219 pub fn can_implement<I: IndexElement>(&self, args: &EmbeddingBagArgs<'_, T, I>) -> Result<()> {
221 if args.weight.shape[0] != self.desc.num_embeddings
222 || args.weight.shape[1] != self.desc.embedding_dim
223 {
224 return Err(Error::InvalidProblem(
225 "baracuda-kernels::EmbeddingBagPlan: weight shape mismatch with descriptor",
226 ));
227 }
228 if args.indices.shape[0] != self.desc.total_indices {
229 return Err(Error::InvalidProblem(
230 "baracuda-kernels::EmbeddingBagPlan: indices.shape[0] != total_indices",
231 ));
232 }
233 if args.offsets.shape[0] != self.desc.num_bags {
234 return Err(Error::InvalidProblem(
235 "baracuda-kernels::EmbeddingBagPlan: offsets.shape[0] != num_bags",
236 ));
237 }
238 if args.output.shape[0] != self.desc.num_bags
239 || args.output.shape[1] != self.desc.embedding_dim
240 {
241 return Err(Error::InvalidProblem(
242 "baracuda-kernels::EmbeddingBagPlan: output shape must be [num_bags, embedding_dim]",
243 ));
244 }
245 let weight_len = args.weight.data.len() as i64;
246 let idx_len = args.indices.data.len() as i64;
247 let off_len = args.offsets.data.len() as i64;
248 let out_len = args.output.data.len() as i64;
249 let weight_numel = args.weight.numel();
250 let idx_numel = args.indices.numel();
251 let off_numel = args.offsets.numel();
252 let out_numel = args.output.numel();
253 if weight_len < weight_numel {
254 return Err(Error::BufferTooSmall {
255 needed: weight_numel as usize,
256 got: weight_len as usize,
257 });
258 }
259 if idx_len < idx_numel {
260 return Err(Error::BufferTooSmall {
261 needed: idx_numel as usize,
262 got: idx_len as usize,
263 });
264 }
265 if off_len < off_numel {
266 return Err(Error::BufferTooSmall {
267 needed: off_numel as usize,
268 got: off_len as usize,
269 });
270 }
271 if out_len < out_numel {
272 return Err(Error::BufferTooSmall {
273 needed: out_numel as usize,
274 got: out_len as usize,
275 });
276 }
277 Ok(())
278 }
279
280 #[inline]
282 pub fn workspace_size(&self) -> usize {
283 0
284 }
285
286 #[inline]
288 pub fn sku(&self) -> KernelSku {
289 self.sku
290 }
291
292 #[inline]
294 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
295 self.sku.precision_guarantee
296 }
297
298 pub fn run<I: IndexElement>(
302 &self,
303 stream: &Stream,
304 _workspace: Workspace<'_>,
305 args: EmbeddingBagArgs<'_, T, I>,
306 ) -> Result<()> {
307 self.can_implement(&args)?;
308 if self.desc.num_bags == 0 || self.desc.embedding_dim == 0 {
309 return Ok(());
310 }
311 let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
312 let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
313 let off_ptr = args.offsets.data.as_raw().0 as *const c_void;
314 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
315 let stream_ptr = stream.as_raw() as *mut c_void;
316 let padding_idx: i64 = self.desc.padding_idx.unwrap_or(PADDING_DISABLED) as i64;
318 let mode = self.desc.mode.ffi_tag();
319
320 let status = match (T::KIND, I::KIND) {
321 (ElementKind::F32, IndexElementKind::I32) => unsafe {
322 baracuda_kernels_sys::baracuda_kernels_embedding_bag_f32_run(
323 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
324 self.desc.num_bags, mode, padding_idx,
325 weight_ptr, idx_ptr, off_ptr, out_ptr,
326 core::ptr::null_mut(), 0, stream_ptr,
327 )
328 },
329 (ElementKind::F64, IndexElementKind::I32) => unsafe {
330 baracuda_kernels_sys::baracuda_kernels_embedding_bag_f64_run(
331 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
332 self.desc.num_bags, mode, padding_idx,
333 weight_ptr, idx_ptr, off_ptr, out_ptr,
334 core::ptr::null_mut(), 0, stream_ptr,
335 )
336 },
337 (ElementKind::F16, IndexElementKind::I32) => unsafe {
338 baracuda_kernels_sys::baracuda_kernels_embedding_bag_f16_run(
339 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
340 self.desc.num_bags, mode, padding_idx,
341 weight_ptr, idx_ptr, off_ptr, out_ptr,
342 core::ptr::null_mut(), 0, stream_ptr,
343 )
344 },
345 (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
346 baracuda_kernels_sys::baracuda_kernels_embedding_bag_bf16_run(
347 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
348 self.desc.num_bags, mode, padding_idx,
349 weight_ptr, idx_ptr, off_ptr, out_ptr,
350 core::ptr::null_mut(), 0, stream_ptr,
351 )
352 },
353 (ElementKind::F32, IndexElementKind::I64) => unsafe {
354 baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_f32_run(
355 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
356 self.desc.num_bags, mode, padding_idx,
357 weight_ptr, idx_ptr, off_ptr, out_ptr,
358 core::ptr::null_mut(), 0, stream_ptr,
359 )
360 },
361 (ElementKind::F64, IndexElementKind::I64) => unsafe {
362 baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_f64_run(
363 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
364 self.desc.num_bags, mode, padding_idx,
365 weight_ptr, idx_ptr, off_ptr, out_ptr,
366 core::ptr::null_mut(), 0, stream_ptr,
367 )
368 },
369 (ElementKind::F16, IndexElementKind::I64) => unsafe {
370 baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_f16_run(
371 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
372 self.desc.num_bags, mode, padding_idx,
373 weight_ptr, idx_ptr, off_ptr, out_ptr,
374 core::ptr::null_mut(), 0, stream_ptr,
375 )
376 },
377 (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
378 baracuda_kernels_sys::baracuda_kernels_embedding_bag_i64idx_bf16_run(
379 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
380 self.desc.num_bags, mode, padding_idx,
381 weight_ptr, idx_ptr, off_ptr, out_ptr,
382 core::ptr::null_mut(), 0, stream_ptr,
383 )
384 },
385 _ => {
386 return Err(Error::Unsupported(
387 "baracuda-kernels::EmbeddingBagPlan::run reached an unimplemented dtype \
388 — select() should have caught this",
389 ));
390 }
391 };
392 map_status(status)
393 }
394}