use std::marker::PhantomData;
use burn::prelude::Backend;
use burn::tensor::{Bool, Tensor as BTensor};
use glowstick::Shape;
use crate::Tensor;
#[macro_export]
macro_rules! tril_mask {
($o:expr,$d:expr,$b:ty,[$($ds:tt)+]) => {{
use $crate::op::tril_mask::TrilMask;
(
std::marker::PhantomData::<$b>,
std::marker::PhantomData::<glowstick::TensorShape<$crate::reshape_tys!($($ds)+)>>,
$o,
$d,
)
.tril_mask($crate::reshape_val!($($ds)+).into_array())
}};
}
pub trait TrilMask<const M: usize> {
type Out;
fn tril_mask(self, shape: [usize; M]) -> Self::Out;
}
impl<B, S, const M: usize> TrilMask<M>
for (PhantomData<B>, PhantomData<S>, i64, &<B as Backend>::Device)
where
B: Backend,
S: Shape,
{
type Out = Tensor<BTensor<B, M, Bool>, S>;
fn tril_mask(self, shape: [usize; M]) -> Self::Out {
Tensor(
BTensor::<B, M, Bool>::tril_mask(shape, self.2, self.3),
PhantomData,
)
}
}