1use super::matmul_impl::*;
6use crate::prelude_dev::*;
7use core::any::TypeId;
8use core::ops::{Add, Mul};
9use core::slice::{from_raw_parts, from_raw_parts_mut};
10use num::{Complex, Zero};
11use rayon::prelude::*;
12
13fn same_type<A: 'static, B: 'static>() -> bool {
15 TypeId::of::<A>() == TypeId::of::<B>()
16}
17
18#[allow(clippy::too_many_arguments)]
19pub fn gemm_faer_ix2_dispatch<TA, TB, TC>(
20 c: &mut [TC],
21 lc: &Layout<Ix2>,
22 a: &[TA],
23 la: &Layout<Ix2>,
24 b: &[TB],
25 lb: &Layout<Ix2>,
26 alpha: TC,
27 beta: TC,
28 pool: Option<&ThreadPool>,
29) -> Result<()>
30where
31 TA: Clone + Send + Sync + 'static,
32 TB: Clone + Send + Sync + 'static,
33 TC: Clone + Send + Sync + 'static,
34 TA: Mul<TB, Output = TC>,
35 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
36{
37 let able_syrk = beta == TC::zero()
39 && same_type::<TA, TC>()
40 && same_type::<TB, TC>()
41 && unsafe {
42 let a_ptr = a.as_ptr().add(la.offset()) as *const TC;
43 let b_ptr = b.as_ptr().add(lb.offset()) as *const TC;
44 let equal_ptr = core::ptr::eq(a_ptr, b_ptr);
45 let equal_shape = la.shape() == lb.reverse_axes().shape();
46 let equal_stride = la.stride() == lb.reverse_axes().stride();
47 equal_ptr && equal_shape && equal_stride
48 };
49
50 macro_rules! impl_gemm_dispatch {
52 ($ty: ty) => {
53 if (same_type::<TA, $ty>() && same_type::<TB, $ty>() && same_type::<TC, $ty>()) {
54 let a_slice = unsafe { from_raw_parts(a.as_ptr() as *const $ty, a.len()) };
55 let b_slice = unsafe { from_raw_parts(b.as_ptr() as *const $ty, b.len()) };
56 let c_slice = unsafe { from_raw_parts_mut(c.as_mut_ptr() as *mut $ty, c.len()) };
57 let alpha = unsafe { *(&alpha as *const TC as *const $ty) };
58 let beta = unsafe { *(&beta as *const TC as *const $ty) };
59 if able_syrk {
60 gemm_with_syrk_faer(c_slice, lc, a_slice, la, alpha, beta, pool)?;
61 } else {
62 gemm_faer(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool)?;
63 }
64 return Ok(());
65 }
66 };
67 }
68
69 impl_gemm_dispatch!(f32);
70 impl_gemm_dispatch!(f64);
71 impl_gemm_dispatch!(Complex<f32>);
72 impl_gemm_dispatch!(Complex<f64>);
73
74 let c_slice = c;
77 let a_slice = a;
78 let b_slice = b;
79 return gemm_ix2_naive_cpu_rayon(c_slice, lc, a_slice, la, b_slice, lb, alpha, beta, pool);
80}
81
82#[allow(clippy::too_many_arguments)]
83pub fn matmul_row_major_faer<TA, TB, TC, DA, DB, DC>(
84 c: &mut [TC],
85 lc: &Layout<DC>,
86 a: &[TA],
87 la: &Layout<DA>,
88 b: &[TB],
89 lb: &Layout<DB>,
90 alpha: TC,
91 beta: TC,
92 pool: Option<&ThreadPool>,
93) -> Result<()>
94where
95 TA: Clone + Send + Sync + 'static,
96 TB: Clone + Send + Sync + 'static,
97 TC: Clone + Send + Sync + 'static,
98 DA: DimAPI,
99 DB: DimAPI,
100 DC: DimAPI,
101 TA: Mul<TB, Output = TC>,
102 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
103{
104 let nthreads = match pool {
109 Some(pool) => pool.current_num_threads(),
110 None => 1,
111 };
112
113 match (la.ndim(), lb.ndim(), lc.ndim()) {
115 (1, 1, 0) => {
116 let la = &la.clone().into_dim::<Ix1>().unwrap();
118 let lb = &lb.clone().into_dim::<Ix1>().unwrap();
119 let lc = &lc.clone().into_dim::<Ix0>().unwrap();
120 let c_num = &mut c[lc.offset()];
121 return inner_dot_naive_cpu_rayon(c_num, a, la, b, lb, alpha, beta, pool);
122 },
123 (2, 2, 2) => {
124 let la = &la.clone().into_dim::<Ix2>().unwrap();
126 let lb = &lb.clone().into_dim::<Ix2>().unwrap();
127 let lc = &lc.clone().into_dim::<Ix2>().unwrap();
128 return gemm_faer_ix2_dispatch(c, lc, a, la, b, lb, alpha, beta, pool);
129 },
130 _ => (),
131 }
132
133 let la_matmul;
136 let lb_matmul;
137 let lc_matmul;
138 let la_rest;
139 let lb_rest;
140 let lc_rest;
141
142 match (la.ndim(), lb.ndim(), lc.ndim()) {
143 (1, 1, 0) | (2, 2, 2) => unreachable!(),
145 (1, 2.., _) => {
146 rstsr_assert_eq!(lb.ndim(), lc.ndim() + 1, InvalidLayout)?;
148 let (la_r, la_m) = la.dim_split_at(-1)?;
149 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
150 let (lc_r, lc_m) = lc.dim_split_at(-1)?;
151 la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
152 lb_rest = lb_r;
153 lc_rest = lc_r;
154 la_matmul = la_m.dim_insert(0)?.into_dim::<Ix2>()?;
155 lb_matmul = lb_m.into_dim::<Ix2>()?;
156 lc_matmul = lc_m.dim_insert(0)?.into_dim::<Ix2>()?;
157 },
158 (2.., 1, _) => {
159 rstsr_assert_eq!(la.ndim(), lc.ndim() + 1, InvalidLayout)?;
161 let (la_r, la_m) = la.dim_split_at(-2)?;
162 let (lb_r, lb_m) = lb.dim_split_at(-1)?;
163 let (lc_r, lc_m) = lc.dim_split_at(-1)?;
164 la_rest = la_r;
165 lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
166 lc_rest = lc_r;
167 la_matmul = la_m.into_dim::<Ix2>()?;
168 lb_matmul = lb_m.dim_insert(1)?.into_dim::<Ix2>()?;
169 lc_matmul = lc_m.dim_insert(1)?.into_dim::<Ix2>()?;
170 },
171 (2, 3.., _) => {
172 rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
174 let (la_r, la_m) = la.dim_split_at(-2)?;
175 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
176 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
177 la_rest = broadcast_layout_to_first(&lc_r, &la_r, RowMajor)?.1;
178 lb_rest = lb_r;
179 lc_rest = lc_r;
180 la_matmul = la_m.into_dim::<Ix2>()?;
181 lb_matmul = lb_m.into_dim::<Ix2>()?;
182 lc_matmul = lc_m.into_dim::<Ix2>()?;
183 },
184 (3.., 2, _) => {
185 rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
187 let (la_r, la_m) = la.dim_split_at(-2)?;
188 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
189 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
190 la_rest = la_r;
191 lb_rest = broadcast_layout_to_first(&lc_r, &lb_r, RowMajor)?.1;
192 lc_rest = lc_r;
193 la_matmul = la_m.into_dim::<Ix2>()?;
194 lb_matmul = lb_m.into_dim::<Ix2>()?;
195 lc_matmul = lc_m.into_dim::<Ix2>()?;
196 },
197 (3.., 3.., _) => {
198 rstsr_assert_eq!(la.ndim(), lc.ndim(), InvalidLayout)?;
200 rstsr_assert_eq!(lb.ndim(), lc.ndim(), InvalidLayout)?;
201 let (la_r, la_m) = la.dim_split_at(-2)?;
202 let (lb_r, lb_m) = lb.dim_split_at(-2)?;
203 let (lc_r, lc_m) = lc.dim_split_at(-2)?;
204 la_rest = la_r;
205 lb_rest = lb_r;
206 lc_rest = lc_r;
207 la_matmul = la_m.into_dim::<Ix2>()?;
208 lb_matmul = lb_m.into_dim::<Ix2>()?;
209 lc_matmul = lc_m.into_dim::<Ix2>()?;
210 },
211 _ => todo!(),
212 }
213 rstsr_assert_eq!(la_rest.shape(), lb_rest.shape(), InvalidLayout)?;
218 rstsr_assert_eq!(lb_rest.shape(), lc_rest.shape(), InvalidLayout)?;
219 let n_task = la_rest.size();
220 let ita_rest = IterLayoutColMajor::new(&la_rest)?;
221 let itb_rest = IterLayoutColMajor::new(&lb_rest)?;
222 let itc_rest = IterLayoutColMajor::new(&lc_rest)?;
223 if n_task > 4 * nthreads {
224 let task = || {
226 ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each(
227 |((ia_rest, ib_rest), ic_rest)| -> Result<()> {
228 let mut la_m = la_matmul.clone();
230 let mut lb_m = lb_matmul.clone();
231 let mut lc_m = lc_matmul.clone();
232 unsafe {
233 la_m.set_offset(ia_rest);
234 lb_m.set_offset(ib_rest);
235 lc_m.set_offset(ic_rest);
236 }
237 let c = unsafe {
239 let c_ptr = c.as_ptr() as *mut TC;
240 let c_len = c.len();
241 from_raw_parts_mut(c_ptr, c_len)
242 };
243 let alpha = alpha.clone();
245 let beta = beta.clone();
246 gemm_faer_ix2_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None)
247 },
248 )
249 };
250 match pool {
251 Some(pool) => pool.install(task)?,
252 None => task()?,
253 };
254 } else {
255 for (ia_rest, ib_rest, ic_rest) in izip!(ita_rest, itb_rest, itc_rest) {
257 let mut la_m = la_matmul.clone();
259 let mut lb_m = lb_matmul.clone();
260 let mut lc_m = lc_matmul.clone();
261 unsafe {
262 la_m.set_offset(ia_rest);
263 lb_m.set_offset(ib_rest);
264 lc_m.set_offset(ic_rest);
265 }
266 let alpha = alpha.clone();
268 let beta = beta.clone();
269 gemm_faer_ix2_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, pool)?;
270 }
271 }
272 return Ok(());
273}
274
275#[allow(clippy::too_many_arguments)]
276impl<TA, TB, TC, DA, DB, DC> DeviceMatMulAPI<TA, TB, TC, DA, DB, DC> for DeviceFaer
277where
278 TA: Clone + Send + Sync + 'static,
279 TB: Clone + Send + Sync + 'static,
280 TC: Clone + Send + Sync + 'static,
281 DA: DimAPI,
282 DB: DimAPI,
283 DC: DimAPI,
284 TA: Mul<TB, Output = TC>,
285 TB: Mul<TA, Output = TC>,
286 TC: Mul<TC, Output = TC> + Add<TC, Output = TC> + Zero + PartialEq,
287{
288 fn matmul(
289 &self,
290 c: &mut Vec<TC>,
291 lc: &Layout<DC>,
292 a: &Vec<TA>,
293 la: &Layout<DA>,
294 b: &Vec<TB>,
295 lb: &Layout<DB>,
296 alpha: TC,
297 beta: TC,
298 ) -> Result<()> {
299 let default_order = self.default_order();
300 let pool = self.get_current_pool();
301 match default_order {
302 RowMajor => matmul_row_major_faer(c, lc, a, la, b, lb, alpha, beta, pool),
303 ColMajor => {
304 let la = la.reverse_axes();
305 let lb = lb.reverse_axes();
306 let lc = lc.reverse_axes();
307 matmul_row_major_faer(c, &lc, b, &lb, a, &la, alpha, beta, pool)
308 },
309 }
310 }
311}
312
313#[cfg(test)]
314mod test {
315 use super::*;
316
317 #[test]
318 fn test_matmul() {
319 let device = DeviceFaer::default();
320 let a = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([3, 5]);
321 let b = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([5, 3]);
322
323 let d = &a % &b;
324 println!("{d}");
325
326 let a = linspace((0.0, 14.0, 15, &device));
327 let b = linspace((0.0, 14.0, 15, &device));
328 println!("{:}", &a % &b);
329
330 #[cfg(not(feature = "col_major"))]
331 {
332 let a = linspace((0.0, 2.0, 3, &device));
333 let b = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
334 println!("{:}", &a % &b);
335
336 let a = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
337 let b = linspace((0.0, 4.0, 5, &device));
338 println!("{:}", &a % &b);
339
340 let a = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([5, 3]);
341 let b = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
342 println!("{:}", &a % &b);
343
344 let a = linspace((0.0, 29.0, 30, &device)).into_shape_assume_contig([2, 3, 5]);
345 let b = linspace((0.0, 14.0, 15, &device)).into_shape_assume_contig([5, 3]);
346 println!("{:}", &a % &b);
347 }
348 }
349}