rstsr-core 0.7.2

An n-Dimension Rust Tensor Toolkit
Documentation
//! Matrix-multiplication for tensor.

use crate::prelude_dev::*;
use core::ops::{Mul, Rem};
use num::{One, Zero};

/* #region matmul by function */

pub fn matmul<TA, TB, TC, DA, DB, DC, B>(
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
) -> Tensor<TC, B, DC>
where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    TA: Mul<TB, Output = TC>,
    TC: Zero + One,
    B: DeviceCreationAnyAPI<TC>,
    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    op_refa_refb_matmul(a, b, TC::one()).rstsr_unwrap()
}

pub fn matmul_from<TA, TB, TC, DA, DB, DC, B>(
    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    alpha: TC,
    beta: TC,
) where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    op_mutc_refa_refb_matmul(c, a, b, alpha, beta).rstsr_unwrap()
}

pub fn op_mutc_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
    mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    alpha: TC,
    beta: TC,
) -> Result<()>
where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
    rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
    rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
    let device = c.device().clone();
    let la = a.layout();
    let lb = b.layout();
    let lc = c.layout().clone();
    let sa = a.raw();
    let sb = b.raw();
    let sc = c.raw_mut();
    device.matmul(sc, &lc, sa, la, sb, lb, alpha, beta)
}

pub fn op_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    alpha: TC,
) -> Result<Tensor<TC, B, DC>>
where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    TC: Zero,
    B: DeviceCreationAnyAPI<TC>,
    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    let (a, b) = (a.view(), b.view());
    rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
    let default_order = a.device().default_order();
    let cfg = LayoutMatMulConfig::<DA, DB>::layout_matmul(a.layout(), b.layout(), default_order)?;
    let lc = cfg.lc;
    let mut c: Tensor<TC, B, _> = unsafe { empty((lc, a.device())) }.into_dim_f()?;
    op_mutc_refa_refb_matmul(&mut c, &a, &b, alpha, TC::zero())?;
    return Ok(c);
}

pub fn matmul_with_output_f<TA, TB, TC, DA, DB, DC, B>(
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
) -> Result<()>
where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    TC: Zero + One,
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero())
}

pub fn matmul_with_output<TA, TB, TC, DA, DB, DC, B>(
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
) where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    TC: Zero + One,
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero()).rstsr_unwrap()
}

pub fn matmul_from_f<TA, TB, TC, DA, DB, DC, B>(
    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    alpha: TC,
    beta: TC,
) -> Result<()>
where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    op_mutc_refa_refb_matmul(c, a, b, alpha, beta)
}

pub fn matmul_f<TA, TB, TC, DA, DB, DC, B>(
    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
) -> Result<Tensor<TC, B, DC>>
where
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    TA: Mul<TB, Output = TC>,
    TC: Zero + One,
    B: DeviceCreationAnyAPI<TC>,
    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    op_refa_refb_matmul(a, b, TC::one())
}

/* #endregion */

/* #region matmul implementation to core ops */

#[duplicate_item(
     TrA                         TrB                       ;
    [ TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
    [&TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
    [ TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
    [&TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
)]
impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<TrB> for TrA
where
    // storage
    RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
    // dimension
    DA: DimAPI,
    DB: DimAPI,
    DC: DimAPI,
    // operation specific
    TA: Mul<TB, Output = TC>,
    TC: Zero + One,
    B: DeviceCreationAnyAPI<TC>,
    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
{
    type Output = Tensor<TC, B, DC>;
    fn rem(self, rhs: TrB) -> Self::Output {
        op_refa_refb_matmul(self, rhs, TC::one()).rstsr_unwrap()
    }
}

/* #endregion */

/* #region matmul tensor trait */

impl<R, T, B, D> TensorAny<R, T, B, D>
where
    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
    B: DeviceAPI<T>,
    D: DimAPI,
{
    pub fn matmul_f<TB, TC, DB, DC>(
        &self,
        rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
    ) -> Result<Tensor<TC, B, DC>>
    where
        // dimension
        DB: DimAPI,
        DC: DimAPI,
        // operation specific
        T: Mul<TB, Output = TC>,
        TC: Zero + One,
        B: DeviceCreationAnyAPI<TC>,
        LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
    {
        op_refa_refb_matmul(self.view(), rhs, TC::one())
    }

    pub fn matmul<TB, TC, DB, DC>(&self, rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>) -> Tensor<TC, B, DC>
    where
        // dimension
        DB: DimAPI,
        DC: DimAPI,
        // operation specific
        T: Mul<TB, Output = TC>,
        TC: Zero + One,
        B: DeviceCreationAnyAPI<TC>,
        LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
    {
        op_refa_refb_matmul(self.view(), rhs, TC::one()).rstsr_unwrap()
    }

    pub fn matmul_with_output_f<TB, TC, DB, DC>(
        &self,
        rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
        c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
    ) -> Result<()>
    where
        // dimension
        DB: DimAPI,
        DC: DimAPI,
        // operation specific
        TC: Zero + One,
        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
    {
        op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero())
    }

    pub fn matmul_with_output<TB, TC, DB, DC>(
        &self,
        rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
        c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
    ) where
        // dimension
        DB: DimAPI,
        DC: DimAPI,
        // operation specific
        TC: Zero + One,
        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
    {
        op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero()).rstsr_unwrap()
    }

    pub fn matmul_from_f<TA, TB, DA, DB>(
        &mut self,
        a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
        alpha: T,
        beta: T,
    ) -> Result<()>
    where
        // storage
        R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
        // dimension
        DA: DimAPI,
        DB: DimAPI,
        // operation specific
        B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
    {
        op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta)
    }

    pub fn matmul_from<TA, TB, DA, DB>(
        &mut self,
        a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
        alpha: T,
        beta: T,
    ) where
        // storage
        R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
        // dimension
        DA: DimAPI,
        DB: DimAPI,
        // operation specific
        B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
    {
        op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta).rstsr_unwrap()
    }
}

/* #endregion */

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_matmul() {
        let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
        let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
        let mut c: Tensor<f64> = zeros([3, 3]);

        op_mutc_refa_refb_matmul(&mut c, &a, &b, 1.0, 0.0).unwrap();
        println!("{c}");

        let d = &a % &b;
        println!("{d}");

        let a = linspace((0.0, 14.0, 15));
        let b = linspace((0.0, 14.0, 15));
        println!("{:}", &a % &b);

        #[cfg(not(feature = "col_major"))]
        {
            let a = linspace((0.0, 2.0, 3));
            let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
            println!("{:}", &a % &b);

            let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
            let b = linspace((0.0, 4.0, 5));
            println!("{:}", &a % &b);

            let a = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
            let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
            println!("{:}", &a % &b);

            let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
            let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
            println!("{:}", &a % &b);
        }
    }

    #[test]
    fn test_matmul_from() {
        #[cfg(not(feature = "col_major"))]
        {
            let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
            let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
            let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
            c.matmul_from(&a, &b, 2.0, 1.5);
            println!("{c}");

            let c_ref = vec![240., 261.5, 283., 304.5, 646., 717.5, 789., 860.5, 1052., 1173.5, 1295., 1416.5];
            assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
        }
        #[cfg(feature = "col_major")]
        {
            let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
            let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
            let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
            c.matmul_from(&a, &b, 2.0, 1.5);
            println!("{c}");

            let c_ref = vec![180.0, 201.5, 223.0, 484.5, 556.0, 627.5, 789.0, 910.5, 1032.0, 1093.5, 1265.0, 1436.5];
            assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
        }
    }
}