Skip to main content

baracuda_kernels/shape_layout/
repeat.rs

1//! `repeat` plan — per-axis tile (output > input). PyTorch
2//! `torch.repeat(x, *repeats)`: `output.shape[d] = input.shape[d] *
3//! repeats[d]`. The kernel walks output cells and computes input
4//! coords as `output_coord[d] % input.shape[d]`.
5
6use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
13    PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
14};
15
16/// Descriptor for a `repeat` op.
17#[derive(Copy, Clone, Debug)]
18pub struct RepeatDescriptor<const N: usize> {
19    /// Input tensor shape.
20    pub input_shape: [i32; N],
21    /// Per-axis repeat factors. Must be `>= 1`.
22    pub repeats: [i32; N],
23    /// Element type.
24    pub element: ElementKind,
25}
26
27impl<const N: usize> RepeatDescriptor<N> {
28    /// Compute the output shape: `input.shape[d] * repeats[d]` per axis.
29    pub fn output_shape(&self) -> [i32; N] {
30        let mut out = [0i32; N];
31        for d in 0..N {
32            out[d] = self.input_shape[d] * self.repeats[d];
33        }
34        out
35    }
36}
37
38/// Args bundle for a Repeat launch.
39pub struct RepeatArgs<'a, T: Element, const N: usize> {
40    /// Input.
41    pub x: TensorRef<'a, T, N>,
42    /// Output — shape matches `desc.output_shape()`.
43    pub y: TensorMut<'a, T, N>,
44}
45
46/// `repeat` plan.
47///
48/// Per-axis tile: `output.shape[d] = input.shape[d] * repeats[d]`
49/// (PyTorch `torch.repeat`). Kernel walks output cells and computes
50/// input coords as `output_coord[d] % input.shape[d]`.
51///
52/// **When to use**: forward repeat. Pair with
53/// [`RepeatBackwardPlan`](crate::RepeatBackwardPlan).
54///
55/// **Dtypes**: `{f32, f64, f16, bf16}`. Pure load + store.
56///
57/// **Shape limits**: rank in `[1, 8]`; `repeats[d] ≥ 1`.
58///
59/// **Workspace**: none.
60///
61/// **Precision guarantee**: deterministic, bit-stable, bit-exact.
62pub struct RepeatPlan<T: Element, const N: usize> {
63    desc: RepeatDescriptor<N>,
64    sku: KernelSku,
65    _marker: PhantomData<T>,
66}
67
68impl<T: Element, const N: usize> RepeatPlan<T, N> {
69    /// Pick a kernel for `desc`.
70    pub fn select(
71        _stream: &Stream,
72        desc: &RepeatDescriptor<N>,
73        _pref: PlanPreference,
74    ) -> Result<Self> {
75        if desc.element != T::KIND {
76            return Err(Error::Unsupported(
77                "baracuda-kernels::RepeatPlan: descriptor element != type parameter T",
78            ));
79        }
80        for d in 0..N {
81            if desc.input_shape[d] < 0 {
82                return Err(Error::InvalidProblem(
83                    "baracuda-kernels::RepeatPlan: input_shape dims must be non-negative",
84                ));
85            }
86            if desc.repeats[d] < 1 {
87                return Err(Error::InvalidProblem(
88                    "baracuda-kernels::RepeatPlan: repeats[d] must be >= 1",
89                ));
90            }
91        }
92        if !matches!(
93            T::KIND,
94            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
95        ) {
96            return Err(Error::Unsupported(
97                "baracuda-kernels::RepeatPlan: supported dtypes are \
98                 `{f32, f16, bf16, f64}`",
99            ));
100        }
101        let precision_guarantee = PrecisionGuarantee {
102            math_precision: MathPrecision::F32,
103            accumulator: ElementKind::F32,
104            bit_stable_on_same_hardware: true,
105            deterministic: true,
106        };
107        let sku = KernelSku {
108            category: OpCategory::ShapeLayout,
109            op: ShapeLayoutKind::Repeat as u16,
110            element: T::KIND,
111            aux_element: None,
112            layout: None,
113            epilogue: None,
114            arch: ArchSku::Sm80,
115            backend: BackendKind::Bespoke,
116            precision_guarantee,
117        };
118        Ok(Self {
119            desc: *desc,
120            sku,
121            _marker: PhantomData,
122        })
123    }
124
125    /// Validate args.
126    pub fn can_implement(&self, args: &RepeatArgs<'_, T, N>) -> Result<()> {
127        if args.x.shape != self.desc.input_shape {
128            return Err(Error::InvalidProblem(
129                "baracuda-kernels::RepeatPlan: X shape mismatch",
130            ));
131        }
132        let expected_out = self.desc.output_shape();
133        if args.y.shape != expected_out {
134            return Err(Error::InvalidProblem(
135                "baracuda-kernels::RepeatPlan: Y shape mismatch with derived output \
136                 (output[d] = input.shape[d] * repeats[d])",
137            ));
138        }
139        if N > 8 {
140            return Err(Error::Unsupported(
141                "baracuda-kernels::RepeatPlan: tensor rank > 8 not supported",
142            ));
143        }
144        let x_numel = args.x.numel();
145        let y_numel = args.y.numel();
146        let x_len = args.x.data.len() as i64;
147        let y_len = args.y.data.len() as i64;
148        if x_len < x_numel {
149            return Err(Error::BufferTooSmall {
150                needed: x_numel as usize,
151                got: x_len as usize,
152            });
153        }
154        if y_len < y_numel {
155            return Err(Error::BufferTooSmall {
156                needed: y_numel as usize,
157                got: y_len as usize,
158            });
159        }
160        Ok(())
161    }
162
163    /// Workspace size in bytes.
164    #[inline]
165    pub fn workspace_size(&self) -> usize {
166        0
167    }
168    /// Identity of the kernel this plan picked.
169    #[inline]
170    pub fn sku(&self) -> KernelSku {
171        self.sku
172    }
173    /// Numerical guarantees.
174    #[inline]
175    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
176        self.sku.precision_guarantee
177    }
178
179    /// Launch.
180    pub fn run(
181        &self,
182        stream: &Stream,
183        _workspace: Workspace<'_>,
184        args: RepeatArgs<'_, T, N>,
185    ) -> Result<()> {
186        self.can_implement(&args)?;
187        let output_numel = args.y.numel();
188        if output_numel == 0 {
189            return Ok(());
190        }
191        let x_ptr = args.x.data.as_raw().0 as *const c_void;
192        let y_ptr = args.y.data.as_raw().0 as *mut c_void;
193        let stream_ptr = stream.as_raw() as *mut c_void;
194
195        let input_shape = self.desc.input_shape;
196        let output_shape = self.desc.output_shape();
197        let stride_x = args.x.stride;
198        let stride_y = args.y.stride;
199        let rank = N as i32;
200
201        // Every Repeat FFI symbol shares the same parameter shape.
202        macro_rules! dispatch {
203            ($sym:ident) => {{
204                unsafe {
205                    baracuda_kernels_sys::$sym(
206                        output_numel,
207                        rank,
208                        input_shape.as_ptr(),
209                        output_shape.as_ptr(),
210                        stride_x.as_ptr(),
211                        stride_y.as_ptr(),
212                        x_ptr,
213                        y_ptr,
214                        core::ptr::null_mut(),
215                        0,
216                        stream_ptr,
217                    )
218                }
219            }};
220        }
221
222        let status = match T::KIND {
223            ElementKind::F32 => dispatch!(baracuda_kernels_repeat_f32_run),
224            ElementKind::F16 => dispatch!(baracuda_kernels_repeat_f16_run),
225            ElementKind::Bf16 => dispatch!(baracuda_kernels_repeat_bf16_run),
226            ElementKind::F64 => dispatch!(baracuda_kernels_repeat_f64_run),
227            _ => {
228                return Err(Error::Unsupported(
229                    "baracuda-kernels::RepeatPlan::run: this dtype is not wired",
230                ));
231            }
232        };
233        map_status(status)
234    }
235}
236
237fn map_status(code: i32) -> Result<()> {
238    match code {
239        0 => Ok(()),
240        1 => Err(Error::MisalignedOperand),
241        2 => Err(Error::InvalidProblem(
242            "baracuda-kernels-sys reported invalid problem",
243        )),
244        3 => Err(Error::Unsupported(
245            "baracuda-kernels-sys reported unsupported configuration",
246        )),
247        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
248        n => Err(Error::CutlassInternal(n)),
249    }
250}