use core::ops::{Add, Mul};
use crate::prelude_dev::*;
impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceCpuSerial
where
TA: Clone,
TB: Clone,
TC: Clone,
DA: DimAPI,
DB: DimAPI,
DC: DimAPI,
TA: Mul<TB, Output = TC>,
TB: Mul<TA, Output = TC>,
TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
Self: DeviceAPI<TA, Raw = Vec<TA>> + DeviceAPI<TB, Raw = Vec<TB>> + DeviceAPI<TC, Raw = Vec<TC>>,
{
fn matmul(
&self,
c: &mut Vec<TC>,
lc: &Layout<DC>,
a: &Vec<TA>,
la: &Layout<DA>,
b: &Vec<TB>,
lb: &Layout<DB>,
alpha: TC,
beta: TC,
) -> Result<()> {
let default_order = self.default_order();
match default_order {
RowMajor => matmul_naive_cpu_serial(c, lc, a, la, b, lb, alpha, beta),
ColMajor => {
let la = la.reverse_axes();
let lb = lb.reverse_axes();
let lc = lc.reverse_axes();
matmul_naive_cpu_serial(c, &lc, b, &lb, a, &la, alpha, beta)
},
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::prelude::*;
#[test]
fn test_row_major() {
let mut device = DeviceCpuSerial::default();
device.set_default_order(RowMajor);
let a = rt::linspace((1.0, 24.0, 24, &device)).into_shape((2, 3, 4));
let b = rt::linspace((1.0, 20.0, 20, &device)).into_shape((4, 5));
let mut c = rt::linspace((1.0, 30.0, 30, &device)).into_shape((2, 3, 5));
let alpha = 1.5;
let beta = 2.0;
let la = a.layout();
let lb = b.layout();
let lc = c.layout().clone();
device.matmul(c.raw_mut(), &lc, a.raw(), la, b.raw(), lb, alpha, beta).unwrap();
println!("Result c: {c:?}");
let c_ref = rt::asarray((
vec![
167., 184., 201., 218., 235., 381., 422., 463., 504., 545., 595., 660., 725., 790., 855., 809., 898.,
987., 1076., 1165., 1023., 1136., 1249., 1362., 1475., 1237., 1374., 1511., 1648., 1785.,
],
&device,
));
assert!((&c.reshape(-1) - c_ref).l2_norm() < 1e-10);
}
#[test]
fn test_col_major() {
let mut device = DeviceCpuSerial::default();
device.set_default_order(ColMajor);
let a = rt::linspace((1.0, 20.0, 20, &device)).into_shape((5, 4));
let b = rt::linspace((1.0, 24.0, 24, &device)).into_shape((4, 3, 2));
let mut c = rt::linspace((1.0, 30.0, 30, &device)).into_shape((5, 3, 2));
let alpha = 1.5;
let beta = 2.0;
let la = a.layout();
let lb = b.layout();
let lc = c.layout().clone();
device.matmul(c.raw_mut(), &lc, a.raw(), la, b.raw(), lb, alpha, beta).unwrap();
println!("Result c: {c:?}");
let c_ref = rt::asarray((
vec![
167., 184., 201., 218., 235., 381., 422., 463., 504., 545., 595., 660., 725., 790., 855., 809., 898.,
987., 1076., 1165., 1023., 1136., 1249., 1362., 1475., 1237., 1374., 1511., 1648., 1785.,
],
&device,
));
assert!((&c.reshape(-1) - c_ref).l2_norm() < 1e-10);
}
}