1use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory, PadMode,
22 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
23};
24use half::{bf16, f16};
25
26#[derive(Copy, Clone, Debug)]
34pub struct PadDescriptor<const N: usize> {
35 pub mode: PadMode,
39 pub input_shape: [i32; N],
41 pub pad_low: [i32; N],
43 pub pad_high: [i32; N],
45 pub value: f32,
47 pub element: ElementKind,
49}
50
51impl<const N: usize> PadDescriptor<N> {
52 pub fn output_shape(&self) -> [i32; N] {
54 let mut out = [0i32; N];
55 for d in 0..N {
56 out[d] = self.input_shape[d] + self.pad_low[d] + self.pad_high[d];
57 }
58 out
59 }
60}
61
62pub struct PadArgs<'a, T: Element, const N: usize> {
69 pub x: TensorRef<'a, T, N>,
71 pub y: TensorMut<'a, T, N>,
73}
74
75pub struct PadPlan<T: Element, const N: usize> {
100 desc: PadDescriptor<N>,
101 sku: KernelSku,
102 _marker: PhantomData<T>,
103}
104
105impl<T: Element, const N: usize> PadPlan<T, N> {
106 pub fn select(
108 _stream: &Stream,
109 desc: &PadDescriptor<N>,
110 _pref: PlanPreference,
111 ) -> Result<Self> {
112 if desc.element != T::KIND {
113 return Err(Error::Unsupported(
114 "baracuda-kernels::PadPlan: descriptor element != type parameter T",
115 ));
116 }
117 for d in 0..N {
118 if desc.input_shape[d] < 0 || desc.pad_low[d] < 0 || desc.pad_high[d] < 0 {
119 return Err(Error::InvalidProblem(
120 "baracuda-kernels::PadPlan: input_shape / pad_low / pad_high \
121 must be non-negative",
122 ));
123 }
124 }
125
126 let dtype_in_scope = matches!(
131 T::KIND,
132 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
133 );
134 let mode_in_scope = matches!(
135 desc.mode,
136 PadMode::Constant | PadMode::Reflect | PadMode::Replicate | PadMode::Circular
137 );
138 if !(dtype_in_scope && mode_in_scope) {
139 return Err(Error::Unsupported(
140 "baracuda-kernels::PadPlan: supported matrix is \
141 {Constant, Reflect, Replicate, Circular} × {f32, f16, bf16, f64}",
142 ));
143 }
144
145 let precision_guarantee = PrecisionGuarantee {
146 math_precision: MathPrecision::F32,
147 accumulator: ElementKind::F32,
148 bit_stable_on_same_hardware: true,
150 deterministic: true,
151 };
152 let sku = KernelSku {
153 category: OpCategory::ShapeLayout,
154 op: ShapeLayoutKind::Pad as u16,
155 element: T::KIND,
156 aux_element: None,
157 layout: None,
158 epilogue: None,
159 arch: ArchSku::Sm80,
160 backend: BackendKind::Bespoke,
161 precision_guarantee,
162 };
163 Ok(Self {
164 desc: *desc,
165 sku,
166 _marker: PhantomData,
167 })
168 }
169
170 pub fn can_implement(&self, args: &PadArgs<'_, T, N>) -> Result<()> {
172 if args.x.shape != self.desc.input_shape {
173 return Err(Error::InvalidProblem(
174 "baracuda-kernels::PadPlan: X shape mismatch with descriptor input_shape",
175 ));
176 }
177 let expected_out = self.desc.output_shape();
178 if args.y.shape != expected_out {
179 return Err(Error::InvalidProblem(
180 "baracuda-kernels::PadPlan: Y shape mismatch with derived output shape \
181 (= input_shape + pad_low + pad_high per axis)",
182 ));
183 }
184 if N > 8 {
185 return Err(Error::Unsupported(
186 "baracuda-kernels::PadPlan: tensor rank > 8 not supported",
187 ));
188 }
189 let y_numel = args.y.numel();
190 let x_numel = args.x.numel();
191 let x_len = args.x.data.len() as i64;
192 let y_len = args.y.data.len() as i64;
193 if y_len < y_numel {
194 return Err(Error::BufferTooSmall {
195 needed: y_numel as usize,
196 got: y_len as usize,
197 });
198 }
199 if x_len < x_numel {
200 return Err(Error::BufferTooSmall {
201 needed: x_numel as usize,
202 got: x_len as usize,
203 });
204 }
205 Ok(())
206 }
207
208 #[inline]
210 pub fn workspace_size(&self) -> usize {
211 0
212 }
213
214 #[inline]
216 pub fn sku(&self) -> KernelSku {
217 self.sku
218 }
219
220 #[inline]
222 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
223 self.sku.precision_guarantee
224 }
225
226 pub fn run(
228 &self,
229 stream: &Stream,
230 _workspace: Workspace<'_>,
231 args: PadArgs<'_, T, N>,
232 ) -> Result<()> {
233 self.can_implement(&args)?;
234 let output_numel = args.y.numel();
235 if output_numel == 0 {
236 return Ok(());
237 }
238 let x_ptr = args.x.data.as_raw().0 as *const c_void;
239 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
240 let stream_ptr = stream.as_raw() as *mut c_void;
241
242 let input_shape = self.desc.input_shape;
243 let output_shape = self.desc.output_shape();
244 let pad_low = self.desc.pad_low;
245 let stride_x = args.x.stride;
246 let stride_y = args.y.stride;
247 let rank = N as i32;
248
249 macro_rules! dispatch_mode {
251 ($sym:ident) => {{
252 unsafe {
253 baracuda_kernels_sys::$sym(
254 output_numel,
255 rank,
256 input_shape.as_ptr(),
257 output_shape.as_ptr(),
258 pad_low.as_ptr(),
259 stride_x.as_ptr(),
260 stride_y.as_ptr(),
261 x_ptr,
262 y_ptr,
263 core::ptr::null_mut(),
264 0,
265 stream_ptr,
266 )
267 }
268 }};
269 }
270
271 let status = match (self.desc.mode, T::KIND) {
272 (PadMode::Constant, ElementKind::F32) => unsafe {
273 baracuda_kernels_sys::baracuda_kernels_pad_constant_f32_run(
274 output_numel,
275 rank,
276 input_shape.as_ptr(),
277 output_shape.as_ptr(),
278 pad_low.as_ptr(),
279 stride_x.as_ptr(),
280 stride_y.as_ptr(),
281 x_ptr,
282 y_ptr,
283 self.desc.value,
284 core::ptr::null_mut(),
285 0,
286 stream_ptr,
287 )
288 },
289 (PadMode::Constant, ElementKind::F16) => unsafe {
290 let value_bits = f16::from_f32(self.desc.value).to_bits();
295 baracuda_kernels_sys::baracuda_kernels_pad_constant_f16_run(
296 output_numel,
297 rank,
298 input_shape.as_ptr(),
299 output_shape.as_ptr(),
300 pad_low.as_ptr(),
301 stride_x.as_ptr(),
302 stride_y.as_ptr(),
303 x_ptr,
304 y_ptr,
305 value_bits,
306 core::ptr::null_mut(),
307 0,
308 stream_ptr,
309 )
310 },
311 (PadMode::Constant, ElementKind::Bf16) => unsafe {
312 let value_bits = bf16::from_f32(self.desc.value).to_bits();
313 baracuda_kernels_sys::baracuda_kernels_pad_constant_bf16_run(
314 output_numel,
315 rank,
316 input_shape.as_ptr(),
317 output_shape.as_ptr(),
318 pad_low.as_ptr(),
319 stride_x.as_ptr(),
320 stride_y.as_ptr(),
321 x_ptr,
322 y_ptr,
323 value_bits,
324 core::ptr::null_mut(),
325 0,
326 stream_ptr,
327 )
328 },
329 (PadMode::Constant, ElementKind::F64) => unsafe {
330 baracuda_kernels_sys::baracuda_kernels_pad_constant_f64_run(
331 output_numel,
332 rank,
333 input_shape.as_ptr(),
334 output_shape.as_ptr(),
335 pad_low.as_ptr(),
336 stride_x.as_ptr(),
337 stride_y.as_ptr(),
338 x_ptr,
339 y_ptr,
340 self.desc.value as f64,
341 core::ptr::null_mut(),
342 0,
343 stream_ptr,
344 )
345 },
346 (PadMode::Reflect, ElementKind::F32) => {
348 dispatch_mode!(baracuda_kernels_pad_reflect_f32_run)
349 }
350 (PadMode::Reflect, ElementKind::F16) => {
351 dispatch_mode!(baracuda_kernels_pad_reflect_f16_run)
352 }
353 (PadMode::Reflect, ElementKind::Bf16) => {
354 dispatch_mode!(baracuda_kernels_pad_reflect_bf16_run)
355 }
356 (PadMode::Reflect, ElementKind::F64) => {
357 dispatch_mode!(baracuda_kernels_pad_reflect_f64_run)
358 }
359 (PadMode::Replicate, ElementKind::F32) => {
361 dispatch_mode!(baracuda_kernels_pad_replicate_f32_run)
362 }
363 (PadMode::Replicate, ElementKind::F16) => {
364 dispatch_mode!(baracuda_kernels_pad_replicate_f16_run)
365 }
366 (PadMode::Replicate, ElementKind::Bf16) => {
367 dispatch_mode!(baracuda_kernels_pad_replicate_bf16_run)
368 }
369 (PadMode::Replicate, ElementKind::F64) => {
370 dispatch_mode!(baracuda_kernels_pad_replicate_f64_run)
371 }
372 (PadMode::Circular, ElementKind::F32) => {
374 dispatch_mode!(baracuda_kernels_pad_circular_f32_run)
375 }
376 (PadMode::Circular, ElementKind::F16) => {
377 dispatch_mode!(baracuda_kernels_pad_circular_f16_run)
378 }
379 (PadMode::Circular, ElementKind::Bf16) => {
380 dispatch_mode!(baracuda_kernels_pad_circular_bf16_run)
381 }
382 (PadMode::Circular, ElementKind::F64) => {
383 dispatch_mode!(baracuda_kernels_pad_circular_f64_run)
384 }
385 _ => {
386 return Err(Error::Unsupported(
387 "baracuda-kernels::PadPlan::run: this (mode, dtype) cell is not wired",
388 ));
389 }
390 };
391 map_status(status)
392 }
393}
394
395fn map_status(code: i32) -> Result<()> {
396 match code {
397 0 => Ok(()),
398 1 => Err(Error::MisalignedOperand),
399 2 => Err(Error::InvalidProblem(
400 "baracuda-kernels-sys reported invalid problem",
401 )),
402 3 => Err(Error::Unsupported(
403 "baracuda-kernels-sys reported unsupported configuration",
404 )),
405 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
406 n => Err(Error::CutlassInternal(n)),
407 }
408}