baracuda_kernels/shape_layout/
repeat.rs1use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
13 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
14};
15
16#[derive(Copy, Clone, Debug)]
18pub struct RepeatDescriptor<const N: usize> {
19 pub input_shape: [i32; N],
21 pub repeats: [i32; N],
23 pub element: ElementKind,
25}
26
27impl<const N: usize> RepeatDescriptor<N> {
28 pub fn output_shape(&self) -> [i32; N] {
30 let mut out = [0i32; N];
31 for d in 0..N {
32 out[d] = self.input_shape[d] * self.repeats[d];
33 }
34 out
35 }
36}
37
38pub struct RepeatArgs<'a, T: Element, const N: usize> {
40 pub x: TensorRef<'a, T, N>,
42 pub y: TensorMut<'a, T, N>,
44}
45
46pub struct RepeatPlan<T: Element, const N: usize> {
63 desc: RepeatDescriptor<N>,
64 sku: KernelSku,
65 _marker: PhantomData<T>,
66}
67
68impl<T: Element, const N: usize> RepeatPlan<T, N> {
69 pub fn select(
71 _stream: &Stream,
72 desc: &RepeatDescriptor<N>,
73 _pref: PlanPreference,
74 ) -> Result<Self> {
75 if desc.element != T::KIND {
76 return Err(Error::Unsupported(
77 "baracuda-kernels::RepeatPlan: descriptor element != type parameter T",
78 ));
79 }
80 for d in 0..N {
81 if desc.input_shape[d] < 0 {
82 return Err(Error::InvalidProblem(
83 "baracuda-kernels::RepeatPlan: input_shape dims must be non-negative",
84 ));
85 }
86 if desc.repeats[d] < 1 {
87 return Err(Error::InvalidProblem(
88 "baracuda-kernels::RepeatPlan: repeats[d] must be >= 1",
89 ));
90 }
91 }
92 if !matches!(
93 T::KIND,
94 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
95 ) {
96 return Err(Error::Unsupported(
97 "baracuda-kernels::RepeatPlan: supported dtypes are \
98 `{f32, f16, bf16, f64}`",
99 ));
100 }
101 let precision_guarantee = PrecisionGuarantee {
102 math_precision: MathPrecision::F32,
103 accumulator: ElementKind::F32,
104 bit_stable_on_same_hardware: true,
105 deterministic: true,
106 };
107 let sku = KernelSku {
108 category: OpCategory::ShapeLayout,
109 op: ShapeLayoutKind::Repeat as u16,
110 element: T::KIND,
111 aux_element: None,
112 layout: None,
113 epilogue: None,
114 arch: ArchSku::Sm80,
115 backend: BackendKind::Bespoke,
116 precision_guarantee,
117 };
118 Ok(Self {
119 desc: *desc,
120 sku,
121 _marker: PhantomData,
122 })
123 }
124
125 pub fn can_implement(&self, args: &RepeatArgs<'_, T, N>) -> Result<()> {
127 if args.x.shape != self.desc.input_shape {
128 return Err(Error::InvalidProblem(
129 "baracuda-kernels::RepeatPlan: X shape mismatch",
130 ));
131 }
132 let expected_out = self.desc.output_shape();
133 if args.y.shape != expected_out {
134 return Err(Error::InvalidProblem(
135 "baracuda-kernels::RepeatPlan: Y shape mismatch with derived output \
136 (output[d] = input.shape[d] * repeats[d])",
137 ));
138 }
139 if N > 8 {
140 return Err(Error::Unsupported(
141 "baracuda-kernels::RepeatPlan: tensor rank > 8 not supported",
142 ));
143 }
144 let x_numel = args.x.numel();
145 let y_numel = args.y.numel();
146 let x_len = args.x.data.len() as i64;
147 let y_len = args.y.data.len() as i64;
148 if x_len < x_numel {
149 return Err(Error::BufferTooSmall {
150 needed: x_numel as usize,
151 got: x_len as usize,
152 });
153 }
154 if y_len < y_numel {
155 return Err(Error::BufferTooSmall {
156 needed: y_numel as usize,
157 got: y_len as usize,
158 });
159 }
160 Ok(())
161 }
162
163 #[inline]
165 pub fn workspace_size(&self) -> usize {
166 0
167 }
168 #[inline]
170 pub fn sku(&self) -> KernelSku {
171 self.sku
172 }
173 #[inline]
175 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
176 self.sku.precision_guarantee
177 }
178
179 pub fn run(
181 &self,
182 stream: &Stream,
183 _workspace: Workspace<'_>,
184 args: RepeatArgs<'_, T, N>,
185 ) -> Result<()> {
186 self.can_implement(&args)?;
187 let output_numel = args.y.numel();
188 if output_numel == 0 {
189 return Ok(());
190 }
191 let x_ptr = args.x.data.as_raw().0 as *const c_void;
192 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
193 let stream_ptr = stream.as_raw() as *mut c_void;
194
195 let input_shape = self.desc.input_shape;
196 let output_shape = self.desc.output_shape();
197 let stride_x = args.x.stride;
198 let stride_y = args.y.stride;
199 let rank = N as i32;
200
201 macro_rules! dispatch {
203 ($sym:ident) => {{
204 unsafe {
205 baracuda_kernels_sys::$sym(
206 output_numel,
207 rank,
208 input_shape.as_ptr(),
209 output_shape.as_ptr(),
210 stride_x.as_ptr(),
211 stride_y.as_ptr(),
212 x_ptr,
213 y_ptr,
214 core::ptr::null_mut(),
215 0,
216 stream_ptr,
217 )
218 }
219 }};
220 }
221
222 let status = match T::KIND {
223 ElementKind::F32 => dispatch!(baracuda_kernels_repeat_f32_run),
224 ElementKind::F16 => dispatch!(baracuda_kernels_repeat_f16_run),
225 ElementKind::Bf16 => dispatch!(baracuda_kernels_repeat_bf16_run),
226 ElementKind::F64 => dispatch!(baracuda_kernels_repeat_f64_run),
227 _ => {
228 return Err(Error::Unsupported(
229 "baracuda-kernels::RepeatPlan::run: this dtype is not wired",
230 ));
231 }
232 };
233 map_status(status)
234 }
235}
236
237fn map_status(code: i32) -> Result<()> {
238 match code {
239 0 => Ok(()),
240 1 => Err(Error::MisalignedOperand),
241 2 => Err(Error::InvalidProblem(
242 "baracuda-kernels-sys reported invalid problem",
243 )),
244 3 => Err(Error::Unsupported(
245 "baracuda-kernels-sys reported unsupported configuration",
246 )),
247 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
248 n => Err(Error::CutlassInternal(n)),
249 }
250}