1use crate::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6 asc: bool,
7 last_dim: usize,
8}
9
10impl ArgSort {
11 fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
12 #[allow(clippy::uninit_vec)]
13 let mut sort_indexes = unsafe {
15 let el_count = layout.shape().elem_count();
16 let mut v = Vec::with_capacity(el_count);
17 v.set_len(el_count);
18 v
19 };
20 if self.asc {
21 sort_indexes
22 .par_chunks_exact_mut(self.last_dim)
23 .zip(vs.par_chunks_exact(self.last_dim))
24 .for_each(|(indexes, vs)| {
25 indexes
26 .iter_mut()
27 .enumerate()
28 .for_each(|(i, v)| *v = i as u32);
29 indexes.sort_by(|&i, &j| {
30 vs[i as usize]
31 .partial_cmp(&vs[j as usize])
32 .unwrap_or(std::cmp::Ordering::Greater)
33 })
34 });
35 } else {
36 sort_indexes
37 .par_chunks_exact_mut(self.last_dim)
38 .zip(vs.par_chunks_exact(self.last_dim))
39 .for_each(|(indexes, vs)| {
40 indexes
41 .iter_mut()
42 .enumerate()
43 .for_each(|(i, v)| *v = i as u32);
44 indexes.sort_by(|&j, &i| {
45 vs[i as usize]
46 .partial_cmp(&vs[j as usize])
47 .unwrap_or(std::cmp::Ordering::Greater)
48 })
49 });
50 }
51 sort_indexes
52 }
53}
54
55#[cfg(feature = "cuda")]
56mod cuda {
57 use super::*;
58 use crate::cuda_backend::cudarc::driver::{
59 CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
60 };
61 use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
62 use crate::{CudaDevice, WithDType};
63
64 impl crate::cuda_backend::Map1Any for ArgSort {
65 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
66 &self,
67 src: &CudaSlice<T>,
68 dev: &CudaDevice,
69 layout: &crate::Layout,
70 _wrap: W,
71 ) -> Result<S> {
72 use cudarc::driver::PushKernelArg;
73
74 let slice = match layout.contiguous_offsets() {
75 None => crate::bail!("input has to be contiguous"),
76 Some((o1, o2)) => src.slice(o1..o2),
77 };
78 let elem_count = layout.shape().elem_count();
79 let dst = unsafe { dev.alloc::<u32>(elem_count)? };
80 let func = if self.asc {
81 dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
82 } else {
83 dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
84 };
85 let ncols = self.last_dim;
86 let nrows = elem_count / ncols;
87 let ncols_pad = next_power_of_2(ncols);
88 let block_dim = ncols_pad.min(1024);
90 let cfg = LaunchConfig {
91 grid_dim: (nrows as u32, 1, 1),
92 block_dim: (block_dim as u32, 1, 1),
93 shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
94 };
95 let stream = dev.cuda_stream();
96 let mut builder = stream.launch_builder(&func);
97 let ncols = ncols as i32;
98 let ncols_pad = ncols_pad as i32;
99 builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
100 unsafe { builder.launch(cfg) }.w()?;
101 Ok(S::U32(dst))
102 }
103 }
104}
105
106impl crate::CustomOp1 for ArgSort {
107 fn name(&self) -> &'static str {
108 "argsort"
109 }
110
111 fn cpu_fwd(
112 &self,
113 storage: &crate::CpuStorage,
114 layout: &crate::Layout,
115 ) -> Result<(crate::CpuStorage, crate::Shape)> {
116 let sort_indexes = match storage {
117 crate::CpuStorage::U8(vs) => self.asort(vs, layout),
118 crate::CpuStorage::U32(vs) => self.asort(vs, layout),
119 crate::CpuStorage::I16(vs) => self.asort(vs, layout),
120 crate::CpuStorage::I32(vs) => self.asort(vs, layout),
121 crate::CpuStorage::I64(vs) => self.asort(vs, layout),
122 crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
123 crate::CpuStorage::F16(vs) => self.asort(vs, layout),
124 crate::CpuStorage::F32(vs) => self.asort(vs, layout),
125 crate::CpuStorage::F64(vs) => self.asort(vs, layout),
126 crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
127 crate::CpuStorage::F6E2M3(_) => {
129 return Err(
130 crate::Error::UnsupportedDTypeForOp(crate::DType::F6E2M3, "argsort").bt(),
131 )
132 }
133 crate::CpuStorage::F6E3M2(_) => {
134 return Err(
135 crate::Error::UnsupportedDTypeForOp(crate::DType::F6E3M2, "argsort").bt(),
136 )
137 }
138 crate::CpuStorage::F4(_) => {
139 return Err(crate::Error::UnsupportedDTypeForOp(crate::DType::F4, "argsort").bt())
140 }
141 crate::CpuStorage::F8E8M0(_) => {
142 return Err(
143 crate::Error::UnsupportedDTypeForOp(crate::DType::F8E8M0, "argsort").bt(),
144 )
145 }
146 };
147 let sort_indexes = crate::CpuStorage::U32(sort_indexes);
148 Ok((sort_indexes, layout.shape().into()))
149 }
150
151 #[cfg(feature = "cuda")]
152 fn cuda_fwd(
153 &self,
154 storage: &crate::CudaStorage,
155 layout: &crate::Layout,
156 ) -> Result<(crate::CudaStorage, crate::Shape)> {
157 use crate::backend::BackendStorage;
158 use crate::cuda_backend::Map1Any;
159 let dev = storage.device();
160 let slice = self.map(&storage.slice, dev, layout)?;
161 let dst = crate::cuda_backend::CudaStorage {
162 slice,
163 device: dev.clone(),
164 };
165 Ok((dst, layout.shape().clone()))
166 }
167
168 #[cfg(feature = "metal")]
169 fn metal_fwd(
170 &self,
171 storage: &crate::MetalStorage,
172 layout: &crate::Layout,
173 ) -> Result<(crate::MetalStorage, crate::Shape)> {
174 use crate::backend::BackendStorage;
175 use crate::DType;
176
177 let name = {
178 if self.asc {
179 match storage.dtype() {
180 DType::BF16 => "asort_asc_bf16",
181 DType::F16 => "asort_asc_f16",
182 DType::F32 => "asort_asc_f32",
183 DType::F64 => "asort_asc_f64",
184 DType::U8 => "asort_asc_u8",
185 DType::U32 => "asort_asc_u32",
186 DType::I16 => "asort_asc_i16",
187 DType::I32 => "asort_asc_i32",
188 DType::I64 => "asort_asc_i64",
189 DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
190 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
191 return Err(
192 crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(),
193 )
194 }
195 }
196 } else {
197 match storage.dtype() {
198 DType::BF16 => "asort_desc_bf16",
199 DType::F16 => "asort_desc_f16",
200 DType::F32 => "asort_desc_f32",
201 DType::F64 => "asort_desc_f64",
202 DType::U8 => "asort_desc_u8",
203 DType::U32 => "asort_desc_u32",
204 DType::I16 => "asort_desc_i16",
205 DType::I32 => "asort_desc_i32",
206 DType::I64 => "asort_desc_i64",
207 DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
208 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
209 return Err(
210 crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(),
211 )
212 }
213 }
214 }
215 };
216 let device = storage.device();
217 let kernels = device.kernels();
218 let command_encoder = device.command_encoder()?;
219 let el = layout.shape().elem_count();
220 let ncols = self.last_dim;
221 let nrows = el / ncols;
222 let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
223 let dst = device.new_buffer(el, DType::U32, "asort")?;
224 let mut ncols_pad = 1;
225 while ncols_pad < ncols {
226 ncols_pad *= 2;
227 }
228 candle_metal_kernels::call_arg_sort(
229 device.metal_device(),
230 &command_encoder,
231 kernels,
232 name,
233 nrows,
234 ncols,
235 ncols_pad,
236 src,
237 &dst,
238 )
239 .map_err(crate::Error::wrap)?;
240 let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
241 Ok((dst, layout.shape().clone()))
242 }
243}
244
245#[allow(unused)]
246fn next_power_of_2(x: usize) -> usize {
247 let mut n = 1;
248 while n < x {
249 n *= 2
250 }
251 n
252}
253
254impl Tensor {
255 pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
261 if !self.is_contiguous() {
262 return Err(crate::Error::RequiresContiguous {
263 op: "arg_sort_last_dim",
264 });
265 }
266 let last_dim = match self.dims().last() {
267 None => crate::bail!("empty last-dim in arg-sort"),
268 Some(last_dim) => *last_dim,
269 };
270 self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
272 }
273
274 pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
281 if !self.is_contiguous() {
282 return Err(crate::Error::RequiresContiguous {
283 op: "sort_last_dim",
284 });
285 }
286 let asort = self.arg_sort_last_dim(asc)?;
287 let sorted = self.gather(&asort, crate::D::Minus1)?;
288 Ok((sorted, asort))
289 }
290}