Skip to main content

baracuda_kernels/shape_layout/
flip.rs

1//! `flip` plan — reverse along selected axes. Output shape == input
2//! shape. Today wired for `{f32, f64, f16, bf16}`.
3
4use core::ffi::c_void;
5use core::marker::PhantomData;
6
7use baracuda_cutlass::{Error, Result};
8use baracuda_driver::Stream;
9use baracuda_kernels_types::{
10    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
11    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
12};
13
14/// Descriptor for a `flip` op.
15///
16/// `flip_axes[d]` is `true` if axis `d` should be reversed, `false`
17/// otherwise.
18#[derive(Copy, Clone, Debug)]
19pub struct FlipDescriptor<const N: usize> {
20    /// Input/output tensor shape (flip preserves shape).
21    pub shape: [i32; N],
22    /// Per-axis mask: `true` = reverse, `false` = no-op.
23    pub flip_axes: [bool; N],
24    /// Element type.
25    pub element: ElementKind,
26}
27
28/// Args bundle for a Flip launch.
29pub struct FlipArgs<'a, T: Element, const N: usize> {
30    /// Input.
31    pub x: TensorRef<'a, T, N>,
32    /// Output — same shape as input.
33    pub y: TensorMut<'a, T, N>,
34}
35
36/// `flip` plan.
37///
38/// `y = x.flip(axes)` — reverse along the selected axes (PyTorch
39/// `torch.flip`). Output shape equals input shape.
40///
41/// **When to use**: forward flip. Pair with
42/// [`FlipBackwardPlan`](crate::FlipBackwardPlan) — but note that
43/// `flip` is an involution, so the BW just re-flips along the same
44/// axes (reuses the FW kernel).
45///
46/// **Dtypes**: `{f32, f64, f16, bf16}`. Pure index permutation, no
47/// arithmetic.
48///
49/// **Shape limits**: rank in `[1, 8]`; `flip_axes` is a per-axis
50/// boolean mask.
51///
52/// **Workspace**: none.
53///
54/// **Precision guarantee**: deterministic, bit-stable, bit-exact.
55pub struct FlipPlan<T: Element, const N: usize> {
56    desc: FlipDescriptor<N>,
57    sku: KernelSku,
58    _marker: PhantomData<T>,
59}
60
61impl<T: Element, const N: usize> FlipPlan<T, N> {
62    /// Pick a kernel for `desc`.
63    pub fn select(
64        _stream: &Stream,
65        desc: &FlipDescriptor<N>,
66        _pref: PlanPreference,
67    ) -> Result<Self> {
68        if desc.element != T::KIND {
69            return Err(Error::Unsupported(
70                "baracuda-kernels::FlipPlan: descriptor element != type parameter T",
71            ));
72        }
73        for &d in desc.shape.iter() {
74            if d < 0 {
75                return Err(Error::InvalidProblem(
76                    "baracuda-kernels::FlipPlan: shape dims must be non-negative",
77                ));
78            }
79        }
80        let supported = matches!(
81            T::KIND,
82            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
83        );
84        if !supported {
85            return Err(Error::Unsupported(
86                "baracuda-kernels::FlipPlan: today only `f32`, `f16`, `bf16`, `f64` \
87                 are wired; other dtypes land in future fanout",
88            ));
89        }
90        let precision_guarantee = PrecisionGuarantee {
91            math_precision: MathPrecision::F32,
92            accumulator: ElementKind::F32,
93            bit_stable_on_same_hardware: true,
94            deterministic: true,
95        };
96        let sku = KernelSku {
97            category: OpCategory::ShapeLayout,
98            op: ShapeLayoutKind::Flip as u16,
99            element: T::KIND,
100            aux_element: None,
101            layout: None,
102            epilogue: None,
103            arch: ArchSku::Sm80,
104            backend: BackendKind::Bespoke,
105            precision_guarantee,
106        };
107        Ok(Self {
108            desc: *desc,
109            sku,
110            _marker: PhantomData,
111        })
112    }
113
114    /// Validate args.
115    pub fn can_implement(&self, args: &FlipArgs<'_, T, N>) -> Result<()> {
116        if args.x.shape != self.desc.shape {
117            return Err(Error::InvalidProblem(
118                "baracuda-kernels::FlipPlan: X shape mismatch with descriptor",
119            ));
120        }
121        if args.y.shape != self.desc.shape {
122            return Err(Error::InvalidProblem(
123                "baracuda-kernels::FlipPlan: Y shape mismatch with descriptor",
124            ));
125        }
126        if N > 8 {
127            return Err(Error::Unsupported(
128                "baracuda-kernels::FlipPlan: tensor rank > 8 not supported",
129            ));
130        }
131        let numel = args.y.numel();
132        let x_len = args.x.data.len() as i64;
133        let y_len = args.y.data.len() as i64;
134        if x_len < numel || y_len < numel {
135            return Err(Error::BufferTooSmall {
136                needed: numel as usize,
137                got: x_len.min(y_len) as usize,
138            });
139        }
140        Ok(())
141    }
142
143    /// Workspace size in bytes.
144    #[inline]
145    pub fn workspace_size(&self) -> usize {
146        0
147    }
148    /// Identity of the kernel this plan picked.
149    #[inline]
150    pub fn sku(&self) -> KernelSku {
151        self.sku
152    }
153    /// Numerical guarantees.
154    #[inline]
155    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
156        self.sku.precision_guarantee
157    }
158
159    /// Launch.
160    pub fn run(
161        &self,
162        stream: &Stream,
163        _workspace: Workspace<'_>,
164        args: FlipArgs<'_, T, N>,
165    ) -> Result<()> {
166        self.can_implement(&args)?;
167        let numel = args.y.numel();
168        if numel == 0 {
169            return Ok(());
170        }
171        let x_ptr = args.x.data.as_raw().0 as *const c_void;
172        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
173        let stream_ptr = stream.as_raw() as *mut c_void;
174
175        // Convert bool flip_axes to i32 array for the FFI.
176        let mut flip_axes_i32 = [0i32; 8];
177        for d in 0..N {
178            flip_axes_i32[d] = if self.desc.flip_axes[d] { 1 } else { 0 };
179        }
180
181        let shape = self.desc.shape;
182        let stride_x = args.x.stride;
183        let stride_y = args.y.stride;
184        let rank = N as i32;
185
186        let status = match T::KIND {
187            ElementKind::F32 => unsafe {
188                baracuda_kernels_sys::baracuda_kernels_flip_f32_run(
189                    numel,
190                    rank,
191                    shape.as_ptr(),
192                    flip_axes_i32.as_ptr(),
193                    stride_x.as_ptr(),
194                    stride_y.as_ptr(),
195                    x_ptr,
196                    y_ptr,
197                    core::ptr::null_mut(),
198                    0,
199                    stream_ptr,
200                )
201            },
202            ElementKind::F16 => unsafe {
203                baracuda_kernels_sys::baracuda_kernels_flip_f16_run(
204                    numel,
205                    rank,
206                    shape.as_ptr(),
207                    flip_axes_i32.as_ptr(),
208                    stride_x.as_ptr(),
209                    stride_y.as_ptr(),
210                    x_ptr,
211                    y_ptr,
212                    core::ptr::null_mut(),
213                    0,
214                    stream_ptr,
215                )
216            },
217            ElementKind::Bf16 => unsafe {
218                baracuda_kernels_sys::baracuda_kernels_flip_bf16_run(
219                    numel,
220                    rank,
221                    shape.as_ptr(),
222                    flip_axes_i32.as_ptr(),
223                    stride_x.as_ptr(),
224                    stride_y.as_ptr(),
225                    x_ptr,
226                    y_ptr,
227                    core::ptr::null_mut(),
228                    0,
229                    stream_ptr,
230                )
231            },
232            ElementKind::F64 => unsafe {
233                baracuda_kernels_sys::baracuda_kernels_flip_f64_run(
234                    numel,
235                    rank,
236                    shape.as_ptr(),
237                    flip_axes_i32.as_ptr(),
238                    stride_x.as_ptr(),
239                    stride_y.as_ptr(),
240                    x_ptr,
241                    y_ptr,
242                    core::ptr::null_mut(),
243                    0,
244                    stream_ptr,
245                )
246            },
247            _ => {
248                return Err(Error::Unsupported(
249                    "baracuda-kernels::FlipPlan::run: only f32/f16/bf16/f64 wired today",
250                ));
251            }
252        };
253        map_status(status)
254    }
255}
256
257fn map_status(code: i32) -> Result<()> {
258    match code {
259        0 => Ok(()),
260        1 => Err(Error::MisalignedOperand),
261        2 => Err(Error::InvalidProblem(
262            "baracuda-kernels-sys reported invalid problem",
263        )),
264        3 => Err(Error::Unsupported(
265            "baracuda-kernels-sys reported unsupported configuration",
266        )),
267        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
268        n => Err(Error::CutlassInternal(n)),
269    }
270}