burn_tensor/tensor/api/
sort.rs

1use core::cmp::Ordering;
2
3use crate::{
4    BasicOps, Device, Element, ElementComparison, ElementConversion, TensorData, TensorKind,
5    backend::Backend,
6    ops::{IntElem, IntTensor},
7};
8use alloc::{vec, vec::Vec};
9use burn_common::reader::try_read_sync;
10
11/// Sort the elements of the input `tensor` by value along a given dimension.
12///
13/// This sort is unstable (i.e., may reorder equal elements).
14///
15/// # Arguments
16///
17/// * `tensor` - The input tensor.
18/// * `dim` - The axis along which to sort.
19/// * `descending` - The sorting order.
20///
21/// # Returns
22///
23/// A tensor with the same shape as the input tensor, where the elements are sorted by value.
24///
25/// # Remarks
26///
27/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
28/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
29/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
30/// or use this function directly.
31pub fn sort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
32    tensor: K::Primitive,
33    dim: usize,
34    descending: bool,
35) -> K::Primitive
36where
37    <K as BasicOps<B>>::Elem: Element,
38{
39    let device = K::device(&tensor);
40    let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
41    sort_data::<B, K>(data, dim, &device, descending)
42}
43
44pub fn sort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
45    mut data: TensorData,
46    dim: usize,
47    device: &Device<B>,
48    descending: bool,
49) -> K::Primitive
50where
51    <K as BasicOps<B>>::Elem: Element,
52{
53    let dims = data.shape.clone();
54    let data_slice = data.as_mut_slice().unwrap();
55    if dims.len() == 1 {
56        // 1D sort
57        data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
58    } else {
59        sort_slice::<B, K>(data_slice, &dims, dim, None, false, descending);
60    }
61
62    K::from_data(data, device)
63}
64
65/// Sort the elements of the input `tensor` by value along a given dimension.
66///
67/// This sort is unstable (i.e., may reorder equal elements).
68///
69/// # Arguments
70///
71/// * `tensor` - The input tensor.
72/// * `dim` - The axis along which to sort.
73/// * `descending` - The sorting order.
74///
75/// # Returns
76///
77/// A tensor with the same shape as the input tensor and corresponding indices, where
78/// the elements are sorted by value and the indices map back to the original input tensor.
79///
80/// # Remarks
81///
82/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
83/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
84/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
85/// or use this function directly.
86pub fn sort_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
87    tensor: K::Primitive,
88    dim: usize,
89    descending: bool,
90) -> (K::Primitive, IntTensor<B>)
91where
92    <K as BasicOps<B>>::Elem: Element,
93{
94    let device = K::device(&tensor);
95    let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
96    sort_data_with_indices::<B, K>(data, dim, &device, descending)
97}
98
99fn sort_data_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
100    mut data: TensorData,
101    dim: usize,
102    device: &Device<B>,
103    descending: bool,
104) -> (K::Primitive, IntTensor<B>)
105where
106    <K as BasicOps<B>>::Elem: Element,
107{
108    let dims = data.shape.clone();
109    let mut indices_data = dim_indices::<B>(&dims, dim);
110    let data_slice = data.as_mut_slice().unwrap();
111    if dims.len() == 1 {
112        // 1D sort
113        indices_data.sort_unstable_by(|&a, &b| {
114            compare(
115                &data_slice[a.elem::<i64>() as usize],
116                &data_slice[b.elem::<i64>() as usize],
117                descending,
118            )
119        });
120
121        // Permute data in-place by the sorted indices
122        let mut indices = indices_data
123            .clone()
124            .iter()
125            .map(|i| i.elem::<i64>() as usize)
126            .collect::<Vec<_>>();
127        for idx in 0..indices.len() {
128            if indices[idx] != idx {
129                let mut current_idx = idx;
130                loop {
131                    let target_idx = indices[current_idx];
132                    indices[current_idx] = current_idx;
133                    if indices[target_idx] == target_idx {
134                        // correct position
135                        break;
136                    }
137
138                    // Permute data by indices
139                    data_slice.swap(current_idx, target_idx);
140                    current_idx = target_idx;
141                }
142            }
143        }
144    } else {
145        sort_slice::<B, K>(
146            data_slice,
147            &dims,
148            dim,
149            Some(&mut indices_data),
150            true,
151            descending,
152        );
153    }
154
155    let shape = data.shape.clone();
156    (
157        K::from_data(data, device),
158        B::int_from_data(TensorData::new(indices_data, shape), device),
159    )
160}
161
162/// Returns the indices that sort the elements of the input `tensor` along a given dimension.
163///
164/// This sort is unstable (i.e., may reorder equal elements).
165///
166/// # Arguments
167///
168/// * `tensor` - The input tensor.
169/// * `dim` - The axis along which to sort.
170/// * `descending` - The sorting order.
171///
172/// # Returns
173///
174/// A tensor with the same shape as the input tensor the indices map back to the original input tensor.
175///
176/// # Remarks
177///
178/// This is a fallback solution that used only when the backend doesn't have the corresponding implementation.
179/// Ideally, it is supposed to be implemented by the backend and the backend implementation will be resolved
180/// by static dispatch. It is not designed for direct usage by users, and not recommended to import
181/// or use this function directly.
182pub fn argsort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
183    tensor: K::Primitive,
184    dim: usize,
185    descending: bool,
186) -> IntTensor<B>
187where
188    <K as BasicOps<B>>::Elem: Element,
189{
190    let device = K::device(&tensor);
191    let data = try_read_sync(K::into_data_async(tensor)).expect("Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.");
192
193    argsort_data::<B, K>(data, dim, &device, descending)
194}
195
196fn argsort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
197    mut data: TensorData,
198    dim: usize,
199    device: &Device<B>,
200    descending: bool,
201) -> IntTensor<B>
202where
203    <K as BasicOps<B>>::Elem: Element,
204{
205    let dims = data.shape.clone();
206    let mut indices_data = dim_indices::<B>(&dims, dim);
207    if dims.len() == 1 {
208        // 1D sort
209        let slice = data.as_slice::<<K as BasicOps<B>>::Elem>().unwrap();
210        indices_data.sort_unstable_by(|&a, &b| {
211            compare(
212                &slice[a.elem::<i64>() as usize],
213                &slice[b.elem::<i64>() as usize],
214                descending,
215            )
216        });
217    } else {
218        sort_slice::<B, K>(
219            data.as_mut_slice().unwrap(),
220            &dims,
221            dim,
222            Some(&mut indices_data),
223            false,
224            descending,
225        );
226    }
227
228    B::int_from_data(TensorData::new(indices_data, data.shape), device)
229}
230
231/// Sort the elements by value along a given dimension.
232///
233/// When `indices` are not provided, the `data` is sorted.
234/// Otherwise, the `indices` are sorted based on the value of the elements in `data`,
235/// and if `permute_both` is enabled then the data is also sorted.
236///
237/// This sort is unstable (i.e., may reorder equal elements).
238fn sort_slice<B: Backend, K: BasicOps<B>>(
239    data: &mut [<K as BasicOps<B>>::Elem],
240    dims: &[usize],
241    dim: usize,
242    mut indices: Option<&mut [IntElem<B>]>,
243    permute_both: bool,
244    descending: bool,
245) where
246    <K as BasicOps<B>>::Elem: Element,
247{
248    let ndims = dims.len();
249    let strides = compute_strides(dims);
250    // Dimensions to access elements to sort
251    let mut sort_dims = dims.to_vec();
252    sort_dims[dim] = 1;
253    let strides_out = compute_strides(&sort_dims);
254
255    // Number of groups to sort
256    let num_sorts: usize = dims
257        .iter()
258        .enumerate()
259        .filter(|&(i, _)| i != dim)
260        .map(|(_, d)| d)
261        .product();
262
263    // TODO: run each sort in parallel
264    // run_par!(|| {
265    //     iter_range_par!(0, num_sorts).for_each(|id| {...})
266    for id in 0..num_sorts {
267        let mut index_offset = 0;
268        let mut stride_dim = 0;
269        let mut shape_dim = 0;
270        for d in 0..ndims {
271            let stride_input = strides[d];
272            let stride_output = strides_out[d];
273            let shape_output = sort_dims[d];
274
275            let num_block = id / stride_output % shape_output;
276
277            if d != dim {
278                index_offset += num_block * stride_input;
279            } else {
280                let shape_input = dims[d];
281                stride_dim = stride_input;
282                shape_dim = shape_input;
283                index_offset += num_block;
284            }
285        }
286
287        // For each group, sort the indices based on the element values
288        // NOTE: Sorting methods like `sort_unstable_by` are in-place but we need to sort
289        // different views/groups of the underlying data, so the swap is performed on the elements
290        // of the (flat index, element value) collection.
291        let mut elements = (0..shape_dim)
292            .map(|d| {
293                let flat_index = d * stride_dim + index_offset;
294                let elem = data[flat_index];
295                (d, flat_index, elem)
296            })
297            .collect::<Vec<_>>();
298
299        elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending));
300
301        // Permute data in-place by the sorted indices
302        for idx in 0..elements.len() {
303            if elements[idx].0 != idx {
304                let mut current_idx = idx;
305                loop {
306                    let target_idx = elements[current_idx].0;
307                    elements[current_idx].0 = current_idx;
308                    if elements[target_idx].0 == target_idx {
309                        // correct position
310                        break;
311                    }
312
313                    if indices.is_none() || permute_both {
314                        // Permute data by indices
315                        data.swap(elements[current_idx].1, elements[target_idx].1);
316                    }
317
318                    if let Some(ref mut indices_data) = indices {
319                        // Permute data element indices
320                        indices_data.swap(elements[current_idx].1, elements[target_idx].1);
321                    }
322
323                    current_idx = target_idx;
324                }
325            }
326        }
327    }
328}
329
330/// Computes the steps for each dimension when traversing an array.
331fn compute_strides(dims: &[usize]) -> Vec<usize> {
332    let mut strides = vec![0; dims.len()];
333    let mut current = 1;
334
335    dims.iter().enumerate().rev().for_each(|(index, val)| {
336        strides[index] = current;
337        current *= val;
338    });
339
340    strides
341}
342
343/// Generates the indices for each element along the specified dimension.
344fn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> {
345    if dims.len() == 1 {
346        (0..dims[dim])
347            .map(|i| (i as i64).elem::<IntElem<B>>())
348            .collect::<Vec<_>>()
349    } else {
350        // Dimension indices tensor
351        let numel_leading_dims: usize = dims[..dim].iter().product();
352        let numel_trailing_dims: usize = dims[dim + 1..].iter().product();
353        (0..dims[dim])
354            .map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims))
355            .collect::<Vec<_>>()
356            .concat()
357            .repeat(numel_leading_dims)
358    }
359}
360
361/// Compare two elements
362fn compare<E: ElementComparison>(a: &E, b: &E, descending: bool) -> Ordering {
363    if descending { b.cmp(a) } else { a.cmp(b) }
364}