1use core::ffi::c_void;
15use core::marker::PhantomData;
16
17use baracuda_cutlass::{Error, Result};
18use baracuda_driver::Stream;
19use baracuda_kernels_types::{
20 ArchSku, BackendKind, Bool, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
21 PlanPreference, PrecisionGuarantee, ReduceKind, TensorMut, TensorRef, Workspace,
22};
23
24#[derive(Copy, Clone, Debug)]
26pub struct BoolReduceDescriptor<const N: usize> {
27 pub kind: ReduceKind,
29 pub input_shape: [i32; N],
31 pub reduce_axis: u8,
33 pub element: ElementKind,
35}
36
37impl<const N: usize> BoolReduceDescriptor<N> {
38 pub fn output_shape(&self) -> [i32; N] {
40 let mut out = self.input_shape;
41 out[self.reduce_axis as usize] = 1;
42 out
43 }
44}
45
46pub struct BoolReduceArgs<'a, T: Element, const N: usize> {
48 pub x: TensorRef<'a, T, N>,
50 pub y: TensorMut<'a, Bool, N>,
52}
53
54pub struct BoolReducePlan<T: Element, const N: usize> {
56 desc: BoolReduceDescriptor<N>,
57 sku: KernelSku,
58 _marker: PhantomData<T>,
59}
60
61impl<T: Element, const N: usize> BoolReducePlan<T, N> {
62 pub fn select(
64 _stream: &Stream,
65 desc: &BoolReduceDescriptor<N>,
66 _pref: PlanPreference,
67 ) -> Result<Self> {
68 if desc.element != T::KIND {
69 return Err(Error::Unsupported(
70 "baracuda-kernels::BoolReducePlan: descriptor element != type parameter T",
71 ));
72 }
73 if (desc.reduce_axis as usize) >= N {
74 return Err(Error::InvalidProblem(
75 "baracuda-kernels::BoolReducePlan: reduce_axis must be < rank",
76 ));
77 }
78 for &d in desc.input_shape.iter() {
79 if d < 0 {
80 return Err(Error::InvalidProblem(
81 "baracuda-kernels::BoolReducePlan: input_shape dims must be non-negative",
82 ));
83 }
84 }
85 let kind_in_scope = matches!(desc.kind, ReduceKind::Any | ReduceKind::All);
86 if !kind_in_scope {
87 return Err(Error::Unsupported(
88 "baracuda-kernels::BoolReducePlan: kind must be Any or All",
89 ));
90 }
91 let dtype_in_scope = matches!(
92 T::KIND,
93 ElementKind::F32
94 | ElementKind::F16
95 | ElementKind::Bf16
96 | ElementKind::F64
97 | ElementKind::I32
98 | ElementKind::I64
99 | ElementKind::Bool
100 );
101 if !dtype_in_scope {
102 return Err(Error::Unsupported(
103 "baracuda-kernels::BoolReducePlan: supported input dtypes are \
104 {f32, f16, bf16, f64, i32, i64, Bool}",
105 ));
106 }
107 let precision_guarantee = PrecisionGuarantee {
110 math_precision: MathPrecision::F32,
111 accumulator: ElementKind::Bool,
112 bit_stable_on_same_hardware: true,
113 deterministic: true,
114 };
115 let sku = KernelSku {
116 category: OpCategory::Reduction,
117 op: desc.kind as u16,
118 element: T::KIND,
119 aux_element: Some(ElementKind::Bool),
123 layout: None,
124 epilogue: None,
125 arch: ArchSku::Sm80,
126 backend: BackendKind::Bespoke,
127 precision_guarantee,
128 };
129 Ok(Self {
130 desc: *desc,
131 sku,
132 _marker: PhantomData,
133 })
134 }
135
136 pub fn can_implement(&self, args: &BoolReduceArgs<'_, T, N>) -> Result<()> {
138 if args.x.shape != self.desc.input_shape {
139 return Err(Error::InvalidProblem(
140 "baracuda-kernels::BoolReducePlan: X shape mismatch with descriptor",
141 ));
142 }
143 let expected_out = self.desc.output_shape();
144 if args.y.shape != expected_out {
145 return Err(Error::InvalidProblem(
146 "baracuda-kernels::BoolReducePlan: Y shape mismatch with derived output \
147 shape (input shape with reduce_axis collapsed to 1)",
148 ));
149 }
150 if N > 8 {
151 return Err(Error::Unsupported(
152 "baracuda-kernels::BoolReducePlan: tensor rank > 8 not supported",
153 ));
154 }
155 let y_numel = args.y.numel();
156 let x_numel = args.x.numel();
157 let x_len = args.x.data.len() as i64;
158 let y_len = args.y.data.len() as i64;
159 if y_len < y_numel {
160 return Err(Error::BufferTooSmall {
161 needed: y_numel as usize,
162 got: y_len as usize,
163 });
164 }
165 if x_len < x_numel {
166 return Err(Error::BufferTooSmall {
167 needed: x_numel as usize,
168 got: x_len as usize,
169 });
170 }
171 Ok(())
172 }
173
174 #[inline]
176 pub fn workspace_size(&self) -> usize {
177 0
178 }
179 #[inline]
181 pub fn sku(&self) -> KernelSku {
182 self.sku
183 }
184 #[inline]
186 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
187 self.sku.precision_guarantee
188 }
189
190 pub fn run(
192 &self,
193 stream: &Stream,
194 _workspace: Workspace<'_>,
195 args: BoolReduceArgs<'_, T, N>,
196 ) -> Result<()> {
197 self.can_implement(&args)?;
198 let output_numel = args.y.numel();
199 if output_numel == 0 {
200 return Ok(());
201 }
202 let x_ptr = args.x.data.as_raw().0 as *const c_void;
203 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
204 let stream_ptr = stream.as_raw() as *mut c_void;
205
206 let output_shape = self.desc.output_shape();
207 let stride_x = args.x.stride;
208 let stride_y = args.y.stride;
209 let rank = N as i32;
210 let reduce_axis = self.desc.reduce_axis as i32;
211 let reduce_extent = self.desc.input_shape[self.desc.reduce_axis as usize];
212 let reduce_stride_x = args.x.stride[self.desc.reduce_axis as usize];
213
214 macro_rules! dispatch {
215 ($sym:ident) => {{
216 unsafe {
217 baracuda_kernels_sys::$sym(
218 output_numel,
219 rank,
220 output_shape.as_ptr(),
221 stride_x.as_ptr(),
222 stride_y.as_ptr(),
223 reduce_axis,
224 reduce_extent,
225 reduce_stride_x,
226 x_ptr,
227 y_ptr,
228 core::ptr::null_mut(),
229 0,
230 stream_ptr,
231 )
232 }
233 }};
234 }
235
236 let status = match (self.desc.kind, T::KIND) {
237 (ReduceKind::Any, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_any_f32_run),
239 (ReduceKind::Any, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_any_f16_run),
240 (ReduceKind::Any, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_any_bf16_run),
241 (ReduceKind::Any, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_any_f64_run),
242 (ReduceKind::Any, ElementKind::I32) => dispatch!(baracuda_kernels_reduce_any_i32_run),
243 (ReduceKind::Any, ElementKind::I64) => dispatch!(baracuda_kernels_reduce_any_i64_run),
244 (ReduceKind::Any, ElementKind::Bool) => dispatch!(baracuda_kernels_reduce_any_bool_run),
245 (ReduceKind::All, ElementKind::F32) => dispatch!(baracuda_kernels_reduce_all_f32_run),
247 (ReduceKind::All, ElementKind::F16) => dispatch!(baracuda_kernels_reduce_all_f16_run),
248 (ReduceKind::All, ElementKind::Bf16) => dispatch!(baracuda_kernels_reduce_all_bf16_run),
249 (ReduceKind::All, ElementKind::F64) => dispatch!(baracuda_kernels_reduce_all_f64_run),
250 (ReduceKind::All, ElementKind::I32) => dispatch!(baracuda_kernels_reduce_all_i32_run),
251 (ReduceKind::All, ElementKind::I64) => dispatch!(baracuda_kernels_reduce_all_i64_run),
252 (ReduceKind::All, ElementKind::Bool) => dispatch!(baracuda_kernels_reduce_all_bool_run),
253 _ => {
254 return Err(Error::Unsupported(
255 "baracuda-kernels::BoolReducePlan::run: only `{Any, All} × \
256 {f32, f16, bf16, f64, i32, i64, Bool}` wired",
257 ));
258 }
259 };
260 map_status(status)
261 }
262}
263
264fn map_status(code: i32) -> Result<()> {
265 match code {
266 0 => Ok(()),
267 1 => Err(Error::MisalignedOperand),
268 2 => Err(Error::InvalidProblem(
269 "baracuda-kernels-sys reported invalid problem",
270 )),
271 3 => Err(Error::Unsupported(
272 "baracuda-kernels-sys reported unsupported configuration",
273 )),
274 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
275 n => Err(Error::CutlassInternal(n)),
276 }
277}