Skip to main content

baracuda_kernels/sort/
argsort.rs

1//! `argsort` plan — sorted indices only (no values output).
2//!
3//! `argsort(x, dim=last, descending)` returns just the i32 permutation
4//! that would sort `x` along the last dim. PyTorch `torch.argsort`.
5//!
6//! Trailblazer dtype coverage: `f32, f64, i32, i64`.
7//! Non-differentiable (set-valued indices) — no BW.
8//!
9//! # Multi-block radix (Phase 40 — Fuel ask Gap 6b)
10//!
11//! `row_len ≤ 1024` is served by the block-bitonic kernel (no workspace
12//! required). For `row_len > 1024` the plan transparently dispatches to
13//! a CUB-segmented-radix-sort kernel which DOES require a workspace blob
14//! (queried via [`ArgsortPlan::workspace_size`]). Callers that only ever
15//! pass `row_len ≤ 1024` will see a `workspace_size()` of `0` and may
16//! continue to pass [`Workspace::None`]. The dispatch happens internally
17//! at `run()` time; the kernel SKU reflects whichever path was selected.
18
19use core::ffi::c_void;
20use core::marker::PhantomData;
21
22use baracuda_cutlass::{Error, Result};
23use baracuda_driver::Stream;
24use baracuda_kernels_types::{
25    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
26    TensorRef, Workspace,
27};
28
29use super::map_status;
30use super::sort::build_sku;
31
32/// Descriptor for an `argsort` op.
33#[derive(Copy, Clone, Debug)]
34pub struct ArgsortDescriptor {
35    /// Number of independent rows.
36    pub batch: i32,
37    /// Length of each row. `≤ 1024` uses the block-bitonic kernel;
38    /// `> 1024` uses the multi-block CUB radix kernel (workspace
39    /// required — call [`ArgsortPlan::workspace_size`]).
40    pub row_len: i32,
41    /// `true` = sort largest-first.
42    pub descending: bool,
43    /// Value element type (input).
44    pub element: ElementKind,
45}
46
47/// Args bundle for an `argsort` launch.
48pub struct ArgsortArgs<'a, T: Element> {
49    /// Input `[batch, row_len]`.
50    pub input: TensorRef<'a, T, 2>,
51    /// Sorted indices output `[batch, row_len]`.
52    pub indices: TensorMut<'a, i32, 2>,
53}
54
55/// `argsort` plan.
56///
57/// Returns only sorted indices along the last axis (PyTorch
58/// `torch.argsort`). No values output, no BW (indices are
59/// non-differentiable).
60///
61/// **When to use**: when only the permutation is needed (gather
62/// downstream tensors via [`GatherPlan`](crate::GatherPlan) using
63/// these indices). For sorted values + indices use
64/// [`SortPlan`](crate::SortPlan).
65///
66/// **Dtypes**: input `{f32, f64, i32, i64}`; output always `i32`.
67///
68/// **Shape limits**: rank-2 `[batch, row_len]`; `row_len ≤ 1024`.
69///
70/// **Workspace**: none.
71///
72/// **Precision guarantee**: deterministic, bit-stable.
73pub struct ArgsortPlan<T: Element> {
74    desc: ArgsortDescriptor,
75    sku: KernelSku,
76    _marker: PhantomData<T>,
77}
78
79impl<T: Element> ArgsortPlan<T> {
80    /// Pick a kernel for `desc`.
81    pub fn select(
82        _stream: &Stream,
83        desc: &ArgsortDescriptor,
84        _pref: PlanPreference,
85    ) -> Result<Self> {
86        // Local validator (broader than `validate_sort_desc`): allows
87        // `row_len > 1024` because the multi-block radix path covers it.
88        if desc.element != T::KIND {
89            return Err(Error::Unsupported(
90                "baracuda-kernels::ArgsortPlan: descriptor element != type parameter T",
91            ));
92        }
93        if desc.batch < 0 || desc.row_len < 0 {
94            return Err(Error::InvalidProblem(
95                "baracuda-kernels::ArgsortPlan: batch / row_len must be non-negative",
96            ));
97        }
98        if !matches!(
99            desc.element,
100            ElementKind::F32 | ElementKind::F64 | ElementKind::I32 | ElementKind::I64
101        ) {
102            return Err(Error::Unsupported(
103                "baracuda-kernels::ArgsortPlan: today only f32 / f64 / i32 / i64 wired",
104            ));
105        }
106        let sku = build_sku::<T>(SortKind::Argsort);
107        Ok(Self {
108            desc: *desc,
109            sku,
110            _marker: PhantomData,
111        })
112    }
113
114    /// Validate args.
115    pub fn can_implement(&self, args: &ArgsortArgs<'_, T>) -> Result<()> {
116        let expected = [self.desc.batch, self.desc.row_len];
117        if args.input.shape != expected {
118            return Err(Error::InvalidProblem(
119                "baracuda-kernels::ArgsortPlan: input shape != [batch, row_len]",
120            ));
121        }
122        if args.indices.shape != expected {
123            return Err(Error::InvalidProblem(
124                "baracuda-kernels::ArgsortPlan: indices shape != [batch, row_len]",
125            ));
126        }
127        Ok(())
128    }
129
130    /// Workspace size in bytes.
131    ///
132    /// `0` when `row_len ≤ 1024` (block-bitonic, in-SMEM). Non-zero
133    /// when `row_len > 1024` (multi-block radix path needs scratch for
134    /// CUB's `DeviceSegmentedRadixSort` plus keys/indices/offset
135    /// buffers). The exact bytes depend on `(batch, row_len, T)`.
136    #[inline]
137    pub fn workspace_size(&self) -> usize {
138        if self.desc.row_len <= 1024 {
139            return 0;
140        }
141        let batch = self.desc.batch;
142        let row_len = self.desc.row_len;
143        if batch == 0 || row_len == 0 {
144            return 0;
145        }
146        match T::KIND {
147            ElementKind::F32 => unsafe {
148                baracuda_kernels_sys::baracuda_kernels_argsort_f32_big_workspace_size(
149                    batch, row_len,
150                )
151            },
152            ElementKind::F64 => unsafe {
153                baracuda_kernels_sys::baracuda_kernels_argsort_f64_big_workspace_size(
154                    batch, row_len,
155                )
156            },
157            ElementKind::I32 => unsafe {
158                baracuda_kernels_sys::baracuda_kernels_argsort_i32_big_workspace_size(
159                    batch, row_len,
160                )
161            },
162            ElementKind::I64 => unsafe {
163                baracuda_kernels_sys::baracuda_kernels_argsort_i64_big_workspace_size(
164                    batch, row_len,
165                )
166            },
167            _ => 0,
168        }
169    }
170
171    /// Identity of the kernel this plan picked.
172    #[inline]
173    pub fn sku(&self) -> KernelSku {
174        self.sku
175    }
176
177    /// Numerical guarantees for this plan's kernel.
178    #[inline]
179    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
180        self.sku.precision_guarantee
181    }
182
183    /// Launch.
184    pub fn run(
185        &self,
186        stream: &Stream,
187        workspace: Workspace<'_>,
188        args: ArgsortArgs<'_, T>,
189    ) -> Result<()> {
190        self.can_implement(&args)?;
191        if self.desc.batch == 0 || self.desc.row_len == 0 {
192            return Ok(());
193        }
194        let in_ptr = args.input.data.as_raw().0 as *const c_void;
195        let idx_ptr = args.indices.data.as_raw().0 as *mut c_void;
196        let stream_ptr = stream.as_raw() as *mut c_void;
197        let desc_flag = if self.desc.descending { 1 } else { 0 };
198
199        // Phase 40 dispatch: `row_len > 1024` → multi-block radix path
200        // (requires non-empty workspace); otherwise block-bitonic
201        // (workspace ignored).
202        let use_big = self.desc.row_len > 1024;
203        let (ws_ptr, ws_bytes) = if use_big {
204            let needed = self.workspace_size();
205            match workspace {
206                Workspace::None => {
207                    if needed == 0 {
208                        (core::ptr::null_mut::<c_void>(), 0usize)
209                    } else {
210                        return Err(Error::WorkspaceTooSmall { needed, got: 0 });
211                    }
212                }
213                Workspace::Borrowed(slice) => {
214                    let got = slice.len();
215                    if got < needed {
216                        return Err(Error::WorkspaceTooSmall { needed, got });
217                    }
218                    (slice.as_raw().0 as *mut c_void, got)
219                }
220            }
221        } else {
222            // Bitonic path ignores workspace; pass null/0 for safety.
223            let _ = workspace;
224            (core::ptr::null_mut::<c_void>(), 0usize)
225        };
226
227        let status = match (T::KIND, use_big) {
228            (ElementKind::F32, false) => unsafe {
229                baracuda_kernels_sys::baracuda_kernels_argsort_f32_run(
230                    self.desc.batch,
231                    self.desc.row_len,
232                    desc_flag,
233                    in_ptr,
234                    idx_ptr,
235                    core::ptr::null_mut(),
236                    0,
237                    stream_ptr,
238                )
239            },
240            (ElementKind::F64, false) => unsafe {
241                baracuda_kernels_sys::baracuda_kernels_argsort_f64_run(
242                    self.desc.batch,
243                    self.desc.row_len,
244                    desc_flag,
245                    in_ptr,
246                    idx_ptr,
247                    core::ptr::null_mut(),
248                    0,
249                    stream_ptr,
250                )
251            },
252            (ElementKind::I32, false) => unsafe {
253                baracuda_kernels_sys::baracuda_kernels_argsort_i32_run(
254                    self.desc.batch,
255                    self.desc.row_len,
256                    desc_flag,
257                    in_ptr,
258                    idx_ptr,
259                    core::ptr::null_mut(),
260                    0,
261                    stream_ptr,
262                )
263            },
264            (ElementKind::I64, false) => unsafe {
265                baracuda_kernels_sys::baracuda_kernels_argsort_i64_run(
266                    self.desc.batch,
267                    self.desc.row_len,
268                    desc_flag,
269                    in_ptr,
270                    idx_ptr,
271                    core::ptr::null_mut(),
272                    0,
273                    stream_ptr,
274                )
275            },
276            (ElementKind::F32, true) => unsafe {
277                baracuda_kernels_sys::baracuda_kernels_argsort_f32_big_run(
278                    self.desc.batch,
279                    self.desc.row_len,
280                    desc_flag,
281                    in_ptr,
282                    idx_ptr,
283                    ws_ptr,
284                    ws_bytes,
285                    stream_ptr,
286                )
287            },
288            (ElementKind::F64, true) => unsafe {
289                baracuda_kernels_sys::baracuda_kernels_argsort_f64_big_run(
290                    self.desc.batch,
291                    self.desc.row_len,
292                    desc_flag,
293                    in_ptr,
294                    idx_ptr,
295                    ws_ptr,
296                    ws_bytes,
297                    stream_ptr,
298                )
299            },
300            (ElementKind::I32, true) => unsafe {
301                baracuda_kernels_sys::baracuda_kernels_argsort_i32_big_run(
302                    self.desc.batch,
303                    self.desc.row_len,
304                    desc_flag,
305                    in_ptr,
306                    idx_ptr,
307                    ws_ptr,
308                    ws_bytes,
309                    stream_ptr,
310                )
311            },
312            (ElementKind::I64, true) => unsafe {
313                baracuda_kernels_sys::baracuda_kernels_argsort_i64_big_run(
314                    self.desc.batch,
315                    self.desc.row_len,
316                    desc_flag,
317                    in_ptr,
318                    idx_ptr,
319                    ws_ptr,
320                    ws_bytes,
321                    stream_ptr,
322                )
323            },
324            _ => {
325                return Err(Error::Unsupported(
326                    "baracuda-kernels::ArgsortPlan::run reached an unimplemented dtype \
327                     — select() should have caught this",
328                ));
329            }
330        };
331        map_status(status)
332    }
333}