use std::marker::PhantomData;
use burn::tensor::{BasicOps, BroadcastArgs, Tensor as BTensor};
use burn::{prelude::Backend, tensor::TensorKind};
use glowstick::{op::broadcast, Shape};
use crate::Tensor;
#[macro_export]
macro_rules! expand {
($t:expr,[$($ds:tt)+]) => {{
type S = glowstick::TensorShape<$crate::reshape_tys!($($ds)+)>;
use $crate::op::expand::Expand;
(
$t,
std::marker::PhantomData::<S>,
)
.expand($crate::reshape_val!($($ds)+).into_array())
}};
($t1:expr,$t2:expr) => {{
use $crate::op::expand::Expand;
(
$t1,
$t2,
)
.expand($t2.inner().shape().dims())
}}
}
pub trait Expand<A, const N: usize, const M: usize>
where
A: BroadcastArgs<N, M>,
{
type Out;
fn expand(self, shape: A) -> Self::Out;
}
impl<B, S1, S2, D1, D2, const N: usize, const M: usize> Expand<[usize; M], N, M>
for (
Tensor<BTensor<B, N, D1>, S1>,
&Tensor<BTensor<B, M, D2>, S2>,
)
where
B: Backend,
S1: Shape,
S2: Shape,
D1: TensorKind<B> + BasicOps<B>,
D2: TensorKind<B>,
(S2, S1): broadcast::Compatible,
{
type Out = Tensor<BTensor<B, M, D1>, <(S2, S1) as broadcast::Compatible>::Out>;
fn expand(self, shape: [usize; M]) -> Self::Out {
Tensor(self.0.into_inner().expand(shape), PhantomData)
}
}
impl<B, S1, S2, const N: usize, const M: usize> Expand<[i32; M], N, M>
for (Tensor<BTensor<B, N>, S1>, PhantomData<S2>)
where
B: Backend,
S1: Shape,
S2: Shape,
(S2, S1): broadcast::Compatible,
{
type Out = Tensor<BTensor<B, M>, <(S2, S1) as broadcast::Compatible>::Out>;
fn expand(self, shape: [i32; M]) -> Self::Out {
Tensor(self.0.into_inner().expand(shape), PhantomData)
}
}