Skip to main content

baracuda_kernels/shape_layout/
repeat_backward.rs

1//! `repeat` backward plan — Category N (Phase 3 BW).
2//!
3//! Backward of `y = repeat(x, repeats)` (PyTorch `torch.repeat`):
4//! `dx[c_in] = sum_{k} dy[c_in + k * input_shape]` per axis — i.e. every
5//! `dy` cell whose `c_out[d] % input_shape[d] == c_in[d]` for all `d`
6//! contributes to `dx[c_in]`. One thread per dx cell loops the per-axis
7//! repeats grid (`prod(repeats[d])` cells) and accumulates. f16 / bf16
8//! accumulate in f32 inside the kernel for numerical stability; f32 /
9//! f64 accumulate in their native dtype.
10//!
11//! Not bit-stable across same-hardware reruns in principle (the grid
12//! iteration order is fixed today, but summation order matters in FP
13//! semantics, so we conservatively report
14//! `bit_stable_on_same_hardware: false`).
15
16use core::ffi::c_void;
17use core::marker::PhantomData;
18
19use baracuda_cutlass::{Error, Result};
20use baracuda_driver::Stream;
21use baracuda_kernels_types::{
22    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
23    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
24};
25
26/// Descriptor for a `repeat` backward op.
27///
28/// Mirrors [`crate::RepeatDescriptor`] for the params the BW needs:
29/// `input_shape` (= dx shape) and `repeats` (per-axis factor).
30#[derive(Copy, Clone, Debug)]
31pub struct RepeatBackwardDescriptor<const N: usize> {
32    /// Input tensor shape (= dx shape).
33    pub input_shape: [i32; N],
34    /// Per-axis repeat factors. Must be `>= 1` (same as forward).
35    pub repeats: [i32; N],
36    /// Element type of dy and dx.
37    pub element: ElementKind,
38}
39
40impl<const N: usize> RepeatBackwardDescriptor<N> {
41    /// Compute the dy shape (= forward output shape):
42    /// `input_shape[d] * repeats[d]` per axis.
43    pub fn dy_shape(&self) -> [i32; N] {
44        let mut out = [0i32; N];
45        for d in 0..N {
46            out[d] = self.input_shape[d] * self.repeats[d];
47        }
48        out
49    }
50}
51
52/// Args bundle for a Repeat backward launch.
53///
54/// `dy.shape` must match the forward output shape (`input_shape[d] *
55/// repeats[d]` per axis). `dx.shape` must match `desc.input_shape`. No
56/// saved forward tensors are needed — the BW formula is a pure sum over
57/// dy.
58pub struct RepeatBackwardArgs<'a, T: Element, const N: usize> {
59    /// Upstream gradient — full forward output shape.
60    pub dy: TensorRef<'a, T, N>,
61    /// Gradient w.r.t. the forward input — input shape.
62    pub dx: TensorMut<'a, T, N>,
63}
64
65/// `repeat` backward plan.
66///
67/// Adjoint of [`crate::RepeatPlan`]:
68/// `dx[c_in] = Σ_k dy[c_in + k · input_shape]` — every `dy` cell that
69/// maps back to `c_in` under the FW's modulo contributes. One thread
70/// per `dx` cell sweeps the `prod(repeats[d])` contributing cells.
71/// f16 / bf16 accumulate in f32 internally; f32 / f64 accumulate in
72/// their native dtype.
73///
74/// **When to use**: BW for [`RepeatPlan`](crate::RepeatPlan).
75///
76/// **Dtypes**: `{f32, f64, f16, bf16}`.
77///
78/// **Shape limits**: rank in `[1, 8]`; `repeats[d] ≥ 1`.
79///
80/// **Workspace**: none.
81///
82/// **Precision guarantee**: deterministic (no atomics — one thread
83/// per output cell, deterministic iteration order). Conservatively
84/// reported as **not bit-stable** because summation order matters in
85/// FP semantics and a future refactor might reorder the inner loop.
86pub struct RepeatBackwardPlan<T: Element, const N: usize> {
87    desc: RepeatBackwardDescriptor<N>,
88    sku: KernelSku,
89    _marker: PhantomData<T>,
90}
91
92impl<T: Element, const N: usize> RepeatBackwardPlan<T, N> {
93    /// Pick a kernel for `desc`.
94    pub fn select(
95        _stream: &Stream,
96        desc: &RepeatBackwardDescriptor<N>,
97        _pref: PlanPreference,
98    ) -> Result<Self> {
99        if desc.element != T::KIND {
100            return Err(Error::Unsupported(
101                "baracuda-kernels::RepeatBackwardPlan: descriptor element != type parameter T",
102            ));
103        }
104        for d in 0..N {
105            if desc.input_shape[d] < 0 {
106                return Err(Error::InvalidProblem(
107                    "baracuda-kernels::RepeatBackwardPlan: input_shape dims must be \
108                     non-negative",
109                ));
110            }
111            if desc.repeats[d] < 1 {
112                return Err(Error::InvalidProblem(
113                    "baracuda-kernels::RepeatBackwardPlan: repeats[d] must be >= 1",
114                ));
115            }
116        }
117        let supported = matches!(
118            T::KIND,
119            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
120        );
121        if !supported {
122            return Err(Error::Unsupported(
123                "baracuda-kernels::RepeatBackwardPlan: today only `f32`, `f16`, `bf16`, \
124                 `f64` are wired",
125            ));
126        }
127        let precision_guarantee = PrecisionGuarantee {
128            math_precision: MathPrecision::F32,
129            accumulator: ElementKind::F32,
130            // Sum order matters in FP — not bit-stable in principle.
131            bit_stable_on_same_hardware: false,
132            deterministic: true,
133        };
134        let sku = KernelSku {
135            category: OpCategory::ShapeLayout,
136            op: ShapeLayoutKind::Repeat as u16,
137            element: T::KIND,
138            aux_element: None,
139            layout: None,
140            epilogue: None,
141            arch: ArchSku::Sm80,
142            backend: BackendKind::Bespoke,
143            precision_guarantee,
144        };
145        Ok(Self {
146            desc: *desc,
147            sku,
148            _marker: PhantomData,
149        })
150    }
151
152    /// Validate args.
153    pub fn can_implement(&self, args: &RepeatBackwardArgs<'_, T, N>) -> Result<()> {
154        if args.dx.shape != self.desc.input_shape {
155            return Err(Error::InvalidProblem(
156                "baracuda-kernels::RepeatBackwardPlan: dx shape mismatch with descriptor \
157                 input_shape",
158            ));
159        }
160        let expected_dy = self.desc.dy_shape();
161        if args.dy.shape != expected_dy {
162            return Err(Error::InvalidProblem(
163                "baracuda-kernels::RepeatBackwardPlan: dy shape mismatch with derived \
164                 output shape (= input_shape[d] * repeats[d] per axis)",
165            ));
166        }
167        if N > 8 {
168            return Err(Error::Unsupported(
169                "baracuda-kernels::RepeatBackwardPlan: tensor rank > 8 not supported",
170            ));
171        }
172        let dx_numel = args.dx.numel();
173        let dy_numel = args.dy.numel();
174        if (args.dx.data.len() as i64) < dx_numel {
175            return Err(Error::BufferTooSmall {
176                needed: dx_numel as usize,
177                got: args.dx.data.len(),
178            });
179        }
180        if (args.dy.data.len() as i64) < dy_numel {
181            return Err(Error::BufferTooSmall {
182                needed: dy_numel as usize,
183                got: args.dy.data.len(),
184            });
185        }
186        Ok(())
187    }
188
189    /// Workspace size in bytes. Always `0`.
190    #[inline]
191    pub fn workspace_size(&self) -> usize {
192        0
193    }
194    /// Identity of the kernel this plan picked.
195    #[inline]
196    pub fn sku(&self) -> KernelSku {
197        self.sku
198    }
199    /// Numerical guarantees.
200    #[inline]
201    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
202        self.sku.precision_guarantee
203    }
204
205    /// Launch.
206    pub fn run(
207        &self,
208        stream: &Stream,
209        _workspace: Workspace<'_>,
210        args: RepeatBackwardArgs<'_, T, N>,
211    ) -> Result<()> {
212        self.can_implement(&args)?;
213        let input_numel = args.dx.numel();
214        if input_numel == 0 {
215            return Ok(());
216        }
217        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
218        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
219        let stream_ptr = stream.as_raw() as *mut c_void;
220
221        let input_shape = self.desc.input_shape;
222        let repeats = self.desc.repeats;
223        let stride_dy = args.dy.stride;
224        let stride_dx = args.dx.stride;
225        let rank = N as i32;
226
227        // All four FFI symbols share the same parameter shape.
228        macro_rules! dispatch {
229            ($sym:ident) => {{
230                unsafe {
231                    baracuda_kernels_sys::$sym(
232                        input_numel,
233                        rank,
234                        input_shape.as_ptr(),
235                        repeats.as_ptr(),
236                        stride_dy.as_ptr(),
237                        stride_dx.as_ptr(),
238                        dy_ptr,
239                        dx_ptr,
240                        core::ptr::null_mut(),
241                        0,
242                        stream_ptr,
243                    )
244                }
245            }};
246        }
247
248        let status = match T::KIND {
249            ElementKind::F32 => dispatch!(baracuda_kernels_repeat_backward_f32_run),
250            ElementKind::F16 => dispatch!(baracuda_kernels_repeat_backward_f16_run),
251            ElementKind::Bf16 => dispatch!(baracuda_kernels_repeat_backward_bf16_run),
252            ElementKind::F64 => dispatch!(baracuda_kernels_repeat_backward_f64_run),
253            _ => {
254                return Err(Error::Unsupported(
255                    "baracuda-kernels::RepeatBackwardPlan::run: only f32/f16/bf16/f64 \
256                     wired today",
257                ));
258            }
259        };
260        map_status(status)
261    }
262}
263
264fn map_status(code: i32) -> Result<()> {
265    match code {
266        0 => Ok(()),
267        1 => Err(Error::MisalignedOperand),
268        2 => Err(Error::InvalidProblem(
269            "baracuda-kernels-sys reported invalid problem",
270        )),
271        3 => Err(Error::Unsupported(
272            "baracuda-kernels-sys reported unsupported configuration",
273        )),
274        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
275        n => Err(Error::CutlassInternal(n)),
276    }
277}