1use std::ops::Deref;
2
3use num_traits::Zero;
4use crate::{AutogradMetaT, Error, Layout, NumDType, Result, Shape, Storage};
5use super::Tensor;
6
7impl<T: NumDType> Tensor<T> {
8 pub fn matmul(&self, rhs: &Self) -> Result<Self> {
17 let a_dims = self.shape().dims();
18 let b_dims = rhs.shape().dims();
19
20 let dim = a_dims.len();
21
22 if dim < 2 || b_dims.len() != dim {
23 Err(Error::ShapeMismatchBinaryOp {
24 lhs: self.shape().clone(),
25 rhs: rhs.shape().clone(),
26 op: "matmul",
27 })?
28 }
29
30 let m = a_dims[dim - 2];
31 let k = a_dims[dim - 1];
32 let k2 = b_dims[dim - 2];
33 let n = b_dims[dim - 1];
34
35 let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
36 if c_shape.element_count() == 0 || k == 0 {
37 return Self::zeros(c_shape);
38 }
39 let batching: usize = a_dims[..dim - 2].iter().product();
40 let batching_b: usize = b_dims[..dim - 2].iter().product();
41 if k != k2 || batching != batching_b {
42 Err(Error::ShapeMismatchBinaryOp {
43 lhs: self.shape().clone(),
44 rhs: rhs.shape().clone(),
45 op: "matmul",
46 })?
47 }
48
49 let c_storage = Self::do_matmul(
51 self.storage_read()?.deref(),
52 self.layout(),
53 rhs.storage_read()?.deref(),
54 rhs.layout(),
55 (batching, m, n, k),
56 );
57
58 let meta = T::AutogradMeta::on_matmul_op(self, rhs);
59 Ok(Self::from_storage(c_storage, c_shape, meta))
60 }
61
62 fn do_matmul(
63 lhs: &Storage<T>, lhs_layout: &Layout,
64 rhs: &Storage<T>, rhs_layout: &Layout,
65 bmnk: (usize, usize, usize, usize)
66 ) -> Storage<T>
67 where T: num_traits::Num + Copy + Zero
68 {
69 let lhs_data = lhs.data();
70 let rhs_data = rhs.data();
71
72 let lhs_rank = lhs_layout.shape().rank();
73 let rhs_rank = rhs_layout.shape().rank();
74 let (bs, ms, ns, ks) = bmnk;
75
76 let mut dst = vec![T::zero(); bs * ms * ns];
77
78 let l_stride_m = lhs_layout.stride()[lhs_rank - 2];
80 let l_stride_k = lhs_layout.stride()[lhs_rank - 1];
81 let r_stride_k = rhs_layout.stride()[rhs_rank - 2];
82 let r_stride_n = rhs_layout.stride()[rhs_rank - 1];
83
84 let batch_dims = &lhs_layout.dims()[..lhs_rank - 2];
86 let l_batch_strides = &lhs_layout.stride()[..lhs_rank - 2];
87 let r_batch_strides = &rhs_layout.stride()[..rhs_rank - 2];
88
89 for b in 0..bs {
90 let mut l_batch_offset = lhs_layout.start_offset();
93 let mut r_batch_offset = rhs_layout.start_offset();
94 let mut temp_b = b;
95
96 for i in (0..batch_dims.len()).rev() {
98 let idx = temp_b % batch_dims[i];
99 l_batch_offset += idx * l_batch_strides[i];
100 r_batch_offset += idx * r_batch_strides[i];
101 temp_b /= batch_dims[i];
102 }
103
104 let dst_batch_offset = b * ms * ns;
105
106 for m in 0..ms {
107 for n in 0..ns {
108 let mut v = T::zero();
109 for k in 0..ks {
110 let l_idx = l_batch_offset + m * l_stride_m + k * l_stride_k;
112 let r_idx = r_batch_offset + k * r_stride_k + n * r_stride_n;
113
114 v = v + lhs_data[l_idx] * rhs_data[r_idx];
115 }
116 dst[dst_batch_offset + m * ns + n] = v;
117 }
118 }
119 }
120
121 Storage::new(dst)
122 }
123}
124
125#[cfg(test)]
126#[allow(unused)]
127mod tests {
128 use crate::{s, DType, IndexOp, Slice};
129
130 use super::*;
131
132 #[test]
133 fn test_matmul_2d() {
134 let a = Tensor::arange(0, 6).unwrap().reshape((2, 3)).unwrap(); let b = Tensor::arange(0, 6).unwrap().reshape((3, 2)).unwrap(); let c = a.matmul(&b).unwrap();
138
139 let expected = Tensor::new(&[
140 [0*0 + 1*2 + 2*4, 0*1 + 1*3 + 2*5], [3*0 + 4*2 + 5*4, 3*1 + 4*3 + 5*5], ]).unwrap();
143
144 assert!(c.allclose(&expected, 1e-5, 1e-8).unwrap());
145 }
146
147
148 #[test]
149 fn test_matmul_batch() {
150 let a = Tensor::arange(0., 12.).unwrap().reshape((2, 2, 3)).unwrap();
152 let b = Tensor::arange(0., 12.).unwrap().reshape((2, 3, 2)).unwrap();
153 let c = a.matmul(&b).unwrap();
154
155 assert_eq!(c.dims(), &[2, 2, 2]);
156
157 let a0 = Tensor::new(&[[0.,1.,2.],[3.,4.,5.]]).unwrap();
159 let b0 = Tensor::new(&[[0.,1.],[2.,3.],[4.,5.]]).unwrap();
160 let c0 = a0.matmul(&b0).unwrap();
161
162 let a1 = Tensor::new(&[[6.,7.,8.],[9.,10.,11.]]).unwrap();
164 let b1 = Tensor::new(&[[6.,7.],[8.,9.],[10.,11.]]).unwrap();
165 let c1 = a1.matmul(&b1).unwrap();
166
167 assert!(c0.allclose(&c.index(0).unwrap(), 1e-5, 1e-8).unwrap());
168 assert!(c1.allclose(&c.index(1).unwrap(), 1e-5, 1e-8).unwrap());
169 }
170
171 #[test]
172 fn test_matmul_not_continues() {
173 let a = Tensor::arange(0., 125.).unwrap().reshape((5, 5, 5)).unwrap();
174
175 let sub_a = a.index((s!(1:3), s!(3:5), 2)).unwrap();
176 let mut vals = Vec::new();
177 for i in 1..3 {
178 for j in 3..5 {
179 vals.push((i * 25 + j * 5 + 2) as f64);
180 }
181 }
182 let expected = Tensor::from_vec(vals, (2, 2)).unwrap();
183 assert!(sub_a.allclose(&expected, 0.0, 0.0).unwrap());
184
185 let b = Tensor::randn(0.0, 1.0, (2, 5)).unwrap();
186
187 let res = sub_a.matmul(&b).unwrap();
188 let res_expected = expected.matmul(&b).unwrap();
189 assert!(res.allclose(&res_expected, 1e-5, 1e-8).unwrap());
190 }
191}