1use crate::matmul_impl::*;
2use crate::prelude_dev::*;
3use crate::threading::with_num_threads;
4use core::any::TypeId;
5use core::ops::{Add, Mul};
6use core::slice::{from_raw_parts, from_raw_parts_mut};
7use num::{Complex, Zero};
8use rayon::prelude::*;
9
10fn same_type<A: 'static, B: 'static>() -> bool {
12 TypeId::of::<A>() == TypeId::of::<B>()
13}
14
15#[allow(clippy::too_many_arguments)]
16pub fn gemm_blas_ix2_no_conj_dispatch<TA, TB, TC>(
17 c: &mut [TC],
18 lc: &Layout<Ix2>,
19 a: &[TA],
20 la: &Layout<Ix2>,
21 b: &[TB],
22 lb: &Layout<Ix2>,
23 alpha: TC,
24 beta: TC,
25 pool: Option<&ThreadPool>,
26) -> Result<()>
27where
28 TA: Clone + Send + Sync + 'static,
29 TB: Clone + Send + Sync + 'static,
30 TC: Clone + Send + Sync + 'static,
31 TA: Mul<TB, Output = TC>,
32 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
33{
34 let able_syrk = beta == TC::zero()
36 && same_type::<TB, TC>()
37 && same_type::<TA, TC>()
38 && unsafe {
39 let a_ptr = a.as_ptr().add(la.offset()) as *const TC;
40 let b_ptr = b.as_ptr().add(lb.offset()) as *const TC;
41 let equal_ptr = core::ptr::eq(a_ptr, b_ptr);
42 let equal_shape = la.shape() == lb.reverse_axes().shape();
43 let equal_stride = la.stride() == lb.reverse_axes().stride();
44 equal_ptr && equal_shape && equal_stride
45 };
46
47 macro_rules! impl_gemm_dispatch {
49 ($ty: ty, $fn_gemm_name: ident, $fn_syrk_name: ident) => {
50 if (same_type::<TA, $ty>() && same_type::<TB, $ty>() && same_type::<TC, $ty>()) {
51 let a_slice = unsafe { from_raw_parts(a.as_ptr() as *const $ty, a.len()) };
52 let b_slice = unsafe { from_raw_parts(b.as_ptr() as *const $ty, b.len()) };
53 let c_slice = unsafe { from_raw_parts_mut(c.as_mut_ptr() as *mut $ty, c.len()) };
54 let alpha = unsafe { *(&alpha as *const TC as *const $ty) };
55 let beta = unsafe { *(&beta as *const TC as *const $ty) };
56 if able_syrk {
57 $fn_syrk_name(c_slice, lc, a_slice, la, alpha, pool)?;
58 } else {
59 $fn_gemm_name(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool)?;
60 }
61 return Ok(());
62 }
63 };
64 }
65
66 impl_gemm_dispatch!(f32, gemm_blas_no_conj_f32, syrk_blas_no_conj_f32);
67 impl_gemm_dispatch!(f64, gemm_blas_no_conj_f64, syrk_blas_no_conj_f64);
68 impl_gemm_dispatch!(Complex<f32>, gemm_blas_no_conj_c32, syrk_blas_no_conj_c32);
69 impl_gemm_dispatch!(Complex<f64>, gemm_blas_no_conj_c64, syrk_blas_no_conj_c64);
70
71 let c_slice = c;
74 let a_slice = a;
75 let b_slice = b;
76 return gemm_ix2_naive_cpu_rayon(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool);
77}
78
79#[allow(clippy::too_many_arguments)]
80pub fn matmul_row_major_blas<TA, TB, TC, DA, DB, DC>(
81 c: &mut [TC],
82 lc: &Layout<DC>,
83 a: &[TA],
84 la: &Layout<DA>,
85 b: &[TB],
86 lb: &Layout<DB>,
87 alpha: TC,
88 beta: TC,
89 pool: Option<&ThreadPool>,
90) -> Result<()>
91where
92 TA: Clone + Send + Sync + 'static,
93 TB: Clone + Send + Sync + 'static,
94 TC: Clone + Send + Sync + 'static,
95 DA: DimAPI,
96 DB: DimAPI,
97 DC: DimAPI,
98 TA: Mul<TB, Output = TC>,
99 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
100{
101 if lc.size() == 0 {
108 return Ok(());
109 }
110
111 let nthreads = match pool {
112 Some(pool) => pool.current_num_threads(),
113 None => 1,
114 };
115
116 match (la.ndim(), lb.ndim(), lc.ndim()) {
118 (1, 1, 0) => {
119 let la = &la.clone().into_dim::<Ix1>().unwrap();
121 let lb = &lb.clone().into_dim::<Ix1>().unwrap();
122 let lc = &lc.clone().into_dim::<Ix0>().unwrap();
123 let c_num = &mut c[lc.offset()];
124 return with_num_threads(nthreads, || inner_dot_naive_cpu_rayon(c_num, a, la, b, lb, alpha, beta, pool));
125 },
126 (2, 2, 2) => {
127 let la = &la.clone().into_dim::<Ix2>().unwrap();
129 let lb = &lb.clone().into_dim::<Ix2>().unwrap();
130 let lc = &lc.clone().into_dim::<Ix2>().unwrap();
131 return with_num_threads(nthreads, || {
132 gemm_blas_ix2_no_conj_dispatch(c, lc, a, la, b, lb, alpha, beta, pool)
133 });
134 },
135 _ => (),
136 };
137
138 let la_matmul;
141 let lb_matmul;
142 let lc_matmul;
143 let la_rest;
144 let lb_rest;
145 let lc_rest;
146
147 match (la.ndim(), lb.ndim(), lc.ndim()) {
148 (1, 1, 0) | (2, 2, 2) => unreachable!(),
150 (1, 2.., _) => {
151 rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?;
153 let (la_r, la_m) = la.dim_split_at(-1)?;
154 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
155 let (lc_r, lc_m) = lc.dim_split_at(-1)?;
156 la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
157 lb_rest = lb_r;
158 lc_rest = lc_r;
159 la_matmul = la_m.dim_insert(0)?.into_dim::<Ix2>()?;
160 lb_matmul = lb_m.into_dim::<Ix2>()?;
161 lc_matmul = lc_m.dim_insert(0)?.into_dim::<Ix2>()?;
162 },
163 (2.., 1, _) => {
164 rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?;
166 let (la_r, la_m) = la.dim_split_at(-2)?;
167 let (lb_r, lb_m) = lb.dim_split_at(-1)?;
168 let (lc_r, lc_m) = lc.dim_split_at(-1)?;
169 la_rest = la_r;
170 lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
171 lc_rest = lc_r;
172 la_matmul = la_m.into_dim::<Ix2>()?;
173 lb_matmul = lb_m.dim_insert(1)?.into_dim::<Ix2>()?;
174 lc_matmul = lc_m.dim_insert(1)?.into_dim::<Ix2>()?;
175 },
176 (2, 3.., _) => {
177 rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
179 let (la_r, la_m) = la.dim_split_at(-2)?;
180 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
181 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
182 la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
183 lb_rest = lb_r;
184 lc_rest = lc_r;
185 la_matmul = la_m.into_dim::<Ix2>()?;
186 lb_matmul = lb_m.into_dim::<Ix2>()?;
187 lc_matmul = lc_m.into_dim::<Ix2>()?;
188 },
189 (3.., 2, _) => {
190 rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
192 let (la_r, la_m) = la.dim_split_at(-2)?;
193 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
194 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
195 la_rest = la_r;
196 lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
197 lc_rest = lc_r;
198 la_matmul = la_m.into_dim::<Ix2>()?;
199 lb_matmul = lb_m.into_dim::<Ix2>()?;
200 lc_matmul = lc_m.into_dim::<Ix2>()?;
201 },
202 (3.., 3.., _) => {
203 rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
205 rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
206 let (la_r, la_m) = la.dim_split_at(-2)?;
207 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
208 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
209 la_rest = la_r;
210 lb_rest = lb_r;
211 lc_rest = lc_r;
212 la_matmul = la_m.into_dim::<Ix2>()?;
213 lb_matmul = lb_m.into_dim::<Ix2>()?;
214 lc_matmul = lc_m.into_dim::<Ix2>()?;
215 },
216 _ => {
217 rstsr_raise!(InvalidLayout, "This is not valid layout for matmul broadcasting.")?;
218 unreachable!()
219 },
220 }
221 rstsr_assert_eq!(la_rest.shape(), lb_rest.shape(), InvalidLayout)?;
226 rstsr_assert_eq!(lb_rest.shape(), lc_rest.shape(), InvalidLayout)?;
227 let n_task = la_rest.size();
228 let ita_rest = IterLayoutColMajor::new(&la_rest)?;
229 let itb_rest = IterLayoutColMajor::new(&lb_rest)?;
230 let itc_rest = IterLayoutColMajor::new(&lc_rest)?;
231 if n_task >= 4 * nthreads {
232 with_num_threads(1, || {
234 let task = || {
235 ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each(
236 |((ia_rest, ib_rest), ic_rest)| -> Result<()> {
237 let mut la_m = la_matmul.clone();
239 let mut lb_m = lb_matmul.clone();
240 let mut lc_m = lc_matmul.clone();
241 unsafe {
242 la_m.set_offset(ia_rest);
243 lb_m.set_offset(ib_rest);
244 lc_m.set_offset(ic_rest);
245 }
246 let c = unsafe {
248 let c_ptr = c.as_ptr() as *mut TC;
249 let c_len = c.len();
250 from_raw_parts_mut(c_ptr, c_len)
251 };
252 let alpha = alpha.clone();
254 let beta = beta.clone();
255 gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None)
256 },
257 )
258 };
259 match pool {
260 Some(pool) => pool.install(task),
261 None => task(),
262 }
263 })
264 } else {
265 with_num_threads(nthreads, || -> Result<()> {
267 izip!(ita_rest, itb_rest, itc_rest).try_for_each(|(ia_rest, ib_rest, ic_rest)| {
268 let mut la_m = la_matmul.clone();
270 let mut lb_m = lb_matmul.clone();
271 let mut lc_m = lc_matmul.clone();
272 unsafe {
273 la_m.set_offset(ia_rest);
274 lb_m.set_offset(ib_rest);
275 lc_m.set_offset(ic_rest);
276 }
277 let alpha = alpha.clone();
279 let beta = beta.clone();
280 gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, pool)
281 })
282 })
283 }
284}
285
286#[allow(clippy::too_many_arguments)]
287impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceBLAS
288where
289 TA: Clone + Send + Sync + 'static,
290 TB: Clone + Send + Sync + 'static,
291 TC: Clone + Send + Sync + 'static,
292 DA: DimAPI,
293 DB: DimAPI,
294 DC: DimAPI,
295 TA: Mul<TB, Output = TC>,
296 TB: Mul<TA, Output = TC>,
297 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
298{
299 fn matmul(
300 &self,
301 c: &mut Vec<TC>,
302 lc: &Layout<DC>,
303 a: &Vec<TA>,
304 la: &Layout<DA>,
305 b: &Vec<TB>,
306 lb: &Layout<DB>,
307 alpha: TC,
308 beta: TC,
309 ) -> Result<()> {
310 let default_order = self.default_order();
311 let pool = self.get_current_pool();
312 match default_order {
313 RowMajor => matmul_row_major_blas(c, lc, a, la, b, lb, alpha, beta, pool),
314 ColMajor => {
315 let la = la.reverse_axes();
316 let lb = lb.reverse_axes();
317 let lc = lc.reverse_axes();
318 matmul_row_major_blas(c, &lc, b, &lb, a, &la, alpha, beta, pool)
319 },
320 }
321 }
322}
323
324#[cfg(test)]
325mod test {
326 use super::*;
327
328 #[test]
329 fn test_matmul() {
330 let device = DeviceBLAS::default();
331 let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]);
332 let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
333 println!("{:}", &a % &b);
334
335 let a = linspace((0.0, 14.0, 15, &device));
336 let b = linspace((0.0, 14.0, 15, &device));
337 println!("{:}", &a % &b);
338
339 let a = linspace((0.0, 2.0, 3, &device));
340 let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
341 println!("{:}", &a % &b);
342
343 let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
344 let b = linspace((0.0, 4.0, 5, &device));
345 println!("{:}", &a % &b);
346
347 let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
348 let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
349 println!("{:}", &a % &b);
350
351 let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
352 let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
353 println!("{:}", &a % &b);
354 }
355
356 #[test]
357 #[ignore]
358 fn parallel_test_full() {
359 let device = DeviceBLAS::default();
360 let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
361 let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
362 for _ in 0..10 {
363 let start = std::time::Instant::now();
364 let _ = &a % &b;
365 println!("time: {:?}", start.elapsed());
366 }
367 }
368
369 #[test]
370 #[ignore]
371 fn parallel_test_full_512() {
372 let device = DeviceBLAS::new(1);
373 let a = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
374 let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
375 for _ in 0..1000 {
376 let start = std::time::Instant::now();
377 let c = &a % &b;
378 println!("{:?}", c.device());
379 println!("time: {:?}", start.elapsed());
380 }
381 }
382
383 #[test]
384 #[ignore]
385 fn parallel_test_par_rule7() {
386 let device = DeviceBLAS::default();
387 let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
388 let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
389 for i in 0..10 {
390 let start = std::time::Instant::now();
391 let c = &a % &b;
392 println!("{:?}", c.layout());
393 println!("time: {:?}", start.elapsed());
394 if i == 0 {
395 println!("{c:?}");
396 }
397 }
398 }
399
400 #[test]
401 #[ignore]
402 fn parallel_test_par_rule6() {
403 let device = DeviceBLAS::default();
404 let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
405 let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
406 for i in 0..10 {
407 let start = std::time::Instant::now();
408 let c = &a % &b;
409 println!("{:?}", c.layout());
410 println!("time: {:?}", start.elapsed());
411 if i == 0 {
412 println!("{c:?}");
413 }
414 }
415 }
416
417 #[test]
418 #[ignore]
419 fn parallel_test_par_rule6_fprefer() {
420 let device = DeviceBLAS::default();
421 let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([512, 512, 256]).into_reverse_axes();
422 let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
423 for i in 0..10 {
424 let start = std::time::Instant::now();
425 let c = &a % &b;
426 println!("{:?}", c.layout());
427 println!("time: {:?}", start.elapsed());
428 if i == 0 {
429 println!("{c:?}");
430 }
431 }
432 }
433
434 #[test]
435 fn syrk_correctness() {
436 let device = DeviceBLAS::default();
437 let a = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
438 let b = linspace((0.0, 1.0, 512 * 512, &device)).into_shape([512, 512]);
439 let c = &a % &a.t();
440 let d = &a % &b.t();
441 assert!(allclose_f64(&c, &d));
442
443 let device = DeviceBLAS::default();
444 let a = linspace((0.0, 1.0, 1024 * 1024, &device)).into_shape([4, 512, 512]);
445 let b = linspace((0.0, 1.0, 1024 * 1024, &device)).into_shape([4, 512, 512]);
446 let c = &a % &a.swapaxes(-1, -2);
447 let d = &a % &b.swapaxes(-1, -2);
448 assert!(allclose_f64(&c, &d));
449 }
450
451 #[test]
452 #[ignore]
453 fn syrk_efficiency() {
454 use std::hint::black_box;
455 let device = DeviceBLAS::default();
456 let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
457 let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([256, 512, 512]);
458 for _ in 0..10 {
459 let start = std::time::Instant::now();
460 black_box(&a % &a.swapaxes(-1, -2));
461 println!("syrk time: {:?}", start.elapsed());
462 let start = std::time::Instant::now();
463 black_box(&a % &b.swapaxes(-1, -2));
464 println!("gemm time: {:?}", start.elapsed());
465 }
466
467 println!("---------------------");
468 let a = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
469 let b = linspace((0.0, 1.0, 8192 * 8192, &device)).into_shape([8192, 8192]);
470 for _ in 0..10 {
471 let start = std::time::Instant::now();
472 black_box(&a % &a.swapaxes(-1, -2));
473 println!("syrk time: {:?}", start.elapsed());
474 let start = std::time::Instant::now();
475 black_box(&a % &b.swapaxes(-1, -2));
476 println!("gemm time: {:?}", start.elapsed());
477 }
478 }
479}