1use core::ops::{Add, Mul};
6
7use crate::prelude_dev::*;
8
9impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceCpuSerial
10where
11 TA: Clone,
12 TB: Clone,
13 TC: Clone,
14 DA: DimAPI,
15 DB: DimAPI,
16 DC: DimAPI,
17 TA: Mul<TB, Output = TC>,
18 TB: Mul<TA, Output = TC>,
19 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
20 Self: DeviceAPI<TA, Raw = Vec<TA>> + DeviceAPI<TB, Raw = Vec<TB>> + DeviceAPI<TC, Raw = Vec<TC>>,
21{
22 fn matmul(
23 &self,
24 c: &mut Vec<TC>,
25 lc: &Layout<DC>,
26 a: &Vec<TA>,
27 la: &Layout<DA>,
28 b: &Vec<TB>,
29 lb: &Layout<DB>,
30 alpha: TC,
31 beta: TC,
32 ) -> Result<()> {
33 let default_order = self.default_order();
34 match default_order {
35 RowMajor => matmul_naive_cpu_serial(c, lc, a, la, b, lb, alpha, beta),
36 ColMajor => {
37 let la = la.reverse_axes();
38 let lb = lb.reverse_axes();
39 let lc = lc.reverse_axes();
40 matmul_naive_cpu_serial(c, &lc, b, &lb, a, &la, alpha, beta)
41 },
42 }
43 }
44}
45
46#[cfg(test)]
47mod test {
48 use super::*;
49 use crate::prelude::*;
50
51 #[test]
52 fn test_row_major() {
53 let mut device = DeviceCpuSerial::default();
60 device.set_default_order(RowMajor);
61 let a = rt::linspace((1.0, 24.0, 24, &device)).into_shape((2, 3, 4));
62 let b = rt::linspace((1.0, 20.0, 20, &device)).into_shape((4, 5));
63 let mut c = rt::linspace((1.0, 30.0, 30, &device)).into_shape((2, 3, 5));
64 let alpha = 1.5;
65 let beta = 2.0;
66 let la = a.layout();
67 let lb = b.layout();
68 let lc = c.layout().clone();
69 device.matmul(c.raw_mut(), &lc, a.raw(), la, b.raw(), lb, alpha, beta).unwrap();
70 println!("Result c: {c:?}");
71
72 let c_ref = rt::asarray((
73 vec![
74 167., 184., 201., 218., 235., 381., 422., 463., 504., 545., 595., 660., 725., 790., 855., 809., 898.,
75 987., 1076., 1165., 1023., 1136., 1249., 1362., 1475., 1237., 1374., 1511., 1648., 1785.,
76 ],
77 &device,
78 ));
79 assert!((&c.reshape(-1) - c_ref).l2_norm() < 1e-10);
80 }
81
82 #[test]
83 fn test_col_major() {
84 let mut device = DeviceCpuSerial::default();
85 device.set_default_order(ColMajor);
86 let a = rt::linspace((1.0, 20.0, 20, &device)).into_shape((5, 4));
87 let b = rt::linspace((1.0, 24.0, 24, &device)).into_shape((4, 3, 2));
88 let mut c = rt::linspace((1.0, 30.0, 30, &device)).into_shape((5, 3, 2));
89 let alpha = 1.5;
90 let beta = 2.0;
91 let la = a.layout();
92 let lb = b.layout();
93 let lc = c.layout().clone();
94 device.matmul(c.raw_mut(), &lc, a.raw(), la, b.raw(), lb, alpha, beta).unwrap();
95 println!("Result c: {c:?}");
96
97 let c_ref = rt::asarray((
98 vec![
99 167., 184., 201., 218., 235., 381., 422., 463., 504., 545., 595., 660., 725., 790., 855., 809., 898.,
100 987., 1076., 1165., 1023., 1136., 1249., 1362., 1475., 1237., 1374., 1511., 1648., 1785.,
101 ],
102 &device,
103 ));
104 assert!((&c.reshape(-1) - c_ref).l2_norm() < 1e-10);
105 }
106}