Skip to main content

lumen_core/tensor/
matmul.rs

1use std::ops::Deref;
2
3use num_traits::Zero;
4use crate::{AutogradMetaT, Error, Layout, NumDType, Result, Shape, Storage};
5use super::Tensor;
6
7impl<T: NumDType> Tensor<T> {
8    /// Returns the matrix-multiplication of the input tensor with the other provided tensor.
9    ///
10    /// # Arguments
11    ///
12    /// * `self` - A tensor with dimensions `b1, b2, ..., bi, m, k`.
13    /// * `rhs` - A tensor with dimensions `b1, b2, ..., bi, k, n`.
14    ///
15    /// The resulting tensor has dimensions `b1, b2, ..., bi, m, n`.
16    pub fn matmul(&self, rhs: &Self) -> Result<Self> {
17        let a_dims = self.shape().dims();
18        let b_dims = rhs.shape().dims();
19
20        let dim = a_dims.len();
21
22        if dim < 2 || b_dims.len() != dim {
23            Err(Error::ShapeMismatchBinaryOp {
24                lhs: self.shape().clone(),
25                rhs: rhs.shape().clone(),
26                op: "matmul",
27            })?
28        }
29
30        let m = a_dims[dim - 2];
31        let k = a_dims[dim - 1];
32        let k2 = b_dims[dim - 2];
33        let n = b_dims[dim - 1];
34
35        let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
36        if c_shape.element_count() == 0 || k == 0 {
37            return Self::zeros(c_shape);
38        }
39        let batching: usize = a_dims[..dim - 2].iter().product();
40        let batching_b: usize = b_dims[..dim - 2].iter().product();
41        if k != k2 || batching != batching_b {
42            Err(Error::ShapeMismatchBinaryOp {
43                lhs: self.shape().clone(),
44                rhs: rhs.shape().clone(),
45                op: "matmul",
46            })?
47        }
48
49        // (..., m, k) @ (..., k, n)
50        let c_storage = Self::do_matmul(
51            self.storage_read()?.deref(),
52            self.layout(),
53            rhs.storage_read()?.deref(),
54            rhs.layout(),
55            (batching, m, n, k),
56        );
57
58        let meta = T::AutogradMeta::on_matmul_op(self, rhs);
59        Ok(Self::from_storage(c_storage, c_shape, meta))
60    }
61
62    fn do_matmul(
63        lhs: &Storage<T>, lhs_layout: &Layout, 
64        rhs: &Storage<T>, rhs_layout: &Layout, 
65        bmnk: (usize, usize, usize, usize)
66    ) -> Storage<T> 
67        where T: num_traits::Num + Copy + Zero
68    {
69        let lhs_data = lhs.data();
70        let rhs_data = rhs.data();
71    
72        let lhs_rank = lhs_layout.shape().rank();
73        let rhs_rank = rhs_layout.shape().rank();
74        let (bs, ms, ns, ks) = bmnk;
75        
76        let mut dst = vec![T::zero(); bs * ms * ns];
77    
78        // 获取最后两个维度的 stride
79        let l_stride_m = lhs_layout.stride()[lhs_rank - 2];
80        let l_stride_k = lhs_layout.stride()[lhs_rank - 1];
81        let r_stride_k = rhs_layout.stride()[rhs_rank - 2];
82        let r_stride_n = rhs_layout.stride()[rhs_rank - 1];
83    
84        // 获取用于 Batch 迭代的维度和步幅(排除最后两个维度)
85        let batch_dims = &lhs_layout.dims()[..lhs_rank - 2];
86        let l_batch_strides = &lhs_layout.stride()[..lhs_rank - 2];
87        let r_batch_strides = &rhs_layout.stride()[..rhs_rank - 2];
88    
89        for b in 0..bs {
90            // 计算当前 batch 在 lhs 和 rhs 中的起始偏移量
91            // 这里我们需要将平面的索引 b 还原为多维索引
92            let mut l_batch_offset = lhs_layout.start_offset();
93            let mut r_batch_offset = rhs_layout.start_offset();
94            let mut temp_b = b;
95            
96            // 逆向计算每个 batch 维度的索引并乘以该维度的 stride
97            for i in (0..batch_dims.len()).rev() {
98                let idx = temp_b % batch_dims[i];
99                l_batch_offset += idx * l_batch_strides[i];
100                r_batch_offset += idx * r_batch_strides[i];
101                temp_b /= batch_dims[i];
102            }
103    
104            let dst_batch_offset = b * ms * ns;
105    
106            for m in 0..ms {
107                for n in 0..ns {
108                    let mut v = T::zero();
109                    for k in 0..ks {
110                        // 使用真实的 stride 进行寻址,无论是否转置都能正确工作
111                        let l_idx = l_batch_offset + m * l_stride_m + k * l_stride_k;
112                        let r_idx = r_batch_offset + k * r_stride_k + n * r_stride_n;
113                        
114                        v = v + lhs_data[l_idx] * rhs_data[r_idx];
115                    }
116                    dst[dst_batch_offset + m * ns + n] = v;
117                }
118            }
119        }
120        
121        Storage::new(dst)
122    }
123}
124
125#[cfg(test)]
126#[allow(unused)]
127mod tests {
128    use crate::{s, DType, IndexOp, Slice};
129
130    use super::*;
131
132    #[test]
133    fn test_matmul_2d() {
134        // A: (2, 3), B: (3, 2)
135        let a = Tensor::arange(0, 6).unwrap().reshape((2, 3)).unwrap(); // [[0,1,2],[3,4,5]]
136        let b = Tensor::arange(0, 6).unwrap().reshape((3, 2)).unwrap(); // [[0,1],[2,3],[4,5]]
137        let c = a.matmul(&b).unwrap();
138
139        let expected = Tensor::new(&[
140            [0*0 + 1*2 + 2*4, 0*1 + 1*3 + 2*5], // [10, 13]
141            [3*0 + 4*2 + 5*4, 3*1 + 4*3 + 5*5], // [28, 40]
142        ]).unwrap();
143
144        assert!(c.allclose(&expected, 1e-5, 1e-8).unwrap());
145    }
146
147
148    #[test]
149    fn test_matmul_batch() {
150        // A: (2, 2, 3), B: (2, 3, 2)
151        let a = Tensor::arange(0., 12.).unwrap().reshape((2, 2, 3)).unwrap();
152        let b = Tensor::arange(0., 12.).unwrap().reshape((2, 3, 2)).unwrap();
153        let c = a.matmul(&b).unwrap();
154
155        assert_eq!(c.dims(), &[2, 2, 2]);
156
157        // batch 0
158        let a0 = Tensor::new(&[[0.,1.,2.],[3.,4.,5.]]).unwrap();
159        let b0 = Tensor::new(&[[0.,1.],[2.,3.],[4.,5.]]).unwrap();
160        let c0 = a0.matmul(&b0).unwrap();
161
162        // batch 1
163        let a1 = Tensor::new(&[[6.,7.,8.],[9.,10.,11.]]).unwrap();
164        let b1 = Tensor::new(&[[6.,7.],[8.,9.],[10.,11.]]).unwrap();
165        let c1 = a1.matmul(&b1).unwrap();
166
167        assert!(c0.allclose(&c.index(0).unwrap(), 1e-5, 1e-8).unwrap());
168        assert!(c1.allclose(&c.index(1).unwrap(), 1e-5, 1e-8).unwrap());
169    }
170
171    #[test]
172    fn test_matmul_not_continues() {
173        let a = Tensor::arange(0., 125.).unwrap().reshape((5, 5, 5)).unwrap();
174
175        let sub_a = a.index((s!(1:3), s!(3:5), 2)).unwrap();
176        let mut vals = Vec::new();
177        for i in 1..3 {
178            for j in 3..5 {
179                vals.push((i * 25 + j * 5 + 2) as f64);
180            }
181        }
182        let expected = Tensor::from_vec(vals, (2, 2)).unwrap();
183        assert!(sub_a.allclose(&expected, 0.0, 0.0).unwrap());
184
185        let b = Tensor::randn(0.0, 1.0, (2, 5)).unwrap();
186
187        let res = sub_a.matmul(&b).unwrap();
188        let res_expected = expected.matmul(&b).unwrap();
189        assert!(res.allclose(&res_expected, 1e-5, 1e-8).unwrap());
190    }
191}