Skip to main content

baracuda_kernels/sort/
sort_backward.rs

1//! `sort_backward` plan — scatter `dy` via the saved indices.
2//!
3//! `dx[batch, indices[batch, i]] = dy[batch, i]`. Launcher zeros `dx`
4//! before the scatter. Trailblazer dtype coverage: `f32, f64`.
5
6use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
13    TensorRef, Workspace,
14};
15
16use super::map_status;
17use super::sort::{build_sku, validate_sort_desc};
18
19/// Descriptor for a `sort_backward` op.
20#[derive(Copy, Clone, Debug)]
21pub struct SortBackwardDescriptor {
22    /// Number of independent rows.
23    pub batch: i32,
24    /// Length of each row.
25    pub row_len: i32,
26    /// Value (gradient) element type.
27    pub element: ElementKind,
28}
29
30/// Args bundle for a `sort_backward` launch.
31pub struct SortBackwardArgs<'a, T: Element> {
32    /// Upstream grad of sorted-values output `[batch, row_len]`.
33    pub dy: TensorRef<'a, T, 2>,
34    /// Saved indices from FW `[batch, row_len]`.
35    pub indices: TensorRef<'a, i32, 2>,
36    /// Grad of the original input `[batch, row_len]` (output).
37    pub dx: TensorMut<'a, T, 2>,
38}
39
40/// `sort_backward` plan.
41///
42/// Adjoint of [`crate::SortPlan`]: scatters `d_values[b, p]` to
43/// `d_input[b, indices[b, p]]`. Pure index-routed permutation —
44/// each input position receives exactly one gradient, so no atomics
45/// needed.
46///
47/// **When to use**: BW for [`SortPlan`](crate::SortPlan). Consumes
48/// the FW's saved `indices` verbatim.
49///
50/// **Dtypes**: `{f32, f64, i32, i64}` (matches FW).
51///
52/// **Shape limits**: rank-2 `[batch, row_len]`; `row_len ≤ 1024`.
53///
54/// **Workspace**: none.
55///
56/// **Precision guarantee**: deterministic, bit-stable.
57pub struct SortBackwardPlan<T: Element> {
58    desc: SortBackwardDescriptor,
59    sku: KernelSku,
60    _marker: PhantomData<T>,
61}
62
63impl<T: Element> SortBackwardPlan<T> {
64    /// Pick a kernel for `desc`.
65    pub fn select(
66        _stream: &Stream,
67        desc: &SortBackwardDescriptor,
68        _pref: PlanPreference,
69    ) -> Result<Self> {
70        validate_sort_desc(
71            desc.batch,
72            desc.row_len,
73            desc.element,
74            T::KIND,
75            "SortBackwardPlan",
76        )?;
77        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
78            return Err(Error::Unsupported(
79                "baracuda-kernels::SortBackwardPlan: today only f32 / f64 grads supported",
80            ));
81        }
82        let sku = build_sku::<T>(SortKind::SortBackward);
83        Ok(Self {
84            desc: *desc,
85            sku,
86            _marker: PhantomData,
87        })
88    }
89
90    /// Validate args.
91    pub fn can_implement(&self, args: &SortBackwardArgs<'_, T>) -> Result<()> {
92        let expected = [self.desc.batch, self.desc.row_len];
93        if args.dy.shape != expected
94            || args.indices.shape != expected
95            || args.dx.shape != expected
96        {
97            return Err(Error::InvalidProblem(
98                "baracuda-kernels::SortBackwardPlan: tensor shapes != [batch, row_len]",
99            ));
100        }
101        Ok(())
102    }
103
104    /// Workspace size in bytes.
105    #[inline]
106    pub fn workspace_size(&self) -> usize {
107        0
108    }
109
110    /// Identity of the kernel this plan picked.
111    #[inline]
112    pub fn sku(&self) -> KernelSku {
113        self.sku
114    }
115
116    /// Numerical guarantees for this plan's kernel.
117    #[inline]
118    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
119        self.sku.precision_guarantee
120    }
121
122    /// Launch.
123    pub fn run(
124        &self,
125        stream: &Stream,
126        _workspace: Workspace<'_>,
127        args: SortBackwardArgs<'_, T>,
128    ) -> Result<()> {
129        self.can_implement(&args)?;
130        if self.desc.batch == 0 || self.desc.row_len == 0 {
131            return Ok(());
132        }
133        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
134        let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
135        let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
136        let stream_ptr = stream.as_raw() as *mut c_void;
137
138        let status = match T::KIND {
139            ElementKind::F32 => unsafe {
140                baracuda_kernels_sys::baracuda_kernels_sort_backward_f32_run(
141                    self.desc.batch,
142                    self.desc.row_len,
143                    dy_ptr,
144                    idx_ptr,
145                    dx_ptr,
146                    core::ptr::null_mut(),
147                    0,
148                    stream_ptr,
149                )
150            },
151            ElementKind::F64 => unsafe {
152                baracuda_kernels_sys::baracuda_kernels_sort_backward_f64_run(
153                    self.desc.batch,
154                    self.desc.row_len,
155                    dy_ptr,
156                    idx_ptr,
157                    dx_ptr,
158                    core::ptr::null_mut(),
159                    0,
160                    stream_ptr,
161                )
162            },
163            _ => {
164                return Err(Error::Unsupported(
165                    "baracuda-kernels::SortBackwardPlan::run reached an unimplemented dtype \
166                     — select() should have caught this",
167                ));
168            }
169        };
170        map_status(status)
171    }
172}