use std::marker::PhantomData;
use burn::prelude::Backend;
use burn::tensor::Tensor as BTensor;
use glowstick::{op::matmul, Shape, TensorShape};
use crate::Tensor;
#[macro_export]
macro_rules! matmul {
($t1:expr,$t2:expr) => {{
use $crate::op::matmul::Matmul;
($t1, $t2).matmul()
}};
($t1:expr,$t2:expr,$($t2s:expr),+) => {{
$crate::matmul![$crate::matmul!($t1, $t2),$($t2s),+]
}};
}
pub trait Matmul {
type Out;
fn matmul(self) -> Self::Out;
}
impl<B, S1, S2, const N: usize> Matmul for (Tensor<BTensor<B, N>, S1>, Tensor<BTensor<B, N>, S2>)
where
B: Backend,
S1: Shape + matmul::Operand,
S2: Shape + matmul::Operand,
(S1, S2): matmul::Compatible,
{
type Out = Tensor<BTensor<B, N>, TensorShape<<(S1, S2) as matmul::Compatible>::Out>>;
fn matmul(self) -> Self::Out {
Tensor(self.0.into_inner().matmul(self.1.into_inner()), PhantomData)
}
}