22_matrix_multiplication/
22_matrix_multiplication.rs1use matten::Tensor;
9
10fn main() {
11 let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
13 let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]);
14 let c = a.matmul(&b);
15 println!("A = {a:?}");
16 println!("B = {b:?}");
17 println!("A × B = {c:?}"); assert_eq!(c.as_slice(), &[19.0, 22.0, 43.0, 50.0]);
19
20 let x = Tensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
22 let y = Tensor::new((1..=12).map(|v| v as f64).collect(), &[3, 4]);
23 let z = x.matmul(&y);
24 println!("X × Y shape = {:?}", z.shape()); println!("Matrix multiplication: OK");
27}