Skip to main content

baracuda_kernels/shape_layout/
pad.rs

1//! `pad` plan — Category N entry point.
2//!
3//! Output shape per-axis is `input_shape[d] + pad_low[d] + pad_high[d]`
4//! (the FIRST Phase 3 plan where output shape differs from input). The
5//! kernel iterates output cells, computes input coord per axis via
6//! subtraction of `pad_low`, and either copies the input cell or writes
7//! the configured pad value.
8//!
9//! All four [`PadMode`]s ({Constant, Reflect, Replicate, Circular})
10//! are wired for `{f32, f16, bf16, f64}` — 16 (mode, dtype) cells.
11//! The descriptor's `value` field is consumed only by `Constant` mode;
12//! the other modes derive pad-region values from the input itself
13//! (mirror, clamp, or cyclic wrap respectively).
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory, PadMode,
22    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
23};
24use half::{bf16, f16};
25
26/// Descriptor for a constant-pad op.
27///
28/// `input_shape` is the shape of the input tensor. `pad_low[d]` and
29/// `pad_high[d]` are the pad amounts on each side of axis `d`. Output
30/// shape is `input_shape[d] + pad_low[d] + pad_high[d]` per axis.
31/// `value` is the constant used in the pad region (for `Constant`
32/// mode). `element` must match `T::KIND` at `select` time.
33#[derive(Copy, Clone, Debug)]
34pub struct PadDescriptor<const N: usize> {
35    /// Padding mode — one of [`PadMode::Constant`] / [`PadMode::Reflect`]
36    /// / [`PadMode::Replicate`] / [`PadMode::Circular`]. All four are
37    /// wired for every supported dtype.
38    pub mode: PadMode,
39    /// Input tensor shape.
40    pub input_shape: [i32; N],
41    /// Pad amount on the low side of each axis. Non-negative.
42    pub pad_low: [i32; N],
43    /// Pad amount on the high side of each axis. Non-negative.
44    pub pad_high: [i32; N],
45    /// Constant value used in the pad region for `Constant` mode.
46    pub value: f32,
47    /// Element type of input and output.
48    pub element: ElementKind,
49}
50
51impl<const N: usize> PadDescriptor<N> {
52    /// Compute the output shape from input shape + pad amounts.
53    pub fn output_shape(&self) -> [i32; N] {
54        let mut out = [0i32; N];
55        for d in 0..N {
56            out[d] = self.input_shape[d] + self.pad_low[d] + self.pad_high[d];
57        }
58        out
59    }
60}
61
62/// Args bundle for a Pad launch.
63///
64/// `x.shape` must match `desc.input_shape`. `y.shape` must match the
65/// output shape derived from descriptor (`input_shape + pad_low +
66/// pad_high` per axis). Both can be strided views — the kernel walks
67/// per-axis strides.
68pub struct PadArgs<'a, T: Element, const N: usize> {
69    /// Input tensor.
70    pub x: TensorRef<'a, T, N>,
71    /// Output tensor — larger than input by the configured pad amounts.
72    pub y: TensorMut<'a, T, N>,
73}
74
75/// `pad` plan.
76///
77/// `y = F.pad(x, pad_low, pad_high, mode, value)` — per-axis low /
78/// high padding (PyTorch `torch.nn.functional.pad`).
79///
80/// **When to use**: forward pad. Pair with
81/// [`PadBackwardPlan`](crate::PadBackwardPlan) for autograd
82/// (slice-back of `Constant` mode; the other modes have
83/// scatter-add BWs not yet wired — see below).
84///
85/// **Dtypes**: `{f32, f64, f16, bf16}` — 16 (mode, dtype) cells.
86///
87/// **Modes**: all four [`PadMode`] variants — `Constant`, `Reflect`,
88/// `Replicate`, `Circular`. `value` is consumed only by `Constant`;
89/// the others derive pad-region values from the input.
90///
91/// **Shape limits**: rank in `[1, 8]`; `pad_low[d]`, `pad_high[d]`
92/// non-negative; output shape per axis is
93/// `input_shape[d] + pad_low[d] + pad_high[d]`.
94///
95/// **Workspace**: none.
96///
97/// **Precision guarantee**: deterministic, bit-stable, bit-exact (no
98/// arithmetic — pure index + copy / value-write).
99pub struct PadPlan<T: Element, const N: usize> {
100    desc: PadDescriptor<N>,
101    sku: KernelSku,
102    _marker: PhantomData<T>,
103}
104
105impl<T: Element, const N: usize> PadPlan<T, N> {
106    /// Pick a kernel for `desc`.
107    pub fn select(
108        _stream: &Stream,
109        desc: &PadDescriptor<N>,
110        _pref: PlanPreference,
111    ) -> Result<Self> {
112        if desc.element != T::KIND {
113            return Err(Error::Unsupported(
114                "baracuda-kernels::PadPlan: descriptor element != type parameter T",
115            ));
116        }
117        for d in 0..N {
118            if desc.input_shape[d] < 0 || desc.pad_low[d] < 0 || desc.pad_high[d] < 0 {
119                return Err(Error::InvalidProblem(
120                    "baracuda-kernels::PadPlan: input_shape / pad_low / pad_high \
121                     must be non-negative",
122                ));
123            }
124        }
125
126        // Full Pad matrix today: 4 modes × {f32, f16, bf16, f64}.
127        // Reflect / Replicate / Circular do not consume the
128        // descriptor's `value` field — pad-region values are derived
129        // from the input itself.
130        let dtype_in_scope = matches!(
131            T::KIND,
132            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
133        );
134        let mode_in_scope = matches!(
135            desc.mode,
136            PadMode::Constant | PadMode::Reflect | PadMode::Replicate | PadMode::Circular
137        );
138        if !(dtype_in_scope && mode_in_scope) {
139            return Err(Error::Unsupported(
140                "baracuda-kernels::PadPlan: supported matrix is \
141                 {Constant, Reflect, Replicate, Circular} × {f32, f16, bf16, f64}",
142            ));
143        }
144
145        let precision_guarantee = PrecisionGuarantee {
146            math_precision: MathPrecision::F32,
147            accumulator: ElementKind::F32,
148            // Pad does no arithmetic — pure copy + constant fill.
149            bit_stable_on_same_hardware: true,
150            deterministic: true,
151        };
152        let sku = KernelSku {
153            category: OpCategory::ShapeLayout,
154            op: ShapeLayoutKind::Pad as u16,
155            element: T::KIND,
156            aux_element: None,
157            layout: None,
158            epilogue: None,
159            arch: ArchSku::Sm80,
160            backend: BackendKind::Bespoke,
161            precision_guarantee,
162        };
163        Ok(Self {
164            desc: *desc,
165            sku,
166            _marker: PhantomData,
167        })
168    }
169
170    /// Validate args.
171    pub fn can_implement(&self, args: &PadArgs<'_, T, N>) -> Result<()> {
172        if args.x.shape != self.desc.input_shape {
173            return Err(Error::InvalidProblem(
174                "baracuda-kernels::PadPlan: X shape mismatch with descriptor input_shape",
175            ));
176        }
177        let expected_out = self.desc.output_shape();
178        if args.y.shape != expected_out {
179            return Err(Error::InvalidProblem(
180                "baracuda-kernels::PadPlan: Y shape mismatch with derived output shape \
181                 (= input_shape + pad_low + pad_high per axis)",
182            ));
183        }
184        if N > 8 {
185            return Err(Error::Unsupported(
186                "baracuda-kernels::PadPlan: tensor rank > 8 not supported",
187            ));
188        }
189        let y_numel = args.y.numel();
190        let x_numel = args.x.numel();
191        let x_len = args.x.data.len() as i64;
192        let y_len = args.y.data.len() as i64;
193        if y_len < y_numel {
194            return Err(Error::BufferTooSmall {
195                needed: y_numel as usize,
196                got: y_len as usize,
197            });
198        }
199        if x_len < x_numel {
200            return Err(Error::BufferTooSmall {
201                needed: x_numel as usize,
202                got: x_len as usize,
203            });
204        }
205        Ok(())
206    }
207
208    /// Workspace size in bytes. Always `0` for the trailblazer.
209    #[inline]
210    pub fn workspace_size(&self) -> usize {
211        0
212    }
213
214    /// Identity of the kernel this plan picked.
215    #[inline]
216    pub fn sku(&self) -> KernelSku {
217        self.sku
218    }
219
220    /// Numerical guarantees for this plan's kernel.
221    #[inline]
222    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
223        self.sku.precision_guarantee
224    }
225
226    /// Launch.
227    pub fn run(
228        &self,
229        stream: &Stream,
230        _workspace: Workspace<'_>,
231        args: PadArgs<'_, T, N>,
232    ) -> Result<()> {
233        self.can_implement(&args)?;
234        let output_numel = args.y.numel();
235        if output_numel == 0 {
236            return Ok(());
237        }
238        let x_ptr = args.x.data.as_raw().0 as *const c_void;
239        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
240        let stream_ptr = stream.as_raw() as *mut c_void;
241
242        let input_shape = self.desc.input_shape;
243        let output_shape = self.desc.output_shape();
244        let pad_low = self.desc.pad_low;
245        let stride_x = args.x.stride;
246        let stride_y = args.y.stride;
247        let rank = N as i32;
248
249        // Non-constant pad modes share a parameter shape (no `value`).
250        macro_rules! dispatch_mode {
251            ($sym:ident) => {{
252                unsafe {
253                    baracuda_kernels_sys::$sym(
254                        output_numel,
255                        rank,
256                        input_shape.as_ptr(),
257                        output_shape.as_ptr(),
258                        pad_low.as_ptr(),
259                        stride_x.as_ptr(),
260                        stride_y.as_ptr(),
261                        x_ptr,
262                        y_ptr,
263                        core::ptr::null_mut(),
264                        0,
265                        stream_ptr,
266                    )
267                }
268            }};
269        }
270
271        let status = match (self.desc.mode, T::KIND) {
272            (PadMode::Constant, ElementKind::F32) => unsafe {
273                baracuda_kernels_sys::baracuda_kernels_pad_constant_f32_run(
274                    output_numel,
275                    rank,
276                    input_shape.as_ptr(),
277                    output_shape.as_ptr(),
278                    pad_low.as_ptr(),
279                    stride_x.as_ptr(),
280                    stride_y.as_ptr(),
281                    x_ptr,
282                    y_ptr,
283                    self.desc.value,
284                    core::ptr::null_mut(),
285                    0,
286                    stream_ptr,
287                )
288            },
289            (PadMode::Constant, ElementKind::F16) => unsafe {
290                // Convert the descriptor's f32 value to f16, then pass
291                // the 16-bit pattern by value — ABI-compatible with the
292                // C side's `__half value` parameter on Windows x64
293                // (small POD struct → register).
294                let value_bits = f16::from_f32(self.desc.value).to_bits();
295                baracuda_kernels_sys::baracuda_kernels_pad_constant_f16_run(
296                    output_numel,
297                    rank,
298                    input_shape.as_ptr(),
299                    output_shape.as_ptr(),
300                    pad_low.as_ptr(),
301                    stride_x.as_ptr(),
302                    stride_y.as_ptr(),
303                    x_ptr,
304                    y_ptr,
305                    value_bits,
306                    core::ptr::null_mut(),
307                    0,
308                    stream_ptr,
309                )
310            },
311            (PadMode::Constant, ElementKind::Bf16) => unsafe {
312                let value_bits = bf16::from_f32(self.desc.value).to_bits();
313                baracuda_kernels_sys::baracuda_kernels_pad_constant_bf16_run(
314                    output_numel,
315                    rank,
316                    input_shape.as_ptr(),
317                    output_shape.as_ptr(),
318                    pad_low.as_ptr(),
319                    stride_x.as_ptr(),
320                    stride_y.as_ptr(),
321                    x_ptr,
322                    y_ptr,
323                    value_bits,
324                    core::ptr::null_mut(),
325                    0,
326                    stream_ptr,
327                )
328            },
329            (PadMode::Constant, ElementKind::F64) => unsafe {
330                baracuda_kernels_sys::baracuda_kernels_pad_constant_f64_run(
331                    output_numel,
332                    rank,
333                    input_shape.as_ptr(),
334                    output_shape.as_ptr(),
335                    pad_low.as_ptr(),
336                    stride_x.as_ptr(),
337                    stride_y.as_ptr(),
338                    x_ptr,
339                    y_ptr,
340                    self.desc.value as f64,
341                    core::ptr::null_mut(),
342                    0,
343                    stream_ptr,
344                )
345            },
346            // Reflect — mirror across boundary.
347            (PadMode::Reflect, ElementKind::F32) => {
348                dispatch_mode!(baracuda_kernels_pad_reflect_f32_run)
349            }
350            (PadMode::Reflect, ElementKind::F16) => {
351                dispatch_mode!(baracuda_kernels_pad_reflect_f16_run)
352            }
353            (PadMode::Reflect, ElementKind::Bf16) => {
354                dispatch_mode!(baracuda_kernels_pad_reflect_bf16_run)
355            }
356            (PadMode::Reflect, ElementKind::F64) => {
357                dispatch_mode!(baracuda_kernels_pad_reflect_f64_run)
358            }
359            // Replicate — clamp to edge.
360            (PadMode::Replicate, ElementKind::F32) => {
361                dispatch_mode!(baracuda_kernels_pad_replicate_f32_run)
362            }
363            (PadMode::Replicate, ElementKind::F16) => {
364                dispatch_mode!(baracuda_kernels_pad_replicate_f16_run)
365            }
366            (PadMode::Replicate, ElementKind::Bf16) => {
367                dispatch_mode!(baracuda_kernels_pad_replicate_bf16_run)
368            }
369            (PadMode::Replicate, ElementKind::F64) => {
370                dispatch_mode!(baracuda_kernels_pad_replicate_f64_run)
371            }
372            // Circular — cyclic wrap.
373            (PadMode::Circular, ElementKind::F32) => {
374                dispatch_mode!(baracuda_kernels_pad_circular_f32_run)
375            }
376            (PadMode::Circular, ElementKind::F16) => {
377                dispatch_mode!(baracuda_kernels_pad_circular_f16_run)
378            }
379            (PadMode::Circular, ElementKind::Bf16) => {
380                dispatch_mode!(baracuda_kernels_pad_circular_bf16_run)
381            }
382            (PadMode::Circular, ElementKind::F64) => {
383                dispatch_mode!(baracuda_kernels_pad_circular_f64_run)
384            }
385            _ => {
386                return Err(Error::Unsupported(
387                    "baracuda-kernels::PadPlan::run: this (mode, dtype) cell is not wired",
388                ));
389            }
390        };
391        map_status(status)
392    }
393}
394
395fn map_status(code: i32) -> Result<()> {
396    match code {
397        0 => Ok(()),
398        1 => Err(Error::MisalignedOperand),
399        2 => Err(Error::InvalidProblem(
400            "baracuda-kernels-sys reported invalid problem",
401        )),
402        3 => Err(Error::Unsupported(
403            "baracuda-kernels-sys reported unsupported configuration",
404        )),
405        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
406        n => Err(Error::CutlassInternal(n)),
407    }
408}