candle_core/
sort.rs

1use crate::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6    asc: bool,
7    last_dim: usize,
8}
9
10impl ArgSort {
11    fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
12        #[allow(clippy::uninit_vec)]
13        // Safety: indexes are set later in the parallelized section.
14        let mut sort_indexes = unsafe {
15            let el_count = layout.shape().elem_count();
16            let mut v = Vec::with_capacity(el_count);
17            v.set_len(el_count);
18            v
19        };
20        if self.asc {
21            sort_indexes
22                .par_chunks_exact_mut(self.last_dim)
23                .zip(vs.par_chunks_exact(self.last_dim))
24                .for_each(|(indexes, vs)| {
25                    indexes
26                        .iter_mut()
27                        .enumerate()
28                        .for_each(|(i, v)| *v = i as u32);
29                    indexes.sort_by(|&i, &j| {
30                        vs[i as usize]
31                            .partial_cmp(&vs[j as usize])
32                            .unwrap_or(std::cmp::Ordering::Greater)
33                    })
34                });
35        } else {
36            sort_indexes
37                .par_chunks_exact_mut(self.last_dim)
38                .zip(vs.par_chunks_exact(self.last_dim))
39                .for_each(|(indexes, vs)| {
40                    indexes
41                        .iter_mut()
42                        .enumerate()
43                        .for_each(|(i, v)| *v = i as u32);
44                    indexes.sort_by(|&j, &i| {
45                        vs[i as usize]
46                            .partial_cmp(&vs[j as usize])
47                            .unwrap_or(std::cmp::Ordering::Greater)
48                    })
49                });
50        }
51        sort_indexes
52    }
53}
54
55#[cfg(feature = "cuda")]
56mod cuda {
57    use super::*;
58    use crate::cuda_backend::cudarc::driver::{
59        CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
60    };
61    use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
62    use crate::{CudaDevice, WithDType};
63
64    impl crate::cuda_backend::Map1Any for ArgSort {
65        fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
66            &self,
67            src: &CudaSlice<T>,
68            dev: &CudaDevice,
69            layout: &crate::Layout,
70            _wrap: W,
71        ) -> Result<S> {
72            let slice = match layout.contiguous_offsets() {
73                None => crate::bail!("input has to be contiguous"),
74                Some((o1, o2)) => src.slice(o1..o2),
75            };
76            let elem_count = layout.shape().elem_count();
77            let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
78            let func = if self.asc {
79                dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
80            } else {
81                dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
82            };
83            let ncols = self.last_dim;
84            let nrows = elem_count / ncols;
85            let ncols_pad = next_power_of_2(ncols);
86            let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
87            let cfg = LaunchConfig {
88                grid_dim: (1, nrows as u32, 1),
89                block_dim: (ncols_pad as u32, 1, 1),
90                shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
91            };
92            unsafe { func.launch(cfg, params) }.w()?;
93            Ok(S::U32(dst))
94        }
95    }
96}
97
98impl crate::CustomOp1 for ArgSort {
99    fn name(&self) -> &'static str {
100        "argsort"
101    }
102
103    fn cpu_fwd(
104        &self,
105        storage: &crate::CpuStorage,
106        layout: &crate::Layout,
107    ) -> Result<(crate::CpuStorage, crate::Shape)> {
108        let sort_indexes = match storage {
109            crate::CpuStorage::U8(vs) => self.asort(vs, layout),
110            crate::CpuStorage::U32(vs) => self.asort(vs, layout),
111            crate::CpuStorage::I64(vs) => self.asort(vs, layout),
112            crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
113            crate::CpuStorage::F16(vs) => self.asort(vs, layout),
114            crate::CpuStorage::F32(vs) => self.asort(vs, layout),
115            crate::CpuStorage::F64(vs) => self.asort(vs, layout),
116        };
117        let sort_indexes = crate::CpuStorage::U32(sort_indexes);
118        Ok((sort_indexes, layout.shape().into()))
119    }
120
121    #[cfg(feature = "cuda")]
122    fn cuda_fwd(
123        &self,
124        storage: &crate::CudaStorage,
125        layout: &crate::Layout,
126    ) -> Result<(crate::CudaStorage, crate::Shape)> {
127        use crate::backend::BackendStorage;
128        use crate::cuda_backend::Map1Any;
129        let dev = storage.device();
130        let slice = self.map(&storage.slice, dev, layout)?;
131        let dst = crate::cuda_backend::CudaStorage {
132            slice,
133            device: dev.clone(),
134        };
135        Ok((dst, layout.shape().clone()))
136    }
137
138    #[cfg(feature = "metal")]
139    fn metal_fwd(
140        &self,
141        storage: &crate::MetalStorage,
142        layout: &crate::Layout,
143    ) -> Result<(crate::MetalStorage, crate::Shape)> {
144        use crate::backend::BackendStorage;
145        use crate::DType;
146
147        let name = {
148            if self.asc {
149                match storage.dtype() {
150                    DType::BF16 => "asort_asc_bf16",
151                    DType::F16 => "asort_asc_f16",
152                    DType::F32 => "asort_asc_f32",
153                    DType::F64 => "asort_asc_f64",
154                    DType::U8 => "asort_asc_u8",
155                    DType::U32 => "asort_asc_u32",
156                    DType::I64 => "asort_asc_i64",
157                }
158            } else {
159                match storage.dtype() {
160                    DType::BF16 => "asort_desc_bf16",
161                    DType::F16 => "asort_desc_f16",
162                    DType::F32 => "asort_desc_f32",
163                    DType::F64 => "asort_desc_f64",
164                    DType::U8 => "asort_desc_u8",
165                    DType::U32 => "asort_desc_u32",
166                    DType::I64 => "asort_desc_i64",
167                }
168            }
169        };
170        let device = storage.device();
171        let kernels = device.kernels();
172        let command_buffer = device.command_buffer()?;
173        let el = layout.shape().elem_count();
174        let ncols = self.last_dim;
175        let nrows = el / ncols;
176        let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
177        let dst = device.new_buffer(el, DType::U32, "asort")?;
178        let mut ncols_pad = 1;
179        while ncols_pad < ncols {
180            ncols_pad *= 2;
181        }
182        candle_metal_kernels::call_arg_sort(
183            device.metal_device(),
184            &command_buffer,
185            kernels,
186            name,
187            nrows,
188            ncols,
189            ncols_pad,
190            src,
191            &dst,
192        )
193        .map_err(crate::Error::wrap)?;
194        let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
195        Ok((dst, layout.shape().clone()))
196    }
197}
198
199#[allow(unused)]
200fn next_power_of_2(x: usize) -> usize {
201    let mut n = 1;
202    while n < x {
203        n *= 2
204    }
205    n
206}
207
208impl Tensor {
209    /// Returns the indices that sort the tensor along the last dimension.
210    ///
211    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
212    /// descending order. The sort is unstable so there is no guarantees on the final order when it
213    /// comes to ties.
214    pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
215        if !self.is_contiguous() {
216            return Err(crate::Error::RequiresContiguous {
217                op: "arg_sort_last_dim",
218            });
219        }
220        let last_dim = match self.dims().last() {
221            None => crate::bail!("empty last-dim in arg-sort"),
222            Some(last_dim) => *last_dim,
223        };
224        // No need for a backward pass for arg sort.
225        self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
226    }
227
228    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
229    /// sorted indexes.
230    ///
231    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
232    /// descending order. The sort is unstable so there is no guarantees on the final order when it
233    /// comes to ties.
234    pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
235        if !self.is_contiguous() {
236            return Err(crate::Error::RequiresContiguous {
237                op: "sort_last_dim",
238            });
239        }
240        let asort = self.arg_sort_last_dim(asc)?;
241        let sorted = self.gather(&asort, crate::D::Minus1)?;
242        Ok((sorted, asort))
243    }
244}