Skip to main content

baracuda_kernels/sort/
searchsorted.rs

1//! `searchsorted` plan — per-query binary search in a 1-D sorted array.
2//!
3//! `searchsorted(sorted_seq[L], values[N], right) -> output[N]` (i32).
4//! `right == false` (default) returns `lower_bound`; `right == true`
5//! returns `upper_bound`. PyTorch `torch.searchsorted`.
6//!
7//! Trailblazer dtype coverage: `f32, f64, i32, i64`. No BW
8//! (set-valued / non-differentiable).
9
10use core::ffi::c_void;
11use core::marker::PhantomData;
12
13use baracuda_cutlass::{Error, Result};
14use baracuda_driver::Stream;
15use baracuda_kernels_types::{
16    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
17    PlanPreference, PrecisionGuarantee, SortKind, TensorMut, TensorRef, Workspace,
18};
19
20use super::map_status;
21
22/// Descriptor for a `searchsorted` op.
23#[derive(Copy, Clone, Debug)]
24pub struct SearchsortedDescriptor {
25    /// Number of query values.
26    pub num_queries: i64,
27    /// Length of the sorted sequence.
28    pub len_sorted: i32,
29    /// `false` = lower_bound (default); `true` = upper_bound.
30    pub right: bool,
31    /// Element type of both sorted_seq and values.
32    pub element: ElementKind,
33}
34
35/// Args bundle for a `searchsorted` launch.
36pub struct SearchsortedArgs<'a, T: Element> {
37    /// Sorted sequence `[len_sorted]`.
38    pub sorted_seq: TensorRef<'a, T, 1>,
39    /// Query values `[num_queries]`.
40    pub values: TensorRef<'a, T, 1>,
41    /// Output positions `[num_queries]` (i32).
42    pub output: TensorMut<'a, i32, 1>,
43}
44
45/// `searchsorted` plan.
46///
47/// `searchsorted(sorted_seq[L], values[N], right) -> output[N]`
48/// (PyTorch `torch.searchsorted`). `right == false` returns
49/// `lower_bound`, `right == true` returns `upper_bound`.
50///
51/// **When to use**: per-query binary search into a sorted 1-D array.
52/// Useful for histogram-binning / quantile-bucketing. No BW.
53///
54/// **Dtypes**: `{f32, f64, i32, i64}` for both sorted_seq and
55/// values; output always `i32`.
56///
57/// **Shape limits**: sorted_seq `[len_sorted]`; values, output
58/// `[num_queries]`.
59///
60/// **Workspace**: none.
61///
62/// **Precision guarantee**: deterministic, bit-stable. Pure binary
63/// search.
64pub struct SearchsortedPlan<T: Element> {
65    desc: SearchsortedDescriptor,
66    sku: KernelSku,
67    _marker: PhantomData<T>,
68}
69
70impl<T: Element> SearchsortedPlan<T> {
71    /// Pick a kernel for `desc`.
72    pub fn select(
73        _stream: &Stream,
74        desc: &SearchsortedDescriptor,
75        _pref: PlanPreference,
76    ) -> Result<Self> {
77        if desc.element != T::KIND {
78            return Err(Error::Unsupported(
79                "baracuda-kernels::SearchsortedPlan: descriptor element != type parameter T",
80            ));
81        }
82        if desc.num_queries < 0 || desc.len_sorted < 0 {
83            return Err(Error::InvalidProblem(
84                "baracuda-kernels::SearchsortedPlan: num_queries / len_sorted must be \
85                 non-negative",
86            ));
87        }
88        if !matches!(
89            T::KIND,
90            ElementKind::F32 | ElementKind::F64 | ElementKind::I32 | ElementKind::I64
91        ) {
92            return Err(Error::Unsupported(
93                "baracuda-kernels::SearchsortedPlan: today only f32 / f64 / i32 / i64 wired",
94            ));
95        }
96        let precision_guarantee = PrecisionGuarantee {
97            math_precision: if T::KIND == ElementKind::F64 {
98                MathPrecision::F64
99            } else {
100                MathPrecision::F32
101            },
102            accumulator: ElementKind::I32,
103            bit_stable_on_same_hardware: true,
104            deterministic: true,
105        };
106        let sku = KernelSku {
107            category: OpCategory::Sorting,
108            op: SortKind::Searchsorted as u16,
109            element: T::KIND,
110            aux_element: Some(ElementKind::I32),
111            layout: None,
112            epilogue: None,
113            arch: ArchSku::Sm80,
114            backend: BackendKind::Bespoke,
115            precision_guarantee,
116        };
117        Ok(Self {
118            desc: *desc,
119            sku,
120            _marker: PhantomData,
121        })
122    }
123
124    /// Validate args.
125    pub fn can_implement(&self, args: &SearchsortedArgs<'_, T>) -> Result<()> {
126        if args.sorted_seq.shape != [self.desc.len_sorted] {
127            return Err(Error::InvalidProblem(
128                "baracuda-kernels::SearchsortedPlan: sorted_seq shape != [len_sorted]",
129            ));
130        }
131        if (args.values.shape[0] as i64) != self.desc.num_queries {
132            return Err(Error::InvalidProblem(
133                "baracuda-kernels::SearchsortedPlan: values shape != [num_queries]",
134            ));
135        }
136        if (args.output.shape[0] as i64) != self.desc.num_queries {
137            return Err(Error::InvalidProblem(
138                "baracuda-kernels::SearchsortedPlan: output shape != [num_queries]",
139            ));
140        }
141        Ok(())
142    }
143
144    /// Workspace size in bytes.
145    #[inline]
146    pub fn workspace_size(&self) -> usize {
147        0
148    }
149
150    /// Identity of the kernel this plan picked.
151    #[inline]
152    pub fn sku(&self) -> KernelSku {
153        self.sku
154    }
155
156    /// Numerical guarantees for this plan's kernel.
157    #[inline]
158    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
159        self.sku.precision_guarantee
160    }
161
162    /// Launch.
163    pub fn run(
164        &self,
165        stream: &Stream,
166        _workspace: Workspace<'_>,
167        args: SearchsortedArgs<'_, T>,
168    ) -> Result<()> {
169        self.can_implement(&args)?;
170        if self.desc.num_queries == 0 {
171            return Ok(());
172        }
173        let seq_ptr = args.sorted_seq.data.as_raw().0 as *const c_void;
174        let val_ptr = args.values.data.as_raw().0 as *const c_void;
175        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
176        let stream_ptr = stream.as_raw() as *mut c_void;
177        let right_flag = if self.desc.right { 1 } else { 0 };
178
179        let status = match T::KIND {
180            ElementKind::F32 => unsafe {
181                baracuda_kernels_sys::baracuda_kernels_searchsorted_f32_run(
182                    self.desc.num_queries,
183                    self.desc.len_sorted,
184                    right_flag,
185                    seq_ptr,
186                    val_ptr,
187                    out_ptr,
188                    core::ptr::null_mut(),
189                    0,
190                    stream_ptr,
191                )
192            },
193            ElementKind::F64 => unsafe {
194                baracuda_kernels_sys::baracuda_kernels_searchsorted_f64_run(
195                    self.desc.num_queries,
196                    self.desc.len_sorted,
197                    right_flag,
198                    seq_ptr,
199                    val_ptr,
200                    out_ptr,
201                    core::ptr::null_mut(),
202                    0,
203                    stream_ptr,
204                )
205            },
206            ElementKind::I32 => unsafe {
207                baracuda_kernels_sys::baracuda_kernels_searchsorted_i32_run(
208                    self.desc.num_queries,
209                    self.desc.len_sorted,
210                    right_flag,
211                    seq_ptr,
212                    val_ptr,
213                    out_ptr,
214                    core::ptr::null_mut(),
215                    0,
216                    stream_ptr,
217                )
218            },
219            ElementKind::I64 => unsafe {
220                baracuda_kernels_sys::baracuda_kernels_searchsorted_i64_run(
221                    self.desc.num_queries,
222                    self.desc.len_sorted,
223                    right_flag,
224                    seq_ptr,
225                    val_ptr,
226                    out_ptr,
227                    core::ptr::null_mut(),
228                    0,
229                    stream_ptr,
230                )
231            },
232            _ => {
233                return Err(Error::Unsupported(
234                    "baracuda-kernels::SearchsortedPlan::run reached an unimplemented dtype",
235                ));
236            }
237        };
238        map_status(status)
239    }
240}