diffusion_rs_common/core/
sort.rs

1use crate::core::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6    asc: bool,
7    last_dim: usize,
8}
9
10impl ArgSort {
11    fn asort<T: crate::core::WithDType>(&self, vs: &[T], layout: &crate::core::Layout) -> Vec<u32> {
12        #[allow(clippy::uninit_vec)]
13        // Safety: indexes are set later in the parallelized section.
14        let mut sort_indexes = unsafe {
15            let el_count = layout.shape().elem_count();
16            let mut v = Vec::with_capacity(el_count);
17            v.set_len(el_count);
18            v
19        };
20        if self.asc {
21            sort_indexes
22                .par_chunks_exact_mut(self.last_dim)
23                .zip(vs.par_chunks_exact(self.last_dim))
24                .for_each(|(indexes, vs)| {
25                    indexes
26                        .iter_mut()
27                        .enumerate()
28                        .for_each(|(i, v)| *v = i as u32);
29                    indexes.sort_by(|&i, &j| {
30                        vs[i as usize]
31                            .partial_cmp(&vs[j as usize])
32                            .unwrap_or(std::cmp::Ordering::Greater)
33                    })
34                });
35        } else {
36            sort_indexes
37                .par_chunks_exact_mut(self.last_dim)
38                .zip(vs.par_chunks_exact(self.last_dim))
39                .for_each(|(indexes, vs)| {
40                    indexes
41                        .iter_mut()
42                        .enumerate()
43                        .for_each(|(i, v)| *v = i as u32);
44                    indexes.sort_by(|&j, &i| {
45                        vs[i as usize]
46                            .partial_cmp(&vs[j as usize])
47                            .unwrap_or(std::cmp::Ordering::Greater)
48                    })
49                });
50        }
51        sort_indexes
52    }
53}
54
55impl crate::core::CustomOp1 for ArgSort {
56    fn name(&self) -> &'static str {
57        "argsort"
58    }
59
60    fn cpu_fwd(
61        &self,
62        storage: &crate::core::CpuStorage,
63        layout: &crate::core::Layout,
64    ) -> Result<(crate::core::CpuStorage, crate::core::Shape)> {
65        let sort_indexes = match storage {
66            crate::core::CpuStorage::U8(vs) => self.asort(vs, layout),
67            crate::core::CpuStorage::I8(vs) => self.asort(vs, layout),
68            crate::core::CpuStorage::U32(vs) => self.asort(vs, layout),
69            crate::core::CpuStorage::I16(vs) => self.asort(vs, layout),
70            crate::core::CpuStorage::I32(vs) => self.asort(vs, layout),
71            crate::core::CpuStorage::I64(vs) => self.asort(vs, layout),
72            crate::core::CpuStorage::BF16(vs) => self.asort(vs, layout),
73            crate::core::CpuStorage::F16(vs) => self.asort(vs, layout),
74            crate::core::CpuStorage::F32(vs) => self.asort(vs, layout),
75            crate::core::CpuStorage::F64(vs) => self.asort(vs, layout),
76            crate::core::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
77        };
78        let sort_indexes = crate::core::CpuStorage::U32(sort_indexes);
79        Ok((sort_indexes, layout.shape().into()))
80    }
81
82    #[cfg(feature = "cuda")]
83    fn cuda_fwd(
84        &self,
85        storage: &crate::core::CudaStorage,
86        layout: &crate::core::Layout,
87    ) -> Result<(crate::core::CudaStorage, crate::core::Shape)> {
88        use crate::core::cuda_backend::cudarc::driver::{
89            CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
90        };
91        use crate::core::cuda_backend::{
92            kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr,
93        };
94        use crate::core::{CudaDevice, WithDType};
95
96        #[allow(non_local_definitions)]
97        impl Map1Any for ArgSort {
98            fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
99                &self,
100                src: &CudaSlice<T>,
101                dev: &CudaDevice,
102                layout: &crate::core::Layout,
103                _wrap: W,
104            ) -> Result<S> {
105                let slice = match layout.contiguous_offsets() {
106                    None => crate::bail!("input has to be contiguous"),
107                    Some((o1, o2)) => src.slice(o1..o2),
108                };
109                let elem_count = layout.shape().elem_count();
110                let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
111                let func = if self.asc {
112                    dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
113                } else {
114                    dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
115                };
116                let ncols = self.last_dim;
117                let nrows = elem_count / ncols;
118                let ncols_pad = next_power_of_2(ncols);
119                let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
120                let cfg = LaunchConfig {
121                    grid_dim: (1, nrows as u32, 1),
122                    block_dim: (ncols_pad as u32, 1, 1),
123                    shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
124                };
125                unsafe { func.launch(cfg, params) }.w()?;
126                Ok(S::U32(dst))
127            }
128        }
129
130        use crate::core::backend::BackendStorage;
131        let dev = storage.device();
132        let slice = self.map(&storage.slice, dev, layout)?;
133        let dst = crate::core::cuda_backend::CudaStorage {
134            slice,
135            device: dev.clone(),
136        };
137        Ok((dst, layout.shape().clone()))
138    }
139
140    #[cfg(feature = "metal")]
141    fn metal_fwd(
142        &self,
143        storage: &crate::core::MetalStorage,
144        layout: &crate::core::Layout,
145    ) -> Result<(crate::core::MetalStorage, crate::core::Shape)> {
146        use crate::core::backend::BackendStorage;
147        use crate::core::DType;
148
149        let name = {
150            if self.asc {
151                match storage.dtype() {
152                    DType::BF16 => "asort_asc_bf16",
153                    DType::F16 => "asort_asc_f16",
154                    DType::F32 => "asort_asc_f32",
155                    DType::F64 => "asort_asc_f64",
156                    DType::I8 => "asort_asc_i8",
157                    DType::U8 => "asort_asc_u8",
158                    DType::U32 => "asort_asc_u32",
159                    DType::I64 => "asort_asc_i64",
160                    DType::I32 => "asort_asc_i32",
161                    DType::I16 => "asort_asc_i16",
162                    DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
163                }
164            } else {
165                match storage.dtype() {
166                    DType::BF16 => "asort_desc_bf16",
167                    DType::F16 => "asort_desc_f16",
168                    DType::F32 => "asort_desc_f32",
169                    DType::F64 => "asort_desc_f64",
170                    DType::I8 => "asort_desc_i8",
171                    DType::U8 => "asort_desc_u8",
172                    DType::U32 => "asort_desc_u32",
173                    DType::I64 => "asort_desc_i64",
174                    DType::I32 => "asort_desc_i32",
175                    DType::I16 => "asort_desc_i16",
176                    DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
177                }
178            }
179        };
180        let device = storage.device();
181        let kernels = device.kernels();
182        let command_buffer = device.command_buffer()?;
183        let el = layout.shape().elem_count();
184        let ncols = self.last_dim;
185        let nrows = el / ncols;
186        let src = crate::core::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
187        let dst = device.new_buffer(el, DType::U32, "asort")?;
188        let mut ncols_pad = 1;
189        while ncols_pad < ncols {
190            ncols_pad *= 2;
191        }
192        crate::metal_kernels::call_arg_sort(
193            device.metal_device(),
194            &command_buffer,
195            kernels,
196            name,
197            nrows,
198            ncols,
199            ncols_pad,
200            src,
201            &dst,
202        )
203        .map_err(crate::core::Error::wrap)?;
204        let dst = crate::core::MetalStorage::new(dst, device.clone(), el, DType::U32);
205        Ok((dst, layout.shape().clone()))
206    }
207}
208
209#[allow(unused)]
210fn next_power_of_2(x: usize) -> usize {
211    let mut n = 1;
212    while n < x {
213        n *= 2
214    }
215    n
216}
217
218impl Tensor {
219    /// Returns the indices that sort the tensor along the last dimension.
220    ///
221    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
222    /// descending order. The sort is unstable so there is no guarantees on the final order when it
223    /// comes to ties.
224    pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
225        if !self.is_contiguous() {
226            return Err(crate::core::Error::RequiresContiguous {
227                op: "arg_sort_last_dim",
228            });
229        }
230        let last_dim = match self.dims().last() {
231            None => crate::bail!("empty last-dim in arg-sort"),
232            Some(last_dim) => *last_dim,
233        };
234        // No need for a backward pass for arg sort.
235        self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
236    }
237
238    /// Sorts the tensor along the last dimension, returns the sorted tensor together with the
239    /// sorted indexes.
240    ///
241    /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
242    /// descending order. The sort is unstable so there is no guarantees on the final order when it
243    /// comes to ties.
244    pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
245        if !self.is_contiguous() {
246            return Err(crate::core::Error::RequiresContiguous {
247                op: "sort_last_dim",
248            });
249        }
250        let asort = self.arg_sort_last_dim(asc)?;
251        let sorted = self.gather(&asort, crate::core::D::Minus1)?;
252        Ok((sorted, asort))
253    }
254}