Skip to main content

baracuda_kernels/sort/
kthvalue.rs

1//! `kthvalue` plan — returns the k-th smallest value + its index along
2//! the last dimension.
3//!
4//! PyTorch `torch.kthvalue(x, k, dim)` returns `(value, index)` for
5//! the **k-th smallest** value (1-indexed in PyTorch). Our descriptor
6//! uses 0-indexed `k` (the user passes `k = 0` for the smallest).
7//!
8//! Composition: invokes [`crate::sort::TopkPlan`] with `largest=false`
9//! and `k = desc.k + 1`, then reads cell `(k)` of the bottom-(k+1)
10//! result. The composition lives at the Rust plan layer — no separate
11//! kthvalue kernel is shipped.
12//!
13//! Trailblazer dtype coverage: `f32, f64`.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::{DeviceBuffer, Stream};
20use baracuda_kernels_types::{
21    contiguous_stride, Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee,
22    SortKind, TensorMut, TensorRef, Workspace,
23};
24
25use super::sort::build_sku;
26use super::topk::{TopkArgs, TopkDescriptor, TopkPlan};
27use super::TOPK_MAX_K;
28
29/// Descriptor for a `kthvalue` op.
30///
31/// `k` is 0-indexed (k = 0 → smallest value, k = row_len - 1 →
32/// largest).
33#[derive(Copy, Clone, Debug)]
34pub struct KthvalueDescriptor {
35    /// Number of independent rows.
36    pub batch: i32,
37    /// Length of each row.
38    pub row_len: i32,
39    /// Which order statistic to return (0-indexed). Trailblazer cap:
40    /// `k < 64` (composes a bottom-(k+1) topk).
41    pub k: i32,
42    /// Value element type.
43    pub element: ElementKind,
44}
45
46/// Args bundle for a `kthvalue` launch.
47pub struct KthvalueArgs<'a, T: Element> {
48    /// Input `[batch, row_len]`.
49    pub input: TensorRef<'a, T, 2>,
50    /// Output values `[batch]` (one cell per row).
51    pub values: TensorMut<'a, T, 1>,
52    /// Output indices `[batch]` (one i32 per row).
53    pub indices: TensorMut<'a, i32, 1>,
54}
55
56/// `kthvalue` plan.
57///
58/// Returns the k-th smallest value and its index along the last axis
59/// (PyTorch `torch.kthvalue`; 0-indexed `k` here, vs PyTorch's
60/// 1-indexed). Composed at the plan layer as a bottom-(k+1)
61/// [`TopkPlan`](crate::TopkPlan), reading cell `(k)` of the result.
62///
63/// **When to use**: order-statistic queries (median, quantile pickup
64/// in fixed K range). Pair with
65/// [`KthvalueBackwardPlan`](crate::KthvalueBackwardPlan).
66///
67/// **Dtypes**: `{f32, f64}`.
68///
69/// **Shape limits**: input `[batch, row_len]`; outputs `[batch]`;
70/// `row_len ≤ 1024`; `k < 64` (composes a bottom-(k+1) topk).
71///
72/// **Workspace**: zero in [`Workspace`]; plan internally allocates a
73/// scratch `[batch, k+1]` topk-result buffer per launch.
74///
75/// **Precision guarantee**: deterministic, bit-stable (inherits topk's
76/// fixed-network guarantee).
77pub struct KthvaluePlan<T: Element> {
78    desc: KthvalueDescriptor,
79    sku: KernelSku,
80    _marker: PhantomData<T>,
81}
82
83impl<T: Element> KthvaluePlan<T> {
84    /// Pick a kernel for `desc`.
85    pub fn select(
86        _stream: &Stream,
87        desc: &KthvalueDescriptor,
88        _pref: PlanPreference,
89    ) -> Result<Self> {
90        if desc.element != T::KIND {
91            return Err(Error::Unsupported(
92                "baracuda-kernels::KthvaluePlan: descriptor element != type parameter T",
93            ));
94        }
95        if desc.batch < 0 || desc.row_len < 0 || desc.k < 0 {
96            return Err(Error::InvalidProblem(
97                "baracuda-kernels::KthvaluePlan: batch / row_len / k must be non-negative",
98            ));
99        }
100        if desc.k >= desc.row_len {
101            return Err(Error::InvalidProblem(
102                "baracuda-kernels::KthvaluePlan: k must be < row_len (0-indexed)",
103            ));
104        }
105        if desc.k + 1 > TOPK_MAX_K {
106            return Err(Error::Unsupported(
107                "baracuda-kernels::KthvaluePlan: k+1 > 64 not supported (composes topk)",
108            ));
109        }
110        if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
111            return Err(Error::Unsupported(
112                "baracuda-kernels::KthvaluePlan: today only f32 / f64 wired (TopkPlan limit)",
113            ));
114        }
115        let sku = build_sku::<T>(SortKind::Kthvalue);
116        Ok(Self {
117            desc: *desc,
118            sku,
119            _marker: PhantomData,
120        })
121    }
122
123    /// Validate args.
124    pub fn can_implement(&self, args: &KthvalueArgs<'_, T>) -> Result<()> {
125        if args.input.shape != [self.desc.batch, self.desc.row_len] {
126            return Err(Error::InvalidProblem(
127                "baracuda-kernels::KthvaluePlan: input shape != [batch, row_len]",
128            ));
129        }
130        if args.values.shape != [self.desc.batch] {
131            return Err(Error::InvalidProblem(
132                "baracuda-kernels::KthvaluePlan: values shape != [batch]",
133            ));
134        }
135        if args.indices.shape != [self.desc.batch] {
136            return Err(Error::InvalidProblem(
137                "baracuda-kernels::KthvaluePlan: indices shape != [batch]",
138            ));
139        }
140        Ok(())
141    }
142
143    /// Workspace size in bytes. Internal device buffers are allocated
144    /// fresh at run() time.
145    #[inline]
146    pub fn workspace_size(&self) -> usize {
147        0
148    }
149
150    /// Identity of the kernel this plan picked.
151    #[inline]
152    pub fn sku(&self) -> KernelSku {
153        self.sku
154    }
155
156    /// Numerical guarantees for this plan's kernel.
157    #[inline]
158    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
159        self.sku.precision_guarantee
160    }
161
162    /// Launch. Composes a bottom-(k+1) topk; reads the last cell as the
163    /// k-th smallest. Allocates two intermediate device buffers and
164    /// round-trips the bottom-(k+1) cells through host memory to
165    /// extract the (k)-th slot per row (the data is small — batch *
166    /// (k+1) cells with `k+1 ≤ 64`).
167    pub fn run(
168        &self,
169        stream: &Stream,
170        _workspace: Workspace<'_>,
171        args: KthvalueArgs<'_, T>,
172    ) -> Result<()> {
173        self.can_implement(&args)?;
174        if self.desc.batch == 0 {
175            return Ok(());
176        }
177
178        let kp1 = self.desc.k + 1;
179        let topk_desc = TopkDescriptor {
180            batch: self.desc.batch,
181            row_len: self.desc.row_len,
182            k: kp1,
183            largest: false,
184            element: T::KIND,
185        };
186        let topk_plan = TopkPlan::<T>::select(stream, &topk_desc, PlanPreference::default())?;
187
188        let ctx = stream.context();
189        let total = (self.desc.batch as usize) * (kp1 as usize);
190        let mut topk_vals: DeviceBuffer<T> = DeviceBuffer::zeros(ctx, total).map_err(|_| {
191            Error::InvalidProblem(
192                "baracuda-kernels::KthvaluePlan: failed to allocate topk values buffer",
193            )
194        })?;
195        let mut topk_idx: DeviceBuffer<i32> = DeviceBuffer::zeros(ctx, total).map_err(|_| {
196            Error::InvalidProblem(
197                "baracuda-kernels::KthvaluePlan: failed to allocate topk indices buffer",
198            )
199        })?;
200
201        let topk_args = TopkArgs::<T> {
202            input: args.input,
203            values: TensorMut {
204                data: topk_vals.as_slice_mut(),
205                shape: [self.desc.batch, kp1],
206                stride: contiguous_stride([self.desc.batch, kp1]),
207            },
208            indices: TensorMut {
209                data: topk_idx.as_slice_mut(),
210                shape: [self.desc.batch, kp1],
211                stride: contiguous_stride([self.desc.batch, kp1]),
212            },
213        };
214        topk_plan.run(stream, Workspace::None, topk_args)?;
215        stream
216            .synchronize()
217            .map_err(|_| Error::CutlassInternal(-1))?;
218
219        // Bring the bottom-(k+1) tiles host-side as raw bytes, pick
220        // the k-th cell per row, and ship the two compacted [batch]
221        // vectors back to the device. We use byte buffers to avoid a
222        // `T: Default` bound on `Element` (which doesn't exist).
223        let val_bytes = total * core::mem::size_of::<T>();
224        let idx_bytes_total = total * core::mem::size_of::<i32>();
225        let mut host_vals: Vec<u8> = vec![0u8; val_bytes];
226        let mut host_idx_bytes: Vec<u8> = vec![0u8; idx_bytes_total];
227        unsafe {
228            copy_d2h_async(
229                host_vals.as_mut_ptr() as *mut c_void,
230                topk_vals.as_raw().0,
231                val_bytes,
232                stream,
233            )?;
234            copy_d2h_async(
235                host_idx_bytes.as_mut_ptr() as *mut c_void,
236                topk_idx.as_raw().0,
237                idx_bytes_total,
238                stream,
239            )?;
240        }
241        stream
242            .synchronize()
243            .map_err(|_| Error::CutlassInternal(-1))?;
244
245        let out_val_bytes = (self.desc.batch as usize) * core::mem::size_of::<T>();
246        let out_idx_bytes = (self.desc.batch as usize) * core::mem::size_of::<i32>();
247        let mut out_vals: Vec<u8> = vec![0u8; out_val_bytes];
248        let mut out_idx: Vec<u8> = vec![0u8; out_idx_bytes];
249        let stride_v = core::mem::size_of::<T>();
250        let stride_i = core::mem::size_of::<i32>();
251        for row in 0..self.desc.batch as usize {
252            let src_v_off = (row * (kp1 as usize) + self.desc.k as usize) * stride_v;
253            let src_i_off = (row * (kp1 as usize) + self.desc.k as usize) * stride_i;
254            let dst_v_off = row * stride_v;
255            let dst_i_off = row * stride_i;
256            out_vals[dst_v_off..dst_v_off + stride_v]
257                .copy_from_slice(&host_vals[src_v_off..src_v_off + stride_v]);
258            out_idx[dst_i_off..dst_i_off + stride_i]
259                .copy_from_slice(&host_idx_bytes[src_i_off..src_i_off + stride_i]);
260        }
261
262        unsafe {
263            copy_h2d_async(
264                args.values.data.as_raw().0 as *mut c_void,
265                out_vals.as_ptr() as *const c_void,
266                out_val_bytes,
267                stream,
268            )?;
269            copy_h2d_async(
270                args.indices.data.as_raw().0 as *mut c_void,
271                out_idx.as_ptr() as *const c_void,
272                out_idx_bytes,
273                stream,
274            )?;
275        }
276        // Keep the host buffers alive until the H2D completes.
277        stream
278            .synchronize()
279            .map_err(|_| Error::CutlassInternal(-1))?;
280        drop(out_vals);
281        drop(out_idx);
282        drop(host_vals);
283        drop(host_idx_bytes);
284        Ok(())
285    }
286}
287
288/// H2D copy helper — same pattern as `linalg/qr.rs` (no `cu` dep in
289/// this crate; we declare the symbol locally).
290unsafe fn copy_h2d_async(
291    dst: *mut c_void,
292    src: *const c_void,
293    bytes: usize,
294    stream: &Stream,
295) -> Result<()> {
296    if bytes == 0 {
297        return Ok(());
298    }
299    #[allow(non_camel_case_types)]
300    type CUresult = i32;
301    unsafe extern "system" {
302        fn cuMemcpyHtoDAsync_v2(
303            dst_device: u64,
304            src_host: *const c_void,
305            byte_count: usize,
306            h_stream: *mut c_void,
307        ) -> CUresult;
308    }
309    let status =
310        unsafe { cuMemcpyHtoDAsync_v2(dst as u64, src, bytes, stream.as_raw() as *mut c_void) };
311    if status != 0 {
312        return Err(Error::CutlassInternal(-status));
313    }
314    Ok(())
315}
316
317/// D2H copy helper.
318unsafe fn copy_d2h_async(
319    dst: *mut c_void,
320    src: u64,
321    bytes: usize,
322    stream: &Stream,
323) -> Result<()> {
324    if bytes == 0 {
325        return Ok(());
326    }
327    #[allow(non_camel_case_types)]
328    type CUresult = i32;
329    unsafe extern "system" {
330        fn cuMemcpyDtoHAsync_v2(
331            dst_host: *mut c_void,
332            src_device: u64,
333            byte_count: usize,
334            h_stream: *mut c_void,
335        ) -> CUresult;
336    }
337    let status =
338        unsafe { cuMemcpyDtoHAsync_v2(dst, src, bytes, stream.as_raw() as *mut c_void) };
339    if status != 0 {
340        return Err(Error::CutlassInternal(-status));
341    }
342    Ok(())
343}