baracuda_kernels/shape_layout/
roll.rs1use core::ffi::c_void;
5use core::marker::PhantomData;
6
7use baracuda_cutlass::{Error, Result};
8use baracuda_driver::Stream;
9use baracuda_kernels_types::{
10 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
11 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
12};
13
14#[derive(Copy, Clone, Debug)]
19pub struct RollDescriptor<const N: usize> {
20 pub shape: [i32; N],
22 pub shifts: [i32; N],
24 pub element: ElementKind,
26}
27
28pub struct RollArgs<'a, T: Element, const N: usize> {
30 pub x: TensorRef<'a, T, N>,
32 pub y: TensorMut<'a, T, N>,
34}
35
36pub struct RollPlan<T: Element, const N: usize> {
54 desc: RollDescriptor<N>,
55 sku: KernelSku,
56 _marker: PhantomData<T>,
57}
58
59impl<T: Element, const N: usize> RollPlan<T, N> {
60 pub fn select(
62 _stream: &Stream,
63 desc: &RollDescriptor<N>,
64 _pref: PlanPreference,
65 ) -> Result<Self> {
66 if desc.element != T::KIND {
67 return Err(Error::Unsupported(
68 "baracuda-kernels::RollPlan: descriptor element != type parameter T",
69 ));
70 }
71 for &d in desc.shape.iter() {
72 if d < 0 {
73 return Err(Error::InvalidProblem(
74 "baracuda-kernels::RollPlan: shape dims must be non-negative",
75 ));
76 }
77 }
78 let supported = matches!(
79 T::KIND,
80 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
81 );
82 if !supported {
83 return Err(Error::Unsupported(
84 "baracuda-kernels::RollPlan: today only `f32`, `f16`, `bf16`, `f64` \
85 are wired; other dtypes land in future fanout",
86 ));
87 }
88 let precision_guarantee = PrecisionGuarantee {
89 math_precision: MathPrecision::F32,
90 accumulator: ElementKind::F32,
91 bit_stable_on_same_hardware: true,
92 deterministic: true,
93 };
94 let sku = KernelSku {
95 category: OpCategory::ShapeLayout,
96 op: ShapeLayoutKind::Roll as u16,
97 element: T::KIND,
98 aux_element: None,
99 layout: None,
100 epilogue: None,
101 arch: ArchSku::Sm80,
102 backend: BackendKind::Bespoke,
103 precision_guarantee,
104 };
105 Ok(Self {
106 desc: *desc,
107 sku,
108 _marker: PhantomData,
109 })
110 }
111
112 pub fn can_implement(&self, args: &RollArgs<'_, T, N>) -> Result<()> {
114 if args.x.shape != self.desc.shape {
115 return Err(Error::InvalidProblem(
116 "baracuda-kernels::RollPlan: X shape mismatch",
117 ));
118 }
119 if args.y.shape != self.desc.shape {
120 return Err(Error::InvalidProblem(
121 "baracuda-kernels::RollPlan: Y shape mismatch",
122 ));
123 }
124 if N > 8 {
125 return Err(Error::Unsupported(
126 "baracuda-kernels::RollPlan: tensor rank > 8 not supported",
127 ));
128 }
129 let numel = args.y.numel();
130 let x_len = args.x.data.len() as i64;
131 let y_len = args.y.data.len() as i64;
132 if x_len < numel || y_len < numel {
133 return Err(Error::BufferTooSmall {
134 needed: numel as usize,
135 got: x_len.min(y_len) as usize,
136 });
137 }
138 Ok(())
139 }
140
141 #[inline]
143 pub fn workspace_size(&self) -> usize {
144 0
145 }
146 #[inline]
148 pub fn sku(&self) -> KernelSku {
149 self.sku
150 }
151 #[inline]
153 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
154 self.sku.precision_guarantee
155 }
156
157 pub fn run(
159 &self,
160 stream: &Stream,
161 _workspace: Workspace<'_>,
162 args: RollArgs<'_, T, N>,
163 ) -> Result<()> {
164 self.can_implement(&args)?;
165 let numel = args.y.numel();
166 if numel == 0 {
167 return Ok(());
168 }
169 let x_ptr = args.x.data.as_raw().0 as *const c_void;
170 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
171 let stream_ptr = stream.as_raw() as *mut c_void;
172
173 let shape = self.desc.shape;
174 let shifts = self.desc.shifts;
175 let stride_x = args.x.stride;
176 let stride_y = args.y.stride;
177 let rank = N as i32;
178
179 let status = match T::KIND {
180 ElementKind::F32 => unsafe {
181 baracuda_kernels_sys::baracuda_kernels_roll_f32_run(
182 numel,
183 rank,
184 shape.as_ptr(),
185 shifts.as_ptr(),
186 stride_x.as_ptr(),
187 stride_y.as_ptr(),
188 x_ptr,
189 y_ptr,
190 core::ptr::null_mut(),
191 0,
192 stream_ptr,
193 )
194 },
195 ElementKind::F16 => unsafe {
196 baracuda_kernels_sys::baracuda_kernels_roll_f16_run(
197 numel,
198 rank,
199 shape.as_ptr(),
200 shifts.as_ptr(),
201 stride_x.as_ptr(),
202 stride_y.as_ptr(),
203 x_ptr,
204 y_ptr,
205 core::ptr::null_mut(),
206 0,
207 stream_ptr,
208 )
209 },
210 ElementKind::Bf16 => unsafe {
211 baracuda_kernels_sys::baracuda_kernels_roll_bf16_run(
212 numel,
213 rank,
214 shape.as_ptr(),
215 shifts.as_ptr(),
216 stride_x.as_ptr(),
217 stride_y.as_ptr(),
218 x_ptr,
219 y_ptr,
220 core::ptr::null_mut(),
221 0,
222 stream_ptr,
223 )
224 },
225 ElementKind::F64 => unsafe {
226 baracuda_kernels_sys::baracuda_kernels_roll_f64_run(
227 numel,
228 rank,
229 shape.as_ptr(),
230 shifts.as_ptr(),
231 stride_x.as_ptr(),
232 stride_y.as_ptr(),
233 x_ptr,
234 y_ptr,
235 core::ptr::null_mut(),
236 0,
237 stream_ptr,
238 )
239 },
240 _ => {
241 return Err(Error::Unsupported(
242 "baracuda-kernels::RollPlan::run: only f32/f16/bf16/f64 wired today",
243 ));
244 }
245 };
246 map_status(status)
247 }
248}
249
250fn map_status(code: i32) -> Result<()> {
251 match code {
252 0 => Ok(()),
253 1 => Err(Error::MisalignedOperand),
254 2 => Err(Error::InvalidProblem(
255 "baracuda-kernels-sys reported invalid problem",
256 )),
257 3 => Err(Error::Unsupported(
258 "baracuda-kernels-sys reported unsupported configuration",
259 )),
260 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
261 n => Err(Error::CutlassInternal(n)),
262 }
263}