use std::marker::PhantomData;
use burn::{
prelude::Backend,
tensor::{BasicOps, Tensor as BTensor, TensorKind},
};
use glowstick::{num::Unsigned, op::cat_dyn, Shape};
use crate::Tensor;
#[macro_export]
macro_rules! cat {
($ts:expr,$i:ty => $d:ty) => {{
use $crate::op::cat::Cat;
(
$ts,
std::marker::PhantomData::<$i>,
std::marker::PhantomData::<$d>,
)
.cat()
}};
}
pub trait Cat {
type Out;
fn cat(self) -> Self::Out;
}
impl<B, Dt, S, I, D, const N: usize> Cat
for (
Vec<Tensor<BTensor<B, N, Dt>, S>>,
PhantomData<I>,
PhantomData<glowstick::Dyn<D>>,
)
where
B: Backend,
Dt: TensorKind<B> + BasicOps<B>,
S: Shape,
(S, I, glowstick::Dyn<D>): cat_dyn::Compatible,
I: Unsigned,
{
type Out = Tensor<BTensor<B, N, Dt>, <(S, I, glowstick::Dyn<D>) as cat_dyn::Compatible>::Out>;
fn cat(self) -> Self::Out {
Tensor(
BTensor::cat(
self.0.into_iter().map(Tensor::into_inner).collect(),
<I as Unsigned>::USIZE,
),
PhantomData,
)
}
}