1use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
23};
24
25#[derive(Copy, Clone, Debug)]
36pub struct WhereDescriptor<const N: usize> {
37 pub shape: [i32; N],
39 pub element: ElementKind,
41}
42
43pub struct WhereArgs<'a, T: Element, const N: usize> {
49 pub cond: TensorRef<'a, u8, N>,
51 pub a: TensorRef<'a, T, N>,
53 pub b: TensorRef<'a, T, N>,
55 pub y: TensorMut<'a, T, N>,
57}
58
59pub struct WherePlan<T: Element, const N: usize> {
64 desc: WhereDescriptor<N>,
65 sku: KernelSku,
66 _marker: PhantomData<T>,
67}
68
69impl<T: Element, const N: usize> WherePlan<T, N> {
70 pub fn select(
73 _stream: &Stream,
74 desc: &WhereDescriptor<N>,
75 _pref: PlanPreference,
76 ) -> Result<Self> {
77 if desc.element != T::KIND {
78 return Err(Error::Unsupported(
79 "baracuda-kernels::WherePlan: descriptor element != type parameter T",
80 ));
81 }
82 for &d in desc.shape.iter() {
83 if d < 0 {
84 return Err(Error::InvalidProblem(
85 "baracuda-kernels::WherePlan: shape dims must be non-negative",
86 ));
87 }
88 }
89
90 let supported = matches!(
92 T::KIND,
93 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
94 );
95 if !supported {
96 return Err(Error::Unsupported(
97 "baracuda-kernels::WherePlan: value dtype must be one of \
98 {F32, F16, Bf16, F64}",
99 ));
100 }
101
102 let (math_precision, accumulator) = match T::KIND {
107 ElementKind::F16 => (MathPrecision::F16, ElementKind::F16),
108 ElementKind::Bf16 => (MathPrecision::Bf16, ElementKind::Bf16),
109 ElementKind::F64 => (MathPrecision::F64, ElementKind::F64),
110 _ => (MathPrecision::F32, ElementKind::F32),
111 };
112 let precision_guarantee = PrecisionGuarantee {
113 math_precision,
114 accumulator,
115 bit_stable_on_same_hardware: true,
116 deterministic: true,
117 };
118 let sku = KernelSku {
119 category: OpCategory::TernaryElementwise,
120 op: 4,
126 element: T::KIND,
127 aux_element: None,
131 layout: None,
132 epilogue: None,
133 arch: ArchSku::Sm80,
134 backend: BackendKind::Bespoke,
135 precision_guarantee,
136 };
137 Ok(Self {
138 desc: *desc,
139 sku,
140 _marker: PhantomData,
141 })
142 }
143
144 pub fn can_implement(&self, args: &WhereArgs<'_, T, N>) -> Result<()> {
146 if args.y.shape != self.desc.shape {
147 return Err(Error::InvalidProblem(
148 "baracuda-kernels::WherePlan: Y shape mismatch with descriptor",
149 ));
150 }
151
152 for d in 0..N {
154 let y_dim = self.desc.shape[d];
155 let checks = [
156 (args.cond.shape[d], args.cond.stride[d]),
157 (args.a.shape[d], args.a.stride[d]),
158 (args.b.shape[d], args.b.stride[d]),
159 ];
160 for (op_dim, op_stride) in checks {
161 if op_dim != y_dim && !(op_dim == 1 && op_stride == 0) {
162 return Err(Error::InvalidProblem(
163 "baracuda-kernels::WherePlan: input axis not broadcast-compatible \
164 with output (require shape[d] == y.shape[d], OR \
165 shape[d] == 1 AND stride[d] == 0)",
166 ));
167 }
168 }
169 }
170
171 if N > 8 {
172 return Err(Error::Unsupported(
173 "baracuda-kernels::WherePlan: tensor rank > 8 not supported",
174 ));
175 }
176
177 let y_numel = args.y.numel();
178 let cond_numel = args.cond.numel();
179 let a_numel = args.a.numel();
180 let b_numel = args.b.numel();
181 let cond_len = args.cond.data.len() as i64;
182 let a_len = args.a.data.len() as i64;
183 let b_len = args.b.data.len() as i64;
184 let y_len = args.y.data.len() as i64;
185 if y_len < y_numel {
186 return Err(Error::BufferTooSmall {
187 needed: y_numel as usize,
188 got: y_len as usize,
189 });
190 }
191 if cond_len < cond_numel {
192 return Err(Error::BufferTooSmall {
193 needed: cond_numel as usize,
194 got: cond_len as usize,
195 });
196 }
197 if a_len < a_numel {
198 return Err(Error::BufferTooSmall {
199 needed: a_numel as usize,
200 got: a_len as usize,
201 });
202 }
203 if b_len < b_numel {
204 return Err(Error::BufferTooSmall {
205 needed: b_numel as usize,
206 got: b_len as usize,
207 });
208 }
209 Ok(())
210 }
211
212 #[inline]
214 pub fn workspace_size(&self) -> usize {
215 0
216 }
217
218 #[inline]
220 pub fn sku(&self) -> KernelSku {
221 self.sku
222 }
223
224 #[inline]
226 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
227 self.sku.precision_guarantee
228 }
229
230 pub fn run(
232 &self,
233 stream: &Stream,
234 _workspace: Workspace<'_>,
235 args: WhereArgs<'_, T, N>,
236 ) -> Result<()> {
237 self.can_implement(&args)?;
238 let numel = args.y.numel();
239 if numel == 0 {
240 return Ok(());
241 }
242 let cond_ptr = args.cond.data.as_raw().0 as *const c_void;
243 let a_ptr = args.a.data.as_raw().0 as *const c_void;
244 let b_ptr = args.b.data.as_raw().0 as *const c_void;
245 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
246 let stream_ptr = stream.as_raw() as *mut c_void;
247
248 let all_contig_same_shape = args.cond.shape == args.y.shape
249 && args.a.shape == args.y.shape
250 && args.b.shape == args.y.shape
251 && args.cond.is_contiguous()
252 && args.a.is_contiguous()
253 && args.b.is_contiguous()
254 && args.y.is_contiguous();
255
256 if !all_contig_same_shape {
257 return self.run_strided(
258 stream_ptr, cond_ptr, a_ptr, b_ptr, y_ptr, numel, &args,
259 );
260 }
261
262 let status = match T::KIND {
263 ElementKind::F32 => unsafe {
264 baracuda_kernels_sys::baracuda_kernels_where_f32_run(
265 numel,
266 cond_ptr,
267 a_ptr,
268 b_ptr,
269 y_ptr,
270 core::ptr::null_mut(),
271 0,
272 stream_ptr,
273 )
274 },
275 ElementKind::F16 => unsafe {
276 baracuda_kernels_sys::baracuda_kernels_where_f16_run(
277 numel,
278 cond_ptr,
279 a_ptr,
280 b_ptr,
281 y_ptr,
282 core::ptr::null_mut(),
283 0,
284 stream_ptr,
285 )
286 },
287 ElementKind::Bf16 => unsafe {
288 baracuda_kernels_sys::baracuda_kernels_where_bf16_run(
289 numel,
290 cond_ptr,
291 a_ptr,
292 b_ptr,
293 y_ptr,
294 core::ptr::null_mut(),
295 0,
296 stream_ptr,
297 )
298 },
299 ElementKind::F64 => unsafe {
300 baracuda_kernels_sys::baracuda_kernels_where_f64_run(
301 numel,
302 cond_ptr,
303 a_ptr,
304 b_ptr,
305 y_ptr,
306 core::ptr::null_mut(),
307 0,
308 stream_ptr,
309 )
310 },
311 _ => {
312 return Err(Error::Unsupported(
313 "baracuda-kernels::WherePlan::run reached an unimplemented dtype \
314 — select() should have caught this",
315 ));
316 }
317 };
318 map_status(status)
319 }
320
321 fn run_strided(
323 &self,
324 stream_ptr: *mut c_void,
325 cond_ptr: *const c_void,
326 a_ptr: *const c_void,
327 b_ptr: *const c_void,
328 y_ptr: *mut c_void,
329 numel: i64,
330 args: &WhereArgs<'_, T, N>,
331 ) -> Result<()> {
332 let shape = args.y.shape;
333 let stride_cond = args.cond.stride;
334 let stride_a = args.a.stride;
335 let stride_b = args.b.stride;
336 let stride_y = args.y.stride;
337 let rank = N as i32;
338
339 let status = match T::KIND {
340 ElementKind::F32 => unsafe {
341 baracuda_kernels_sys::baracuda_kernels_where_f32_strided_run(
342 numel,
343 rank,
344 shape.as_ptr(),
345 stride_cond.as_ptr(),
346 stride_a.as_ptr(),
347 stride_b.as_ptr(),
348 stride_y.as_ptr(),
349 cond_ptr,
350 a_ptr,
351 b_ptr,
352 y_ptr,
353 core::ptr::null_mut(),
354 0,
355 stream_ptr,
356 )
357 },
358 ElementKind::F16 => unsafe {
359 baracuda_kernels_sys::baracuda_kernels_where_f16_strided_run(
360 numel,
361 rank,
362 shape.as_ptr(),
363 stride_cond.as_ptr(),
364 stride_a.as_ptr(),
365 stride_b.as_ptr(),
366 stride_y.as_ptr(),
367 cond_ptr,
368 a_ptr,
369 b_ptr,
370 y_ptr,
371 core::ptr::null_mut(),
372 0,
373 stream_ptr,
374 )
375 },
376 ElementKind::Bf16 => unsafe {
377 baracuda_kernels_sys::baracuda_kernels_where_bf16_strided_run(
378 numel,
379 rank,
380 shape.as_ptr(),
381 stride_cond.as_ptr(),
382 stride_a.as_ptr(),
383 stride_b.as_ptr(),
384 stride_y.as_ptr(),
385 cond_ptr,
386 a_ptr,
387 b_ptr,
388 y_ptr,
389 core::ptr::null_mut(),
390 0,
391 stream_ptr,
392 )
393 },
394 ElementKind::F64 => unsafe {
395 baracuda_kernels_sys::baracuda_kernels_where_f64_strided_run(
396 numel,
397 rank,
398 shape.as_ptr(),
399 stride_cond.as_ptr(),
400 stride_a.as_ptr(),
401 stride_b.as_ptr(),
402 stride_y.as_ptr(),
403 cond_ptr,
404 a_ptr,
405 b_ptr,
406 y_ptr,
407 core::ptr::null_mut(),
408 0,
409 stream_ptr,
410 )
411 },
412 _ => {
413 return Err(Error::Unsupported(
414 "baracuda-kernels::WherePlan: strided path reached unimplemented dtype \
415 — select() should have caught this",
416 ));
417 }
418 };
419 map_status(status)
420 }
421}
422
423fn map_status(code: i32) -> Result<()> {
424 match code {
425 0 => Ok(()),
426 1 => Err(Error::MisalignedOperand),
427 2 => Err(Error::InvalidProblem(
428 "baracuda-kernels-sys reported invalid problem",
429 )),
430 3 => Err(Error::Unsupported(
431 "baracuda-kernels-sys reported unsupported configuration",
432 )),
433 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
434 n => Err(Error::CutlassInternal(n)),
435 }
436}