Skip to main content

rstsr_core/tensor/linalg/
matmul.rs

1//! Matrix-multiplication for tensor.
2
3use crate::prelude_dev::*;
4use core::ops::{Mul, Rem};
5use num::{One, Zero};
6
7/* #region matmul by function */
8
9pub fn matmul<TA, TB, TC, DA, DB, DC, B>(
10    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
11    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
12) -> Tensor<TC, B, DC>
13where
14    // dimension
15    DA: DimAPI,
16    DB: DimAPI,
17    DC: DimAPI,
18    // operation specific
19    TA: Mul<TB, Output = TC>,
20    TC: Zero + One,
21    B: DeviceCreationAnyAPI<TC>,
22    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
23    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
24{
25    op_refa_refb_matmul(a, b, TC::one()).rstsr_unwrap()
26}
27
28pub fn matmul_from<TA, TB, TC, DA, DB, DC, B>(
29    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
30    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
31    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
32    alpha: TC,
33    beta: TC,
34) where
35    // dimension
36    DA: DimAPI,
37    DB: DimAPI,
38    DC: DimAPI,
39    // operation specific
40    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
41{
42    op_mutc_refa_refb_matmul(c, a, b, alpha, beta).rstsr_unwrap()
43}
44
45pub fn op_mutc_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
46    mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
47    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
48    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
49    alpha: TC,
50    beta: TC,
51) -> Result<()>
52where
53    // dimension
54    DA: DimAPI,
55    DB: DimAPI,
56    DC: DimAPI,
57    // operation specific
58    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
59{
60    let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
61    rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
62    rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
63    let device = c.device().clone();
64    let la = a.layout();
65    let lb = b.layout();
66    let lc = c.layout().clone();
67    let sa = a.raw();
68    let sb = b.raw();
69    let sc = c.raw_mut();
70    device.matmul(sc, &lc, sa, la, sb, lb, alpha, beta)
71}
72
73pub fn op_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
74    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
75    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
76    alpha: TC,
77) -> Result<Tensor<TC, B, DC>>
78where
79    // dimension
80    DA: DimAPI,
81    DB: DimAPI,
82    DC: DimAPI,
83    // operation specific
84    TC: Zero,
85    B: DeviceCreationAnyAPI<TC>,
86    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
87    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
88{
89    let (a, b) = (a.view(), b.view());
90    rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
91    let default_order = a.device().default_order();
92    let cfg = LayoutMatMulConfig::<DA, DB>::layout_matmul(a.layout(), b.layout(), default_order)?;
93    let lc = cfg.lc;
94    let mut c: Tensor<TC, B, _> = unsafe { empty((lc, a.device())) }.into_dim_f()?;
95    op_mutc_refa_refb_matmul(&mut c, &a, &b, alpha, TC::zero())?;
96    return Ok(c);
97}
98
99pub fn matmul_with_output_f<TA, TB, TC, DA, DB, DC, B>(
100    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
101    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
102    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
103) -> Result<()>
104where
105    // dimension
106    DA: DimAPI,
107    DB: DimAPI,
108    DC: DimAPI,
109    // operation specific
110    TC: Zero + One,
111    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
112{
113    op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero())
114}
115
116pub fn matmul_with_output<TA, TB, TC, DA, DB, DC, B>(
117    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
118    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
119    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
120) where
121    // dimension
122    DA: DimAPI,
123    DB: DimAPI,
124    DC: DimAPI,
125    // operation specific
126    TC: Zero + One,
127    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
128{
129    op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero()).rstsr_unwrap()
130}
131
132pub fn matmul_from_f<TA, TB, TC, DA, DB, DC, B>(
133    c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
134    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
135    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
136    alpha: TC,
137    beta: TC,
138) -> Result<()>
139where
140    // dimension
141    DA: DimAPI,
142    DB: DimAPI,
143    DC: DimAPI,
144    // operation specific
145    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
146{
147    op_mutc_refa_refb_matmul(c, a, b, alpha, beta)
148}
149
150pub fn matmul_f<TA, TB, TC, DA, DB, DC, B>(
151    a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
152    b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
153) -> Result<Tensor<TC, B, DC>>
154where
155    // dimension
156    DA: DimAPI,
157    DB: DimAPI,
158    DC: DimAPI,
159    // operation specific
160    TA: Mul<TB, Output = TC>,
161    TC: Zero + One,
162    B: DeviceCreationAnyAPI<TC>,
163    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
164    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
165{
166    op_refa_refb_matmul(a, b, TC::one())
167}
168
169/* #endregion */
170
171/* #region matmul implementation to core ops */
172
173#[duplicate_item(
174     TrA                         TrB                       ;
175    [ TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
176    [&TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
177    [ TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
178    [&TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
179)]
180impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<TrB> for TrA
181where
182    // storage
183    RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
184    RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
185    // dimension
186    DA: DimAPI,
187    DB: DimAPI,
188    DC: DimAPI,
189    // operation specific
190    TA: Mul<TB, Output = TC>,
191    TC: Zero + One,
192    B: DeviceCreationAnyAPI<TC>,
193    LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
194    B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
195{
196    type Output = Tensor<TC, B, DC>;
197    fn rem(self, rhs: TrB) -> Self::Output {
198        op_refa_refb_matmul(self, rhs, TC::one()).rstsr_unwrap()
199    }
200}
201
202/* #endregion */
203
204/* #region matmul tensor trait */
205
206impl<R, T, B, D> TensorAny<R, T, B, D>
207where
208    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
209    B: DeviceAPI<T>,
210    D: DimAPI,
211{
212    pub fn matmul_f<TB, TC, DB, DC>(
213        &self,
214        rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
215    ) -> Result<Tensor<TC, B, DC>>
216    where
217        // dimension
218        DB: DimAPI,
219        DC: DimAPI,
220        // operation specific
221        T: Mul<TB, Output = TC>,
222        TC: Zero + One,
223        B: DeviceCreationAnyAPI<TC>,
224        LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
225        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
226    {
227        op_refa_refb_matmul(self.view(), rhs, TC::one())
228    }
229
230    pub fn matmul<TB, TC, DB, DC>(&self, rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>) -> Tensor<TC, B, DC>
231    where
232        // dimension
233        DB: DimAPI,
234        DC: DimAPI,
235        // operation specific
236        T: Mul<TB, Output = TC>,
237        TC: Zero + One,
238        B: DeviceCreationAnyAPI<TC>,
239        LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
240        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
241    {
242        op_refa_refb_matmul(self.view(), rhs, TC::one()).rstsr_unwrap()
243    }
244
245    pub fn matmul_with_output_f<TB, TC, DB, DC>(
246        &self,
247        rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
248        c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
249    ) -> Result<()>
250    where
251        // dimension
252        DB: DimAPI,
253        DC: DimAPI,
254        // operation specific
255        TC: Zero + One,
256        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
257    {
258        op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero())
259    }
260
261    pub fn matmul_with_output<TB, TC, DB, DC>(
262        &self,
263        rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
264        c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
265    ) where
266        // dimension
267        DB: DimAPI,
268        DC: DimAPI,
269        // operation specific
270        TC: Zero + One,
271        B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
272    {
273        op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero()).rstsr_unwrap()
274    }
275
276    pub fn matmul_from_f<TA, TB, DA, DB>(
277        &mut self,
278        a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
279        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
280        alpha: T,
281        beta: T,
282    ) -> Result<()>
283    where
284        // storage
285        R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
286        // dimension
287        DA: DimAPI,
288        DB: DimAPI,
289        // operation specific
290        B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
291    {
292        op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta)
293    }
294
295    pub fn matmul_from<TA, TB, DA, DB>(
296        &mut self,
297        a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
298        b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
299        alpha: T,
300        beta: T,
301    ) where
302        // storage
303        R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
304        // dimension
305        DA: DimAPI,
306        DB: DimAPI,
307        // operation specific
308        B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
309    {
310        op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta).rstsr_unwrap()
311    }
312}
313
314/* #endregion */
315
316#[cfg(test)]
317mod test {
318    use super::*;
319
320    #[test]
321    fn test_matmul() {
322        let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
323        let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
324        let mut c: Tensor<f64> = zeros([3, 3]);
325
326        op_mutc_refa_refb_matmul(&mut c, &a, &b, 1.0, 0.0).unwrap();
327        println!("{c}");
328
329        let d = &a % &b;
330        println!("{d}");
331
332        let a = linspace((0.0, 14.0, 15));
333        let b = linspace((0.0, 14.0, 15));
334        println!("{:}", &a % &b);
335
336        #[cfg(not(feature = "col_major"))]
337        {
338            let a = linspace((0.0, 2.0, 3));
339            let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
340            println!("{:}", &a % &b);
341
342            let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
343            let b = linspace((0.0, 4.0, 5));
344            println!("{:}", &a % &b);
345
346            let a = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
347            let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
348            println!("{:}", &a % &b);
349
350            let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
351            let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
352            println!("{:}", &a % &b);
353        }
354    }
355
356    #[test]
357    fn test_matmul_from() {
358        #[cfg(not(feature = "col_major"))]
359        {
360            let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
361            let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
362            let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
363            c.matmul_from(&a, &b, 2.0, 1.5);
364            println!("{c}");
365
366            let c_ref = vec![240., 261.5, 283., 304.5, 646., 717.5, 789., 860.5, 1052., 1173.5, 1295., 1416.5];
367            assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
368        }
369        #[cfg(feature = "col_major")]
370        {
371            let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
372            let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
373            let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
374            c.matmul_from(&a, &b, 2.0, 1.5);
375            println!("{c}");
376
377            let c_ref = vec![180.0, 201.5, 223.0, 484.5, 556.0, 627.5, 789.0, 910.5, 1032.0, 1093.5, 1265.0, 1436.5];
378            assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
379        }
380    }
381}