candle_core/
sort.rs

1use crate::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6    asc: bool,
7    last_dim: usize,
8}
9
10impl ArgSort {
11    fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
12        #[allow(clippy::uninit_vec)]
13        // Safety: indexes are set later in the parallelized section.
14        let mut sort_indexes = unsafe {
15            let el_count = layout.shape().elem_count();
16            let mut v = Vec::with_capacity(el_count);
17            v.set_len(el_count);
18            v
19        };
20        if self.asc {
21            sort_indexes
22                .par_chunks_exact_mut(self.last_dim)
23                .zip(vs.par_chunks_exact(self.last_dim))
24                .for_each(|(indexes, vs)| {
25                    indexes
26                        .iter_mut()
27                        .enumerate()
28                        .for_each(|(i, v)| *v = i as u32);
29                    indexes.sort_by(|&i, &j| {
30                        vs[i as usize]
31                            .partial_cmp(&vs[j as usize])
32                            .unwrap_or(std::cmp::Ordering::Greater)
33                    })
34                });
35        } else {
36            sort_indexes
37                .par_chunks_exact_mut(self.last_dim)
38                .zip(vs.par_chunks_exact(self.last_dim))
39                .for_each(|(indexes, vs)| {
40                    indexes
41                        .iter_mut()
42                        .enumerate()
43                        .for_each(|(i, v)| *v = i as u32);
44                    indexes.sort_by(|&j, &i| {
45                        vs[i as usize]
46                            .partial_cmp(&vs[j as usize])
47                            .unwrap_or(std::cmp::Ordering::Greater)
48                    })
49                });
50        }
51        sort_indexes
52    }
53}
54
55#[cfg(feature = "cuda")]
56mod cuda {
57    use super::*;
58    use crate::cuda_backend::cudarc::driver::{
59        CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
60    };
61    use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
62    use crate::{CudaDevice, WithDType};
63
64    impl crate::cuda_backend::Map1Any for ArgSort {
65        fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
66            &self,
67            src: &CudaSlice<T>,
68            dev: &CudaDevice,
69            layout: &crate::Layout,
70            _wrap: W,
71        ) -> Result<S> {
72            use cudarc::driver::PushKernelArg;
73
74            let slice = match layout.contiguous_offsets() {
75                None => crate::bail!("input has to be contiguous"),
76                Some((o1, o2)) => src.slice(o1..o2),
77            };
78            let elem_count = layout.shape().elem_count();
79            let dst = unsafe { dev.alloc::<u32>(elem_count)? };
80            let func = if self.asc {
81                dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
82            } else {
83                dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
84            };
85            let ncols = self.last_dim;
86            let nrows = elem_count / ncols;
87            let ncols_pad = next_power_of_2(ncols);
88            let cfg = LaunchConfig {
89                grid_dim: (1, nrows as u32, 1),
90                block_dim: (ncols_pad as u32, 1, 1),
91                shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
92            };
93            let stream = dev.cuda_stream();
94            let mut builder = stream.launch_builder(&func);
95            let ncols = ncols as i32;
96            let ncols_pad = ncols_pad as i32;
97            builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
98            unsafe { builder.launch(cfg) }.w()?;
99            Ok(S::U32(dst))
100        }
101    }
102}
103
104impl crate::CustomOp1 for ArgSort {
105    fn name(&self) -> &'static str {
106        "argsort"
107    }
108
109    fn cpu_fwd(
110        &self,
111        storage: &crate::CpuStorage,
112        layout: &crate::Layout,
113    ) -> Result<(crate::CpuStorage, crate::Shape)> {
114        let sort_indexes = match storage {
115            crate::CpuStorage::U8(vs) => self.asort(vs, layout),
116            crate::CpuStorage::U32(vs) => self.asort(vs, layout),
117            crate::CpuStorage::I64(vs) => self.asort(vs, layout),
118            crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
119            crate::CpuStorage::F16(vs) => self.asort(vs, layout),
120            crate::CpuStorage::F32(vs) => self.asort(vs, layout),
121            crate::CpuStorage::F64(vs) => self.asort(vs, layout),
122        };
123        let sort_indexes = crate::CpuStorage::U32(sort_indexes);
124        Ok((sort_indexes, layout.shape().into()))
125    }
126
127    #[cfg(feature = "cuda")]
128    fn cuda_fwd(
129        &self,
130        storage: &crate::CudaStorage,
131        layout: &crate::Layout,
132    ) -> Result<(crate::CudaStorage, crate::Shape)> {
133        use crate::backend::BackendStorage;
134        use crate::cuda_backend::Map1Any;
135        let dev = storage.device();
136        let slice = self.map(&storage.slice, dev, layout)?;
137        let dst = crate::cuda_backend::CudaStorage {
138            slice,
139            device: dev.clone(),
140        };
141        Ok((dst, layout.shape().clone()))
142    }
143
144    #[cfg(feature = "metal")]
145    fn metal_fwd(
146        &self,
147        storage: &crate::MetalStorage,
148        layout: &crate::Layout,
149    ) -> Result<(crate::MetalStorage, crate::Shape)> {
150        use crate::backend::BackendStorage;
151        use crate::DType;
152
153        let name = {
154            if self.asc {
155                match storage.dtype() {
156                    DType::BF16 => "asort_asc_bf16",
157                    DType::F16 => "asort_asc_f16",
158                    DType::F32 => "asort_asc_f32",
159                    DType::F64 => "asort_asc_f64",
160                    DType::U8 => "asort_asc_u8",
161                    DType::U32 => "asort_asc_u32",
162                    DType::I64 => "asort_asc_i64",
163                }
164            } else {
165                match storage.dtype() {
166                    DType::BF16 => "asort_desc_bf16",
167                    DType::F16 => "asort_desc_f16",
168                    DType::F32 => "asort_desc_f32",
169                    DType::F64 => "asort_desc_f64",
170                    DType::U8 => "asort_desc_u8",
171                    DType::U32 => "asort_desc_u32",
172                    DType::I64 => "asort_desc_i64",
173                }
174            }
175        };
176        let device = storage.device();
177        let kernels = device.kernels();
178        let command_buffer = device.command_buffer()?;
179        let el = layout.shape().elem_count();
180        let ncols = self.last_dim;
181        let nrows = el / ncols;
182        let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
183        let dst = device.new_buffer(el, DType::U32, "asort")?;
184        let mut ncols_pad = 1;
185        while ncols_pad < ncols {
186            ncols_pad *= 2;
187        }
188        candle_metal_kernels::call_arg_sort(
189            device.metal_device(),
190            &command_buffer,
191            kernels,
192            name,
193            nrows,
194            ncols,
195            ncols_pad,
196            src,
197            &dst,
198        )
199        .map_err(crate::Error::wrap)?;
200        let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
201        Ok((dst, layout.shape().clone()))
202    }
203}
204
205#[allow(unused)]
206fn next_power_of_2(x: usize) -> usize {
207    let mut n = 1;
208    while n < x {
209        n *= 2
210    }
211    n
212}
213
214impl Tensor {
215    /// Returns the indices that sort the tensor along the last dimension.
216    ///
217    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
218    /// descending order. The sort is unstable so there is no guarantees on the final order when it
219    /// comes to ties.
220    pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
221        if !self.is_contiguous() {
222            return Err(crate::Error::RequiresContiguous {
223                op: "arg_sort_last_dim",
224            });
225        }
226        let last_dim = match self.dims().last() {
227            None => crate::bail!("empty last-dim in arg-sort"),
228            Some(last_dim) => *last_dim,
229        };
230        // No need for a backward pass for arg sort.
231        self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
232    }
233
234    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
235    /// sorted indexes.
236    ///
237    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
238    /// descending order. The sort is unstable so there is no guarantees on the final order when it
239    /// comes to ties.
240    pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
241        if !self.is_contiguous() {
242            return Err(crate::Error::RequiresContiguous {
243                op: "sort_last_dim",
244            });
245        }
246        let asort = self.arg_sort_last_dim(asc)?;
247        let sorted = self.gather(&asort, crate::D::Minus1)?;
248        Ok((sorted, asort))
249    }
250}