use crate::prelude_dev::*;
use core::ops::{Mul, Rem};
use num::{One, Zero};
pub fn matmul<TA, TB, TC, DA, DB, DC, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
) -> Tensor<TC, B, DC>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TA: Mul<TB, Output = TC>,
TC: Zero + One,
B: DeviceCreationAnyAPI<TC>,
LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
op_refa_refb_matmul(a, b, TC::one()).rstsr_unwrap()
}
pub fn matmul_from<TA, TB, TC, DA, DB, DC, B>(
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
alpha: TC,
beta: TC,
) where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
op_mutc_refa_refb_matmul(c, a, b, alpha, beta).rstsr_unwrap()
}
pub fn op_mutc_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
alpha: TC,
beta: TC,
) -> Result<()>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
let device = c.device().clone();
let la = a.layout();
let lb = b.layout();
let lc = c.layout().clone();
let sa = a.raw();
let sb = b.raw();
let sc = c.raw_mut();
device.matmul(sc, &lc, sa, la, sb, lb, alpha, beta)
}
pub fn op_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
alpha: TC,
) -> Result<Tensor<TC, B, DC>>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TC: Zero,
B: DeviceCreationAnyAPI<TC>,
LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
let (a, b) = (a.view(), b.view());
rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
let default_order = a.device().default_order();
let cfg = LayoutMatMulConfig::<DA, DB>::layout_matmul(a.layout(), b.layout(), default_order)?;
let lc = cfg.lc;
let mut c: Tensor<TC, B, _> = unsafe { empty((lc, a.device())) }.into_dim_f()?;
op_mutc_refa_refb_matmul(&mut c, &a, &b, alpha, TC::zero())?;
return Ok(c);
}
pub fn matmul_with_output_f<TA, TB, TC, DA, DB, DC, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
) -> Result<()>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TC: Zero + One,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero())
}
pub fn matmul_with_output<TA, TB, TC, DA, DB, DC, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
) where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TC: Zero + One,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero()).rstsr_unwrap()
}
pub fn matmul_from_f<TA, TB, TC, DA, DB, DC, B>(
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
alpha: TC,
beta: TC,
) -> Result<()>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
op_mutc_refa_refb_matmul(c, a, b, alpha, beta)
}
pub fn matmul_f<TA, TB, TC, DA, DB, DC, B>(
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
) -> Result<Tensor<TC, B, DC>>
where
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TA: Mul<TB, Output = TC>,
TC: Zero + One,
B: DeviceCreationAnyAPI<TC>,
LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
op_refa_refb_matmul(a, b, TC::one())
}
#[duplicate_item(
TrA TrB ;
[ TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
[&TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
[ TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
[&TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
)]
impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<TrB> for TrA
where
RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TA: Mul<TB, Output = TC>,
TC: Zero + One,
B: DeviceCreationAnyAPI<TC>,
LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
type Output = Tensor<TC, B, DC>;
fn rem(self, rhs: TrB) -> Self::Output {
op_refa_refb_matmul(self, rhs, TC::one()).rstsr_unwrap()
}
}
impl<R, T, B, D> TensorAny<R, T, B, D>
where
R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
B: DeviceAPI<T>,
D: DimAPI,
{
pub fn matmul_f<TB, TC, DB, DC>(
&self,
rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
) -> Result<Tensor<TC, B, DC>>
where
DB: DimAPI,
DC: DimAPI,
T: Mul<TB, Output = TC>,
TC: Zero + One,
B: DeviceCreationAnyAPI<TC>,
LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
{
op_refa_refb_matmul(self.view(), rhs, TC::one())
}
pub fn matmul<TB, TC, DB, DC>(&self, rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>) -> Tensor<TC, B, DC>
where
DB: DimAPI,
DC: DimAPI,
T: Mul<TB, Output = TC>,
TC: Zero + One,
B: DeviceCreationAnyAPI<TC>,
LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
{
op_refa_refb_matmul(self.view(), rhs, TC::one()).rstsr_unwrap()
}
pub fn matmul_with_output_f<TB, TC, DB, DC>(
&self,
rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
) -> Result<()>
where
DB: DimAPI,
DC: DimAPI,
TC: Zero + One,
B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
{
op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero())
}
pub fn matmul_with_output<TB, TC, DB, DC>(
&self,
rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
) where
DB: DimAPI,
DC: DimAPI,
TC: Zero + One,
B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
{
op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero()).rstsr_unwrap()
}
pub fn matmul_from_f<TA, TB, DA, DB>(
&mut self,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
alpha: T,
beta: T,
) -> Result<()>
where
R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
DA: DimAPI,
DB: DimAPI,
B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
{
op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta)
}
pub fn matmul_from<TA, TB, DA, DB>(
&mut self,
a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
alpha: T,
beta: T,
) where
R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
DA: DimAPI,
DB: DimAPI,
B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
{
op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta).rstsr_unwrap()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_matmul() {
let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
let mut c: Tensor<f64> = zeros([3, 3]);
op_mutc_refa_refb_matmul(&mut c, &a, &b, 1.0, 0.0).unwrap();
println!("{c}");
let d = &a % &b;
println!("{d}");
let a = linspace((0.0, 14.0, 15));
let b = linspace((0.0, 14.0, 15));
println!("{:}", &a % &b);
#[cfg(not(feature = "col_major"))]
{
let a = linspace((0.0, 2.0, 3));
let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
println!("{:}", &a % &b);
let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
let b = linspace((0.0, 4.0, 5));
println!("{:}", &a % &b);
let a = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
println!("{:}", &a % &b);
let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
println!("{:}", &a % &b);
}
}
#[test]
fn test_matmul_from() {
#[cfg(not(feature = "col_major"))]
{
let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
c.matmul_from(&a, &b, 2.0, 1.5);
println!("{c}");
let c_ref = vec![240., 261.5, 283., 304.5, 646., 717.5, 789., 860.5, 1052., 1173.5, 1295., 1416.5];
assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
}
#[cfg(feature = "col_major")]
{
let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
c.matmul_from(&a, &b, 2.0, 1.5);
println!("{c}");
let c_ref = vec![180.0, 201.5, 223.0, 484.5, 556.0, 627.5, 789.0, 910.5, 1032.0, 1093.5, 1265.0, 1436.5];
assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
}
}
}