1use core::ffi::c_void;
20use core::marker::PhantomData;
21
22use baracuda_cutlass::{Error, Result};
23use baracuda_driver::Stream;
24use baracuda_kernels_types::{
25 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
26 TensorRef, Workspace,
27};
28
29use super::map_status;
30use super::sort::build_sku;
31
32#[derive(Copy, Clone, Debug)]
34pub struct ArgsortDescriptor {
35 pub batch: i32,
37 pub row_len: i32,
41 pub descending: bool,
43 pub element: ElementKind,
45}
46
47pub struct ArgsortArgs<'a, T: Element> {
49 pub input: TensorRef<'a, T, 2>,
51 pub indices: TensorMut<'a, i32, 2>,
53}
54
55pub struct ArgsortPlan<T: Element> {
74 desc: ArgsortDescriptor,
75 sku: KernelSku,
76 _marker: PhantomData<T>,
77}
78
79impl<T: Element> ArgsortPlan<T> {
80 pub fn select(
82 _stream: &Stream,
83 desc: &ArgsortDescriptor,
84 _pref: PlanPreference,
85 ) -> Result<Self> {
86 if desc.element != T::KIND {
89 return Err(Error::Unsupported(
90 "baracuda-kernels::ArgsortPlan: descriptor element != type parameter T",
91 ));
92 }
93 if desc.batch < 0 || desc.row_len < 0 {
94 return Err(Error::InvalidProblem(
95 "baracuda-kernels::ArgsortPlan: batch / row_len must be non-negative",
96 ));
97 }
98 if !matches!(
99 desc.element,
100 ElementKind::F32 | ElementKind::F64 | ElementKind::I32 | ElementKind::I64
101 ) {
102 return Err(Error::Unsupported(
103 "baracuda-kernels::ArgsortPlan: today only f32 / f64 / i32 / i64 wired",
104 ));
105 }
106 let sku = build_sku::<T>(SortKind::Argsort);
107 Ok(Self {
108 desc: *desc,
109 sku,
110 _marker: PhantomData,
111 })
112 }
113
114 pub fn can_implement(&self, args: &ArgsortArgs<'_, T>) -> Result<()> {
116 let expected = [self.desc.batch, self.desc.row_len];
117 if args.input.shape != expected {
118 return Err(Error::InvalidProblem(
119 "baracuda-kernels::ArgsortPlan: input shape != [batch, row_len]",
120 ));
121 }
122 if args.indices.shape != expected {
123 return Err(Error::InvalidProblem(
124 "baracuda-kernels::ArgsortPlan: indices shape != [batch, row_len]",
125 ));
126 }
127 Ok(())
128 }
129
130 #[inline]
137 pub fn workspace_size(&self) -> usize {
138 if self.desc.row_len <= 1024 {
139 return 0;
140 }
141 let batch = self.desc.batch;
142 let row_len = self.desc.row_len;
143 if batch == 0 || row_len == 0 {
144 return 0;
145 }
146 match T::KIND {
147 ElementKind::F32 => unsafe {
148 baracuda_kernels_sys::baracuda_kernels_argsort_f32_big_workspace_size(
149 batch, row_len,
150 )
151 },
152 ElementKind::F64 => unsafe {
153 baracuda_kernels_sys::baracuda_kernels_argsort_f64_big_workspace_size(
154 batch, row_len,
155 )
156 },
157 ElementKind::I32 => unsafe {
158 baracuda_kernels_sys::baracuda_kernels_argsort_i32_big_workspace_size(
159 batch, row_len,
160 )
161 },
162 ElementKind::I64 => unsafe {
163 baracuda_kernels_sys::baracuda_kernels_argsort_i64_big_workspace_size(
164 batch, row_len,
165 )
166 },
167 _ => 0,
168 }
169 }
170
171 #[inline]
173 pub fn sku(&self) -> KernelSku {
174 self.sku
175 }
176
177 #[inline]
179 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
180 self.sku.precision_guarantee
181 }
182
183 pub fn run(
185 &self,
186 stream: &Stream,
187 workspace: Workspace<'_>,
188 args: ArgsortArgs<'_, T>,
189 ) -> Result<()> {
190 self.can_implement(&args)?;
191 if self.desc.batch == 0 || self.desc.row_len == 0 {
192 return Ok(());
193 }
194 let in_ptr = args.input.data.as_raw().0 as *const c_void;
195 let idx_ptr = args.indices.data.as_raw().0 as *mut c_void;
196 let stream_ptr = stream.as_raw() as *mut c_void;
197 let desc_flag = if self.desc.descending { 1 } else { 0 };
198
199 let use_big = self.desc.row_len > 1024;
203 let (ws_ptr, ws_bytes) = if use_big {
204 let needed = self.workspace_size();
205 match workspace {
206 Workspace::None => {
207 if needed == 0 {
208 (core::ptr::null_mut::<c_void>(), 0usize)
209 } else {
210 return Err(Error::WorkspaceTooSmall { needed, got: 0 });
211 }
212 }
213 Workspace::Borrowed(slice) => {
214 let got = slice.len();
215 if got < needed {
216 return Err(Error::WorkspaceTooSmall { needed, got });
217 }
218 (slice.as_raw().0 as *mut c_void, got)
219 }
220 }
221 } else {
222 let _ = workspace;
224 (core::ptr::null_mut::<c_void>(), 0usize)
225 };
226
227 let status = match (T::KIND, use_big) {
228 (ElementKind::F32, false) => unsafe {
229 baracuda_kernels_sys::baracuda_kernels_argsort_f32_run(
230 self.desc.batch,
231 self.desc.row_len,
232 desc_flag,
233 in_ptr,
234 idx_ptr,
235 core::ptr::null_mut(),
236 0,
237 stream_ptr,
238 )
239 },
240 (ElementKind::F64, false) => unsafe {
241 baracuda_kernels_sys::baracuda_kernels_argsort_f64_run(
242 self.desc.batch,
243 self.desc.row_len,
244 desc_flag,
245 in_ptr,
246 idx_ptr,
247 core::ptr::null_mut(),
248 0,
249 stream_ptr,
250 )
251 },
252 (ElementKind::I32, false) => unsafe {
253 baracuda_kernels_sys::baracuda_kernels_argsort_i32_run(
254 self.desc.batch,
255 self.desc.row_len,
256 desc_flag,
257 in_ptr,
258 idx_ptr,
259 core::ptr::null_mut(),
260 0,
261 stream_ptr,
262 )
263 },
264 (ElementKind::I64, false) => unsafe {
265 baracuda_kernels_sys::baracuda_kernels_argsort_i64_run(
266 self.desc.batch,
267 self.desc.row_len,
268 desc_flag,
269 in_ptr,
270 idx_ptr,
271 core::ptr::null_mut(),
272 0,
273 stream_ptr,
274 )
275 },
276 (ElementKind::F32, true) => unsafe {
277 baracuda_kernels_sys::baracuda_kernels_argsort_f32_big_run(
278 self.desc.batch,
279 self.desc.row_len,
280 desc_flag,
281 in_ptr,
282 idx_ptr,
283 ws_ptr,
284 ws_bytes,
285 stream_ptr,
286 )
287 },
288 (ElementKind::F64, true) => unsafe {
289 baracuda_kernels_sys::baracuda_kernels_argsort_f64_big_run(
290 self.desc.batch,
291 self.desc.row_len,
292 desc_flag,
293 in_ptr,
294 idx_ptr,
295 ws_ptr,
296 ws_bytes,
297 stream_ptr,
298 )
299 },
300 (ElementKind::I32, true) => unsafe {
301 baracuda_kernels_sys::baracuda_kernels_argsort_i32_big_run(
302 self.desc.batch,
303 self.desc.row_len,
304 desc_flag,
305 in_ptr,
306 idx_ptr,
307 ws_ptr,
308 ws_bytes,
309 stream_ptr,
310 )
311 },
312 (ElementKind::I64, true) => unsafe {
313 baracuda_kernels_sys::baracuda_kernels_argsort_i64_big_run(
314 self.desc.batch,
315 self.desc.row_len,
316 desc_flag,
317 in_ptr,
318 idx_ptr,
319 ws_ptr,
320 ws_bytes,
321 stream_ptr,
322 )
323 },
324 _ => {
325 return Err(Error::Unsupported(
326 "baracuda-kernels::ArgsortPlan::run reached an unimplemented dtype \
327 — select() should have caught this",
328 ));
329 }
330 };
331 map_status(status)
332 }
333}