1use crate::prelude_dev::*;
4use core::ops::{Mul, Rem};
5use num::{One, Zero};
6
7pub fn matmul<TA, TB, TC, DA, DB, DC, B>(
10 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
11 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
12) -> Tensor<TC, B, DC>
13where
14 DA: DimAPI,
16 DB: DimAPI,
17 DC: DimAPI,
18 TA: Mul<TB, Output = TC>,
20 TC: Zero + One,
21 B: DeviceCreationAnyAPI<TC>,
22 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
23 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
24{
25 op_refa_refb_matmul(a, b, TC::one()).rstsr_unwrap()
26}
27
28pub fn matmul_from<TA, TB, TC, DA, DB, DC, B>(
29 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
30 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
31 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
32 alpha: TC,
33 beta: TC,
34) where
35 DA: DimAPI,
37 DB: DimAPI,
38 DC: DimAPI,
39 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
41{
42 op_mutc_refa_refb_matmul(c, a, b, alpha, beta).rstsr_unwrap()
43}
44
45pub fn op_mutc_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
46 mut c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
47 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
48 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
49 alpha: TC,
50 beta: TC,
51) -> Result<()>
52where
53 DA: DimAPI,
55 DB: DimAPI,
56 DC: DimAPI,
57 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
59{
60 let (a, b, mut c) = (a.view(), b.view(), c.view_mut());
61 rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
62 rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
63 let device = c.device().clone();
64 let la = a.layout();
65 let lb = b.layout();
66 let lc = c.layout().clone();
67 let sa = a.raw();
68 let sb = b.raw();
69 let sc = c.raw_mut();
70 device.matmul(sc, &lc, sa, la, sb, lb, alpha, beta)
71}
72
73pub fn op_refa_refb_matmul<TA, TB, TC, DA, DB, DC, B>(
74 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
75 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
76 alpha: TC,
77) -> Result<Tensor<TC, B, DC>>
78where
79 DA: DimAPI,
81 DB: DimAPI,
82 DC: DimAPI,
83 TC: Zero,
85 B: DeviceCreationAnyAPI<TC>,
86 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
87 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
88{
89 let (a, b) = (a.view(), b.view());
90 rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
91 let default_order = a.device().default_order();
92 let cfg = LayoutMatMulConfig::<DA, DB>::layout_matmul(a.layout(), b.layout(), default_order)?;
93 let lc = cfg.lc;
94 let mut c: Tensor<TC, B, _> = unsafe { empty((lc, a.device())) }.into_dim_f()?;
95 op_mutc_refa_refb_matmul(&mut c, &a, &b, alpha, TC::zero())?;
96 return Ok(c);
97}
98
99pub fn matmul_with_output_f<TA, TB, TC, DA, DB, DC, B>(
100 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
101 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
102 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
103) -> Result<()>
104where
105 DA: DimAPI,
107 DB: DimAPI,
108 DC: DimAPI,
109 TC: Zero + One,
111 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
112{
113 op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero())
114}
115
116pub fn matmul_with_output<TA, TB, TC, DA, DB, DC, B>(
117 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
118 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
119 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
120) where
121 DA: DimAPI,
123 DB: DimAPI,
124 DC: DimAPI,
125 TC: Zero + One,
127 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
128{
129 op_mutc_refa_refb_matmul(c, a, b, TC::one(), TC::zero()).rstsr_unwrap()
130}
131
132pub fn matmul_from_f<TA, TB, TC, DA, DB, DC, B>(
133 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
134 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
135 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
136 alpha: TC,
137 beta: TC,
138) -> Result<()>
139where
140 DA: DimAPI,
142 DB: DimAPI,
143 DC: DimAPI,
144 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
146{
147 op_mutc_refa_refb_matmul(c, a, b, alpha, beta)
148}
149
150pub fn matmul_f<TA, TB, TC, DA, DB, DC, B>(
151 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
152 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
153) -> Result<Tensor<TC, B, DC>>
154where
155 DA: DimAPI,
157 DB: DimAPI,
158 DC: DimAPI,
159 TA: Mul<TB, Output = TC>,
161 TC: Zero + One,
162 B: DeviceCreationAnyAPI<TC>,
163 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
164 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
165{
166 op_refa_refb_matmul(a, b, TC::one())
167}
168
169#[duplicate_item(
174 TrA TrB ;
175 [ TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
176 [&TensorAny<RA, TA, B, DA>] [ TensorAny<RB, TB, B, DB>];
177 [ TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
178 [&TensorAny<RA, TA, B, DA>] [&TensorAny<RB, TB, B, DB>];
179)]
180impl<RA, RB, TA, TB, TC, DA, DB, DC, B> Rem<TrB> for TrA
181where
182 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
184 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
185 DA: DimAPI,
187 DB: DimAPI,
188 DC: DimAPI,
189 TA: Mul<TB, Output = TC>,
191 TC: Zero + One,
192 B: DeviceCreationAnyAPI<TC>,
193 LayoutMatMulConfig<DA, DB>: LayoutMatMulAPI<DA, DB, DC = DC>,
194 B: DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>,
195{
196 type Output = Tensor<TC, B, DC>;
197 fn rem(self, rhs: TrB) -> Self::Output {
198 op_refa_refb_matmul(self, rhs, TC::one()).rstsr_unwrap()
199 }
200}
201
202impl<R, T, B, D> TensorAny<R, T, B, D>
207where
208 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
209 B: DeviceAPI<T>,
210 D: DimAPI,
211{
212 pub fn matmul_f<TB, TC, DB, DC>(
213 &self,
214 rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
215 ) -> Result<Tensor<TC, B, DC>>
216 where
217 DB: DimAPI,
219 DC: DimAPI,
220 T: Mul<TB, Output = TC>,
222 TC: Zero + One,
223 B: DeviceCreationAnyAPI<TC>,
224 LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
225 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
226 {
227 op_refa_refb_matmul(self.view(), rhs, TC::one())
228 }
229
230 pub fn matmul<TB, TC, DB, DC>(&self, rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>) -> Tensor<TC, B, DC>
231 where
232 DB: DimAPI,
234 DC: DimAPI,
235 T: Mul<TB, Output = TC>,
237 TC: Zero + One,
238 B: DeviceCreationAnyAPI<TC>,
239 LayoutMatMulConfig<D, DB>: LayoutMatMulAPI<D, DB, DC = DC>,
240 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
241 {
242 op_refa_refb_matmul(self.view(), rhs, TC::one()).rstsr_unwrap()
243 }
244
245 pub fn matmul_with_output_f<TB, TC, DB, DC>(
246 &self,
247 rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
248 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
249 ) -> Result<()>
250 where
251 DB: DimAPI,
253 DC: DimAPI,
254 TC: Zero + One,
256 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
257 {
258 op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero())
259 }
260
261 pub fn matmul_with_output<TB, TC, DB, DC>(
262 &self,
263 rhs: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
264 c: impl TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
265 ) where
266 DB: DimAPI,
268 DC: DimAPI,
269 TC: Zero + One,
271 B: DeviceMatMulAPI<T, TB, TC, D, DB, DC>,
272 {
273 op_mutc_refa_refb_matmul(c, self.view(), rhs, TC::one(), TC::zero()).rstsr_unwrap()
274 }
275
276 pub fn matmul_from_f<TA, TB, DA, DB>(
277 &mut self,
278 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
279 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
280 alpha: T,
281 beta: T,
282 ) -> Result<()>
283 where
284 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
286 DA: DimAPI,
288 DB: DimAPI,
289 B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
291 {
292 op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta)
293 }
294
295 pub fn matmul_from<TA, TB, DA, DB>(
296 &mut self,
297 a: impl TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
298 b: impl TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
299 alpha: T,
300 beta: T,
301 ) where
302 R: DataMutAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
304 DA: DimAPI,
306 DB: DimAPI,
307 B: DeviceMatMulAPI<TA, TB, T, DA, DB, D>,
309 {
310 op_mutc_refa_refb_matmul(self.view_mut(), a, b, alpha, beta).rstsr_unwrap()
311 }
312}
313
314#[cfg(test)]
317mod test {
318 use super::*;
319
320 #[test]
321 fn test_matmul() {
322 let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
323 let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
324 let mut c: Tensor<f64> = zeros([3, 3]);
325
326 op_mutc_refa_refb_matmul(&mut c, &a, &b, 1.0, 0.0).unwrap();
327 println!("{c}");
328
329 let d = &a % &b;
330 println!("{d}");
331
332 let a = linspace((0.0, 14.0, 15));
333 let b = linspace((0.0, 14.0, 15));
334 println!("{:}", &a % &b);
335
336 #[cfg(not(feature = "col_major"))]
337 {
338 let a = linspace((0.0, 2.0, 3));
339 let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
340 println!("{:}", &a % &b);
341
342 let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
343 let b = linspace((0.0, 4.0, 5));
344 println!("{:}", &a % &b);
345
346 let a = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
347 let b = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
348 println!("{:}", &a % &b);
349
350 let a = linspace((0.0, 29.0, 30)).into_shape([2, 3, 5]);
351 let b = linspace((0.0, 14.0, 15)).into_shape([5, 3]);
352 println!("{:}", &a % &b);
353 }
354 }
355
356 #[test]
357 fn test_matmul_from() {
358 #[cfg(not(feature = "col_major"))]
359 {
360 let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
361 let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
362 let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
363 c.matmul_from(&a, &b, 2.0, 1.5);
364 println!("{c}");
365
366 let c_ref = vec![240., 261.5, 283., 304.5, 646., 717.5, 789., 860.5, 1052., 1173.5, 1295., 1416.5];
367 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
368 }
369 #[cfg(feature = "col_major")]
370 {
371 let a = linspace((0.0, 14.0, 15)).into_shape([3, 5]);
372 let b = linspace((0.0, 19.0, 20)).into_shape([5, 4]);
373 let mut c = linspace((0.0, 11.0, 12)).into_shape([3, 4]);
374 c.matmul_from(&a, &b, 2.0, 1.5);
375 println!("{c}");
376
377 let c_ref = vec![180.0, 201.5, 223.0, 484.5, 556.0, 627.5, 789.0, 910.5, 1032.0, 1093.5, 1265.0, 1436.5];
378 assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
379 }
380 }
381}