Skip to main content

baracuda_kernels/sort/
sort.rs

1//! `sort` plan — Category O trailblazer.
2//!
3//! `sort(x, dim=last, descending)` returns sorted values AND sorted
4//! indices along the last dimension of `x`. PyTorch `torch.sort`.
5//!
6//! Trailblazer dtype coverage: `f32, f64, i32, i64`.
7//!
8//! **Saved-indices contract.** The FW emits both `values` AND
9//! `indices` (i32) in one launch. BW reads the saved indices to route
10//! the upstream grad back to the original positions. Callers must
11//! retain the indices output for the BW pass.
12//!
13//! Trailblazer cap: `row_len ≤ 1024` (one CUDA block per row, bitonic
14//! network in shared memory). Larger rows return
15//! `Error::Unsupported` — a tile-radix follow-up is reserved.
16//!
17//! BW: see [`crate::sort::SortBackwardPlan`].
18
19use core::ffi::c_void;
20use core::marker::PhantomData;
21
22use baracuda_cutlass::{Error, Result};
23use baracuda_driver::Stream;
24use baracuda_kernels_types::{
25    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
26    PlanPreference, PrecisionGuarantee, SortKind, TensorMut, TensorRef, Workspace,
27};
28
29use super::{map_status, SORT_MAX_ROW};
30
31/// Descriptor for a `sort` op.
32#[derive(Copy, Clone, Debug)]
33pub struct SortDescriptor {
34    /// Number of independent rows to sort.
35    pub batch: i32,
36    /// Length of each row. Trailblazer cap: `≤ 1024`.
37    pub row_len: i32,
38    /// `true` = sort largest-first; `false` = sort smallest-first.
39    pub descending: bool,
40    /// Value element type.
41    pub element: ElementKind,
42}
43
44/// Args bundle for a `sort` launch.
45pub struct SortArgs<'a, T: Element> {
46    /// Input `[batch, row_len]`.
47    pub input: TensorRef<'a, T, 2>,
48    /// Sorted values output `[batch, row_len]`.
49    pub values: TensorMut<'a, T, 2>,
50    /// Sorted indices output `[batch, row_len]` — saved for BW.
51    pub indices: TensorMut<'a, i32, 2>,
52}
53
54/// `sort` plan.
55///
56/// `sort(x, dim=last, descending)` — returns sorted values AND
57/// sorted indices along the last axis (PyTorch `torch.sort`).
58///
59/// **When to use**: forward row-wise sort. Pair with
60/// [`SortBackwardPlan`](crate::SortBackwardPlan) for autograd; for
61/// indices-only output use [`ArgsortPlan`](crate::ArgsortPlan).
62///
63/// **Dtypes**: `{f32, f64, i32, i64}`; indices always `i32`.
64///
65/// **Shape limits**: rank-2 `[batch, row_len]`; `row_len ≤ 1024`
66/// (one CUDA block per row, bitonic network in shared memory).
67/// Larger rows return `Unsupported` — tile-radix follow-up reserved.
68///
69/// **Workspace**: none.
70///
71/// **Precision guarantee**: deterministic, bit-stable. Block-bitonic
72/// is a fixed comparator network — no atomics, no reductions.
73///
74/// **Saved-indices contract**: FW emits both `values` and `indices`
75/// in a single launch. BW reads the saved indices verbatim; callers
76/// must retain `indices` for autograd.
77pub struct SortPlan<T: Element> {
78    desc: SortDescriptor,
79    sku: KernelSku,
80    _marker: PhantomData<T>,
81}
82
83impl<T: Element> SortPlan<T> {
84    /// Pick a kernel for `desc`.
85    pub fn select(
86        _stream: &Stream,
87        desc: &SortDescriptor,
88        _pref: PlanPreference,
89    ) -> Result<Self> {
90        validate_sort_desc(desc.batch, desc.row_len, desc.element, T::KIND, "SortPlan")?;
91        let sku = build_sku::<T>(SortKind::Sort);
92        Ok(Self {
93            desc: *desc,
94            sku,
95            _marker: PhantomData,
96        })
97    }
98
99    /// Validate args.
100    pub fn can_implement(&self, args: &SortArgs<'_, T>) -> Result<()> {
101        validate_sort_args_2(
102            self.desc.batch,
103            self.desc.row_len,
104            args.input.shape,
105            args.values.shape,
106            args.indices.shape,
107            "SortPlan",
108        )
109    }
110
111    /// Workspace size in bytes.
112    #[inline]
113    pub fn workspace_size(&self) -> usize {
114        0
115    }
116
117    /// Identity of the kernel this plan picked.
118    #[inline]
119    pub fn sku(&self) -> KernelSku {
120        self.sku
121    }
122
123    /// Numerical guarantees for this plan's kernel.
124    #[inline]
125    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
126        self.sku.precision_guarantee
127    }
128
129    /// Launch.
130    pub fn run(
131        &self,
132        stream: &Stream,
133        _workspace: Workspace<'_>,
134        args: SortArgs<'_, T>,
135    ) -> Result<()> {
136        self.can_implement(&args)?;
137        if self.desc.batch == 0 || self.desc.row_len == 0 {
138            return Ok(());
139        }
140        let in_ptr = args.input.data.as_raw().0 as *const c_void;
141        let vals_ptr = args.values.data.as_raw().0 as *mut c_void;
142        let idx_ptr = args.indices.data.as_raw().0 as *mut c_void;
143        let stream_ptr = stream.as_raw() as *mut c_void;
144        let desc_flag = if self.desc.descending { 1 } else { 0 };
145
146        let status = match T::KIND {
147            ElementKind::F32 => unsafe {
148                baracuda_kernels_sys::baracuda_kernels_sort_f32_run(
149                    self.desc.batch,
150                    self.desc.row_len,
151                    desc_flag,
152                    in_ptr,
153                    vals_ptr,
154                    idx_ptr,
155                    core::ptr::null_mut(),
156                    0,
157                    stream_ptr,
158                )
159            },
160            ElementKind::F64 => unsafe {
161                baracuda_kernels_sys::baracuda_kernels_sort_f64_run(
162                    self.desc.batch,
163                    self.desc.row_len,
164                    desc_flag,
165                    in_ptr,
166                    vals_ptr,
167                    idx_ptr,
168                    core::ptr::null_mut(),
169                    0,
170                    stream_ptr,
171                )
172            },
173            ElementKind::I32 => unsafe {
174                baracuda_kernels_sys::baracuda_kernels_sort_i32_run(
175                    self.desc.batch,
176                    self.desc.row_len,
177                    desc_flag,
178                    in_ptr,
179                    vals_ptr,
180                    idx_ptr,
181                    core::ptr::null_mut(),
182                    0,
183                    stream_ptr,
184                )
185            },
186            ElementKind::I64 => unsafe {
187                baracuda_kernels_sys::baracuda_kernels_sort_i64_run(
188                    self.desc.batch,
189                    self.desc.row_len,
190                    desc_flag,
191                    in_ptr,
192                    vals_ptr,
193                    idx_ptr,
194                    core::ptr::null_mut(),
195                    0,
196                    stream_ptr,
197                )
198            },
199            _ => {
200                return Err(Error::Unsupported(
201                    "baracuda-kernels::SortPlan::run reached an unimplemented dtype \
202                     — select() should have caught this",
203                ));
204            }
205        };
206        map_status(status)
207    }
208}
209
210// ---- Shared descriptor / args / SKU helpers (used by argsort / msort too) ----
211
212/// Validate descriptor fields shared across sort / argsort / msort / topk.
213pub(crate) fn validate_sort_desc(
214    batch: i32,
215    row_len: i32,
216    descriptor_element: ElementKind,
217    expected_element: ElementKind,
218    _plan_name: &'static str,
219) -> Result<()> {
220    if descriptor_element != expected_element {
221        return Err(Error::Unsupported(
222            "baracuda-kernels::sort: descriptor element != type parameter T",
223        ));
224    }
225    if batch < 0 || row_len < 0 {
226        return Err(Error::InvalidProblem(
227            "baracuda-kernels::sort: batch / row_len must be non-negative",
228        ));
229    }
230    if row_len > SORT_MAX_ROW {
231        return Err(Error::Unsupported(
232            "baracuda-kernels::sort: row_len > 1024 not supported in the \
233             block-bitonic trailblazer (tile-radix follow-up reserved)",
234        ));
235    }
236    if !matches!(
237        descriptor_element,
238        ElementKind::F32 | ElementKind::F64 | ElementKind::I32 | ElementKind::I64
239    ) {
240        return Err(Error::Unsupported(
241            "baracuda-kernels::sort: today only f32 / f64 / i32 / i64 wired",
242        ));
243    }
244    Ok(())
245}
246
247/// Validate value+indices shapes for sort / msort args.
248pub(crate) fn validate_sort_args_2(
249    batch: i32,
250    row_len: i32,
251    in_shape: [i32; 2],
252    vals_shape: [i32; 2],
253    idx_shape: [i32; 2],
254    _plan_name: &'static str,
255) -> Result<()> {
256    let expected = [batch, row_len];
257    if in_shape != expected {
258        return Err(Error::InvalidProblem(
259            "baracuda-kernels::sort: input shape != [batch, row_len]",
260        ));
261    }
262    if vals_shape != expected {
263        return Err(Error::InvalidProblem(
264            "baracuda-kernels::sort: values shape != [batch, row_len]",
265        ));
266    }
267    if idx_shape != expected {
268        return Err(Error::InvalidProblem(
269            "baracuda-kernels::sort: indices shape != [batch, row_len]",
270        ));
271    }
272    Ok(())
273}
274
275/// Construct a `KernelSku` for a sort-family plan.
276pub(crate) fn build_sku<T: Element>(op: SortKind) -> KernelSku {
277    let precision_guarantee = PrecisionGuarantee {
278        math_precision: if T::KIND == ElementKind::F64 {
279            MathPrecision::F64
280        } else {
281            MathPrecision::F32
282        },
283        accumulator: T::KIND,
284        // Block-bitonic is fully deterministic: per-row work is
285        // serialized within one block (no inter-block reduction), and
286        // ties are broken (for msort) by original index. Histogram /
287        // bincount / unique-consecutive use atomic counters so their
288        // output order is not deterministic — those plans re-tag this
289        // field through their own builder.
290        bit_stable_on_same_hardware: true,
291        deterministic: true,
292    };
293    KernelSku {
294        category: OpCategory::Sorting,
295        op: op as u16,
296        element: T::KIND,
297        aux_element: Some(ElementKind::I32),
298        layout: None,
299        epilogue: None,
300        arch: ArchSku::Sm80,
301        backend: BackendKind::Bespoke,
302        precision_guarantee,
303    }
304}