1use core::cmp::Ordering;
2
3use crate::{
4 BasicOps, Device, Element, ElementComparison, ElementConversion, TensorData, TensorKind,
5 backend::Backend,
6 ops::{IntElem, IntTensor},
7};
8use alloc::{vec, vec::Vec};
9use burn_common::reader::try_read_sync;
10
11pub fn sort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
32 tensor: K::Primitive,
33 dim: usize,
34 descending: bool,
35) -> K::Primitive
36where
37 <K as BasicOps<B>>::Elem: Element,
38{
39 let device = K::device(&tensor);
40 let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
41 sort_data::<B, K>(data, dim, &device, descending)
42}
43
44pub fn sort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
45 mut data: TensorData,
46 dim: usize,
47 device: &Device<B>,
48 descending: bool,
49) -> K::Primitive
50where
51 <K as BasicOps<B>>::Elem: Element,
52{
53 let dims = data.shape.clone();
54 let data_slice = data.as_mut_slice().unwrap();
55 if dims.len() == 1 {
56 data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
58 } else {
59 sort_slice::<B, K>(data_slice, &dims, dim, None, false, descending);
60 }
61
62 K::from_data(data, device)
63}
64
65pub fn sort_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
87 tensor: K::Primitive,
88 dim: usize,
89 descending: bool,
90) -> (K::Primitive, IntTensor<B>)
91where
92 <K as BasicOps<B>>::Elem: Element,
93{
94 let device = K::device(&tensor);
95 let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
96 sort_data_with_indices::<B, K>(data, dim, &device, descending)
97}
98
99fn sort_data_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
100 mut data: TensorData,
101 dim: usize,
102 device: &Device<B>,
103 descending: bool,
104) -> (K::Primitive, IntTensor<B>)
105where
106 <K as BasicOps<B>>::Elem: Element,
107{
108 let dims = data.shape.clone();
109 let mut indices_data = dim_indices::<B>(&dims, dim);
110 let data_slice = data.as_mut_slice().unwrap();
111 if dims.len() == 1 {
112 indices_data.sort_unstable_by(|&a, &b| {
114 compare(
115 &data_slice[a.elem::<i64>() as usize],
116 &data_slice[b.elem::<i64>() as usize],
117 descending,
118 )
119 });
120
121 let mut indices = indices_data
123 .clone()
124 .iter()
125 .map(|i| i.elem::<i64>() as usize)
126 .collect::<Vec<_>>();
127 for idx in 0..indices.len() {
128 if indices[idx] != idx {
129 let mut current_idx = idx;
130 loop {
131 let target_idx = indices[current_idx];
132 indices[current_idx] = current_idx;
133 if indices[target_idx] == target_idx {
134 break;
136 }
137
138 data_slice.swap(current_idx, target_idx);
140 current_idx = target_idx;
141 }
142 }
143 }
144 } else {
145 sort_slice::<B, K>(
146 data_slice,
147 &dims,
148 dim,
149 Some(&mut indices_data),
150 true,
151 descending,
152 );
153 }
154
155 let shape = data.shape.clone();
156 (
157 K::from_data(data, device),
158 B::int_from_data(TensorData::new(indices_data, shape), device),
159 )
160}
161
162pub fn argsort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
183 tensor: K::Primitive,
184 dim: usize,
185 descending: bool,
186) -> IntTensor<B>
187where
188 <K as BasicOps<B>>::Elem: Element,
189{
190 let device = K::device(&tensor);
191 let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
192
193 argsort_data::<B, K>(data, dim, &device, descending)
194}
195
196fn argsort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
197 mut data: TensorData,
198 dim: usize,
199 device: &Device<B>,
200 descending: bool,
201) -> IntTensor<B>
202where
203 <K as BasicOps<B>>::Elem: Element,
204{
205 let dims = data.shape.clone();
206 let mut indices_data = dim_indices::<B>(&dims, dim);
207 if dims.len() == 1 {
208 let slice = data.as_slice::<<K as BasicOps<B>>::Elem>().unwrap();
210 indices_data.sort_unstable_by(|&a, &b| {
211 compare(
212 &slice[a.elem::<i64>() as usize],
213 &slice[b.elem::<i64>() as usize],
214 descending,
215 )
216 });
217 } else {
218 sort_slice::<B, K>(
219 data.as_mut_slice().unwrap(),
220 &dims,
221 dim,
222 Some(&mut indices_data),
223 false,
224 descending,
225 );
226 }
227
228 B::int_from_data(TensorData::new(indices_data, data.shape), device)
229}
230
231fn sort_slice<B: Backend, K: BasicOps<B>>(
239 data: &mut [<K as BasicOps<B>>::Elem],
240 dims: &[usize],
241 dim: usize,
242 mut indices: Option<&mut [IntElem<B>]>,
243 permute_both: bool,
244 descending: bool,
245) where
246 <K as BasicOps<B>>::Elem: Element,
247{
248 let ndims = dims.len();
249 let strides = compute_strides(dims);
250 let mut sort_dims = dims.to_vec();
252 sort_dims[dim] = 1;
253 let strides_out = compute_strides(&sort_dims);
254
255 let num_sorts: usize = dims
257 .iter()
258 .enumerate()
259 .filter(|&(i, _)| i != dim)
260 .map(|(_, d)| d)
261 .product();
262
263 for id in 0..num_sorts {
267 let mut index_offset = 0;
268 let mut stride_dim = 0;
269 let mut shape_dim = 0;
270 for d in 0..ndims {
271 let stride_input = strides[d];
272 let stride_output = strides_out[d];
273 let shape_output = sort_dims[d];
274
275 let num_block = id / stride_output % shape_output;
276
277 if d != dim {
278 index_offset += num_block * stride_input;
279 } else {
280 let shape_input = dims[d];
281 stride_dim = stride_input;
282 shape_dim = shape_input;
283 index_offset += num_block;
284 }
285 }
286
287 let mut elements = (0..shape_dim)
292 .map(|d| {
293 let flat_index = d * stride_dim + index_offset;
294 let elem = data[flat_index];
295 (d, flat_index, elem)
296 })
297 .collect::<Vec<_>>();
298
299 elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending));
300
301 for idx in 0..elements.len() {
303 if elements[idx].0 != idx {
304 let mut current_idx = idx;
305 loop {
306 let target_idx = elements[current_idx].0;
307 elements[current_idx].0 = current_idx;
308 if elements[target_idx].0 == target_idx {
309 break;
311 }
312
313 if indices.is_none() || permute_both {
314 data.swap(elements[current_idx].1, elements[target_idx].1);
316 }
317
318 if let Some(ref mut indices_data) = indices {
319 indices_data.swap(elements[current_idx].1, elements[target_idx].1);
321 }
322
323 current_idx = target_idx;
324 }
325 }
326 }
327 }
328}
329
330fn compute_strides(dims: &[usize]) -> Vec<usize> {
332 let mut strides = vec![0; dims.len()];
333 let mut current = 1;
334
335 dims.iter().enumerate().rev().for_each(|(index, val)| {
336 strides[index] = current;
337 current *= val;
338 });
339
340 strides
341}
342
343fn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> {
345 if dims.len() == 1 {
346 (0..dims[dim])
347 .map(|i| (i as i64).elem::<IntElem<B>>())
348 .collect::<Vec<_>>()
349 } else {
350 let numel_leading_dims: usize = dims[..dim].iter().product();
352 let numel_trailing_dims: usize = dims[dim + 1..].iter().product();
353 (0..dims[dim])
354 .map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims))
355 .collect::<Vec<_>>()
356 .concat()
357 .repeat(numel_leading_dims)
358 }
359}
360
361fn compare<E: ElementComparison>(a: &E, b: &E, descending: bool) -> Ordering {
363 if descending { b.cmp(a) } else { a.cmp(b) }
364}