1use core::ffi::c_void;
20use core::marker::PhantomData;
21
22use baracuda_cutlass::{Error, Result};
23use baracuda_driver::Stream;
24use baracuda_kernels_types::{
25 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
26 PlanPreference, PrecisionGuarantee, SortKind, TensorMut, TensorRef, Workspace,
27};
28
29use super::{map_status, SORT_MAX_ROW};
30
31#[derive(Copy, Clone, Debug)]
33pub struct SortDescriptor {
34 pub batch: i32,
36 pub row_len: i32,
38 pub descending: bool,
40 pub element: ElementKind,
42}
43
44pub struct SortArgs<'a, T: Element> {
46 pub input: TensorRef<'a, T, 2>,
48 pub values: TensorMut<'a, T, 2>,
50 pub indices: TensorMut<'a, i32, 2>,
52}
53
54pub struct SortPlan<T: Element> {
78 desc: SortDescriptor,
79 sku: KernelSku,
80 _marker: PhantomData<T>,
81}
82
83impl<T: Element> SortPlan<T> {
84 pub fn select(
86 _stream: &Stream,
87 desc: &SortDescriptor,
88 _pref: PlanPreference,
89 ) -> Result<Self> {
90 validate_sort_desc(desc.batch, desc.row_len, desc.element, T::KIND, "SortPlan")?;
91 let sku = build_sku::<T>(SortKind::Sort);
92 Ok(Self {
93 desc: *desc,
94 sku,
95 _marker: PhantomData,
96 })
97 }
98
99 pub fn can_implement(&self, args: &SortArgs<'_, T>) -> Result<()> {
101 validate_sort_args_2(
102 self.desc.batch,
103 self.desc.row_len,
104 args.input.shape,
105 args.values.shape,
106 args.indices.shape,
107 "SortPlan",
108 )
109 }
110
111 #[inline]
113 pub fn workspace_size(&self) -> usize {
114 0
115 }
116
117 #[inline]
119 pub fn sku(&self) -> KernelSku {
120 self.sku
121 }
122
123 #[inline]
125 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
126 self.sku.precision_guarantee
127 }
128
129 pub fn run(
131 &self,
132 stream: &Stream,
133 _workspace: Workspace<'_>,
134 args: SortArgs<'_, T>,
135 ) -> Result<()> {
136 self.can_implement(&args)?;
137 if self.desc.batch == 0 || self.desc.row_len == 0 {
138 return Ok(());
139 }
140 let in_ptr = args.input.data.as_raw().0 as *const c_void;
141 let vals_ptr = args.values.data.as_raw().0 as *mut c_void;
142 let idx_ptr = args.indices.data.as_raw().0 as *mut c_void;
143 let stream_ptr = stream.as_raw() as *mut c_void;
144 let desc_flag = if self.desc.descending { 1 } else { 0 };
145
146 let status = match T::KIND {
147 ElementKind::F32 => unsafe {
148 baracuda_kernels_sys::baracuda_kernels_sort_f32_run(
149 self.desc.batch,
150 self.desc.row_len,
151 desc_flag,
152 in_ptr,
153 vals_ptr,
154 idx_ptr,
155 core::ptr::null_mut(),
156 0,
157 stream_ptr,
158 )
159 },
160 ElementKind::F64 => unsafe {
161 baracuda_kernels_sys::baracuda_kernels_sort_f64_run(
162 self.desc.batch,
163 self.desc.row_len,
164 desc_flag,
165 in_ptr,
166 vals_ptr,
167 idx_ptr,
168 core::ptr::null_mut(),
169 0,
170 stream_ptr,
171 )
172 },
173 ElementKind::I32 => unsafe {
174 baracuda_kernels_sys::baracuda_kernels_sort_i32_run(
175 self.desc.batch,
176 self.desc.row_len,
177 desc_flag,
178 in_ptr,
179 vals_ptr,
180 idx_ptr,
181 core::ptr::null_mut(),
182 0,
183 stream_ptr,
184 )
185 },
186 ElementKind::I64 => unsafe {
187 baracuda_kernels_sys::baracuda_kernels_sort_i64_run(
188 self.desc.batch,
189 self.desc.row_len,
190 desc_flag,
191 in_ptr,
192 vals_ptr,
193 idx_ptr,
194 core::ptr::null_mut(),
195 0,
196 stream_ptr,
197 )
198 },
199 _ => {
200 return Err(Error::Unsupported(
201 "baracuda-kernels::SortPlan::run reached an unimplemented dtype \
202 — select() should have caught this",
203 ));
204 }
205 };
206 map_status(status)
207 }
208}
209
210pub(crate) fn validate_sort_desc(
214 batch: i32,
215 row_len: i32,
216 descriptor_element: ElementKind,
217 expected_element: ElementKind,
218 _plan_name: &'static str,
219) -> Result<()> {
220 if descriptor_element != expected_element {
221 return Err(Error::Unsupported(
222 "baracuda-kernels::sort: descriptor element != type parameter T",
223 ));
224 }
225 if batch < 0 || row_len < 0 {
226 return Err(Error::InvalidProblem(
227 "baracuda-kernels::sort: batch / row_len must be non-negative",
228 ));
229 }
230 if row_len > SORT_MAX_ROW {
231 return Err(Error::Unsupported(
232 "baracuda-kernels::sort: row_len > 1024 not supported in the \
233 block-bitonic trailblazer (tile-radix follow-up reserved)",
234 ));
235 }
236 if !matches!(
237 descriptor_element,
238 ElementKind::F32 | ElementKind::F64 | ElementKind::I32 | ElementKind::I64
239 ) {
240 return Err(Error::Unsupported(
241 "baracuda-kernels::sort: today only f32 / f64 / i32 / i64 wired",
242 ));
243 }
244 Ok(())
245}
246
247pub(crate) fn validate_sort_args_2(
249 batch: i32,
250 row_len: i32,
251 in_shape: [i32; 2],
252 vals_shape: [i32; 2],
253 idx_shape: [i32; 2],
254 _plan_name: &'static str,
255) -> Result<()> {
256 let expected = [batch, row_len];
257 if in_shape != expected {
258 return Err(Error::InvalidProblem(
259 "baracuda-kernels::sort: input shape != [batch, row_len]",
260 ));
261 }
262 if vals_shape != expected {
263 return Err(Error::InvalidProblem(
264 "baracuda-kernels::sort: values shape != [batch, row_len]",
265 ));
266 }
267 if idx_shape != expected {
268 return Err(Error::InvalidProblem(
269 "baracuda-kernels::sort: indices shape != [batch, row_len]",
270 ));
271 }
272 Ok(())
273}
274
275pub(crate) fn build_sku<T: Element>(op: SortKind) -> KernelSku {
277 let precision_guarantee = PrecisionGuarantee {
278 math_precision: if T::KIND == ElementKind::F64 {
279 MathPrecision::F64
280 } else {
281 MathPrecision::F32
282 },
283 accumulator: T::KIND,
284 bit_stable_on_same_hardware: true,
291 deterministic: true,
292 };
293 KernelSku {
294 category: OpCategory::Sorting,
295 op: op as u16,
296 element: T::KIND,
297 aux_element: Some(ElementKind::I32),
298 layout: None,
299 epilogue: None,
300 arch: ArchSku::Sm80,
301 backend: BackendKind::Bespoke,
302 precision_guarantee,
303 }
304}