Skip to main content

baracuda_kernels/sort/
mod.rs

1//! Sorting / order-statistics op family — Category O.
2//!
3//! Phase 9 of the baracuda-kernels comprehensive plan. Ships the
4//! block-bitonic trailblazer family:
5//!
6//! - [`SortPlan`] / [`SortBackwardPlan`] / [`ArgsortPlan`] /
7//!   [`MsortPlan`] / [`MsortBackwardPlan`] — block-bitonic sort, one
8//!   CUDA block per row. Trailblazer cap: `row_len ≤ 1024`. Larger
9//!   rows are reserved for a future tile-radix follow-up; the Plan
10//!   returns `Unsupported` for `row_len > 1024`.
11//! - [`TopkPlan`] / [`TopkBackwardPlan`] / [`KthvaluePlan`] /
12//!   [`KthvalueBackwardPlan`] — block-bitonic select; trailblazer
13//!   cap: `k ≤ 64` (LLM-inference range).
14//! - [`UniquePlan`] / [`UniqueConsecutivePlan`] — set-valued, no BW.
15//!   `unique` chains sort + consecutive-dedup at the plan layer.
16//! - [`HistogramPlan`] / [`HistogramddPlan`] / [`BincountPlan`] —
17//!   atomic-bin accumulation; FW only. `histogramdd` returns
18//!   `Unsupported` for `ndim > 1` in the trailblazer.
19//! - [`SearchsortedPlan`] — per-query binary search; FW only.
20//!
21//! **Saved-indices contract for sort / msort / topk / kthvalue BW.**
22//! The FW emits both sorted values AND sorted indices in a single
23//! launch (FW Args carry `values` and `indices` as required outputs).
24//! BW Args receive the **saved indices** verbatim — no recomputation
25//! at BW time. The Plan's selector pegs the indices dtype to `i32`
26//! across every kernel SKU in this family.
27
28pub mod argsort;
29pub mod bincount;
30pub mod histogram;
31pub mod histogramdd;
32pub mod kthvalue;
33pub mod kthvalue_backward;
34pub mod msort;
35pub mod searchsorted;
36pub mod sort;
37pub mod sort_backward;
38pub mod topk;
39pub mod topk_backward;
40pub mod unique;
41pub mod unique_consecutive;
42
43pub use argsort::{ArgsortArgs, ArgsortDescriptor, ArgsortPlan};
44pub use bincount::{BincountArgs, BincountDescriptor, BincountPlan};
45pub use histogram::{HistogramArgs, HistogramDescriptor, HistogramPlan};
46pub use histogramdd::{HistogramddArgs, HistogramddDescriptor, HistogramddPlan};
47pub use kthvalue::{KthvalueArgs, KthvalueDescriptor, KthvaluePlan};
48pub use kthvalue_backward::{
49    KthvalueBackwardArgs, KthvalueBackwardDescriptor, KthvalueBackwardPlan,
50};
51pub use msort::{MsortArgs, MsortBackwardArgs, MsortBackwardDescriptor, MsortBackwardPlan,
52    MsortDescriptor, MsortPlan};
53pub use searchsorted::{SearchsortedArgs, SearchsortedDescriptor, SearchsortedPlan};
54pub use sort::{SortArgs, SortDescriptor, SortPlan};
55pub use sort_backward::{SortBackwardArgs, SortBackwardDescriptor, SortBackwardPlan};
56pub use topk::{TopkArgs, TopkDescriptor, TopkPlan};
57pub use topk_backward::{TopkBackwardArgs, TopkBackwardDescriptor, TopkBackwardPlan};
58pub use unique::{UniqueArgs, UniqueDescriptor, UniquePlan};
59pub use unique_consecutive::{
60    UniqueConsecutiveArgs, UniqueConsecutiveDescriptor, UniqueConsecutivePlan,
61};
62
63use baracuda_cutlass::{Error, Result};
64
65/// Maximum supported `row_len` in the block-bitonic trailblazer. Must
66/// match `MAX_ROW` in `baracuda_sort.cuh`.
67pub const SORT_MAX_ROW: i32 = 1024;
68/// Maximum supported `k` in the block-bitonic topk trailblazer. Must
69/// match `MAX_K` in `baracuda_topk.cuh`.
70pub const TOPK_MAX_K: i32 = 64;
71
72/// Shared status-code mapper for the sort family.
73pub(crate) fn map_status(code: i32) -> Result<()> {
74    match code {
75        0 => Ok(()),
76        1 => Err(Error::MisalignedOperand),
77        2 => Err(Error::InvalidProblem(
78            "baracuda-kernels-sys::sort reported invalid problem",
79        )),
80        3 => Err(Error::Unsupported(
81            "baracuda-kernels-sys::sort reported unsupported configuration \
82             (e.g. row_len > 1024 or k > 64)",
83        )),
84        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
85        n => Err(Error::CutlassInternal(n)),
86    }
87}