1use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_types::{
34 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
35 PlanPreference, PrecisionGuarantee, ScanKind, TensorMut, TensorRef, Workspace,
36};
37
38#[derive(Copy, Clone, Debug)]
40pub struct ScanDescriptor<const N: usize> {
41 pub kind: ScanKind,
43 pub input_shape: [i32; N],
45 pub scan_axis: u8,
47 pub reverse: bool,
50 pub element: ElementKind,
52}
53
54pub struct ScanArgs<'a, T: Element, const N: usize> {
56 pub x: TensorRef<'a, T, N>,
58 pub y: TensorMut<'a, T, N>,
60}
61
62pub struct ScanPlan<T: Element, const N: usize> {
68 desc: ScanDescriptor<N>,
69 sku: KernelSku,
70 _marker: PhantomData<T>,
71}
72
73impl<T: Element, const N: usize> ScanPlan<T, N> {
74 pub fn select(
78 _stream: &Stream,
79 desc: &ScanDescriptor<N>,
80 _pref: PlanPreference,
81 ) -> Result<Self> {
82 if desc.element != T::KIND {
83 return Err(Error::Unsupported(
84 "baracuda-kernels::ScanPlan: descriptor element != T",
85 ));
86 }
87 if (desc.scan_axis as usize) >= N {
88 return Err(Error::InvalidProblem(
89 "baracuda-kernels::ScanPlan: scan_axis out of range for rank N",
90 ));
91 }
92 for &d in desc.input_shape.iter() {
93 if d < 0 {
94 return Err(Error::InvalidProblem(
95 "baracuda-kernels::ScanPlan: shape dims must be non-negative",
96 ));
97 }
98 }
99 if N > 8 {
100 return Err(Error::Unsupported(
101 "baracuda-kernels::ScanPlan: tensor rank > 8 not supported \
102 (kernel param block fixes MAX_RANK = 8)",
103 ));
104 }
105
106 let dtype_in_fp_family = matches!(
109 T::KIND,
110 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
111 );
112 let kind_supported = matches!(
113 desc.kind,
114 ScanKind::Cumsum
115 | ScanKind::Cumprod
116 | ScanKind::Cummax
117 | ScanKind::Cummin
118 | ScanKind::LogCumsumExp
119 );
120 let supported = kind_supported && dtype_in_fp_family;
121 if !supported {
122 return Err(Error::Unsupported(
123 "baracuda-kernels::ScanPlan: wired today: \
124 `{Cumsum, Cumprod, Cummax, Cummin, LogCumsumExp} × {f32, f16, bf16, f64}`",
125 ));
126 }
127
128 let precision_guarantee = PrecisionGuarantee {
129 math_precision: MathPrecision::F32,
130 accumulator: ElementKind::F32,
131 bit_stable_on_same_hardware: true,
134 deterministic: true,
135 };
136 let sku = KernelSku {
137 category: OpCategory::Scan,
138 op: desc.kind as u16,
139 element: T::KIND,
140 aux_element: None,
141 layout: None,
142 epilogue: None,
143 arch: ArchSku::Sm80,
144 backend: BackendKind::Bespoke,
145 precision_guarantee,
146 };
147 Ok(Self {
148 desc: *desc,
149 sku,
150 _marker: PhantomData,
151 })
152 }
153
154 pub fn can_implement(&self, args: &ScanArgs<'_, T, N>) -> Result<()> {
156 if args.x.shape != self.desc.input_shape {
157 return Err(Error::InvalidProblem(
158 "baracuda-kernels::ScanPlan: x shape mismatch",
159 ));
160 }
161 if args.y.shape != self.desc.input_shape {
162 return Err(Error::InvalidProblem(
163 "baracuda-kernels::ScanPlan: y shape mismatch (scans are \
164 length-preserving — y.shape must equal x.shape)",
165 ));
166 }
167 let numel = args.x.numel();
168 let x_len = args.x.data.len() as i64;
169 let y_len = args.y.data.len() as i64;
170 if x_len < numel || y_len < numel {
171 return Err(Error::BufferTooSmall {
172 needed: numel as usize,
173 got: x_len.min(y_len) as usize,
174 });
175 }
176 Ok(())
177 }
178
179 #[inline]
181 pub fn workspace_size(&self) -> usize {
182 0
183 }
184 #[inline]
186 pub fn sku(&self) -> KernelSku {
187 self.sku
188 }
189 #[inline]
193 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
194 self.sku.precision_guarantee
195 }
196
197 pub fn run(
200 &self,
201 stream: &Stream,
202 _workspace: Workspace<'_>,
203 args: ScanArgs<'_, T, N>,
204 ) -> Result<()> {
205 self.can_implement(&args)?;
206 let numel = args.x.numel();
207 if numel == 0 {
208 return Ok(());
209 }
210 let x_ptr = args.x.data.as_raw().0 as *const c_void;
211 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
212 let stream_ptr = stream.as_raw() as *mut c_void;
213
214 let axis = self.desc.scan_axis as usize;
215 let shape = self.desc.input_shape;
216 let stride_x = args.x.stride;
217 let stride_y = args.y.stride;
218 let rank = N as i32;
219 let scan_extent = shape[axis];
220 let scan_stride_x = stride_x[axis];
221 let reverse = if self.desc.reverse { 1i32 } else { 0 };
222
223 macro_rules! dispatch {
224 ($sym:ident) => {
225 unsafe {
226 baracuda_kernels_sys::$sym(
227 numel,
228 rank,
229 shape.as_ptr(),
230 stride_x.as_ptr(),
231 stride_y.as_ptr(),
232 axis as i32,
233 scan_extent,
234 scan_stride_x,
235 reverse,
236 x_ptr,
237 y_ptr,
238 core::ptr::null_mut(),
239 0,
240 stream_ptr,
241 )
242 }
243 };
244 }
245
246 let status = match (self.desc.kind, T::KIND) {
247 (ScanKind::Cumsum, ElementKind::F32) => dispatch!(baracuda_kernels_scan_cumsum_f32_run),
248 (ScanKind::Cumsum, ElementKind::F16) => dispatch!(baracuda_kernels_scan_cumsum_f16_run),
249 (ScanKind::Cumsum, ElementKind::Bf16) => {
250 dispatch!(baracuda_kernels_scan_cumsum_bf16_run)
251 }
252 (ScanKind::Cumsum, ElementKind::F64) => dispatch!(baracuda_kernels_scan_cumsum_f64_run),
253 (ScanKind::Cumprod, ElementKind::F32) => {
254 dispatch!(baracuda_kernels_scan_cumprod_f32_run)
255 }
256 (ScanKind::Cumprod, ElementKind::F16) => {
257 dispatch!(baracuda_kernels_scan_cumprod_f16_run)
258 }
259 (ScanKind::Cumprod, ElementKind::Bf16) => {
260 dispatch!(baracuda_kernels_scan_cumprod_bf16_run)
261 }
262 (ScanKind::Cumprod, ElementKind::F64) => {
263 dispatch!(baracuda_kernels_scan_cumprod_f64_run)
264 }
265 (ScanKind::Cummax, ElementKind::F32) => dispatch!(baracuda_kernels_scan_cummax_f32_run),
266 (ScanKind::Cummax, ElementKind::F16) => dispatch!(baracuda_kernels_scan_cummax_f16_run),
267 (ScanKind::Cummax, ElementKind::Bf16) => {
268 dispatch!(baracuda_kernels_scan_cummax_bf16_run)
269 }
270 (ScanKind::Cummax, ElementKind::F64) => dispatch!(baracuda_kernels_scan_cummax_f64_run),
271 (ScanKind::Cummin, ElementKind::F32) => dispatch!(baracuda_kernels_scan_cummin_f32_run),
272 (ScanKind::Cummin, ElementKind::F16) => dispatch!(baracuda_kernels_scan_cummin_f16_run),
273 (ScanKind::Cummin, ElementKind::Bf16) => {
274 dispatch!(baracuda_kernels_scan_cummin_bf16_run)
275 }
276 (ScanKind::Cummin, ElementKind::F64) => dispatch!(baracuda_kernels_scan_cummin_f64_run),
277 (ScanKind::LogCumsumExp, ElementKind::F32) => {
278 dispatch!(baracuda_kernels_scan_log_cumsum_exp_f32_run)
279 }
280 (ScanKind::LogCumsumExp, ElementKind::F16) => {
281 dispatch!(baracuda_kernels_scan_log_cumsum_exp_f16_run)
282 }
283 (ScanKind::LogCumsumExp, ElementKind::Bf16) => {
284 dispatch!(baracuda_kernels_scan_log_cumsum_exp_bf16_run)
285 }
286 (ScanKind::LogCumsumExp, ElementKind::F64) => {
287 dispatch!(baracuda_kernels_scan_log_cumsum_exp_f64_run)
288 }
289 _ => {
290 return Err(Error::Unsupported(
291 "baracuda-kernels::ScanPlan::run reached an unimplemented \
292 (kind, dtype) pair — select() should have caught this",
293 ));
294 }
295 };
296 map_status(status)
297 }
298}
299
300fn map_status(code: i32) -> Result<()> {
301 match code {
302 0 => Ok(()),
303 1 => Err(Error::MisalignedOperand),
304 2 => Err(Error::InvalidProblem(
305 "baracuda-kernels-sys reported invalid problem",
306 )),
307 3 => Err(Error::Unsupported(
308 "baracuda-kernels-sys reported unsupported configuration",
309 )),
310 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
311 n => Err(Error::CutlassInternal(n)),
312 }
313}