rstsr_core/device_cpu_serial/
matmul.rs

1//! Matrix multiplication for CPU backend.
2//!
3//! **This implementation is not optimized!**
4
5use core::ops::{Add, Mul};
6
7use crate::prelude_dev::*;
8
9impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceCpuSerial
10where
11    TA: Clone,
12    TB: Clone,
13    TC: Clone,
14    DA: DimAPI,
15    DB: DimAPI,
16    DC: DimAPI,
17    TA: Mul<TB, Output = TC>,
18    TB: Mul<TA, Output = TC>,
19    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
20    Self: DeviceAPI<TA, Raw = Vec<TA>> + DeviceAPI<TB, Raw = Vec<TB>> + DeviceAPI<TC, Raw = Vec<TC>>,
21{
22    fn matmul(
23        &self,
24        c: &mut Vec<TC>,
25        lc: &Layout<DC>,
26        a: &Vec<TA>,
27        la: &Layout<DA>,
28        b: &Vec<TB>,
29        lb: &Layout<DB>,
30        alpha: TC,
31        beta: TC,
32    ) -> Result<()> {
33        let default_order = self.default_order();
34        match default_order {
35            RowMajor => matmul_naive_cpu_serial(c, lc, a, la, b, lb, alpha, beta),
36            ColMajor => {
37                let la = la.reverse_axes();
38                let lb = lb.reverse_axes();
39                let lc = lc.reverse_axes();
40                matmul_naive_cpu_serial(c, &lc, b, &lb, a, &la, alpha, beta)
41            },
42        }
43    }
44}
45
46#[cfg(test)]
47mod test {
48    use super::*;
49    use crate::prelude::*;
50
51    #[test]
52    fn test_row_major() {
53        /* Python code
54            a = np.linspace(1, 24, 24).reshape(2, 3, 4)
55            b = np.linspace(1, 20, 20).reshape(4, 5)
56            c = np.linspace(1, 30, 30).reshape(2, 3, 5)
57            (1.5 * a @ b + 2.0 * c).reshape(-1)
58        */
59        let mut device = DeviceCpuSerial::default();
60        device.set_default_order(RowMajor);
61        let a = rt::linspace((1.0, 24.0, 24, &device)).into_shape((2, 3, 4));
62        let b = rt::linspace((1.0, 20.0, 20, &device)).into_shape((4, 5));
63        let mut c = rt::linspace((1.0, 30.0, 30, &device)).into_shape((2, 3, 5));
64        let alpha = 1.5;
65        let beta = 2.0;
66        let la = a.layout();
67        let lb = b.layout();
68        let lc = c.layout().clone();
69        device.matmul(c.raw_mut(), &lc, a.raw(), la, b.raw(), lb, alpha, beta).unwrap();
70        println!("Result c: {c:?}");
71
72        let c_ref = rt::asarray((
73            vec![
74                167., 184., 201., 218., 235., 381., 422., 463., 504., 545., 595., 660., 725., 790., 855., 809., 898.,
75                987., 1076., 1165., 1023., 1136., 1249., 1362., 1475., 1237., 1374., 1511., 1648., 1785.,
76            ],
77            &device,
78        ));
79        assert!((&c.reshape(-1) - c_ref).l2_norm() < 1e-10);
80    }
81
82    #[test]
83    fn test_col_major() {
84        let mut device = DeviceCpuSerial::default();
85        device.set_default_order(ColMajor);
86        let a = rt::linspace((1.0, 20.0, 20, &device)).into_shape((5, 4));
87        let b = rt::linspace((1.0, 24.0, 24, &device)).into_shape((4, 3, 2));
88        let mut c = rt::linspace((1.0, 30.0, 30, &device)).into_shape((5, 3, 2));
89        let alpha = 1.5;
90        let beta = 2.0;
91        let la = a.layout();
92        let lb = b.layout();
93        let lc = c.layout().clone();
94        device.matmul(c.raw_mut(), &lc, a.raw(), la, b.raw(), lb, alpha, beta).unwrap();
95        println!("Result c: {c:?}");
96
97        let c_ref = rt::asarray((
98            vec![
99                167., 184., 201., 218., 235., 381., 422., 463., 504., 545., 595., 660., 725., 790., 855., 809., 898.,
100                987., 1076., 1165., 1023., 1136., 1249., 1362., 1475., 1237., 1374., 1511., 1648., 1785.,
101            ],
102            &device,
103        ));
104        assert!((&c.reshape(-1) - c_ref).l2_norm() < 1e-10);
105    }
106}