baracuda_kernels/shape_layout/
flip.rs1use core::ffi::c_void;
5use core::marker::PhantomData;
6
7use baracuda_cutlass::{Error, Result};
8use baracuda_driver::Stream;
9use baracuda_kernels_types::{
10 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
11 PlanPreference, PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
12};
13
14#[derive(Copy, Clone, Debug)]
19pub struct FlipDescriptor<const N: usize> {
20 pub shape: [i32; N],
22 pub flip_axes: [bool; N],
24 pub element: ElementKind,
26}
27
28pub struct FlipArgs<'a, T: Element, const N: usize> {
30 pub x: TensorRef<'a, T, N>,
32 pub y: TensorMut<'a, T, N>,
34}
35
36pub struct FlipPlan<T: Element, const N: usize> {
56 desc: FlipDescriptor<N>,
57 sku: KernelSku,
58 _marker: PhantomData<T>,
59}
60
61impl<T: Element, const N: usize> FlipPlan<T, N> {
62 pub fn select(
64 _stream: &Stream,
65 desc: &FlipDescriptor<N>,
66 _pref: PlanPreference,
67 ) -> Result<Self> {
68 if desc.element != T::KIND {
69 return Err(Error::Unsupported(
70 "baracuda-kernels::FlipPlan: descriptor element != type parameter T",
71 ));
72 }
73 for &d in desc.shape.iter() {
74 if d < 0 {
75 return Err(Error::InvalidProblem(
76 "baracuda-kernels::FlipPlan: shape dims must be non-negative",
77 ));
78 }
79 }
80 let supported = matches!(
81 T::KIND,
82 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
83 );
84 if !supported {
85 return Err(Error::Unsupported(
86 "baracuda-kernels::FlipPlan: today only `f32`, `f16`, `bf16`, `f64` \
87 are wired; other dtypes land in future fanout",
88 ));
89 }
90 let precision_guarantee = PrecisionGuarantee {
91 math_precision: MathPrecision::F32,
92 accumulator: ElementKind::F32,
93 bit_stable_on_same_hardware: true,
94 deterministic: true,
95 };
96 let sku = KernelSku {
97 category: OpCategory::ShapeLayout,
98 op: ShapeLayoutKind::Flip as u16,
99 element: T::KIND,
100 aux_element: None,
101 layout: None,
102 epilogue: None,
103 arch: ArchSku::Sm80,
104 backend: BackendKind::Bespoke,
105 precision_guarantee,
106 };
107 Ok(Self {
108 desc: *desc,
109 sku,
110 _marker: PhantomData,
111 })
112 }
113
114 pub fn can_implement(&self, args: &FlipArgs<'_, T, N>) -> Result<()> {
116 if args.x.shape != self.desc.shape {
117 return Err(Error::InvalidProblem(
118 "baracuda-kernels::FlipPlan: X shape mismatch with descriptor",
119 ));
120 }
121 if args.y.shape != self.desc.shape {
122 return Err(Error::InvalidProblem(
123 "baracuda-kernels::FlipPlan: Y shape mismatch with descriptor",
124 ));
125 }
126 if N > 8 {
127 return Err(Error::Unsupported(
128 "baracuda-kernels::FlipPlan: tensor rank > 8 not supported",
129 ));
130 }
131 let numel = args.y.numel();
132 let x_len = args.x.data.len() as i64;
133 let y_len = args.y.data.len() as i64;
134 if x_len < numel || y_len < numel {
135 return Err(Error::BufferTooSmall {
136 needed: numel as usize,
137 got: x_len.min(y_len) as usize,
138 });
139 }
140 Ok(())
141 }
142
143 #[inline]
145 pub fn workspace_size(&self) -> usize {
146 0
147 }
148 #[inline]
150 pub fn sku(&self) -> KernelSku {
151 self.sku
152 }
153 #[inline]
155 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
156 self.sku.precision_guarantee
157 }
158
159 pub fn run(
161 &self,
162 stream: &Stream,
163 _workspace: Workspace<'_>,
164 args: FlipArgs<'_, T, N>,
165 ) -> Result<()> {
166 self.can_implement(&args)?;
167 let numel = args.y.numel();
168 if numel == 0 {
169 return Ok(());
170 }
171 let x_ptr = args.x.data.as_raw().0 as *const c_void;
172 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
173 let stream_ptr = stream.as_raw() as *mut c_void;
174
175 let mut flip_axes_i32 = [0i32; 8];
177 for d in 0..N {
178 flip_axes_i32[d] = if self.desc.flip_axes[d] { 1 } else { 0 };
179 }
180
181 let shape = self.desc.shape;
182 let stride_x = args.x.stride;
183 let stride_y = args.y.stride;
184 let rank = N as i32;
185
186 let status = match T::KIND {
187 ElementKind::F32 => unsafe {
188 baracuda_kernels_sys::baracuda_kernels_flip_f32_run(
189 numel,
190 rank,
191 shape.as_ptr(),
192 flip_axes_i32.as_ptr(),
193 stride_x.as_ptr(),
194 stride_y.as_ptr(),
195 x_ptr,
196 y_ptr,
197 core::ptr::null_mut(),
198 0,
199 stream_ptr,
200 )
201 },
202 ElementKind::F16 => unsafe {
203 baracuda_kernels_sys::baracuda_kernels_flip_f16_run(
204 numel,
205 rank,
206 shape.as_ptr(),
207 flip_axes_i32.as_ptr(),
208 stride_x.as_ptr(),
209 stride_y.as_ptr(),
210 x_ptr,
211 y_ptr,
212 core::ptr::null_mut(),
213 0,
214 stream_ptr,
215 )
216 },
217 ElementKind::Bf16 => unsafe {
218 baracuda_kernels_sys::baracuda_kernels_flip_bf16_run(
219 numel,
220 rank,
221 shape.as_ptr(),
222 flip_axes_i32.as_ptr(),
223 stride_x.as_ptr(),
224 stride_y.as_ptr(),
225 x_ptr,
226 y_ptr,
227 core::ptr::null_mut(),
228 0,
229 stream_ptr,
230 )
231 },
232 ElementKind::F64 => unsafe {
233 baracuda_kernels_sys::baracuda_kernels_flip_f64_run(
234 numel,
235 rank,
236 shape.as_ptr(),
237 flip_axes_i32.as_ptr(),
238 stride_x.as_ptr(),
239 stride_y.as_ptr(),
240 x_ptr,
241 y_ptr,
242 core::ptr::null_mut(),
243 0,
244 stream_ptr,
245 )
246 },
247 _ => {
248 return Err(Error::Unsupported(
249 "baracuda-kernels::FlipPlan::run: only f32/f16/bf16/f64 wired today",
250 ));
251 }
252 };
253 map_status(status)
254 }
255}
256
257fn map_status(code: i32) -> Result<()> {
258 match code {
259 0 => Ok(()),
260 1 => Err(Error::MisalignedOperand),
261 2 => Err(Error::InvalidProblem(
262 "baracuda-kernels-sys reported invalid problem",
263 )),
264 3 => Err(Error::Unsupported(
265 "baracuda-kernels-sys reported unsupported configuration",
266 )),
267 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
268 n => Err(Error::CutlassInternal(n)),
269 }
270}