Skip to main content

baracuda_kernels/shape_layout/
roll.rs

1//! `roll` plan — cyclic shift along axes. Output shape == input shape.
2//! Today wired for `{f32, f64, f16, bf16}`.
3
4use core::ffi::c_void;
5use core::marker::PhantomData;
6
7use baracuda_cutlass::{Error, Result};
8use baracuda_driver::Stream;
9use baracuda_kernels_types::{
10    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
11    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
12};
13
14/// Descriptor for a `roll` op.
15///
16/// `shifts[d]` is the shift amount along axis `d`. Can be negative.
17/// Normalized to `[0, shape[d])` by the kernel.
18#[derive(Copy, Clone, Debug)]
19pub struct RollDescriptor<const N: usize> {
20    /// Input/output tensor shape.
21    pub shape: [i32; N],
22    /// Per-axis shift amounts (positive or negative).
23    pub shifts: [i32; N],
24    /// Element type.
25    pub element: ElementKind,
26}
27
28/// Args bundle for a Roll launch.
29pub struct RollArgs<'a, T: Element, const N: usize> {
30    /// Input.
31    pub x: TensorRef<'a, T, N>,
32    /// Output — same shape as input.
33    pub y: TensorMut<'a, T, N>,
34}
35
36/// `roll` plan.
37///
38/// `y = x.roll(shifts, axes)` — cyclic shift along axes (PyTorch
39/// `torch.roll`). Output shape equals input shape.
40///
41/// **When to use**: forward cyclic shift. Pair with
42/// [`RollBackwardPlan`](crate::RollBackwardPlan), which reuses the
43/// FW kernel with negated shifts.
44///
45/// **Dtypes**: `{f32, f64, f16, bf16}`. Pure index permutation.
46///
47/// **Shape limits**: rank in `[1, 8]`; per-axis shifts can be
48/// negative (kernel normalizes to `[0, shape[d])`).
49///
50/// **Workspace**: none.
51///
52/// **Precision guarantee**: deterministic, bit-stable, bit-exact.
53pub struct RollPlan<T: Element, const N: usize> {
54    desc: RollDescriptor<N>,
55    sku: KernelSku,
56    _marker: PhantomData<T>,
57}
58
59impl<T: Element, const N: usize> RollPlan<T, N> {
60    /// Pick a kernel for `desc`.
61    pub fn select(
62        _stream: &Stream,
63        desc: &RollDescriptor<N>,
64        _pref: PlanPreference,
65    ) -> Result<Self> {
66        if desc.element != T::KIND {
67            return Err(Error::Unsupported(
68                "baracuda-kernels::RollPlan: descriptor element != type parameter T",
69            ));
70        }
71        for &d in desc.shape.iter() {
72            if d < 0 {
73                return Err(Error::InvalidProblem(
74                    "baracuda-kernels::RollPlan: shape dims must be non-negative",
75                ));
76            }
77        }
78        let supported = matches!(
79            T::KIND,
80            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
81        );
82        if !supported {
83            return Err(Error::Unsupported(
84                "baracuda-kernels::RollPlan: today only `f32`, `f16`, `bf16`, `f64` \
85                 are wired; other dtypes land in future fanout",
86            ));
87        }
88        let precision_guarantee = PrecisionGuarantee {
89            math_precision: MathPrecision::F32,
90            accumulator: ElementKind::F32,
91            bit_stable_on_same_hardware: true,
92            deterministic: true,
93        };
94        let sku = KernelSku {
95            category: OpCategory::ShapeLayout,
96            op: ShapeLayoutKind::Roll as u16,
97            element: T::KIND,
98            aux_element: None,
99            layout: None,
100            epilogue: None,
101            arch: ArchSku::Sm80,
102            backend: BackendKind::Bespoke,
103            precision_guarantee,
104        };
105        Ok(Self {
106            desc: *desc,
107            sku,
108            _marker: PhantomData,
109        })
110    }
111
112    /// Validate args.
113    pub fn can_implement(&self, args: &RollArgs<'_, T, N>) -> Result<()> {
114        if args.x.shape != self.desc.shape {
115            return Err(Error::InvalidProblem(
116                "baracuda-kernels::RollPlan: X shape mismatch",
117            ));
118        }
119        if args.y.shape != self.desc.shape {
120            return Err(Error::InvalidProblem(
121                "baracuda-kernels::RollPlan: Y shape mismatch",
122            ));
123        }
124        if N > 8 {
125            return Err(Error::Unsupported(
126                "baracuda-kernels::RollPlan: tensor rank > 8 not supported",
127            ));
128        }
129        let numel = args.y.numel();
130        let x_len = args.x.data.len() as i64;
131        let y_len = args.y.data.len() as i64;
132        if x_len < numel || y_len < numel {
133            return Err(Error::BufferTooSmall {
134                needed: numel as usize,
135                got: x_len.min(y_len) as usize,
136            });
137        }
138        Ok(())
139    }
140
141    /// Workspace size in bytes.
142    #[inline]
143    pub fn workspace_size(&self) -> usize {
144        0
145    }
146    /// Identity of the kernel this plan picked.
147    #[inline]
148    pub fn sku(&self) -> KernelSku {
149        self.sku
150    }
151    /// Numerical guarantees.
152    #[inline]
153    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
154        self.sku.precision_guarantee
155    }
156
157    /// Launch.
158    pub fn run(
159        &self,
160        stream: &Stream,
161        _workspace: Workspace<'_>,
162        args: RollArgs<'_, T, N>,
163    ) -> Result<()> {
164        self.can_implement(&args)?;
165        let numel = args.y.numel();
166        if numel == 0 {
167            return Ok(());
168        }
169        let x_ptr = args.x.data.as_raw().0 as *const c_void;
170        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
171        let stream_ptr = stream.as_raw() as *mut c_void;
172
173        let shape = self.desc.shape;
174        let shifts = self.desc.shifts;
175        let stride_x = args.x.stride;
176        let stride_y = args.y.stride;
177        let rank = N as i32;
178
179        let status = match T::KIND {
180            ElementKind::F32 => unsafe {
181                baracuda_kernels_sys::baracuda_kernels_roll_f32_run(
182                    numel,
183                    rank,
184                    shape.as_ptr(),
185                    shifts.as_ptr(),
186                    stride_x.as_ptr(),
187                    stride_y.as_ptr(),
188                    x_ptr,
189                    y_ptr,
190                    core::ptr::null_mut(),
191                    0,
192                    stream_ptr,
193                )
194            },
195            ElementKind::F16 => unsafe {
196                baracuda_kernels_sys::baracuda_kernels_roll_f16_run(
197                    numel,
198                    rank,
199                    shape.as_ptr(),
200                    shifts.as_ptr(),
201                    stride_x.as_ptr(),
202                    stride_y.as_ptr(),
203                    x_ptr,
204                    y_ptr,
205                    core::ptr::null_mut(),
206                    0,
207                    stream_ptr,
208                )
209            },
210            ElementKind::Bf16 => unsafe {
211                baracuda_kernels_sys::baracuda_kernels_roll_bf16_run(
212                    numel,
213                    rank,
214                    shape.as_ptr(),
215                    shifts.as_ptr(),
216                    stride_x.as_ptr(),
217                    stride_y.as_ptr(),
218                    x_ptr,
219                    y_ptr,
220                    core::ptr::null_mut(),
221                    0,
222                    stream_ptr,
223                )
224            },
225            ElementKind::F64 => unsafe {
226                baracuda_kernels_sys::baracuda_kernels_roll_f64_run(
227                    numel,
228                    rank,
229                    shape.as_ptr(),
230                    shifts.as_ptr(),
231                    stride_x.as_ptr(),
232                    stride_y.as_ptr(),
233                    x_ptr,
234                    y_ptr,
235                    core::ptr::null_mut(),
236                    0,
237                    stream_ptr,
238                )
239            },
240            _ => {
241                return Err(Error::Unsupported(
242                    "baracuda-kernels::RollPlan::run: only f32/f16/bf16/f64 wired today",
243                ));
244            }
245        };
246        map_status(status)
247    }
248}
249
250fn map_status(code: i32) -> Result<()> {
251    match code {
252        0 => Ok(()),
253        1 => Err(Error::MisalignedOperand),
254        2 => Err(Error::InvalidProblem(
255            "baracuda-kernels-sys reported invalid problem",
256        )),
257        3 => Err(Error::Unsupported(
258            "baracuda-kernels-sys reported unsupported configuration",
259        )),
260        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
261        n => Err(Error::CutlassInternal(n)),
262    }
263}