qudit_core/accel/matmul.rs
1//! Functions and structs for efficient generalized matrix multiplication (GEMM).
2
3use coe::is_same;
4use nano_gemm::Plan;
5use num_traits::One;
6use num_traits::Zero;
7
8use crate::ComplexScalar;
9use crate::c32;
10use crate::c64;
11use faer::MatMut;
12use faer::MatRef;
13
14/// Stores a plan for a generalized matrix multiplication (GEMM). Based on the dimensions and underlying
15/// field of the matrices, the plan will select the appropriate mili/micro-kernels for performance.
16pub struct MatMulPlan<C: ComplexScalar> {
17 m: usize,
18 n: usize,
19 k: usize,
20 plan: Plan<C>,
21}
22
23impl<C: ComplexScalar> MatMulPlan<C> {
24 /// Creates a new GEMM plan for column-major matrices.
25 ///
26 /// # Arguments
27 ///
28 /// * `m`: Number of rows in the left-hand side matrix.
29 /// * `n`: Number of columns in the right-hand side matrix.
30 /// * `k`: Number of columns in the left-hand side matrix.
31 /// This should equal the number of rows in the right-hand side matrix.
32 ///
33 /// # Returns
34 ///
35 /// * A `MatMulPlan` instance.
36 ///
37 pub fn new(m: usize, n: usize, k: usize) -> Self {
38 if is_same::<C, c32>() {
39 let plan = Plan::new_colmajor_lhs_and_dst_c32(m, n, k);
40 // Safety: This is safe because C is c32.
41 Self {
42 m,
43 n,
44 k,
45 plan: unsafe { std::mem::transmute::<Plan<c32>, Plan<C>>(plan) },
46 }
47 } else {
48 let plan = Plan::new_colmajor_lhs_and_dst_c64(m, n, k);
49 Self {
50 m,
51 n,
52 k,
53 plan: unsafe { std::mem::transmute::<Plan<c64>, Plan<C>>(plan) },
54 }
55 }
56 }
57
58 /// Executes the milikernel of the plan, for matrix multiplication. (`alpha = 0`, `beta = 1`)
59 /// We do not perform comprehensive checks.
60 ///
61 /// # Arguments
62 ///
63 /// * `lhs`: The left-hand side matrix to multiply.
64 /// * `rhs`: The right-hand side matrix to multiply.
65 /// * `out`: The output matrix where the result will be stored.
66 ///
67 /// # Safety
68 ///
69 /// * The matrices must be column-major.
70 /// * The dimensions of `out` must be `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
71 ///
72 /// # Examples
73 /// ```
74 /// use qudit_core::accel::MatMulPlan;
75 /// use faer::{mat, Mat};
76 /// use qudit_core::c64;
77 ///
78 /// let mut out = Mat::<c64>::zeros(2, 2);
79 ///
80 /// let lhs = mat![
81 /// [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
82 /// [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
83 /// ];
84 /// let rhs = mat![
85 /// [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
86 /// [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
87 /// ];
88 ///
89 /// let test_plan = MatMulPlan::new(lhs.nrows(), rhs.ncols(), lhs.ncols());
90 /// test_plan.execute_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
91 ///
92 /// let expected = mat![
93 /// [c64::new(19.0, 0.0), c64::new(22.0, 0.0)],
94 /// [c64::new(43.0, 0.0), c64::new(50.0, 0.0)]
95 /// ];
96 ///
97 /// assert_eq!(expected, out);
98 /// ```
99 ///
100 #[inline(always)]
101 pub fn execute_unchecked(&self, lhs: MatRef<C>, rhs: MatRef<C>, out: MatMut<C>) {
102 let m = lhs.nrows();
103 let n = rhs.ncols();
104 let k = lhs.ncols();
105 let out_col_stride = out.col_stride();
106
107 unsafe {
108 self.plan.execute_unchecked(
109 m,
110 n,
111 k,
112 out.as_ptr_mut() as _,
113 1,
114 out_col_stride,
115 lhs.as_ptr() as _,
116 1,
117 lhs.col_stride(),
118 rhs.as_ptr() as _,
119 1,
120 rhs.col_stride(),
121 C::zero(),
122 C::one(),
123 false,
124 false,
125 );
126 }
127 }
128
129 #[inline(always)]
130 #[allow(clippy::too_many_arguments)]
131 /// Perform the matrix multiplication given by the plan without checking bounds.
132 ///
133 /// # Safety
134 ///
135 /// The multiplication defined here must be valid. The pointers must point
136 /// to adequately sized and proper buffers of memory that describe matrices
137 /// with the dimensions and strides given.
138 pub unsafe fn execute_raw_unchecked(
139 &self,
140 lhs: *const C,
141 rhs: *const C,
142 out: *mut C,
143 dst_rs: isize,
144 dst_cs: isize,
145 lhs_rs: isize,
146 lhs_cs: isize,
147 rhs_rs: isize,
148 rhs_cs: isize,
149 ) {
150 unsafe {
151 self.plan.execute_unchecked(
152 self.m,
153 self.n,
154 self.k,
155 out,
156 dst_rs,
157 dst_cs,
158 lhs,
159 lhs_rs,
160 lhs_cs,
161 rhs,
162 rhs_rs,
163 rhs_cs,
164 C::zero(),
165 C::one(),
166 false,
167 false,
168 );
169 }
170 }
171
172 /// Executes the milikernel of the plan, for matrix multiplication followed by addition.
173 /// (`alpha = 1`, `beta = 1`) We do not perform comprehensive checks.
174 ///
175 /// # Arguments
176 ///
177 /// * `lhs`: The left-hand side matrix to add.
178 /// * `rhs`: The right-hand side matrix to add.
179 /// * `out`: The output matrix where the result will be stored.
180 ///
181 /// # Safety
182 ///
183 /// * The matrices must be column-major.
184 /// * The dimensions of `out` must be `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
185 ///
186 /// # Examples
187 /// ```
188 /// use qudit_core::accel::MatMulPlan;
189 /// use faer::{mat, Mat};
190 /// use qudit_core::c64;
191 ///
192 /// let mut out = Mat::<c64>::ones(2, 2);
193 ///
194 /// let lhs = mat![
195 /// [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
196 /// [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
197 /// ];
198 /// let rhs = mat![
199 /// [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
200 /// [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
201 /// ];
202 ///
203 /// let test_plan = MatMulPlan::new(lhs.nrows(), rhs.ncols(), lhs.ncols());
204 /// test_plan.execute_add_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
205 ///
206 /// let expected = mat![
207 /// [c64::new(20.0, 0.0), c64::new(23.0, 0.0)],
208 /// [c64::new(44.0, 0.0), c64::new(51.0, 0.0)]
209 /// ];
210 ///
211 /// assert_eq!(expected, out);
212 /// ```
213 ///
214 pub fn execute_add_unchecked(&self, lhs: MatRef<C>, rhs: MatRef<C>, out: MatMut<C>) {
215 let m = lhs.nrows();
216 let n = rhs.ncols();
217 let k = lhs.ncols();
218 let out_col_stride = out.col_stride();
219
220 unsafe {
221 self.plan.execute_unchecked(
222 m,
223 n,
224 k,
225 out.as_ptr_mut() as _,
226 1,
227 out_col_stride,
228 lhs.as_ptr() as _,
229 1,
230 lhs.col_stride(),
231 rhs.as_ptr() as _,
232 1,
233 rhs.col_stride(),
234 C::one(),
235 C::one(),
236 false,
237 false,
238 );
239 }
240 }
241
242 #[inline(always)]
243 #[allow(clippy::too_many_arguments)]
244 /// Perform the additive matrix multiplication given by the plan without checking bounds.
245 ///
246 /// # Safety
247 ///
248 /// The multiplication defined here must be valid. The pointers must point
249 /// to adequately sized and proper buffers of memory that describe matrices
250 /// with the dimensions and strides given.
251 pub unsafe fn execute_add_raw_unchecked(
252 &self,
253 lhs: *const C,
254 rhs: *const C,
255 out: *mut C,
256 dst_rs: isize,
257 dst_cs: isize,
258 lhs_rs: isize,
259 lhs_cs: isize,
260 rhs_rs: isize,
261 rhs_cs: isize,
262 ) {
263 unsafe {
264 self.plan.execute_unchecked(
265 self.m,
266 self.n,
267 self.k,
268 out,
269 dst_rs,
270 dst_cs,
271 lhs,
272 lhs_rs,
273 lhs_cs,
274 rhs,
275 rhs_rs,
276 rhs_cs,
277 C::one(),
278 C::one(),
279 false,
280 false,
281 );
282 }
283 }
284}
285
286/// Performs matrix-matrix multiplication. (`alpha = 0`, `beta = 1`)
287///
288/// # Arguments
289///
290/// * `lhs`: The left-hand side matrix to multiply.
291/// * `rhs`: The right-hand side matrix to multiply.
292/// * `out`: The output matrix where the result will be stored.
293///
294/// # Safety
295///
296/// * The matrices must be column-major.
297/// * The dimensions of `out` must be `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
298///
299/// # Examples
300/// ```
301/// use qudit_core::accel::matmul_unchecked;
302/// use faer::{mat, Mat};
303/// use qudit_core::c64;
304///
305/// let mut out = Mat::<c64>::zeros(2, 2);
306///
307/// let lhs = mat![
308/// [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
309/// [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
310/// ];
311/// let rhs = mat![
312/// [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
313/// [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
314/// ];
315///
316/// matmul_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
317///
318/// let expected = mat![
319/// [c64::new(19.0, 0.0), c64::new(22.0, 0.0)],
320/// [c64::new(43.0, 0.0), c64::new(50.0, 0.0)]
321/// ];
322///
323/// assert_eq!(expected, out);
324/// ```
325///
326#[inline(always)]
327pub fn matmul_unchecked<C: ComplexScalar>(lhs: MatRef<C>, rhs: MatRef<C>, out: MatMut<C>) {
328 let m = lhs.nrows();
329 let n = rhs.ncols();
330 let k = lhs.ncols();
331
332 // After the runtime check of C, we explicitly transmute our inputs.
333 // This allows type-specific optimizations.
334 if is_same::<C, c32>() {
335 let plan = Plan::new_colmajor_lhs_and_dst_c32(m, n, k);
336 let out: MatMut<c32> = unsafe { std::mem::transmute(out) };
337 let rhs: MatRef<c32> = unsafe { std::mem::transmute(rhs) };
338 let lhs: MatRef<c32> = unsafe { std::mem::transmute(lhs) };
339 let out_col_stride = out.col_stride();
340
341 unsafe {
342 plan.execute_unchecked(
343 m,
344 n,
345 k,
346 out.as_ptr_mut() as _,
347 1,
348 out_col_stride,
349 lhs.as_ptr() as _,
350 1,
351 lhs.col_stride(),
352 rhs.as_ptr() as _,
353 1,
354 rhs.col_stride(),
355 c32::zero(),
356 c32::one(), // TODO: Figure if I can create custom kernels for one/zero alpha/beta
357 false,
358 false,
359 );
360 }
361 } else {
362 let plan = Plan::new_colmajor_lhs_and_dst_c64(m, n, k);
363 let out: MatMut<c64> = unsafe { std::mem::transmute(out) };
364 let rhs: MatRef<c64> = unsafe { std::mem::transmute(rhs) };
365 let lhs: MatRef<c64> = unsafe { std::mem::transmute(lhs) };
366 let out_col_stride = out.col_stride();
367
368 unsafe {
369 plan.execute_unchecked(
370 m,
371 n,
372 k,
373 out.as_ptr_mut() as _,
374 1,
375 out_col_stride,
376 lhs.as_ptr() as _,
377 1,
378 lhs.col_stride(),
379 rhs.as_ptr() as _,
380 1,
381 rhs.col_stride(),
382 c64::zero(),
383 c64::one(), // TODO: Figure if I can create custom kernels for one/zero alpha/beta
384 false,
385 false,
386 );
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::{c32, c64};
395 use faer::Mat;
396 use faer::mat;
397 use num_traits::Zero;
398
399 #[test]
400 fn test_matmul_unchecked() {
401 let m = 2;
402 let n = 2;
403 let k = 2;
404
405 let mut lhs = Mat::<c32>::zeros(m, k);
406 let mut rhs = Mat::<c32>::zeros(k, n);
407 let mut out = Mat::<c32>::zeros(m, n);
408
409 for i in 0..m {
410 for j in 0..k {
411 lhs[(i, j)] = c32::new((i + j) as f32, (i + j) as f32);
412 }
413 }
414
415 for i in 0..k {
416 for j in 0..n {
417 rhs[(i, j)] = c32::new((i + j) as f32, (i + j) as f32);
418 }
419 }
420
421 matmul_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
422
423 for i in 0..m {
424 for j in 0..n {
425 let mut sum = c32::zero();
426 for l in 0..k {
427 sum += lhs[(i, l)] * rhs[(l, j)];
428 }
429 assert_eq!(out[(i, j)], sum);
430 }
431 }
432 }
433
434 #[test]
435 fn matmul_unchecked2() {
436 let mut out = Mat::<c64>::zeros(2, 2);
437
438 let lhs = mat![
439 [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
440 [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
441 ];
442 let rhs = mat![
443 [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
444 [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
445 ];
446
447 matmul_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
448
449 let expected = mat![
450 [c64::new(19.0, 0.0), c64::new(22.0, 0.0)],
451 [c64::new(43.0, 0.0), c64::new(50.0, 0.0)]
452 ];
453
454 assert_eq!(out, expected);
455 }
456}