1use core::ffi::c_void;
11use core::marker::PhantomData;
12
13use baracuda_cutlass::{Error, Result};
14use baracuda_driver::Stream;
15use baracuda_kernels_types::{
16 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
17 PlanPreference, PrecisionGuarantee, SortKind, TensorMut, TensorRef, Workspace,
18};
19
20use super::map_status;
21
22#[derive(Copy, Clone, Debug)]
24pub struct SearchsortedDescriptor {
25 pub num_queries: i64,
27 pub len_sorted: i32,
29 pub right: bool,
31 pub element: ElementKind,
33}
34
35pub struct SearchsortedArgs<'a, T: Element> {
37 pub sorted_seq: TensorRef<'a, T, 1>,
39 pub values: TensorRef<'a, T, 1>,
41 pub output: TensorMut<'a, i32, 1>,
43}
44
45pub struct SearchsortedPlan<T: Element> {
65 desc: SearchsortedDescriptor,
66 sku: KernelSku,
67 _marker: PhantomData<T>,
68}
69
70impl<T: Element> SearchsortedPlan<T> {
71 pub fn select(
73 _stream: &Stream,
74 desc: &SearchsortedDescriptor,
75 _pref: PlanPreference,
76 ) -> Result<Self> {
77 if desc.element != T::KIND {
78 return Err(Error::Unsupported(
79 "baracuda-kernels::SearchsortedPlan: descriptor element != type parameter T",
80 ));
81 }
82 if desc.num_queries < 0 || desc.len_sorted < 0 {
83 return Err(Error::InvalidProblem(
84 "baracuda-kernels::SearchsortedPlan: num_queries / len_sorted must be \
85 non-negative",
86 ));
87 }
88 if !matches!(
89 T::KIND,
90 ElementKind::F32 | ElementKind::F64 | ElementKind::I32 | ElementKind::I64
91 ) {
92 return Err(Error::Unsupported(
93 "baracuda-kernels::SearchsortedPlan: today only f32 / f64 / i32 / i64 wired",
94 ));
95 }
96 let precision_guarantee = PrecisionGuarantee {
97 math_precision: if T::KIND == ElementKind::F64 {
98 MathPrecision::F64
99 } else {
100 MathPrecision::F32
101 },
102 accumulator: ElementKind::I32,
103 bit_stable_on_same_hardware: true,
104 deterministic: true,
105 };
106 let sku = KernelSku {
107 category: OpCategory::Sorting,
108 op: SortKind::Searchsorted as u16,
109 element: T::KIND,
110 aux_element: Some(ElementKind::I32),
111 layout: None,
112 epilogue: None,
113 arch: ArchSku::Sm80,
114 backend: BackendKind::Bespoke,
115 precision_guarantee,
116 };
117 Ok(Self {
118 desc: *desc,
119 sku,
120 _marker: PhantomData,
121 })
122 }
123
124 pub fn can_implement(&self, args: &SearchsortedArgs<'_, T>) -> Result<()> {
126 if args.sorted_seq.shape != [self.desc.len_sorted] {
127 return Err(Error::InvalidProblem(
128 "baracuda-kernels::SearchsortedPlan: sorted_seq shape != [len_sorted]",
129 ));
130 }
131 if (args.values.shape[0] as i64) != self.desc.num_queries {
132 return Err(Error::InvalidProblem(
133 "baracuda-kernels::SearchsortedPlan: values shape != [num_queries]",
134 ));
135 }
136 if (args.output.shape[0] as i64) != self.desc.num_queries {
137 return Err(Error::InvalidProblem(
138 "baracuda-kernels::SearchsortedPlan: output shape != [num_queries]",
139 ));
140 }
141 Ok(())
142 }
143
144 #[inline]
146 pub fn workspace_size(&self) -> usize {
147 0
148 }
149
150 #[inline]
152 pub fn sku(&self) -> KernelSku {
153 self.sku
154 }
155
156 #[inline]
158 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
159 self.sku.precision_guarantee
160 }
161
162 pub fn run(
164 &self,
165 stream: &Stream,
166 _workspace: Workspace<'_>,
167 args: SearchsortedArgs<'_, T>,
168 ) -> Result<()> {
169 self.can_implement(&args)?;
170 if self.desc.num_queries == 0 {
171 return Ok(());
172 }
173 let seq_ptr = args.sorted_seq.data.as_raw().0 as *const c_void;
174 let val_ptr = args.values.data.as_raw().0 as *const c_void;
175 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
176 let stream_ptr = stream.as_raw() as *mut c_void;
177 let right_flag = if self.desc.right { 1 } else { 0 };
178
179 let status = match T::KIND {
180 ElementKind::F32 => unsafe {
181 baracuda_kernels_sys::baracuda_kernels_searchsorted_f32_run(
182 self.desc.num_queries,
183 self.desc.len_sorted,
184 right_flag,
185 seq_ptr,
186 val_ptr,
187 out_ptr,
188 core::ptr::null_mut(),
189 0,
190 stream_ptr,
191 )
192 },
193 ElementKind::F64 => unsafe {
194 baracuda_kernels_sys::baracuda_kernels_searchsorted_f64_run(
195 self.desc.num_queries,
196 self.desc.len_sorted,
197 right_flag,
198 seq_ptr,
199 val_ptr,
200 out_ptr,
201 core::ptr::null_mut(),
202 0,
203 stream_ptr,
204 )
205 },
206 ElementKind::I32 => unsafe {
207 baracuda_kernels_sys::baracuda_kernels_searchsorted_i32_run(
208 self.desc.num_queries,
209 self.desc.len_sorted,
210 right_flag,
211 seq_ptr,
212 val_ptr,
213 out_ptr,
214 core::ptr::null_mut(),
215 0,
216 stream_ptr,
217 )
218 },
219 ElementKind::I64 => unsafe {
220 baracuda_kernels_sys::baracuda_kernels_searchsorted_i64_run(
221 self.desc.num_queries,
222 self.desc.len_sorted,
223 right_flag,
224 seq_ptr,
225 val_ptr,
226 out_ptr,
227 core::ptr::null_mut(),
228 0,
229 stream_ptr,
230 )
231 },
232 _ => {
233 return Err(Error::Unsupported(
234 "baracuda-kernels::SearchsortedPlan::run reached an unimplemented dtype",
235 ));
236 }
237 };
238 map_status(status)
239 }
240}