baracuda_kernels/shape_layout/
repeat_backward.rs1use core::ffi::c_void;
17use core::marker::PhantomData;
18
19use baracuda_cutlass::{Error, Result};
20use baracuda_driver::Stream;
21use baracuda_kernels_types::{
22 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
23 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
24};
25
26#[derive(Copy, Clone, Debug)]
31pub struct RepeatBackwardDescriptor<const N: usize> {
32 pub input_shape: [i32; N],
34 pub repeats: [i32; N],
36 pub element: ElementKind,
38}
39
40impl<const N: usize> RepeatBackwardDescriptor<N> {
41 pub fn dy_shape(&self) -> [i32; N] {
44 let mut out = [0i32; N];
45 for d in 0..N {
46 out[d] = self.input_shape[d] * self.repeats[d];
47 }
48 out
49 }
50}
51
52pub struct RepeatBackwardArgs<'a, T: Element, const N: usize> {
59 pub dy: TensorRef<'a, T, N>,
61 pub dx: TensorMut<'a, T, N>,
63}
64
65pub struct RepeatBackwardPlan<T: Element, const N: usize> {
87 desc: RepeatBackwardDescriptor<N>,
88 sku: KernelSku,
89 _marker: PhantomData<T>,
90}
91
92impl<T: Element, const N: usize> RepeatBackwardPlan<T, N> {
93 pub fn select(
95 _stream: &Stream,
96 desc: &RepeatBackwardDescriptor<N>,
97 _pref: PlanPreference,
98 ) -> Result<Self> {
99 if desc.element != T::KIND {
100 return Err(Error::Unsupported(
101 "baracuda-kernels::RepeatBackwardPlan: descriptor element != type parameter T",
102 ));
103 }
104 for d in 0..N {
105 if desc.input_shape[d] < 0 {
106 return Err(Error::InvalidProblem(
107 "baracuda-kernels::RepeatBackwardPlan: input_shape dims must be \
108 non-negative",
109 ));
110 }
111 if desc.repeats[d] < 1 {
112 return Err(Error::InvalidProblem(
113 "baracuda-kernels::RepeatBackwardPlan: repeats[d] must be >= 1",
114 ));
115 }
116 }
117 let supported = matches!(
118 T::KIND,
119 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
120 );
121 if !supported {
122 return Err(Error::Unsupported(
123 "baracuda-kernels::RepeatBackwardPlan: today only `f32`, `f16`, `bf16`, \
124 `f64` are wired",
125 ));
126 }
127 let precision_guarantee = PrecisionGuarantee {
128 math_precision: MathPrecision::F32,
129 accumulator: ElementKind::F32,
130 bit_stable_on_same_hardware: false,
132 deterministic: true,
133 };
134 let sku = KernelSku {
135 category: OpCategory::ShapeLayout,
136 op: ShapeLayoutKind::Repeat as u16,
137 element: T::KIND,
138 aux_element: None,
139 layout: None,
140 epilogue: None,
141 arch: ArchSku::Sm80,
142 backend: BackendKind::Bespoke,
143 precision_guarantee,
144 };
145 Ok(Self {
146 desc: *desc,
147 sku,
148 _marker: PhantomData,
149 })
150 }
151
152 pub fn can_implement(&self, args: &RepeatBackwardArgs<'_, T, N>) -> Result<()> {
154 if args.dx.shape != self.desc.input_shape {
155 return Err(Error::InvalidProblem(
156 "baracuda-kernels::RepeatBackwardPlan: dx shape mismatch with descriptor \
157 input_shape",
158 ));
159 }
160 let expected_dy = self.desc.dy_shape();
161 if args.dy.shape != expected_dy {
162 return Err(Error::InvalidProblem(
163 "baracuda-kernels::RepeatBackwardPlan: dy shape mismatch with derived \
164 output shape (= input_shape[d] * repeats[d] per axis)",
165 ));
166 }
167 if N > 8 {
168 return Err(Error::Unsupported(
169 "baracuda-kernels::RepeatBackwardPlan: tensor rank > 8 not supported",
170 ));
171 }
172 let dx_numel = args.dx.numel();
173 let dy_numel = args.dy.numel();
174 if (args.dx.data.len() as i64) < dx_numel {
175 return Err(Error::BufferTooSmall {
176 needed: dx_numel as usize,
177 got: args.dx.data.len(),
178 });
179 }
180 if (args.dy.data.len() as i64) < dy_numel {
181 return Err(Error::BufferTooSmall {
182 needed: dy_numel as usize,
183 got: args.dy.data.len(),
184 });
185 }
186 Ok(())
187 }
188
189 #[inline]
191 pub fn workspace_size(&self) -> usize {
192 0
193 }
194 #[inline]
196 pub fn sku(&self) -> KernelSku {
197 self.sku
198 }
199 #[inline]
201 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
202 self.sku.precision_guarantee
203 }
204
205 pub fn run(
207 &self,
208 stream: &Stream,
209 _workspace: Workspace<'_>,
210 args: RepeatBackwardArgs<'_, T, N>,
211 ) -> Result<()> {
212 self.can_implement(&args)?;
213 let input_numel = args.dx.numel();
214 if input_numel == 0 {
215 return Ok(());
216 }
217 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
218 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
219 let stream_ptr = stream.as_raw() as *mut c_void;
220
221 let input_shape = self.desc.input_shape;
222 let repeats = self.desc.repeats;
223 let stride_dy = args.dy.stride;
224 let stride_dx = args.dx.stride;
225 let rank = N as i32;
226
227 macro_rules! dispatch {
229 ($sym:ident) => {{
230 unsafe {
231 baracuda_kernels_sys::$sym(
232 input_numel,
233 rank,
234 input_shape.as_ptr(),
235 repeats.as_ptr(),
236 stride_dy.as_ptr(),
237 stride_dx.as_ptr(),
238 dy_ptr,
239 dx_ptr,
240 core::ptr::null_mut(),
241 0,
242 stream_ptr,
243 )
244 }
245 }};
246 }
247
248 let status = match T::KIND {
249 ElementKind::F32 => dispatch!(baracuda_kernels_repeat_backward_f32_run),
250 ElementKind::F16 => dispatch!(baracuda_kernels_repeat_backward_f16_run),
251 ElementKind::Bf16 => dispatch!(baracuda_kernels_repeat_backward_bf16_run),
252 ElementKind::F64 => dispatch!(baracuda_kernels_repeat_backward_f64_run),
253 _ => {
254 return Err(Error::Unsupported(
255 "baracuda-kernels::RepeatBackwardPlan::run: only f32/f16/bf16/f64 \
256 wired today",
257 ));
258 }
259 };
260 map_status(status)
261 }
262}
263
264fn map_status(code: i32) -> Result<()> {
265 match code {
266 0 => Ok(()),
267 1 => Err(Error::MisalignedOperand),
268 2 => Err(Error::InvalidProblem(
269 "baracuda-kernels-sys reported invalid problem",
270 )),
271 3 => Err(Error::Unsupported(
272 "baracuda-kernels-sys reported unsupported configuration",
273 )),
274 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
275 n => Err(Error::CutlassInternal(n)),
276 }
277}