1use core::ffi::c_void;
22use core::marker::PhantomData;
23
24use baracuda_cutlass::{Error, Result};
25use baracuda_driver::Stream;
26use baracuda_kernels_types::{
27 ArchSku, BackendKind, Element, ElementKind, EmbeddingKind, IndexElement, IndexElementKind,
28 KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
29 TensorRef, Workspace,
30};
31
32use crate::indexing::gather::map_status;
33
34use super::PADDING_DISABLED;
35
36#[derive(Copy, Clone, Debug)]
38pub struct EmbeddingBagMaxDescriptor {
39 pub num_embeddings: i32,
41 pub embedding_dim: i32,
43 pub num_bags: i32,
46 pub total_indices: i32,
48 pub padding_idx: Option<i32>,
51 pub element: ElementKind,
53}
54
55pub struct EmbeddingBagMaxArgs<'a, T: Element, I: IndexElement = i32> {
57 pub weight: TensorRef<'a, T, 2>,
59 pub indices: TensorRef<'a, I, 1>,
61 pub offsets: TensorRef<'a, i32, 1>,
63 pub output: TensorMut<'a, T, 2>,
65 pub output_index: TensorMut<'a, i32, 2>,
68}
69
70pub struct EmbeddingBagMaxPlan<T: Element> {
91 desc: EmbeddingBagMaxDescriptor,
92 sku: KernelSku,
93 _marker: PhantomData<T>,
94}
95
96impl<T: Element> EmbeddingBagMaxPlan<T> {
97 pub fn select(
99 _stream: &Stream,
100 desc: &EmbeddingBagMaxDescriptor,
101 _pref: PlanPreference,
102 ) -> Result<Self> {
103 if desc.element != T::KIND {
104 return Err(Error::Unsupported(
105 "baracuda-kernels::EmbeddingBagMaxPlan: descriptor element != type parameter T",
106 ));
107 }
108 if desc.num_embeddings < 0
109 || desc.embedding_dim < 0
110 || desc.num_bags < 0
111 || desc.total_indices < 0
112 {
113 return Err(Error::InvalidProblem(
114 "baracuda-kernels::EmbeddingBagMaxPlan: num_embeddings / embedding_dim / \
115 num_bags / total_indices must be non-negative",
116 ));
117 }
118 let supported = matches!(
119 T::KIND,
120 ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
121 );
122 if !supported {
123 return Err(Error::Unsupported(
124 "baracuda-kernels::EmbeddingBagMaxPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
125 ));
126 }
127 let precision_guarantee = PrecisionGuarantee {
128 math_precision: if T::KIND == ElementKind::F64 {
129 MathPrecision::F64
130 } else {
131 MathPrecision::F32
132 },
133 accumulator: if T::KIND == ElementKind::F64 {
134 ElementKind::F64
135 } else {
136 ElementKind::F32
137 },
138 bit_stable_on_same_hardware: true,
139 deterministic: true,
140 };
141 let sku = KernelSku {
142 category: OpCategory::Embedding,
143 op: EmbeddingKind::EmbeddingBagMax as u16,
144 element: T::KIND,
145 aux_element: Some(ElementKind::I32),
146 layout: None,
147 epilogue: None,
148 arch: ArchSku::Sm80,
149 backend: BackendKind::Bespoke,
150 precision_guarantee,
151 };
152 Ok(Self {
153 desc: *desc,
154 sku,
155 _marker: PhantomData,
156 })
157 }
158
159 pub fn can_implement<I: IndexElement>(
161 &self,
162 args: &EmbeddingBagMaxArgs<'_, T, I>,
163 ) -> Result<()> {
164 if args.weight.shape[0] != self.desc.num_embeddings
165 || args.weight.shape[1] != self.desc.embedding_dim
166 {
167 return Err(Error::InvalidProblem(
168 "baracuda-kernels::EmbeddingBagMaxPlan: weight shape mismatch with descriptor",
169 ));
170 }
171 if args.indices.shape[0] != self.desc.total_indices {
172 return Err(Error::InvalidProblem(
173 "baracuda-kernels::EmbeddingBagMaxPlan: indices.shape[0] != total_indices",
174 ));
175 }
176 if args.offsets.shape[0] != self.desc.num_bags {
177 return Err(Error::InvalidProblem(
178 "baracuda-kernels::EmbeddingBagMaxPlan: offsets.shape[0] != num_bags",
179 ));
180 }
181 if args.output.shape[0] != self.desc.num_bags
182 || args.output.shape[1] != self.desc.embedding_dim
183 {
184 return Err(Error::InvalidProblem(
185 "baracuda-kernels::EmbeddingBagMaxPlan: output shape must be [num_bags, embedding_dim]",
186 ));
187 }
188 if args.output_index.shape[0] != self.desc.num_bags
189 || args.output_index.shape[1] != self.desc.embedding_dim
190 {
191 return Err(Error::InvalidProblem(
192 "baracuda-kernels::EmbeddingBagMaxPlan: output_index shape must be [num_bags, embedding_dim]",
193 ));
194 }
195 Ok(())
196 }
197
198 #[inline]
200 pub fn workspace_size(&self) -> usize {
201 0
202 }
203
204 #[inline]
206 pub fn sku(&self) -> KernelSku {
207 self.sku
208 }
209
210 #[inline]
212 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
213 self.sku.precision_guarantee
214 }
215
216 pub fn run<I: IndexElement>(
218 &self,
219 stream: &Stream,
220 _workspace: Workspace<'_>,
221 args: EmbeddingBagMaxArgs<'_, T, I>,
222 ) -> Result<()> {
223 self.can_implement(&args)?;
224 if self.desc.num_bags == 0 || self.desc.embedding_dim == 0 {
225 return Ok(());
226 }
227 let weight_ptr = args.weight.data.as_raw().0 as *const c_void;
228 let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
229 let off_ptr = args.offsets.data.as_raw().0 as *const c_void;
230 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
231 let out_idx_ptr = args.output_index.data.as_raw().0 as *mut c_void;
232 let stream_ptr = stream.as_raw() as *mut c_void;
233 let padding_idx: i64 = self.desc.padding_idx.unwrap_or(PADDING_DISABLED) as i64;
234
235 let status = match (T::KIND, I::KIND) {
236 (ElementKind::F32, IndexElementKind::I32) => unsafe {
237 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_f32_run(
238 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
239 self.desc.num_bags, padding_idx,
240 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
241 core::ptr::null_mut(), 0, stream_ptr,
242 )
243 },
244 (ElementKind::F64, IndexElementKind::I32) => unsafe {
245 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_f64_run(
246 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
247 self.desc.num_bags, padding_idx,
248 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
249 core::ptr::null_mut(), 0, stream_ptr,
250 )
251 },
252 (ElementKind::F16, IndexElementKind::I32) => unsafe {
253 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_f16_run(
254 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
255 self.desc.num_bags, padding_idx,
256 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
257 core::ptr::null_mut(), 0, stream_ptr,
258 )
259 },
260 (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
261 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_bf16_run(
262 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
263 self.desc.num_bags, padding_idx,
264 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
265 core::ptr::null_mut(), 0, stream_ptr,
266 )
267 },
268 (ElementKind::F32, IndexElementKind::I64) => unsafe {
269 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_f32_run(
270 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
271 self.desc.num_bags, padding_idx,
272 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
273 core::ptr::null_mut(), 0, stream_ptr,
274 )
275 },
276 (ElementKind::F64, IndexElementKind::I64) => unsafe {
277 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_f64_run(
278 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
279 self.desc.num_bags, padding_idx,
280 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
281 core::ptr::null_mut(), 0, stream_ptr,
282 )
283 },
284 (ElementKind::F16, IndexElementKind::I64) => unsafe {
285 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_f16_run(
286 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
287 self.desc.num_bags, padding_idx,
288 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
289 core::ptr::null_mut(), 0, stream_ptr,
290 )
291 },
292 (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
293 baracuda_kernels_sys::baracuda_kernels_embedding_bag_max_i64idx_bf16_run(
294 self.desc.total_indices, self.desc.num_embeddings, self.desc.embedding_dim,
295 self.desc.num_bags, padding_idx,
296 weight_ptr, idx_ptr, off_ptr, out_ptr, out_idx_ptr,
297 core::ptr::null_mut(), 0, stream_ptr,
298 )
299 },
300 _ => {
301 return Err(Error::Unsupported(
302 "baracuda-kernels::EmbeddingBagMaxPlan::run reached an unimplemented dtype",
303 ));
304 }
305 };
306 map_status(status)
307 }
308}