use core::cmp::Ordering;
use crate::{
Backend, TensorData,
element::{Element, ElementComparison, ElementConversion},
tensor::{BasicOps, Device, IntElem, IntTensor, TensorKind},
};
use alloc::{vec, vec::Vec};
use burn_std::reader::try_read_sync;
pub fn sort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> K::Primitive
where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
sort_data::<B, K>(data, dim, &device, descending)
}
pub fn sort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
mut data: TensorData,
dim: usize,
device: &Device<B>,
descending: bool,
) -> K::Primitive
where
<K as BasicOps<B>>::Elem: Element,
{
let dims = data.shape.clone();
let data_slice = data.as_mut_slice().unwrap();
if dims.len() == 1 {
data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
} else {
sort_slice::<B, K>(data_slice, &dims, dim, None, false, descending);
}
K::from_data(data, device)
}
pub fn sort_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> (K::Primitive, IntTensor<B>)
where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
sort_data_with_indices::<B, K>(data, dim, &device, descending)
}
fn sort_data_with_indices<B: Backend, K: TensorKind<B> + BasicOps<B>>(
mut data: TensorData,
dim: usize,
device: &Device<B>,
descending: bool,
) -> (K::Primitive, IntTensor<B>)
where
<K as BasicOps<B>>::Elem: Element,
{
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
let data_slice = data.as_mut_slice().unwrap();
if dims.len() == 1 {
indices_data.sort_unstable_by(|&a, &b| {
compare(
&data_slice[a.elem::<i64>() as usize],
&data_slice[b.elem::<i64>() as usize],
descending,
)
});
let mut indices = indices_data
.clone()
.iter()
.map(|i| i.elem::<i64>() as usize)
.collect::<Vec<_>>();
for idx in 0..indices.len() {
if indices[idx] != idx {
let mut current_idx = idx;
loop {
let target_idx = indices[current_idx];
indices[current_idx] = current_idx;
if indices[target_idx] == target_idx {
break;
}
data_slice.swap(current_idx, target_idx);
current_idx = target_idx;
}
}
}
} else {
sort_slice::<B, K>(
data_slice,
&dims,
dim,
Some(&mut indices_data),
true,
descending,
);
}
let shape = data.shape.clone();
(
K::from_data(data, device),
B::int_from_data(TensorData::new(indices_data, shape), device),
)
}
pub fn argsort<B: Backend, K: TensorKind<B> + BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> IntTensor<B>
where
<K as BasicOps<B>>::Elem: Element,
{
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
argsort_data::<B, K>(data, dim, &device, descending)
}
fn argsort_data<B: Backend, K: TensorKind<B> + BasicOps<B>>(
mut data: TensorData,
dim: usize,
device: &Device<B>,
descending: bool,
) -> IntTensor<B>
where
<K as BasicOps<B>>::Elem: Element,
{
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
if dims.len() == 1 {
let slice = data.as_slice::<<K as BasicOps<B>>::Elem>().unwrap();
indices_data.sort_unstable_by(|&a, &b| {
compare(
&slice[a.elem::<i64>() as usize],
&slice[b.elem::<i64>() as usize],
descending,
)
});
} else {
sort_slice::<B, K>(
data.as_mut_slice().unwrap(),
&dims,
dim,
Some(&mut indices_data),
false,
descending,
);
}
B::int_from_data(TensorData::new(indices_data, data.shape), device)
}
fn sort_slice<B: Backend, K: BasicOps<B>>(
data: &mut [<K as BasicOps<B>>::Elem],
dims: &[usize],
dim: usize,
mut indices: Option<&mut [IntElem<B>]>,
permute_both: bool,
descending: bool,
) where
<K as BasicOps<B>>::Elem: Element,
{
let ndims = dims.len();
let strides = compute_strides(dims);
let mut sort_dims = dims.to_vec();
sort_dims[dim] = 1;
let strides_out = compute_strides(&sort_dims);
let num_sorts: usize = dims
.iter()
.enumerate()
.filter(|&(i, _)| i != dim)
.map(|(_, d)| d)
.product();
for id in 0..num_sorts {
let mut index_offset = 0;
let mut stride_dim = 0;
let mut shape_dim = 0;
for d in 0..ndims {
let stride_input = strides[d];
let stride_output = strides_out[d];
let shape_output = sort_dims[d];
let num_block = id / stride_output % shape_output;
if d != dim {
index_offset += num_block * stride_input;
} else {
let shape_input = dims[d];
stride_dim = stride_input;
shape_dim = shape_input;
index_offset += num_block;
}
}
let mut elements = (0..shape_dim)
.map(|d| {
let flat_index = d * stride_dim + index_offset;
let elem = data[flat_index];
(d, flat_index, elem)
})
.collect::<Vec<_>>();
elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending));
for idx in 0..elements.len() {
if elements[idx].0 != idx {
let mut current_idx = idx;
loop {
let target_idx = elements[current_idx].0;
elements[current_idx].0 = current_idx;
if elements[target_idx].0 == target_idx {
break;
}
if indices.is_none() || permute_both {
data.swap(elements[current_idx].1, elements[target_idx].1);
}
if let Some(ref mut indices_data) = indices {
indices_data.swap(elements[current_idx].1, elements[target_idx].1);
}
current_idx = target_idx;
}
}
}
}
}
fn compute_strides(dims: &[usize]) -> Vec<usize> {
let mut strides = vec![0; dims.len()];
let mut current = 1;
dims.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
strides
}
fn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> {
if dims.len() == 1 {
(0..dims[dim])
.map(|i| (i as i64).elem::<IntElem<B>>())
.collect::<Vec<_>>()
} else {
let numel_leading_dims: usize = dims[..dim].iter().product();
let numel_trailing_dims: usize = dims[dim + 1..].iter().product();
(0..dims[dim])
.map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims))
.collect::<Vec<_>>()
.concat()
.repeat(numel_leading_dims)
}
}
fn compare<E: ElementComparison>(a: &E, b: &E, descending: bool) -> Ordering {
if descending { b.cmp(a) } else { a.cmp(b) }
}