1use core::ffi::c_void;
10use core::marker::PhantomData;
11
12use baracuda_cutlass::{Error, Result};
13use baracuda_driver::Stream;
14use baracuda_kernels_types::{
15 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
16 TensorRef, Workspace,
17};
18
19use super::map_status;
20use super::sort::{build_sku, validate_sort_args_2, validate_sort_desc};
21
22#[derive(Copy, Clone, Debug)]
24pub struct MsortDescriptor {
25 pub batch: i32,
27 pub row_len: i32,
29 pub descending: bool,
31 pub element: ElementKind,
33}
34
35pub struct MsortArgs<'a, T: Element> {
37 pub input: TensorRef<'a, T, 2>,
39 pub values: TensorMut<'a, T, 2>,
41 pub indices: TensorMut<'a, i32, 2>,
43}
44
45pub struct MsortPlan<T: Element> {
64 desc: MsortDescriptor,
65 sku: KernelSku,
66 _marker: PhantomData<T>,
67}
68
69impl<T: Element> MsortPlan<T> {
70 pub fn select(
72 _stream: &Stream,
73 desc: &MsortDescriptor,
74 _pref: PlanPreference,
75 ) -> Result<Self> {
76 validate_sort_desc(desc.batch, desc.row_len, desc.element, T::KIND, "MsortPlan")?;
77 let sku = build_sku::<T>(SortKind::Msort);
78 Ok(Self {
79 desc: *desc,
80 sku,
81 _marker: PhantomData,
82 })
83 }
84
85 pub fn can_implement(&self, args: &MsortArgs<'_, T>) -> Result<()> {
87 validate_sort_args_2(
88 self.desc.batch,
89 self.desc.row_len,
90 args.input.shape,
91 args.values.shape,
92 args.indices.shape,
93 "MsortPlan",
94 )
95 }
96
97 #[inline]
99 pub fn workspace_size(&self) -> usize {
100 0
101 }
102
103 #[inline]
105 pub fn sku(&self) -> KernelSku {
106 self.sku
107 }
108
109 #[inline]
111 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
112 self.sku.precision_guarantee
113 }
114
115 pub fn run(
117 &self,
118 stream: &Stream,
119 _workspace: Workspace<'_>,
120 args: MsortArgs<'_, T>,
121 ) -> Result<()> {
122 self.can_implement(&args)?;
123 if self.desc.batch == 0 || self.desc.row_len == 0 {
124 return Ok(());
125 }
126 let in_ptr = args.input.data.as_raw().0 as *const c_void;
127 let vals_ptr = args.values.data.as_raw().0 as *mut c_void;
128 let idx_ptr = args.indices.data.as_raw().0 as *mut c_void;
129 let stream_ptr = stream.as_raw() as *mut c_void;
130 let desc_flag = if self.desc.descending { 1 } else { 0 };
131
132 let status = match T::KIND {
133 ElementKind::F32 => unsafe {
134 baracuda_kernels_sys::baracuda_kernels_msort_f32_run(
135 self.desc.batch,
136 self.desc.row_len,
137 desc_flag,
138 in_ptr,
139 vals_ptr,
140 idx_ptr,
141 core::ptr::null_mut(),
142 0,
143 stream_ptr,
144 )
145 },
146 ElementKind::F64 => unsafe {
147 baracuda_kernels_sys::baracuda_kernels_msort_f64_run(
148 self.desc.batch,
149 self.desc.row_len,
150 desc_flag,
151 in_ptr,
152 vals_ptr,
153 idx_ptr,
154 core::ptr::null_mut(),
155 0,
156 stream_ptr,
157 )
158 },
159 ElementKind::I32 => unsafe {
160 baracuda_kernels_sys::baracuda_kernels_msort_i32_run(
161 self.desc.batch,
162 self.desc.row_len,
163 desc_flag,
164 in_ptr,
165 vals_ptr,
166 idx_ptr,
167 core::ptr::null_mut(),
168 0,
169 stream_ptr,
170 )
171 },
172 ElementKind::I64 => unsafe {
173 baracuda_kernels_sys::baracuda_kernels_msort_i64_run(
174 self.desc.batch,
175 self.desc.row_len,
176 desc_flag,
177 in_ptr,
178 vals_ptr,
179 idx_ptr,
180 core::ptr::null_mut(),
181 0,
182 stream_ptr,
183 )
184 },
185 _ => {
186 return Err(Error::Unsupported(
187 "baracuda-kernels::MsortPlan::run reached an unimplemented dtype",
188 ));
189 }
190 };
191 map_status(status)
192 }
193}
194
195#[derive(Copy, Clone, Debug)]
199pub struct MsortBackwardDescriptor {
200 pub batch: i32,
202 pub row_len: i32,
204 pub element: ElementKind,
206}
207
208pub struct MsortBackwardArgs<'a, T: Element> {
210 pub dy: TensorRef<'a, T, 2>,
212 pub indices: TensorRef<'a, i32, 2>,
214 pub dx: TensorMut<'a, T, 2>,
216}
217
218pub struct MsortBackwardPlan<T: Element> {
234 desc: MsortBackwardDescriptor,
235 sku: KernelSku,
236 _marker: PhantomData<T>,
237}
238
239impl<T: Element> MsortBackwardPlan<T> {
240 pub fn select(
242 _stream: &Stream,
243 desc: &MsortBackwardDescriptor,
244 _pref: PlanPreference,
245 ) -> Result<Self> {
246 validate_sort_desc(
247 desc.batch,
248 desc.row_len,
249 desc.element,
250 T::KIND,
251 "MsortBackwardPlan",
252 )?;
253 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
254 return Err(Error::Unsupported(
255 "baracuda-kernels::MsortBackwardPlan: today only f32 / f64 grads supported",
256 ));
257 }
258 let sku = build_sku::<T>(SortKind::MsortBackward);
259 Ok(Self {
260 desc: *desc,
261 sku,
262 _marker: PhantomData,
263 })
264 }
265
266 pub fn can_implement(&self, args: &MsortBackwardArgs<'_, T>) -> Result<()> {
268 let expected = [self.desc.batch, self.desc.row_len];
269 if args.dy.shape != expected
270 || args.indices.shape != expected
271 || args.dx.shape != expected
272 {
273 return Err(Error::InvalidProblem(
274 "baracuda-kernels::MsortBackwardPlan: tensor shapes != [batch, row_len]",
275 ));
276 }
277 Ok(())
278 }
279
280 #[inline]
282 pub fn workspace_size(&self) -> usize {
283 0
284 }
285
286 #[inline]
288 pub fn sku(&self) -> KernelSku {
289 self.sku
290 }
291
292 #[inline]
294 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
295 self.sku.precision_guarantee
296 }
297
298 pub fn run(
300 &self,
301 stream: &Stream,
302 _workspace: Workspace<'_>,
303 args: MsortBackwardArgs<'_, T>,
304 ) -> Result<()> {
305 self.can_implement(&args)?;
306 if self.desc.batch == 0 || self.desc.row_len == 0 {
307 return Ok(());
308 }
309 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
310 let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
311 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
312 let stream_ptr = stream.as_raw() as *mut c_void;
313
314 let status = match T::KIND {
315 ElementKind::F32 => unsafe {
316 baracuda_kernels_sys::baracuda_kernels_msort_backward_f32_run(
317 self.desc.batch,
318 self.desc.row_len,
319 dy_ptr,
320 idx_ptr,
321 dx_ptr,
322 core::ptr::null_mut(),
323 0,
324 stream_ptr,
325 )
326 },
327 ElementKind::F64 => unsafe {
328 baracuda_kernels_sys::baracuda_kernels_msort_backward_f64_run(
329 self.desc.batch,
330 self.desc.row_len,
331 dy_ptr,
332 idx_ptr,
333 dx_ptr,
334 core::ptr::null_mut(),
335 0,
336 stream_ptr,
337 )
338 },
339 _ => {
340 return Err(Error::Unsupported(
341 "baracuda-kernels::MsortBackwardPlan::run reached an unimplemented dtype",
342 ));
343 }
344 };
345 map_status(status)
346 }
347}