use std::marker::PhantomData;
use burn::tensor::{
BasicOps, Int, Numeric,
Tensor as BTensor,
};
use burn::{prelude::Backend, tensor::TensorKind};
use glowstick::cmp::Greater;
use glowstick::{
num::Unsigned, Shape,
};
use crate::Tensor;
#[macro_export]
macro_rules! sort_descending_with_indices {
[$t:expr,$i:ty] => {{
use $crate::op::sort_descending_with_indices::SortDescendingWithIndices;
($t, std::marker::PhantomData::<$i>).sort_descending_with_indices()
}};
[$t:expr,$i:ty,$($is:ty),+] => {{
$crate::sort_descending_with_indices![$crate::sort_descending_with_indices![$t,$i],$($is),+]
}};
}
pub trait SortDescendingWithIndices {
type Out;
fn sort_descending_with_indices(self) -> Self::Out;
}
impl<B, D, S, const N: usize, Dim> SortDescendingWithIndices
for (Tensor<BTensor<B, N, D>, S>, PhantomData<Dim>)
where
B: Backend,
D: TensorKind<B> + BasicOps<B> + Numeric<B>,
S: Shape,
Dim: Unsigned,
(<S as Shape>::Rank, Dim): Greater,
{
type Out = (Tensor<BTensor<B, N, D>, S>, Tensor<BTensor<B, N, Int>, S>);
fn sort_descending_with_indices(self) -> Self::Out {
let (t, i) = self
.0
.into_inner()
.sort_descending_with_indices(<Dim as Unsigned>::USIZE);
(Tensor(t, PhantomData), Tensor(i, PhantomData))
}
}