1use crate::prelude_dev::*;
4use core::ops::{Add, Mul, Rem};
5use num::{One, Zero};
6
7pub fn op_mutc_refa_refb_matmul<RA, RB, RC, TA, TB, TC, DA, DB, DC, B>(
10 c: &mut TensorAny<RC, TC, B, DC>,
11 a: &TensorAny<RA, TA, B, DA>,
12 b: &TensorAny<RB, TB, B, DB>,
13 alpha: TC,
14 beta: TC,
15) -> Result<()>
16where
17 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
19 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
20 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
21 DA: DimAPI,
23 DB: DimAPI,
24 DC: DimAPI,
25 TA: Mul<TB, Output = TC>,
27 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
28 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
29{
30 rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
31 rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
32 let device = c.device().clone();
33 let la = a.layout();
34 let lb = b.layout();
35 let lc = c.layout().clone();
36 let sa = a.raw();
37 let sb = b.raw();
38 let sc = c.raw_mut();
39 device.matmul(sc, &lc, sa, la, sb, lb, alpha, beta)
40}
41
42pub fn op_refa_refb_matmul<RA, RB, TA, TB, TC, DA, DB, DC, B>(
43 a: &TensorAny<RA, TA, B, DA>,
44 b: &TensorAny<RB, TB, B, DB>,
45 alpha: TC,
46) -> Result<Tensor<TC, B, DC>>
47where
48 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
50 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
51 DA: DimAPI,
53 DB: DimAPI,
54 DC: DimAPI,
55 TA: Mul<TB, Output = TC>,
57 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero,
58 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
59 B: DeviceCreationAnyAPI<TC>,
60 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
61 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
62{
63 rstsr_assert!(b.device().same_device(b.device()), DeviceMismatch)?;
64 let default_order = a.device().default_order();
65 let cfg = LayoutMatMulConfig::<DA, DB>::layout_matmul(a.layout(), b.layout(), default_order)?;
66 let lc = cfg.lc;
67 let mut c: Tensor<TC, B, _> = unsafe { empty((lc, a.device())) }.into_dim_f()?;
68 op_mutc_refa_refb_matmul(&mut c, a, b, alpha, TC::zero())?;
69 return Ok(c);
70}
71
72pub fn matmul_with_output_f<RA, RB, RC, TA, TB, TC, DA, DB, DC, B>(
73 a: &TensorAny<RA, TA, B, DA>,
74 b: &TensorAny<RB, TB, B, DB>,
75 c: &mut TensorAny<RC, TC, B, DC>,
76) -> Result<()>
77where
78 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
80 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
81 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
82 DA: DimAPI,
84 DB: DimAPI,
85 DC: DimAPI,
86 TA: Mul<TB, Output = TC>,
88 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
89 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
90{
91 op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero())
92}
93
94pub fn matmul_with_output<RA, RB, RC, TA, TB, TC, DA, DB, DC, B>(
95 a: &TensorAny<RA, TA, B, DA>,
96 b: &TensorAny<RB, TB, B, DB>,
97 c: &mut TensorAny<RC, TC, B, DC>,
98) where
99 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
101 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
102 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
103 DA: DimAPI,
105 DB: DimAPI,
106 DC: DimAPI,
107 TA: Mul<TB, Output = TC>,
109 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
110 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
111{
112 op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero()).rstsr_unwrap()
113}
114
115pub fn matmul_from_f<RA, RB, RC, TA, TB, TC, DA, DB, DC, B>(
116 c: &mut TensorAny<RC, TC, B, DC>,
117 a: &TensorAny<RA, TA, B, DA>,
118 b: &TensorAny<RB, TB, B, DB>,
119 alpha: TC,
120 beta: TC,
121) -> Result<()>
122where
123 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
125 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
126 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
127 DA: DimAPI,
129 DB: DimAPI,
130 DC: DimAPI,
131 TA: Mul<TB, Output = TC>,
133 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
134 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
135{
136 op_mutc_refa_refb_matmul(c, a, b, alpha, beta)
137}
138
139pub fn matmul_from<RA, RB, RC, TA, TB, TC, DA, DB, DC, B>(
140 c: &mut TensorAny<RC, TC, B, DC>,
141 a: &TensorAny<RA, TA, B, DA>,
142 b: &TensorAny<RB, TB, B, DB>,
143 alpha: TC,
144 beta: TC,
145) where
146 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
148 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
149 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
150 DA: DimAPI,
152 DB: DimAPI,
153 DC: DimAPI,
154 TA: Mul<TB, Output = TC>,
156 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
157 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
158{
159 op_mutc_refa_refb_matmul(c, a, b, alpha, beta).rstsr_unwrap()
160}
161
162pub fn matmul_f<RA, RB, TA, TB, TC, DA, DB, DC, B>(
163 a: &TensorAny<RA, TA, B, DA>,
164 b: &TensorAny<RB, TB, B, DB>,
165) -> Result<Tensor<TC, B, DC>>
166where
167 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
169 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
170 DA: DimAPI,
172 DB: DimAPI,
173 DC: DimAPI,
174 TA: Mul<TB, Output = TC>,
176 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
177 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
178 B: DeviceCreationAnyAPI<TC>,
179 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
180 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
181{
182 op_refa_refb_matmul(a, b, TC::one())
183}
184
185pub fn matmul<RA, RB, TA, TB, TC, DA, DB, DC, B>(
186 a: &TensorAny<RA, TA, B, DA>,
187 b: &TensorAny<RB, TB, B, DB>,
188) -> Tensor<TC, B, DC>
189where
190 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
192 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
193 DA: DimAPI,
195 DB: DimAPI,
196 DC: DimAPI,
197 TA: Mul<TB, Output = TC>,
199 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
200 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
201 B: DeviceCreationAnyAPI<TC>,
202 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
203 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
204{
205 op_refa_refb_matmul(a, b, TC::one()).rstsr_unwrap()
206}
207
208impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<&TensorAny<RB, TB, B, DB>> for &TensorAny<RA, TA, B, DA>
213where
214 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
216 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
217 DA: DimAPI,
219 DB: DimAPI,
220 DC: DimAPI,
221 TA: Mul<TB, Output = TC>,
223 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
224 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
225 B: DeviceCreationAnyAPI<TC>,
226 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
227 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
228{
229 type Output = Tensor<TC, B, DC>;
230 fn rem(self, rhs: &TensorAny<RB, TB, B, DB>) -> Self::Output {
231 op_refa_refb_matmul(self, rhs, TC::one()).rstsr_unwrap()
232 }
233}
234
235impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<&TensorAny<RB, TB, B, DB>> for TensorAny<RA, TA, B, DA>
236where
237 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
239 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
240 DA: DimAPI,
242 DB: DimAPI,
243 DC: DimAPI,
244 TA: Mul<TB, Output = TC>,
246 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
247 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
248 B: DeviceCreationAnyAPI<TC>,
249 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
250 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
251{
252 type Output = Tensor<TC, B, DC>;
253 fn rem(self, rhs: &TensorAny<RB, TB, B, DB>) -> Self::Output {
254 op_refa_refb_matmul(&self, rhs, TC::one()).rstsr_unwrap()
255 }
256}
257
258impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<TensorAny<RB, TB, B, DB>> for &TensorAny<RA, TA, B, DA>
259where
260 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
262 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
263 DA: DimAPI,
265 DB: DimAPI,
266 DC: DimAPI,
267 TA: Mul<TB, Output = TC>,
269 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
270 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
271 B: DeviceCreationAnyAPI<TC>,
272 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
273 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
274{
275 type Output = Tensor<TC, B, DC>;
276 fn rem(self, rhs: TensorAny<RB, TB, B, DB>) -> Self::Output {
277 op_refa_refb_matmul(self, &rhs, TC::one()).rstsr_unwrap()
278 }
279}
280
281impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<TensorAny<RB, TB, B, DB>> for TensorAny<RA, TA, B, DA>
282where
283 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
285 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
286 DA: DimAPI,
288 DB: DimAPI,
289 DC: DimAPI,
290 TA: Mul<TB, Output = TC>,
292 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
293 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
294 B: DeviceCreationAnyAPI<TC>,
295 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
296 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
297{
298 type Output = Tensor<TC, B, DC>;
299 fn rem(self, rhs: TensorAny<RB, TB, B, DB>) -> Self::Output {
300 op_refa_refb_matmul(&self, &rhs, TC::one()).rstsr_unwrap()
301 }
302}
303
304impl<R, T, B, D> TensorAny<R, T, B, D>
312where
313 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
314 B: DeviceAPI<T>,
315 D: DimAPI,
316{
317 pub fn matmul_f<RB, TB, TC, DB, DC>(&self, rhs: &TensorAny<RB, TB, B, DB>) -> Result<Tensor<TC, B, DC>>
318 where
319 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
321 DB: DimAPI,
323 DC: DimAPI,
324 T: Mul<TB, Output = TC>,
326 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
327 B: DeviceAPI<TB> + DeviceAPI<TC>,
328 B: DeviceCreationAnyAPI<TC>,
329 LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
330 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
331 {
332 op_refa_refb_matmul(self, rhs, TC::one())
333 }
334
335 pub fn matmul<RB, TB, TC, DB, DC>(&self, rhs: &TensorAny<RB, TB, B, DB>) -> Tensor<TC, B, DC>
336 where
337 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
339 DB: DimAPI,
341 DC: DimAPI,
342 T: Mul<TB, Output = TC>,
344 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
345 B: DeviceAPI<TB> + DeviceAPI<TC>,
346 B: DeviceCreationAnyAPI<TC>,
347 LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
348 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
349 {
350 op_refa_refb_matmul(self, rhs, TC::one()).rstsr_unwrap()
351 }
352
353 pub fn matmul_with_output_f<RB, RC, TB, TC, DB, DC>(
354 &self,
355 rhs: &TensorAny<RB, TB, B, DB>,
356 c: &mut TensorAny<RC, TC, B, DC>,
357 ) -> Result<()>
358 where
359 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
361 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
362 DB: DimAPI,
364 DC: DimAPI,
365 T: Mul<TB, Output = TC>,
367 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
368 B: DeviceAPI<TB> + DeviceAPI<TC>,
369 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
370 {
371 op_mutc_refa_refb_matmul(c, self, rhs, TC::one(), TC::zero())
372 }
373
374 pub fn matmul_with_output<RB, RC, TB, TC, DB, DC>(
375 &self,
376 rhs: &TensorAny<RB, TB, B, DB>,
377 c: &mut TensorAny<RC, TC, B, DC>,
378 ) where
379 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
381 RC: DataMutAPI<Data = <B as DeviceRawAPI<TC>>::Raw>,
382 DB: DimAPI,
384 DC: DimAPI,
385 T: Mul<TB, Output = TC>,
387 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + One,
388 B: DeviceAPI<TB> + DeviceAPI<TC>,
389 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
390 {
391 op_mutc_refa_refb_matmul(c, self, rhs, TC::one(), TC::zero()).rstsr_unwrap()
392 }
393
394 pub fn matmul_from_f<RA, RB, TA, TB, DA, DB>(
395 &mut self,
396 a: &TensorAny<RA, TA, B, DA>,
397 b: &TensorAny<RB, TB, B, DB>,
398 alpha: T,
399 beta: T,
400 ) -> Result<()>
401 where
402 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
404 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
405 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
406 DA: DimAPI,
408 DB: DimAPI,
409 TA: Mul<TB, Output = T>,
411 T: Mul<T, Output = T> + Add<T, Output = T> + Zero + One,
412 B: DeviceAPI<TA> + DeviceAPI<TB>,
413 B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
414 {
415 op_mutc_refa_refb_matmul(self, a, b, alpha, beta)
416 }
417
418 pub fn matmul_from<RA, RB, TA, TB, DA, DB>(
419 &mut self,
420 a: &TensorAny<RA, TA, B, DA>,
421 b: &TensorAny<RB, TB, B, DB>,
422 alpha: T,
423 beta: T,
424 ) where
425 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
427 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
428 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
429 DA: DimAPI,
431 DB: DimAPI,
432 TA: Mul<TB, Output = T>,
434 T: Mul<T, Output = T> + Add<T, Output = T> + Zero + One,
435 B: DeviceAPI<TA> + DeviceAPI<TB>,
436 B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
437 {
438 op_mutc_refa_refb_matmul(self, a, b, alpha, beta).rstsr_unwrap()
439 }
440}
441
442#[cfg(test)]
445mod test {
446 use super::*;
447
448 #[test]
449 fn test_matmul() {
450 let a = linspace((0.0, 14.0, 15)).into_shape_assume_contig([3, 5]);
451 let b = linspace((0.0, 14.0, 15)).into_shape_assume_contig([5, 3]);
452 let mut c: Tensor<f64> = zeros([3, 3]);
453
454 op_mutc_refa_refb_matmul(&mut c, &a, &b, 1.0, 0.0).unwrap();
455 println!("{c}");
456
457 let d = &a % &b;
458 println!("{d}");
459
460 let a = linspace((0.0, 14.0, 15));
461 let b = linspace((0.0, 14.0, 15));
462 println!("{:}", &a % &b);
463
464 #[cfg(not(feature = "col_major"))]
465 {
466 let a = linspace((0.0, 2.0, 3));
467 let b = linspace((0.0, 29.0, 30)).into_shape_assume_contig([2, 3, 5]);
468 println!("{:}", &a % &b);
469
470 let a = linspace((0.0, 29.0, 30)).into_shape_assume_contig([2, 3, 5]);
471 let b = linspace((0.0, 4.0, 5));
472 println!("{:}", &a % &b);
473
474 let a = linspace((0.0, 14.0, 15)).into_shape_assume_contig([5, 3]);
475 let b = linspace((0.0, 29.0, 30)).into_shape_assume_contig([2, 3, 5]);
476 println!("{:}", &a % &b);
477
478 let a = linspace((0.0, 29.0, 30)).into_shape_assume_contig([2, 3, 5]);
479 let b = linspace((0.0, 14.0, 15)).into_shape_assume_contig([5, 3]);
480 println!("{:}", &a % &b);
481 }
482 }
483
484 #[test]
485 fn test_matmul_from() {
486 #[cfg(not(feature = "col_major"))]
487 {
488 let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
489 let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
490 let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
491 c.matmul_from(&a, &b, 2.0, 1.5);
492 println!("{c}");
493
494 let c_ref = vec![240., 261.5, 283., 304.5, 646., 717.5, 789., 860.5, 1052., 1173.5, 1295., 1416.5];
495 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
496 }
497 #[cfg(feature = "col_major")]
498 {
499 let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
500 let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
501 let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
502 c.matmul_from(&a, &b, 2.0, 1.5);
503 println!("{c}");
504
505 let c_ref = vec![180.0, 201.5, 223.0, 484.5, 556.0, 627.5, 789.0, 910.5, 1032.0, 1093.5, 1265.0, 1436.5];
506 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
507 }
508 }
509}