use std::ops::Mul;
pub trait ForwardMul<SelfInput, OtherGrad>
{
type ResultGrad;
fn forward_mul(&self, other: &OtherGrad) -> Self::ResultGrad;
}
macro_rules! impl_forward_mul {
($($t:ty),*) => {
$(
impl<SelfInput, OtherGrad, ResultGrad> ForwardMul<SelfInput, OtherGrad> for $t
where
for<'a,'b> &'a $t: Mul<&'b OtherGrad, Output = ResultGrad>,
{
type ResultGrad = ResultGrad;
fn forward_mul(
&self,
other: &OtherGrad,
) -> Self::ResultGrad {
self * other
}
}
)*
};
}
impl_forward_mul!(f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);