use core::cmp::Ordering;
use crate::{
Backend, DType, TensorData,
element::{ElementConversion, ElementOrdered},
tensor::{BasicOps, IntElem, IntTensor},
};
use alloc::{vec, vec::Vec};
use burn_std::{IntDType, reader::try_read_sync};
use burn_std::{bf16, f16};
macro_rules! sort_dispatch_dtype {
($fn:ident, $data:ident, $($args:expr),*) => {
match $data.dtype {
DType::F64 => $fn::<B, f64>($data, $($args),*),
DType::F32 | DType::Flex32 => $fn::<B, f32>($data, $($args),*),
DType::F16 => $fn::<B, f16>($data, $($args),*),
DType::BF16 => $fn::<B, bf16>($data, $($args),*),
DType::I64 => $fn::<B, i64>($data, $($args),*),
DType::I32 => $fn::<B, i32>($data, $($args),*),
DType::I16 => $fn::<B, i16>($data, $($args),*),
DType::I8 => $fn::<B, i8>($data, $($args),*),
DType::U64 => $fn::<B, u64>($data, $($args),*),
DType::U32 => $fn::<B, u32>($data, $($args),*),
DType::U16 => $fn::<B, u16>($data, $($args),*),
DType::U8 => $fn::<B, u8>($data, $($args),*),
DType::Bool(_) | DType::QFloat(_) => unimplemented!("not supported for sorting operations"),
}
};
}
pub fn sort<B: Backend, K: BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
) -> K::Primitive {
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
let dtype = data.dtype;
let data = sort_dispatch_dtype!(sort_data, data, dim, descending);
K::from_data(data, &device, dtype)
}
pub fn sort_data<B: Backend, E: ElementOrdered>(
mut data: TensorData,
dim: usize,
descending: bool,
) -> TensorData {
let dims = data.shape.clone();
let data_slice = data.as_mut_slice().unwrap();
if dims.len() == 1 {
data_slice.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
} else {
sort_slice::<B, E>(data_slice, &dims, dim, None, false, descending);
}
data
}
pub fn sort_with_indices<B: Backend, K: BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
indices_dtype: IntDType,
) -> (K::Primitive, IntTensor<B>) {
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
let dtype = data.dtype;
let (values, indices) = sort_dispatch_dtype!(sort_data_with_indices, data, dim, descending);
(
K::from_data(values, &device, dtype),
B::int_from_data(indices.convert_dtype(indices_dtype.into()), &device),
)
}
fn sort_data_with_indices<B: Backend, E: ElementOrdered>(
mut data: TensorData,
dim: usize,
descending: bool,
) -> (TensorData, TensorData) {
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
let data_slice = data.as_mut_slice().unwrap();
if dims.len() == 1 {
indices_data.sort_unstable_by(|&a, &b| {
compare(
&data_slice[a.elem::<i64>() as usize],
&data_slice[b.elem::<i64>() as usize],
descending,
)
});
let mut indices = indices_data
.clone()
.iter()
.map(|i| i.elem::<i64>() as usize)
.collect::<Vec<_>>();
for idx in 0..indices.len() {
if indices[idx] != idx {
let mut current_idx = idx;
loop {
let target_idx = indices[current_idx];
indices[current_idx] = current_idx;
if indices[target_idx] == target_idx {
break;
}
data_slice.swap(current_idx, target_idx);
current_idx = target_idx;
}
}
}
} else {
sort_slice::<B, E>(
data_slice,
&dims,
dim,
Some(&mut indices_data),
true,
descending,
);
}
(data, TensorData::new(indices_data, dims))
}
pub fn argsort<B: Backend, K: BasicOps<B>>(
tensor: K::Primitive,
dim: usize,
descending: bool,
out_dtype: IntDType,
) -> IntTensor<B> {
let device = K::device(&tensor);
let msg = "Failed to synchronously read tensor data. This operation is not supported until this backend has a GPU sorting implementation.";
let data = try_read_sync(K::into_data_async(tensor))
.expect(msg)
.expect(msg);
let data = sort_dispatch_dtype!(argsort_data, data, dim, descending);
B::int_from_data(data.convert_dtype(out_dtype.into()), &device)
}
fn argsort_data<B: Backend, E: ElementOrdered>(
mut data: TensorData,
dim: usize,
descending: bool,
) -> TensorData {
let dims = data.shape.clone();
let mut indices_data = dim_indices::<B>(&dims, dim);
if dims.len() == 1 {
let slice = data.as_slice::<E>().unwrap();
indices_data.sort_unstable_by(|&a, &b| {
compare(
&slice[a.elem::<i64>() as usize],
&slice[b.elem::<i64>() as usize],
descending,
)
});
} else {
sort_slice::<B, E>(
data.as_mut_slice().unwrap(),
&dims,
dim,
Some(&mut indices_data),
false,
descending,
);
}
TensorData::new(indices_data, dims)
}
fn sort_slice<B: Backend, E: ElementOrdered>(
data: &mut [E],
dims: &[usize],
dim: usize,
mut indices: Option<&mut [IntElem<B>]>,
permute_both: bool,
descending: bool,
) {
let ndims = dims.len();
let strides = compute_strides(dims);
let mut sort_dims = dims.to_vec();
sort_dims[dim] = 1;
let strides_out = compute_strides(&sort_dims);
let num_sorts: usize = dims
.iter()
.enumerate()
.filter(|&(i, _)| i != dim)
.map(|(_, d)| d)
.product();
for id in 0..num_sorts {
let mut index_offset = 0;
let mut stride_dim = 0;
let mut shape_dim = 0;
for d in 0..ndims {
let stride_input = strides[d];
let stride_output = strides_out[d];
let shape_output = sort_dims[d];
let num_block = id / stride_output % shape_output;
if d != dim {
index_offset += num_block * stride_input;
} else {
let shape_input = dims[d];
stride_dim = stride_input;
shape_dim = shape_input;
index_offset += num_block;
}
}
let mut elements = (0..shape_dim)
.map(|d| {
let flat_index = d * stride_dim + index_offset;
let elem = data[flat_index];
(d, flat_index, elem)
})
.collect::<Vec<_>>();
elements.sort_unstable_by(|&(_, _, a), &(_, _, b)| compare(&a, &b, descending));
for idx in 0..elements.len() {
if elements[idx].0 != idx {
let mut current_idx = idx;
loop {
let target_idx = elements[current_idx].0;
elements[current_idx].0 = current_idx;
if elements[target_idx].0 == target_idx {
break;
}
if indices.is_none() || permute_both {
data.swap(elements[current_idx].1, elements[target_idx].1);
}
if let Some(ref mut indices_data) = indices {
indices_data.swap(elements[current_idx].1, elements[target_idx].1);
}
current_idx = target_idx;
}
}
}
}
}
fn compute_strides(dims: &[usize]) -> Vec<usize> {
let mut strides = vec![0; dims.len()];
let mut current = 1;
dims.iter().enumerate().rev().for_each(|(index, val)| {
strides[index] = current;
current *= val;
});
strides
}
fn dim_indices<B: Backend>(dims: &[usize], dim: usize) -> Vec<IntElem<B>> {
if dims.len() == 1 {
(0..dims[dim])
.map(|i| (i as i64).elem::<IntElem<B>>())
.collect::<Vec<_>>()
} else {
let numel_leading_dims: usize = dims[..dim].iter().product();
let numel_trailing_dims: usize = dims[dim + 1..].iter().product();
(0..dims[dim])
.map(|i| [(i as i64).elem::<IntElem<B>>()].repeat(numel_trailing_dims))
.collect::<Vec<_>>()
.concat()
.repeat(numel_leading_dims)
}
}
fn compare<E: ElementOrdered>(a: &E, b: &E, descending: bool) -> Ordering {
if descending { b.cmp(a) } else { a.cmp(b) }
}