1use crate::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6 asc: bool,
7 last_dim: usize,
8}
9
10impl ArgSort {
11 fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
12 #[allow(clippy::uninit_vec)]
13 let mut sort_indexes = unsafe {
15 let el_count = layout.shape().elem_count();
16 let mut v = Vec::with_capacity(el_count);
17 v.set_len(el_count);
18 v
19 };
20 if self.asc {
21 sort_indexes
22 .par_chunks_exact_mut(self.last_dim)
23 .zip(vs.par_chunks_exact(self.last_dim))
24 .for_each(|(indexes, vs)| {
25 indexes
26 .iter_mut()
27 .enumerate()
28 .for_each(|(i, v)| *v = i as u32);
29 indexes.sort_by(|&i, &j| {
30 vs[i as usize]
31 .partial_cmp(&vs[j as usize])
32 .unwrap_or(std::cmp::Ordering::Greater)
33 })
34 });
35 } else {
36 sort_indexes
37 .par_chunks_exact_mut(self.last_dim)
38 .zip(vs.par_chunks_exact(self.last_dim))
39 .for_each(|(indexes, vs)| {
40 indexes
41 .iter_mut()
42 .enumerate()
43 .for_each(|(i, v)| *v = i as u32);
44 indexes.sort_by(|&j, &i| {
45 vs[i as usize]
46 .partial_cmp(&vs[j as usize])
47 .unwrap_or(std::cmp::Ordering::Greater)
48 })
49 });
50 }
51 sort_indexes
52 }
53}
54
55#[cfg(feature = "cuda")]
56mod cuda {
57 use super::*;
58 use crate::cuda_backend::cudarc::driver::{
59 CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
60 };
61 use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
62 use crate::{CudaDevice, WithDType};
63
64 impl crate::cuda_backend::Map1Any for ArgSort {
65 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
66 &self,
67 src: &CudaSlice<T>,
68 dev: &CudaDevice,
69 layout: &crate::Layout,
70 _wrap: W,
71 ) -> Result<S> {
72 use cudarc::driver::PushKernelArg;
73
74 let slice = match layout.contiguous_offsets() {
75 None => crate::bail!("input has to be contiguous"),
76 Some((o1, o2)) => src.slice(o1..o2),
77 };
78 let elem_count = layout.shape().elem_count();
79 let dst = unsafe { dev.alloc::<u32>(elem_count)? };
80 let func = if self.asc {
81 dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
82 } else {
83 dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
84 };
85 let ncols = self.last_dim;
86 let nrows = elem_count / ncols;
87 let ncols_pad = next_power_of_2(ncols);
88 let cfg = LaunchConfig {
89 grid_dim: (1, nrows as u32, 1),
90 block_dim: (ncols_pad as u32, 1, 1),
91 shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
92 };
93 let stream = dev.cuda_stream();
94 let mut builder = stream.launch_builder(&func);
95 let ncols = ncols as i32;
96 let ncols_pad = ncols_pad as i32;
97 builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
98 unsafe { builder.launch(cfg) }.w()?;
99 Ok(S::U32(dst))
100 }
101 }
102}
103
104impl crate::CustomOp1 for ArgSort {
105 fn name(&self) -> &'static str {
106 "argsort"
107 }
108
109 fn cpu_fwd(
110 &self,
111 storage: &crate::CpuStorage,
112 layout: &crate::Layout,
113 ) -> Result<(crate::CpuStorage, crate::Shape)> {
114 let sort_indexes = match storage {
115 crate::CpuStorage::U8(vs) => self.asort(vs, layout),
116 crate::CpuStorage::U32(vs) => self.asort(vs, layout),
117 crate::CpuStorage::I64(vs) => self.asort(vs, layout),
118 crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
119 crate::CpuStorage::F16(vs) => self.asort(vs, layout),
120 crate::CpuStorage::F32(vs) => self.asort(vs, layout),
121 crate::CpuStorage::F64(vs) => self.asort(vs, layout),
122 };
123 let sort_indexes = crate::CpuStorage::U32(sort_indexes);
124 Ok((sort_indexes, layout.shape().into()))
125 }
126
127 #[cfg(feature = "cuda")]
128 fn cuda_fwd(
129 &self,
130 storage: &crate::CudaStorage,
131 layout: &crate::Layout,
132 ) -> Result<(crate::CudaStorage, crate::Shape)> {
133 use crate::backend::BackendStorage;
134 use crate::cuda_backend::Map1Any;
135 let dev = storage.device();
136 let slice = self.map(&storage.slice, dev, layout)?;
137 let dst = crate::cuda_backend::CudaStorage {
138 slice,
139 device: dev.clone(),
140 };
141 Ok((dst, layout.shape().clone()))
142 }
143
144 #[cfg(feature = "metal")]
145 fn metal_fwd(
146 &self,
147 storage: &crate::MetalStorage,
148 layout: &crate::Layout,
149 ) -> Result<(crate::MetalStorage, crate::Shape)> {
150 use crate::backend::BackendStorage;
151 use crate::DType;
152
153 let name = {
154 if self.asc {
155 match storage.dtype() {
156 DType::BF16 => "asort_asc_bf16",
157 DType::F16 => "asort_asc_f16",
158 DType::F32 => "asort_asc_f32",
159 DType::F64 => "asort_asc_f64",
160 DType::U8 => "asort_asc_u8",
161 DType::U32 => "asort_asc_u32",
162 DType::I64 => "asort_asc_i64",
163 }
164 } else {
165 match storage.dtype() {
166 DType::BF16 => "asort_desc_bf16",
167 DType::F16 => "asort_desc_f16",
168 DType::F32 => "asort_desc_f32",
169 DType::F64 => "asort_desc_f64",
170 DType::U8 => "asort_desc_u8",
171 DType::U32 => "asort_desc_u32",
172 DType::I64 => "asort_desc_i64",
173 }
174 }
175 };
176 let device = storage.device();
177 let kernels = device.kernels();
178 let command_buffer = device.command_buffer()?;
179 let el = layout.shape().elem_count();
180 let ncols = self.last_dim;
181 let nrows = el / ncols;
182 let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
183 let dst = device.new_buffer(el, DType::U32, "asort")?;
184 let mut ncols_pad = 1;
185 while ncols_pad < ncols {
186 ncols_pad *= 2;
187 }
188 candle_metal_kernels::call_arg_sort(
189 device.metal_device(),
190 &command_buffer,
191 kernels,
192 name,
193 nrows,
194 ncols,
195 ncols_pad,
196 src,
197 &dst,
198 )
199 .map_err(crate::Error::wrap)?;
200 let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
201 Ok((dst, layout.shape().clone()))
202 }
203}
204
205#[allow(unused)]
206fn next_power_of_2(x: usize) -> usize {
207 let mut n = 1;
208 while n < x {
209 n *= 2
210 }
211 n
212}
213
214impl Tensor {
215 pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
221 if !self.is_contiguous() {
222 return Err(crate::Error::RequiresContiguous {
223 op: "arg_sort_last_dim",
224 });
225 }
226 let last_dim = match self.dims().last() {
227 None => crate::bail!("empty last-dim in arg-sort"),
228 Some(last_dim) => *last_dim,
229 };
230 self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
232 }
233
234 pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
241 if !self.is_contiguous() {
242 return Err(crate::Error::RequiresContiguous {
243 op: "sort_last_dim",
244 });
245 }
246 let asort = self.arg_sort_last_dim(asc)?;
247 let sorted = self.gather(&asort, crate::D::Minus1)?;
248 Ok((sorted, asort))
249 }
250}