use std::marker::PhantomData;
use burn::prelude::Backend;
use burn::tensor::Tensor as BTensor;
use glowstick::{
num::{Unsigned, U0, U1},
op::narrow,
Shape,
};
use crate::Tensor;
#[macro_export]
macro_rules! mean_dim {
[$t:expr,$i:ty] => {{
use $crate::op::mean_dim::MeanDim;
($t, std::marker::PhantomData::<$i>).mean_dim()
}};
[$t:expr,$i:ty,$($is:ty),+] => {{
$crate::mean_dim![$crate::mean_dim![$t,$i],$($is),+]
}};
}
pub trait MeanDim {
type Out;
fn mean_dim(self) -> Self::Out;
}
impl<B, S, const N: usize, Dim> MeanDim for (Tensor<BTensor<B, N>, S>, PhantomData<Dim>)
where
B: Backend,
S: Shape,
Dim: Unsigned,
(S, Dim, U0, U1): narrow::Compatible,
{
type Out = Tensor<BTensor<B, N>, <(S, Dim, U0, U1) as narrow::Compatible>::Out>;
fn mean_dim(self) -> Self::Out {
Tensor(
self.0.into_inner().mean_dim(<Dim as Unsigned>::USIZE),
PhantomData,
)
}
}