diffusion_rs_common/core/
sort.rs1use crate::core::{Result, Tensor};
2use rayon::prelude::*;
3
4#[derive(Debug, Clone, Copy)]
5struct ArgSort {
6 asc: bool,
7 last_dim: usize,
8}
9
10impl ArgSort {
11 fn asort<T: crate::core::WithDType>(&self, vs: &[T], layout: &crate::core::Layout) -> Vec<u32> {
12 #[allow(clippy::uninit_vec)]
13 let mut sort_indexes = unsafe {
15 let el_count = layout.shape().elem_count();
16 let mut v = Vec::with_capacity(el_count);
17 v.set_len(el_count);
18 v
19 };
20 if self.asc {
21 sort_indexes
22 .par_chunks_exact_mut(self.last_dim)
23 .zip(vs.par_chunks_exact(self.last_dim))
24 .for_each(|(indexes, vs)| {
25 indexes
26 .iter_mut()
27 .enumerate()
28 .for_each(|(i, v)| *v = i as u32);
29 indexes.sort_by(|&i, &j| {
30 vs[i as usize]
31 .partial_cmp(&vs[j as usize])
32 .unwrap_or(std::cmp::Ordering::Greater)
33 })
34 });
35 } else {
36 sort_indexes
37 .par_chunks_exact_mut(self.last_dim)
38 .zip(vs.par_chunks_exact(self.last_dim))
39 .for_each(|(indexes, vs)| {
40 indexes
41 .iter_mut()
42 .enumerate()
43 .for_each(|(i, v)| *v = i as u32);
44 indexes.sort_by(|&j, &i| {
45 vs[i as usize]
46 .partial_cmp(&vs[j as usize])
47 .unwrap_or(std::cmp::Ordering::Greater)
48 })
49 });
50 }
51 sort_indexes
52 }
53}
54
55impl crate::core::CustomOp1 for ArgSort {
56 fn name(&self) -> &'static str {
57 "argsort"
58 }
59
60 fn cpu_fwd(
61 &self,
62 storage: &crate::core::CpuStorage,
63 layout: &crate::core::Layout,
64 ) -> Result<(crate::core::CpuStorage, crate::core::Shape)> {
65 let sort_indexes = match storage {
66 crate::core::CpuStorage::U8(vs) => self.asort(vs, layout),
67 crate::core::CpuStorage::I8(vs) => self.asort(vs, layout),
68 crate::core::CpuStorage::U32(vs) => self.asort(vs, layout),
69 crate::core::CpuStorage::I16(vs) => self.asort(vs, layout),
70 crate::core::CpuStorage::I32(vs) => self.asort(vs, layout),
71 crate::core::CpuStorage::I64(vs) => self.asort(vs, layout),
72 crate::core::CpuStorage::BF16(vs) => self.asort(vs, layout),
73 crate::core::CpuStorage::F16(vs) => self.asort(vs, layout),
74 crate::core::CpuStorage::F32(vs) => self.asort(vs, layout),
75 crate::core::CpuStorage::F64(vs) => self.asort(vs, layout),
76 crate::core::CpuStorage::F8E4M3(vs) => self.asort(vs, layout),
77 };
78 let sort_indexes = crate::core::CpuStorage::U32(sort_indexes);
79 Ok((sort_indexes, layout.shape().into()))
80 }
81
82 #[cfg(feature = "cuda")]
83 fn cuda_fwd(
84 &self,
85 storage: &crate::core::CudaStorage,
86 layout: &crate::core::Layout,
87 ) -> Result<(crate::core::CudaStorage, crate::core::Shape)> {
88 use crate::core::cuda_backend::cudarc::driver::{
89 CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
90 };
91 use crate::core::cuda_backend::{
92 kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr,
93 };
94 use crate::core::{CudaDevice, WithDType};
95
96 #[allow(non_local_definitions)]
97 impl Map1Any for ArgSort {
98 fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
99 &self,
100 src: &CudaSlice<T>,
101 dev: &CudaDevice,
102 layout: &crate::core::Layout,
103 _wrap: W,
104 ) -> Result<S> {
105 let slice = match layout.contiguous_offsets() {
106 None => crate::bail!("input has to be contiguous"),
107 Some((o1, o2)) => src.slice(o1..o2),
108 };
109 let elem_count = layout.shape().elem_count();
110 let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
111 let func = if self.asc {
112 dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
113 } else {
114 dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
115 };
116 let ncols = self.last_dim;
117 let nrows = elem_count / ncols;
118 let ncols_pad = next_power_of_2(ncols);
119 let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
120 let cfg = LaunchConfig {
121 grid_dim: (1, nrows as u32, 1),
122 block_dim: (ncols_pad as u32, 1, 1),
123 shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
124 };
125 unsafe { func.launch(cfg, params) }.w()?;
126 Ok(S::U32(dst))
127 }
128 }
129
130 use crate::core::backend::BackendStorage;
131 let dev = storage.device();
132 let slice = self.map(&storage.slice, dev, layout)?;
133 let dst = crate::core::cuda_backend::CudaStorage {
134 slice,
135 device: dev.clone(),
136 };
137 Ok((dst, layout.shape().clone()))
138 }
139
140 #[cfg(feature = "metal")]
141 fn metal_fwd(
142 &self,
143 storage: &crate::core::MetalStorage,
144 layout: &crate::core::Layout,
145 ) -> Result<(crate::core::MetalStorage, crate::core::Shape)> {
146 use crate::core::backend::BackendStorage;
147 use crate::core::DType;
148
149 let name = {
150 if self.asc {
151 match storage.dtype() {
152 DType::BF16 => "asort_asc_bf16",
153 DType::F16 => "asort_asc_f16",
154 DType::F32 => "asort_asc_f32",
155 DType::F64 => "asort_asc_f64",
156 DType::I8 => "asort_asc_i8",
157 DType::U8 => "asort_asc_u8",
158 DType::U32 => "asort_asc_u32",
159 DType::I64 => "asort_asc_i64",
160 DType::I32 => "asort_asc_i32",
161 DType::I16 => "asort_asc_i16",
162 DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
163 }
164 } else {
165 match storage.dtype() {
166 DType::BF16 => "asort_desc_bf16",
167 DType::F16 => "asort_desc_f16",
168 DType::F32 => "asort_desc_f32",
169 DType::F64 => "asort_desc_f64",
170 DType::I8 => "asort_desc_i8",
171 DType::U8 => "asort_desc_u8",
172 DType::U32 => "asort_desc_u32",
173 DType::I64 => "asort_desc_i64",
174 DType::I32 => "asort_desc_i32",
175 DType::I16 => "asort_desc_i16",
176 DType::F8E4M3 => crate::bail!("Metal device does not yet support F8E4M3."),
177 }
178 }
179 };
180 let device = storage.device();
181 let kernels = device.kernels();
182 let command_buffer = device.command_buffer()?;
183 let el = layout.shape().elem_count();
184 let ncols = self.last_dim;
185 let nrows = el / ncols;
186 let src = crate::core::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
187 let dst = device.new_buffer(el, DType::U32, "asort")?;
188 let mut ncols_pad = 1;
189 while ncols_pad < ncols {
190 ncols_pad *= 2;
191 }
192 crate::metal_kernels::call_arg_sort(
193 device.metal_device(),
194 &command_buffer,
195 kernels,
196 name,
197 nrows,
198 ncols,
199 ncols_pad,
200 src,
201 &dst,
202 )
203 .map_err(crate::core::Error::wrap)?;
204 let dst = crate::core::MetalStorage::new(dst, device.clone(), el, DType::U32);
205 Ok((dst, layout.shape().clone()))
206 }
207}
208
209#[allow(unused)]
210fn next_power_of_2(x: usize) -> usize {
211 let mut n = 1;
212 while n < x {
213 n *= 2
214 }
215 n
216}
217
218impl Tensor {
219 pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
225 if !self.is_contiguous() {
226 return Err(crate::core::Error::RequiresContiguous {
227 op: "arg_sort_last_dim",
228 });
229 }
230 let last_dim = match self.dims().last() {
231 None => crate::bail!("empty last-dim in arg-sort"),
232 Some(last_dim) => *last_dim,
233 };
234 self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
236 }
237
238 pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
245 if !self.is_contiguous() {
246 return Err(crate::core::Error::RequiresContiguous {
247 op: "sort_last_dim",
248 });
249 }
250 let asort = self.arg_sort_last_dim(asc)?;
251 let sorted = self.gather(&asort, crate::core::D::Minus1)?;
252 Ok((sorted, asort))
253 }
254}