1use core::ffi::c_void;
32use core::marker::PhantomData;
33
34use baracuda_cutlass::{Error, Result};
35use baracuda_driver::Stream;
36use baracuda_kernels_types::{
37 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
38 PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
39};
40
41#[derive(Copy, Clone, Debug)]
48pub struct ReduceDescriptor<const N: usize> {
49 pub kind: ReduceKind,
51 pub input_shape: [i32; N],
53 pub reduce_axis: u8,
55 pub element: ElementKind,
57 pub correction: i32,
61}
62
63impl<const N: usize> ReduceDescriptor<N> {
64 pub fn output_shape(&self) -> [i32; N] {
66 let mut out = self.input_shape;
67 out[self.reduce_axis as usize] = 1;
68 out
69 }
70}
71
72pub struct ReduceArgs<'a, T: Element, const N: usize> {
78 pub x: TensorRef<'a, T, N>,
80 pub y: TensorMut<'a, T, N>,
82}
83
84pub struct ReducePlan<T: Element, const N: usize> {
90 desc: ReduceDescriptor<N>,
91 sku: KernelSku,
92 _marker: PhantomData<T>,
93}
94
95impl<T: Element, const N: usize> ReducePlan<T, N> {
96 pub fn select(
98 _stream: &Stream,
99 desc: &ReduceDescriptor<N>,
100 _pref: PlanPreference,
101 ) -> Result<Self> {
102 if desc.element != T::KIND {
103 return Err(Error::Unsupported(
104 "baracuda-kernels::ReducePlan: descriptor element != type parameter T",
105 ));
106 }
107 if (desc.reduce_axis as usize) >= N {
108 return Err(Error::InvalidProblem(
109 "baracuda-kernels::ReducePlan: reduce_axis must be < rank",
110 ));
111 }
112 for &d in desc.input_shape.iter() {
113 if d < 0 {
114 return Err(Error::InvalidProblem(
115 "baracuda-kernels::ReducePlan: input_shape dims must be non-negative",
116 ));
117 }
118 }
119
120 let kind_in_scope = matches!(
128 desc.kind,
129 ReduceKind::Sum
130 | ReduceKind::Mean
131 | ReduceKind::Max
132 | ReduceKind::Min
133 | ReduceKind::Prod
134 | ReduceKind::Norm2
135 | ReduceKind::LogSumExp
136 | ReduceKind::Var
137 | ReduceKind::Std
138 );
139 let dtype_in_scope = matches!(
140 T::KIND,
141 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
142 );
143 let supported = kind_in_scope && dtype_in_scope;
144 if !supported {
145 return Err(Error::Unsupported(
146 "baracuda-kernels::ReducePlan: supported matrix is \
147 {Sum, Mean, Max, Min, Prod, Norm2, LogSumExp, Var, Std} × \
148 {f32, f16, bf16, f64}; other (kind, dtype) pairs land \
149 in later fanout (Argmax/Argmin via ArgReducePlan; trace \
150 via TracePlan)",
151 ));
152 }
153
154 let precision_guarantee = PrecisionGuarantee {
158 math_precision: MathPrecision::F32,
159 accumulator: ElementKind::F32,
160 bit_stable_on_same_hardware: true,
161 deterministic: true,
162 };
163 let sku = KernelSku {
164 category: OpCategory::Reduction,
165 op: desc.kind as u16,
166 element: T::KIND,
167 aux_element: None,
168 layout: None,
169 epilogue: None,
170 arch: ArchSku::Sm80,
171 backend: BackendKind::Bespoke,
172 precision_guarantee,
173 };
174 Ok(Self {
175 desc: *desc,
176 sku,
177 _marker: PhantomData,
178 })
179 }
180
181 pub fn can_implement(&self, args: &ReduceArgs<'_, T, N>) -> Result<()> {
183 if args.x.shape != self.desc.input_shape {
184 return Err(Error::InvalidProblem(
185 "baracuda-kernels::ReducePlan: X shape mismatch with descriptor input_shape",
186 ));
187 }
188 let expected_out = self.desc.output_shape();
189 if args.y.shape != expected_out {
190 return Err(Error::InvalidProblem(
191 "baracuda-kernels::ReducePlan: Y shape mismatch with derived output shape \
192 (input shape with reduce_axis collapsed to 1)",
193 ));
194 }
195 if N > 8 {
196 return Err(Error::Unsupported(
197 "baracuda-kernels::ReducePlan: tensor rank > 8 not supported",
198 ));
199 }
200 let y_numel = args.y.numel();
201 let x_numel = args.x.numel();
202 let x_len = args.x.data.len() as i64;
203 let y_len = args.y.data.len() as i64;
204 if y_len < y_numel {
205 return Err(Error::BufferTooSmall {
206 needed: y_numel as usize,
207 got: y_len as usize,
208 });
209 }
210 if x_len < x_numel {
211 return Err(Error::BufferTooSmall {
212 needed: x_numel as usize,
213 got: x_len as usize,
214 });
215 }
216 Ok(())
217 }
218
219 #[inline]
221 pub fn workspace_size(&self) -> usize {
222 0
223 }
224
225 #[inline]
227 pub fn sku(&self) -> KernelSku {
228 self.sku
229 }
230
231 #[inline]
233 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
234 self.sku.precision_guarantee
235 }
236
237 pub fn run(
239 &self,
240 stream: &Stream,
241 _workspace: Workspace<'_>,
242 args: ReduceArgs<'_, T, N>,
243 ) -> Result<()> {
244 self.can_implement(&args)?;
245 let output_numel = args.y.numel();
246 if output_numel == 0 {
247 return Ok(());
248 }
249 let x_ptr = args.x.data.as_raw().0 as *const c_void;
250 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
251 let stream_ptr = stream.as_raw() as *mut c_void;
252
253 let output_shape = self.desc.output_shape();
254 let stride_x = args.x.stride;
255 let stride_y = args.y.stride;
256 let rank = N as i32;
257 let reduce_axis = self.desc.reduce_axis as i32;
258 let reduce_extent = self.desc.input_shape[self.desc.reduce_axis as usize];
259 let reduce_stride_x = args.x.stride[self.desc.reduce_axis as usize];
260
261 macro_rules! dispatch {
265 ($sym:ident) => {{
266 unsafe {
267 baracuda_kernels_sys::$sym(
268 output_numel,
269 rank,
270 output_shape.as_ptr(),
271 stride_x.as_ptr(),
272 stride_y.as_ptr(),
273 reduce_axis,
274 reduce_extent,
275 reduce_stride_x,
276 x_ptr,
277 y_ptr,
278 core::ptr::null_mut(),
279 0,
280 stream_ptr,
281 )
282 }
283 }};
284 }
285
286 let status = match (self.desc.kind, T::KIND) {
287 (ReduceKind::Sum, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_sum_f32_run),
289 (ReduceKind::Sum, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_sum_f16_run),
290 (ReduceKind::Sum, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_sum_bf16_run),
291 (ReduceKind::Sum, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_sum_f64_run),
292 (ReduceKind::Mean, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_mean_f32_run),
294 (ReduceKind::Mean, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_mean_f16_run),
295 (ReduceKind::Mean, ElementKind::Bf16) => {
296 dispatch!(baracuda_kernels_reduce_mean_bf16_run)
297 }
298 (ReduceKind::Mean, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_mean_f64_run),
299 (ReduceKind::Max, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_max_f32_run),
301 (ReduceKind::Max, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_max_f16_run),
302 (ReduceKind::Max, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_max_bf16_run),
303 (ReduceKind::Max, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_max_f64_run),
304 (ReduceKind::Min, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_min_f32_run),
306 (ReduceKind::Min, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_min_f16_run),
307 (ReduceKind::Min, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_min_bf16_run),
308 (ReduceKind::Min, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_min_f64_run),
309 (ReduceKind::Prod, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_prod_f32_run),
311 (ReduceKind::Prod, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_prod_f16_run),
312 (ReduceKind::Prod, ElementKind::Bf16) => {
313 dispatch!(baracuda_kernels_reduce_prod_bf16_run)
314 }
315 (ReduceKind::Prod, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_prod_f64_run),
316 (ReduceKind::Norm2, ElementKind::F32) => {
319 dispatch!(baracuda_kernels_reduce_norm2_f32_run)
320 }
321 (ReduceKind::Norm2, ElementKind::F16) => {
322 dispatch!(baracuda_kernels_reduce_norm2_f16_run)
323 }
324 (ReduceKind::Norm2, ElementKind::Bf16) => {
325 dispatch!(baracuda_kernels_reduce_norm2_bf16_run)
326 }
327 (ReduceKind::Norm2, ElementKind::F64) => {
328 dispatch!(baracuda_kernels_reduce_norm2_f64_run)
329 }
330 (ReduceKind::LogSumExp, ElementKind::F32) => {
334 dispatch!(baracuda_kernels_reduce_logsumexp_f32_run)
335 }
336 (ReduceKind::LogSumExp, ElementKind::F16) => {
337 dispatch!(baracuda_kernels_reduce_logsumexp_f16_run)
338 }
339 (ReduceKind::LogSumExp, ElementKind::Bf16) => {
340 dispatch!(baracuda_kernels_reduce_logsumexp_bf16_run)
341 }
342 (ReduceKind::LogSumExp, ElementKind::F64) => {
343 dispatch!(baracuda_kernels_reduce_logsumexp_f64_run)
344 }
345 (ReduceKind::Var, ElementKind::F32) => unsafe {
350 baracuda_kernels_sys::baracuda_kernels_reduce_var_f32_run(
351 output_numel, rank, output_shape.as_ptr(),
352 stride_x.as_ptr(), stride_y.as_ptr(),
353 reduce_axis, reduce_extent, reduce_stride_x,
354 self.desc.correction,
355 x_ptr, y_ptr,
356 core::ptr::null_mut(), 0, stream_ptr,
357 )
358 },
359 (ReduceKind::Var, ElementKind::F16) => unsafe {
360 baracuda_kernels_sys::baracuda_kernels_reduce_var_f16_run(
361 output_numel, rank, output_shape.as_ptr(),
362 stride_x.as_ptr(), stride_y.as_ptr(),
363 reduce_axis, reduce_extent, reduce_stride_x,
364 self.desc.correction,
365 x_ptr, y_ptr,
366 core::ptr::null_mut(), 0, stream_ptr,
367 )
368 },
369 (ReduceKind::Var, ElementKind::Bf16) => unsafe {
370 baracuda_kernels_sys::baracuda_kernels_reduce_var_bf16_run(
371 output_numel, rank, output_shape.as_ptr(),
372 stride_x.as_ptr(), stride_y.as_ptr(),
373 reduce_axis, reduce_extent, reduce_stride_x,
374 self.desc.correction,
375 x_ptr, y_ptr,
376 core::ptr::null_mut(), 0, stream_ptr,
377 )
378 },
379 (ReduceKind::Var, ElementKind::F64) => unsafe {
380 baracuda_kernels_sys::baracuda_kernels_reduce_var_f64_run(
381 output_numel, rank, output_shape.as_ptr(),
382 stride_x.as_ptr(), stride_y.as_ptr(),
383 reduce_axis, reduce_extent, reduce_stride_x,
384 self.desc.correction,
385 x_ptr, y_ptr,
386 core::ptr::null_mut(), 0, stream_ptr,
387 )
388 },
389 (ReduceKind::Std, ElementKind::F32) => unsafe {
390 baracuda_kernels_sys::baracuda_kernels_reduce_std_f32_run(
391 output_numel, rank, output_shape.as_ptr(),
392 stride_x.as_ptr(), stride_y.as_ptr(),
393 reduce_axis, reduce_extent, reduce_stride_x,
394 self.desc.correction,
395 x_ptr, y_ptr,
396 core::ptr::null_mut(), 0, stream_ptr,
397 )
398 },
399 (ReduceKind::Std, ElementKind::F16) => unsafe {
400 baracuda_kernels_sys::baracuda_kernels_reduce_std_f16_run(
401 output_numel, rank, output_shape.as_ptr(),
402 stride_x.as_ptr(), stride_y.as_ptr(),
403 reduce_axis, reduce_extent, reduce_stride_x,
404 self.desc.correction,
405 x_ptr, y_ptr,
406 core::ptr::null_mut(), 0, stream_ptr,
407 )
408 },
409 (ReduceKind::Std, ElementKind::Bf16) => unsafe {
410 baracuda_kernels_sys::baracuda_kernels_reduce_std_bf16_run(
411 output_numel, rank, output_shape.as_ptr(),
412 stride_x.as_ptr(), stride_y.as_ptr(),
413 reduce_axis, reduce_extent, reduce_stride_x,
414 self.desc.correction,
415 x_ptr, y_ptr,
416 core::ptr::null_mut(), 0, stream_ptr,
417 )
418 },
419 (ReduceKind::Std, ElementKind::F64) => unsafe {
420 baracuda_kernels_sys::baracuda_kernels_reduce_std_f64_run(
421 output_numel, rank, output_shape.as_ptr(),
422 stride_x.as_ptr(), stride_y.as_ptr(),
423 reduce_axis, reduce_extent, reduce_stride_x,
424 self.desc.correction,
425 x_ptr, y_ptr,
426 core::ptr::null_mut(), 0, stream_ptr,
427 )
428 },
429 _ => {
430 return Err(Error::Unsupported(
431 "baracuda-kernels::ReducePlan::run: this (kind, dtype) cell is not yet wired",
432 ));
433 }
434 };
435 map_status(status)
436 }
437}
438
439fn map_status(code: i32) -> Result<()> {
440 match code {
441 0 => Ok(()),
442 1 => Err(Error::MisalignedOperand),
443 2 => Err(Error::InvalidProblem(
444 "baracuda-kernels-sys reported invalid problem",
445 )),
446 3 => Err(Error::Unsupported(
447 "baracuda-kernels-sys reported unsupported configuration",
448 )),
449 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
450 n => Err(Error::CutlassInternal(n)),
451 }
452}