1use crate::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6 asc: bool,
7 last_dim: usize,
8}
9
10impl ArgSort {
11 fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
12 #[allow(clippy::uninit_vec)]
13 let mut sort_indexes = unsafe {
15 let el_count = layout.shape().elem_count();
16 let mut v = Vec::with_capacity(el_count);
17 v.set_len(el_count);
18 v
19 };
20 if self.asc {
21 sort_indexes
22 .par_chunks_exact_mut(self.last_dim)
23 .zip(vs.par_chunks_exact(self.last_dim))
24 .for_each(|(indexes, vs)| {
25 indexes
26 .iter_mut()
27 .enumerate()
28 .for_each(|(i, v)| *v = i as u32);
29 indexes.sort_by(|&i, &j| {
30 vs[i as usize]
31 .partial_cmp(&vs[j as usize])
32 .unwrap_or(std::cmp::Ordering::Greater)
33 })
34 });
35 } else {
36 sort_indexes
37 .par_chunks_exact_mut(self.last_dim)
38 .zip(vs.par_chunks_exact(self.last_dim))
39 .for_each(|(indexes, vs)| {
40 indexes
41 .iter_mut()
42 .enumerate()
43 .for_each(|(i, v)| *v = i as u32);
44 indexes.sort_by(|&j, &i| {
45 vs[i as usize]
46 .partial_cmp(&vs[j as usize])
47 .unwrap_or(std::cmp::Ordering::Greater)
48 })
49 });
50 }
51 sort_indexes
52 }
53}
54
55#[cfg(feature = "cuda")]
56mod cuda {
57 use super::*;
58 use crate::cuda_backend::cudarc::driver::{
59 CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
60 };
61 use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
62 use crate::{CudaDevice, WithDType};
63
64 impl crate::cuda_backend::Map1Any for ArgSort {
65 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
66 &self,
67 src: &CudaSlice<T>,
68 dev: &CudaDevice,
69 layout: &crate::Layout,
70 _wrap: W,
71 ) -> Result<S> {
72 let slice = match layout.contiguous_offsets() {
73 None => crate::bail!("input has to be contiguous"),
74 Some((o1, o2)) => src.slice(o1..o2),
75 };
76 let elem_count = layout.shape().elem_count();
77 let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
78 let func = if self.asc {
79 dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
80 } else {
81 dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
82 };
83 let ncols = self.last_dim;
84 let nrows = elem_count / ncols;
85 let ncols_pad = next_power_of_2(ncols);
86 let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
87 let cfg = LaunchConfig {
88 grid_dim: (1, nrows as u32, 1),
89 block_dim: (ncols_pad as u32, 1, 1),
90 shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
91 };
92 unsafe { func.launch(cfg, params) }.w()?;
93 Ok(S::U32(dst))
94 }
95 }
96}
97
98impl crate::CustomOp1 for ArgSort {
99 fn name(&self) -> &'static str {
100 "argsort"
101 }
102
103 fn cpu_fwd(
104 &self,
105 storage: &crate::CpuStorage,
106 layout: &crate::Layout,
107 ) -> Result<(crate::CpuStorage, crate::Shape)> {
108 let sort_indexes = match storage {
109 crate::CpuStorage::U8(vs) => self.asort(vs, layout),
110 crate::CpuStorage::U32(vs) => self.asort(vs, layout),
111 crate::CpuStorage::I64(vs) => self.asort(vs, layout),
112 crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
113 crate::CpuStorage::F16(vs) => self.asort(vs, layout),
114 crate::CpuStorage::F32(vs) => self.asort(vs, layout),
115 crate::CpuStorage::F64(vs) => self.asort(vs, layout),
116 };
117 let sort_indexes = crate::CpuStorage::U32(sort_indexes);
118 Ok((sort_indexes, layout.shape().into()))
119 }
120
121 #[cfg(feature = "cuda")]
122 fn cuda_fwd(
123 &self,
124 storage: &crate::CudaStorage,
125 layout: &crate::Layout,
126 ) -> Result<(crate::CudaStorage, crate::Shape)> {
127 use crate::backend::BackendStorage;
128 use crate::cuda_backend::Map1Any;
129 let dev = storage.device();
130 let slice = self.map(&storage.slice, dev, layout)?;
131 let dst = crate::cuda_backend::CudaStorage {
132 slice,
133 device: dev.clone(),
134 };
135 Ok((dst, layout.shape().clone()))
136 }
137
138 #[cfg(feature = "metal")]
139 fn metal_fwd(
140 &self,
141 storage: &crate::MetalStorage,
142 layout: &crate::Layout,
143 ) -> Result<(crate::MetalStorage, crate::Shape)> {
144 use crate::backend::BackendStorage;
145 use crate::DType;
146
147 let name = {
148 if self.asc {
149 match storage.dtype() {
150 DType::BF16 => "asort_asc_bf16",
151 DType::F16 => "asort_asc_f16",
152 DType::F32 => "asort_asc_f32",
153 DType::F64 => "asort_asc_f64",
154 DType::U8 => "asort_asc_u8",
155 DType::U32 => "asort_asc_u32",
156 DType::I64 => "asort_asc_i64",
157 }
158 } else {
159 match storage.dtype() {
160 DType::BF16 => "asort_desc_bf16",
161 DType::F16 => "asort_desc_f16",
162 DType::F32 => "asort_desc_f32",
163 DType::F64 => "asort_desc_f64",
164 DType::U8 => "asort_desc_u8",
165 DType::U32 => "asort_desc_u32",
166 DType::I64 => "asort_desc_i64",
167 }
168 }
169 };
170 let device = storage.device();
171 let kernels = device.kernels();
172 let command_buffer = device.command_buffer()?;
173 let el = layout.shape().elem_count();
174 let ncols = self.last_dim;
175 let nrows = el / ncols;
176 let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
177 let dst = device.new_buffer(el, DType::U32, "asort")?;
178 let mut ncols_pad = 1;
179 while ncols_pad < ncols {
180 ncols_pad *= 2;
181 }
182 candle_metal_kernels::call_arg_sort(
183 device.metal_device(),
184 &command_buffer,
185 kernels,
186 name,
187 nrows,
188 ncols,
189 ncols_pad,
190 src,
191 &dst,
192 )
193 .map_err(crate::Error::wrap)?;
194 let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
195 Ok((dst, layout.shape().clone()))
196 }
197}
198
199#[allow(unused)]
200fn next_power_of_2(x: usize) -> usize {
201 let mut n = 1;
202 while n < x {
203 n *= 2
204 }
205 n
206}
207
208impl Tensor {
209 pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
215 if !self.is_contiguous() {
216 return Err(crate::Error::RequiresContiguous {
217 op: "arg_sort_last_dim",
218 });
219 }
220 let last_dim = match self.dims().last() {
221 None => crate::bail!("empty last-dim in arg-sort"),
222 Some(last_dim) => *last_dim,
223 };
224 self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
226 }
227
228 pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
235 if !self.is_contiguous() {
236 return Err(crate::Error::RequiresContiguous {
237 op: "sort_last_dim",
238 });
239 }
240 let asort = self.arg_sort_last_dim(asc)?;
241 let sorted = self.gather(&asort, crate::D::Minus1)?;
242 Ok((sorted, asort))
243 }
244}