candle_core/
sort.rs

1use crate::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6    asc: bool,
7    last_dim: usize,
8}
9
10impl ArgSort {
11    fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
12        #[allow(clippy::uninit_vec)]
13        // Safety: indexes are set later in the parallelized section.
14        let mut sort_indexes = unsafe {
15            let el_count = layout.shape().elem_count();
16            let mut v = Vec::with_capacity(el_count);
17            v.set_len(el_count);
18            v
19        };
20        if self.asc {
21            sort_indexes
22                .par_chunks_exact_mut(self.last_dim)
23                .zip(vs.par_chunks_exact(self.last_dim))
24                .for_each(|(indexes, vs)| {
25                    indexes
26                        .iter_mut()
27                        .enumerate()
28                        .for_each(|(i, v)| *v = i as u32);
29                    indexes.sort_by(|&i, &j| {
30                        vs[i as usize]
31                            .partial_cmp(&vs[j as usize])
32                            .unwrap_or(std::cmp::Ordering::Greater)
33                    })
34                });
35        } else {
36            sort_indexes
37                .par_chunks_exact_mut(self.last_dim)
38                .zip(vs.par_chunks_exact(self.last_dim))
39                .for_each(|(indexes, vs)| {
40                    indexes
41                        .iter_mut()
42                        .enumerate()
43                        .for_each(|(i, v)| *v = i as u32);
44                    indexes.sort_by(|&j, &i| {
45                        vs[i as usize]
46                            .partial_cmp(&vs[j as usize])
47                            .unwrap_or(std::cmp::Ordering::Greater)
48                    })
49                });
50        }
51        sort_indexes
52    }
53}
54
55#[cfg(feature = "cuda")]
56mod cuda {
57    use super::*;
58    use crate::cuda_backend::cudarc::driver::{
59        CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
60    };
61    use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
62    use crate::{CudaDevice, WithDType};
63
64    impl crate::cuda_backend::Map1Any for ArgSort {
65        fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
66            &self,
67            src: &CudaSlice<T>,
68            dev: &CudaDevice,
69            layout: &crate::Layout,
70            _wrap: W,
71        ) -> Result<S> {
72            use cudarc::driver::PushKernelArg;
73
74            let slice = match layout.contiguous_offsets() {
75                None => crate::bail!("input has to be contiguous"),
76                Some((o1, o2)) => src.slice(o1..o2),
77            };
78            let elem_count = layout.shape().elem_count();
79            let dst = unsafe { dev.alloc::<u32>(elem_count)? };
80            let func = if self.asc {
81                dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
82            } else {
83                dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
84            };
85            let ncols = self.last_dim;
86            let nrows = elem_count / ncols;
87            let ncols_pad = next_power_of_2(ncols);
88            // Limit block dim to 1024 threads, which is the maximum on modern CUDA gpus.
89            let block_dim = ncols_pad.min(1024);
90            let cfg = LaunchConfig {
91                grid_dim: (nrows as u32, 1, 1),
92                block_dim: (block_dim as u32, 1, 1),
93                shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
94            };
95            let stream = dev.cuda_stream();
96            let mut builder = stream.launch_builder(&func);
97            let ncols = ncols as i32;
98            let ncols_pad = ncols_pad as i32;
99            builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
100            unsafe { builder.launch(cfg) }.w()?;
101            Ok(S::U32(dst))
102        }
103    }
104}
105
106impl crate::CustomOp1 for ArgSort {
107    fn name(&self) -> &'static str {
108        "argsort"
109    }
110
111    fn cpu_fwd(
112        &self,
113        storage: &crate::CpuStorage,
114        layout: &crate::Layout,
115    ) -> Result<(crate::CpuStorage, crate::Shape)> {
116        let sort_indexes = match storage {
117            crate::CpuStorage::U8(vs) => self.asort(vs, layout),
118            crate::CpuStorage::U32(vs) => self.asort(vs, layout),
119            crate::CpuStorage::I16(vs) => self.asort(vs, layout),
120            crate::CpuStorage::I32(vs) => self.asort(vs, layout),
121            crate::CpuStorage::I64(vs) => self.asort(vs, layout),
122            crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
123            crate::CpuStorage::F16(vs) => self.asort(vs, layout),
124            crate::CpuStorage::F32(vs) => self.asort(vs, layout),
125            crate::CpuStorage::F64(vs) => self.asort(vs, layout),
126            crate::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
127            // Dummy types don't support sorting
128            crate::CpuStorage::F6E2M3(_) => {
129                return Err(
130                    crate::Error::UnsupportedDTypeForOp(crate::DType::F6E2M3, "argsort").bt(),
131                )
132            }
133            crate::CpuStorage::F6E3M2(_) => {
134                return Err(
135                    crate::Error::UnsupportedDTypeForOp(crate::DType::F6E3M2, "argsort").bt(),
136                )
137            }
138            crate::CpuStorage::F4(_) => {
139                return Err(crate::Error::UnsupportedDTypeForOp(crate::DType::F4, "argsort").bt())
140            }
141            crate::CpuStorage::F8E8M0(_) => {
142                return Err(
143                    crate::Error::UnsupportedDTypeForOp(crate::DType::F8E8M0, "argsort").bt(),
144                )
145            }
146        };
147        let sort_indexes = crate::CpuStorage::U32(sort_indexes);
148        Ok((sort_indexes, layout.shape().into()))
149    }
150
151    #[cfg(feature = "cuda")]
152    fn cuda_fwd(
153        &self,
154        storage: &crate::CudaStorage,
155        layout: &crate::Layout,
156    ) -> Result<(crate::CudaStorage, crate::Shape)> {
157        use crate::backend::BackendStorage;
158        use crate::cuda_backend::Map1Any;
159        let dev = storage.device();
160        let slice = self.map(&storage.slice, dev, layout)?;
161        let dst = crate::cuda_backend::CudaStorage {
162            slice,
163            device: dev.clone(),
164        };
165        Ok((dst, layout.shape().clone()))
166    }
167
168    #[cfg(feature = "metal")]
169    fn metal_fwd(
170        &self,
171        storage: &crate::MetalStorage,
172        layout: &crate::Layout,
173    ) -> Result<(crate::MetalStorage, crate::Shape)> {
174        use crate::backend::BackendStorage;
175        use crate::DType;
176
177        let name = {
178            if self.asc {
179                match storage.dtype() {
180                    DType::BF16 => "asort_asc_bf16",
181                    DType::F16 => "asort_asc_f16",
182                    DType::F32 => "asort_asc_f32",
183                    DType::F64 => "asort_asc_f64",
184                    DType::U8 => "asort_asc_u8",
185                    DType::U32 => "asort_asc_u32",
186                    DType::I16 => "asort_asc_i16",
187                    DType::I32 => "asort_asc_i32",
188                    DType::I64 => "asort_asc_i64",
189                    DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
190                    DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
191                        return Err(
192                            crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(),
193                        )
194                    }
195                }
196            } else {
197                match storage.dtype() {
198                    DType::BF16 => "asort_desc_bf16",
199                    DType::F16 => "asort_desc_f16",
200                    DType::F32 => "asort_desc_f32",
201                    DType::F64 => "asort_desc_f64",
202                    DType::U8 => "asort_desc_u8",
203                    DType::U32 => "asort_desc_u32",
204                    DType::I16 => "asort_desc_i16",
205                    DType::I32 => "asort_desc_i32",
206                    DType::I64 => "asort_desc_i64",
207                    DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
208                    DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
209                        return Err(
210                            crate::Error::UnsupportedDTypeForOp(storage.dtype(), "argsort").bt(),
211                        )
212                    }
213                }
214            }
215        };
216        let device = storage.device();
217        let kernels = device.kernels();
218        let command_encoder = device.command_encoder()?;
219        let el = layout.shape().elem_count();
220        let ncols = self.last_dim;
221        let nrows = el / ncols;
222        let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
223        let dst = device.new_buffer(el, DType::U32, "asort")?;
224        let mut ncols_pad = 1;
225        while ncols_pad < ncols {
226            ncols_pad *= 2;
227        }
228        candle_metal_kernels::call_arg_sort(
229            device.metal_device(),
230            &command_encoder,
231            kernels,
232            name,
233            nrows,
234            ncols,
235            ncols_pad,
236            src,
237            &dst,
238        )
239        .map_err(crate::Error::wrap)?;
240        let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
241        Ok((dst, layout.shape().clone()))
242    }
243}
244
245#[allow(unused)]
246fn next_power_of_2(x: usize) -> usize {
247    let mut n = 1;
248    while n < x {
249        n *= 2
250    }
251    n
252}
253
254impl Tensor {
255    /// Returns the indices that sort the tensor along the last dimension.
256    ///
257    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
258    /// descending order. The sort is unstable so there is no guarantees on the final order when it
259    /// comes to ties.
260    pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
261        if !self.is_contiguous() {
262            return Err(crate::Error::RequiresContiguous {
263                op: "arg_sort_last_dim",
264            });
265        }
266        let last_dim = match self.dims().last() {
267            None => crate::bail!("empty last-dim in arg-sort"),
268            Some(last_dim) => *last_dim,
269        };
270        // No need for a backward pass for arg sort.
271        self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
272    }
273
274    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
275    /// sorted indexes.
276    ///
277    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
278    /// descending order. The sort is unstable so there is no guarantees on the final order when it
279    /// comes to ties.
280    pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
281        if !self.is_contiguous() {
282            return Err(crate::Error::RequiresContiguous {
283                op: "sort_last_dim",
284            });
285        }
286        let asort = self.arg_sort_last_dim(asc)?;
287        let sorted = self.gather(&asort, crate::D::Minus1)?;
288        Ok((sorted, asort))
289    }
290}