Skip to main content

baracuda_kernels/sort/
msort.rs

1//! `msort` (stable sort) plan + BW.
2//!
3//! Same as [`crate::sort::SortPlan`] but with stability guarantee —
4//! equal keys preserve input order via tie-break on original index.
5//! PyTorch `torch.msort`.
6//!
7//! Trailblazer dtype coverage: FW `f32, f64, i32, i64`; BW `f32, f64`.
8
9use core::ffi::c_void;
10use core::marker::PhantomData;
11
12use baracuda_cutlass::{Error, Result};
13use baracuda_driver::Stream;
14use baracuda_kernels_types::{
15    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
16    TensorRef, Workspace,
17};
18
19use super::map_status;
20use super::sort::{build_sku, validate_sort_args_2, validate_sort_desc};
21
22/// Descriptor for an `msort` op.
23#[derive(Copy, Clone, Debug)]
24pub struct MsortDescriptor {
25    /// Number of independent rows.
26    pub batch: i32,
27    /// Length of each row. Trailblazer cap: `≤ 1024`.
28    pub row_len: i32,
29    /// `true` = sort largest-first.
30    pub descending: bool,
31    /// Value element type.
32    pub element: ElementKind,
33}
34
35/// Args bundle for an `msort` launch.
36pub struct MsortArgs<'a, T: Element> {
37    /// Input `[batch, row_len]`.
38    pub input: TensorRef<'a, T, 2>,
39    /// Sorted values `[batch, row_len]`.
40    pub values: TensorMut<'a, T, 2>,
41    /// Sorted indices `[batch, row_len]` — saved for BW.
42    pub indices: TensorMut<'a, i32, 2>,
43}
44
45/// `msort` plan.
46///
47/// **Stable** row-wise sort: equal keys preserve input order via
48/// tie-break on the original index (PyTorch `torch.msort`).
49/// Functionally identical to [`SortPlan`](crate::SortPlan) plus the
50/// stability guarantee.
51///
52/// **When to use**: when stable order matters (e.g. unique-by-key
53/// pipelines). Pair with
54/// [`MsortBackwardPlan`](crate::MsortBackwardPlan).
55///
56/// **Dtypes**: FW `{f32, f64, i32, i64}`.
57///
58/// **Shape limits**: rank-2 `[batch, row_len]`; `row_len ≤ 1024`.
59///
60/// **Workspace**: none.
61///
62/// **Precision guarantee**: deterministic, bit-stable.
63pub struct MsortPlan<T: Element> {
64    desc: MsortDescriptor,
65    sku: KernelSku,
66    _marker: PhantomData<T>,
67}
68
69impl<T: Element> MsortPlan<T> {
70    /// Pick a kernel for `desc`.
71    pub fn select(
72        _stream: &Stream,
73        desc: &MsortDescriptor,
74        _pref: PlanPreference,
75    ) -> Result<Self> {
76        validate_sort_desc(desc.batch, desc.row_len, desc.element, T::KIND, "MsortPlan")?;
77        let sku = build_sku::<T>(SortKind::Msort);
78        Ok(Self {
79            desc: *desc,
80            sku,
81            _marker: PhantomData,
82        })
83    }
84
85    /// Validate args.
86    pub fn can_implement(&self, args: &MsortArgs<'_, T>) -> Result<()> {
87        validate_sort_args_2(
88            self.desc.batch,
89            self.desc.row_len,
90            args.input.shape,
91            args.values.shape,
92            args.indices.shape,
93            "MsortPlan",
94        )
95    }
96
97    /// Workspace size in bytes.
98    #[inline]
99    pub fn workspace_size(&self) -> usize {
100        0
101    }
102
103    /// Identity of the kernel this plan picked.
104    #[inline]
105    pub fn sku(&self) -> KernelSku {
106        self.sku
107    }
108
109    /// Numerical guarantees for this plan's kernel.
110    #[inline]
111    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
112        self.sku.precision_guarantee
113    }
114
115    /// Launch.
116    pub fn run(
117        &self,
118        stream: &Stream,
119        _workspace: Workspace<'_>,
120        args: MsortArgs<'_, T>,
121    ) -> Result<()> {
122        self.can_implement(&args)?;
123        if self.desc.batch == 0 || self.desc.row_len == 0 {
124            return Ok(());
125        }
126        let in_ptr = args.input.data.as_raw().0 as *const c_void;
127        let vals_ptr = args.values.data.as_raw().0 as *mut c_void;
128        let idx_ptr = args.indices.data.as_raw().0 as *mut c_void;
129        let stream_ptr = stream.as_raw() as *mut c_void;
130        let desc_flag = if self.desc.descending { 1 } else { 0 };
131
132        let status = match T::KIND {
133            ElementKind::F32 => unsafe {
134                baracuda_kernels_sys::baracuda_kernels_msort_f32_run(
135                    self.desc.batch,
136                    self.desc.row_len,
137                    desc_flag,
138                    in_ptr,
139                    vals_ptr,
140                    idx_ptr,
141                    core::ptr::null_mut(),
142                    0,
143                    stream_ptr,
144                )
145            },
146            ElementKind::F64 => unsafe {
147                baracuda_kernels_sys::baracuda_kernels_msort_f64_run(
148                    self.desc.batch,
149                    self.desc.row_len,
150                    desc_flag,
151                    in_ptr,
152                    vals_ptr,
153                    idx_ptr,
154                    core::ptr::null_mut(),
155                    0,
156                    stream_ptr,
157                )
158            },
159            ElementKind::I32 => unsafe {
160                baracuda_kernels_sys::baracuda_kernels_msort_i32_run(
161                    self.desc.batch,
162                    self.desc.row_len,
163                    desc_flag,
164                    in_ptr,
165                    vals_ptr,
166                    idx_ptr,
167                    core::ptr::null_mut(),
168                    0,
169                    stream_ptr,
170                )
171            },
172            ElementKind::I64 => unsafe {
173                baracuda_kernels_sys::baracuda_kernels_msort_i64_run(
174                    self.desc.batch,
175                    self.desc.row_len,
176                    desc_flag,
177                    in_ptr,
178                    vals_ptr,
179                    idx_ptr,
180                    core::ptr::null_mut(),
181                    0,
182                    stream_ptr,
183                )
184            },
185            _ => {
186                return Err(Error::Unsupported(
187                    "baracuda-kernels::MsortPlan::run reached an unimplemented dtype",
188                ));
189            }
190        };
191        map_status(status)
192    }
193}
194
195// ---- BW ----
196
197/// Descriptor for an `msort_backward` op.
198#[derive(Copy, Clone, Debug)]
199pub struct MsortBackwardDescriptor {
200    /// Number of independent rows.
201    pub batch: i32,
202    /// Length of each row.
203    pub row_len: i32,
204    /// Grad element type.
205    pub element: ElementKind,
206}
207
208/// Args bundle for an `msort_backward` launch.
209pub struct MsortBackwardArgs<'a, T: Element> {
210    /// Upstream grad of sorted-values output `[batch, row_len]`.
211    pub dy: TensorRef<'a, T, 2>,
212    /// Saved indices from FW `[batch, row_len]`.
213    pub indices: TensorRef<'a, i32, 2>,
214    /// Grad of the input `[batch, row_len]`.
215    pub dx: TensorMut<'a, T, 2>,
216}
217
218/// `msort_backward` plan.
219///
220/// Adjoint of [`crate::MsortPlan`]. Same permutation-scatter as
221/// [`SortBackwardPlan`](crate::SortBackwardPlan) (the stability tie-
222/// break only affected the FW indices, not the BW math).
223///
224/// **When to use**: BW for [`MsortPlan`](crate::MsortPlan).
225///
226/// **Dtypes**: `{f32, f64}`.
227///
228/// **Shape limits**: rank-2 `[batch, row_len]`; `row_len ≤ 1024`.
229///
230/// **Workspace**: none.
231///
232/// **Precision guarantee**: deterministic, bit-stable.
233pub struct MsortBackwardPlan<T: Element> {
234    desc: MsortBackwardDescriptor,
235    sku: KernelSku,
236    _marker: PhantomData<T>,
237}
238
239impl<T: Element> MsortBackwardPlan<T> {
240    /// Pick a kernel for `desc`.
241    pub fn select(
242        _stream: &Stream,
243        desc: &MsortBackwardDescriptor,
244        _pref: PlanPreference,
245    ) -> Result<Self> {
246        validate_sort_desc(
247            desc.batch,
248            desc.row_len,
249            desc.element,
250            T::KIND,
251            "MsortBackwardPlan",
252        )?;
253        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
254            return Err(Error::Unsupported(
255                "baracuda-kernels::MsortBackwardPlan: today only f32 / f64 grads supported",
256            ));
257        }
258        let sku = build_sku::<T>(SortKind::MsortBackward);
259        Ok(Self {
260            desc: *desc,
261            sku,
262            _marker: PhantomData,
263        })
264    }
265
266    /// Validate args.
267    pub fn can_implement(&self, args: &MsortBackwardArgs<'_, T>) -> Result<()> {
268        let expected = [self.desc.batch, self.desc.row_len];
269        if args.dy.shape != expected
270            || args.indices.shape != expected
271            || args.dx.shape != expected
272        {
273            return Err(Error::InvalidProblem(
274                "baracuda-kernels::MsortBackwardPlan: tensor shapes != [batch, row_len]",
275            ));
276        }
277        Ok(())
278    }
279
280    /// Workspace size in bytes.
281    #[inline]
282    pub fn workspace_size(&self) -> usize {
283        0
284    }
285
286    /// Identity of the kernel this plan picked.
287    #[inline]
288    pub fn sku(&self) -> KernelSku {
289        self.sku
290    }
291
292    /// Numerical guarantees for this plan's kernel.
293    #[inline]
294    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
295        self.sku.precision_guarantee
296    }
297
298    /// Launch.
299    pub fn run(
300        &self,
301        stream: &Stream,
302        _workspace: Workspace<'_>,
303        args: MsortBackwardArgs<'_, T>,
304    ) -> Result<()> {
305        self.can_implement(&args)?;
306        if self.desc.batch == 0 || self.desc.row_len == 0 {
307            return Ok(());
308        }
309        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
310        let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
311        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
312        let stream_ptr = stream.as_raw() as *mut c_void;
313
314        let status = match T::KIND {
315            ElementKind::F32 => unsafe {
316                baracuda_kernels_sys::baracuda_kernels_msort_backward_f32_run(
317                    self.desc.batch,
318                    self.desc.row_len,
319                    dy_ptr,
320                    idx_ptr,
321                    dx_ptr,
322                    core::ptr::null_mut(),
323                    0,
324                    stream_ptr,
325                )
326            },
327            ElementKind::F64 => unsafe {
328                baracuda_kernels_sys::baracuda_kernels_msort_backward_f64_run(
329                    self.desc.batch,
330                    self.desc.row_len,
331                    dy_ptr,
332                    idx_ptr,
333                    dx_ptr,
334                    core::ptr::null_mut(),
335                    0,
336                    stream_ptr,
337                )
338            },
339            _ => {
340                return Err(Error::Unsupported(
341                    "baracuda-kernels::MsortBackwardPlan::run reached an unimplemented dtype",
342                ));
343            }
344        };
345        map_status(status)
346    }
347}