qudit_core/accel/
matmul.rs

1//! Functions and structs for efficient generalized matrix multiplication (GEMM).
2
3use coe::is_same;
4use nano_gemm::Plan;
5use num_traits::One;
6use num_traits::Zero;
7
8use crate::ComplexScalar;
9use crate::c32;
10use crate::c64;
11use faer::MatMut;
12use faer::MatRef;
13
14/// Stores a plan for a generalized matrix multiplication (GEMM). Based on the dimensions and underlying
15/// field of the matrices, the plan will select the appropriate mili/micro-kernels for performance.
16pub struct MatMulPlan<C: ComplexScalar> {
17    m: usize,
18    n: usize,
19    k: usize,
20    plan: Plan<C>,
21}
22
23impl<C: ComplexScalar> MatMulPlan<C> {
24    /// Creates a new GEMM plan for column-major matrices.
25    ///
26    /// # Arguments
27    ///
28    /// * `m`: Number of rows in the left-hand side matrix.
29    /// * `n`: Number of columns in the right-hand side matrix.
30    /// * `k`: Number of columns in the left-hand side matrix.
31    ///   This should equal the number of rows in the right-hand side matrix.
32    ///
33    /// # Returns
34    ///
35    /// * A `MatMulPlan` instance.
36    ///
37    pub fn new(m: usize, n: usize, k: usize) -> Self {
38        if is_same::<C, c32>() {
39            let plan = Plan::new_colmajor_lhs_and_dst_c32(m, n, k);
40            // Safety: This is safe because C is c32.
41            Self {
42                m,
43                n,
44                k,
45                plan: unsafe { std::mem::transmute::<Plan<c32>, Plan<C>>(plan) },
46            }
47        } else {
48            let plan = Plan::new_colmajor_lhs_and_dst_c64(m, n, k);
49            Self {
50                m,
51                n,
52                k,
53                plan: unsafe { std::mem::transmute::<Plan<c64>, Plan<C>>(plan) },
54            }
55        }
56    }
57
58    /// Executes the milikernel of the plan, for matrix multiplication. (`alpha = 0`, `beta = 1`)
59    /// We do not perform comprehensive checks.
60    ///
61    /// # Arguments
62    ///
63    /// * `lhs`: The left-hand side matrix to multiply.
64    /// * `rhs`: The right-hand side matrix to multiply.
65    /// * `out`: The output matrix where the result will be stored.
66    ///
67    /// # Safety
68    ///
69    /// * The matrices must be column-major.
70    /// * The dimensions of `out` must be `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
71    ///
72    /// # Examples
73    /// ```
74    /// use qudit_core::accel::MatMulPlan;
75    /// use faer::{mat, Mat};
76    /// use qudit_core::c64;
77    ///
78    /// let mut out = Mat::<c64>::zeros(2, 2);
79    ///
80    /// let lhs = mat![
81    ///     [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
82    ///     [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
83    /// ];
84    /// let rhs = mat![
85    ///     [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
86    ///     [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
87    /// ];
88    ///
89    /// let test_plan = MatMulPlan::new(lhs.nrows(), rhs.ncols(), lhs.ncols());
90    /// test_plan.execute_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
91    ///
92    /// let expected = mat![
93    ///     [c64::new(19.0, 0.0), c64::new(22.0, 0.0)],
94    ///     [c64::new(43.0, 0.0), c64::new(50.0, 0.0)]
95    /// ];
96    ///
97    /// assert_eq!(expected, out);
98    /// ```
99    ///
100    #[inline(always)]
101    pub fn execute_unchecked(&self, lhs: MatRef<C>, rhs: MatRef<C>, out: MatMut<C>) {
102        let m = lhs.nrows();
103        let n = rhs.ncols();
104        let k = lhs.ncols();
105        let out_col_stride = out.col_stride();
106
107        unsafe {
108            self.plan.execute_unchecked(
109                m,
110                n,
111                k,
112                out.as_ptr_mut() as _,
113                1,
114                out_col_stride,
115                lhs.as_ptr() as _,
116                1,
117                lhs.col_stride(),
118                rhs.as_ptr() as _,
119                1,
120                rhs.col_stride(),
121                C::zero(),
122                C::one(),
123                false,
124                false,
125            );
126        }
127    }
128
129    #[inline(always)]
130    #[allow(clippy::too_many_arguments)]
131    /// Perform the matrix multiplication given by the plan without checking bounds.
132    ///
133    /// # Safety
134    ///
135    /// The multiplication defined here must be valid. The pointers must point
136    /// to adequately sized and proper buffers of memory that describe matrices
137    /// with the dimensions and strides given.
138    pub unsafe fn execute_raw_unchecked(
139        &self,
140        lhs: *const C,
141        rhs: *const C,
142        out: *mut C,
143        dst_rs: isize,
144        dst_cs: isize,
145        lhs_rs: isize,
146        lhs_cs: isize,
147        rhs_rs: isize,
148        rhs_cs: isize,
149    ) {
150        unsafe {
151            self.plan.execute_unchecked(
152                self.m,
153                self.n,
154                self.k,
155                out,
156                dst_rs,
157                dst_cs,
158                lhs,
159                lhs_rs,
160                lhs_cs,
161                rhs,
162                rhs_rs,
163                rhs_cs,
164                C::zero(),
165                C::one(),
166                false,
167                false,
168            );
169        }
170    }
171
172    /// Executes the milikernel of the plan, for matrix multiplication followed by addition.
173    /// (`alpha = 1`, `beta = 1`) We do not perform comprehensive checks.
174    ///
175    /// # Arguments
176    ///
177    /// * `lhs`: The left-hand side matrix to add.
178    /// * `rhs`: The right-hand side matrix to add.
179    /// * `out`: The output matrix where the result will be stored.
180    ///
181    /// # Safety
182    ///
183    /// * The matrices must be column-major.
184    /// * The dimensions of `out` must be `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
185    ///
186    /// # Examples
187    /// ```
188    /// use qudit_core::accel::MatMulPlan;
189    /// use faer::{mat, Mat};
190    /// use qudit_core::c64;
191    ///
192    /// let mut out = Mat::<c64>::ones(2, 2);
193    ///
194    /// let lhs = mat![
195    ///     [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
196    ///     [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
197    /// ];
198    /// let rhs = mat![
199    ///     [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
200    ///     [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
201    /// ];
202    ///
203    /// let test_plan = MatMulPlan::new(lhs.nrows(), rhs.ncols(), lhs.ncols());
204    /// test_plan.execute_add_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
205    ///
206    /// let expected = mat![
207    ///     [c64::new(20.0, 0.0), c64::new(23.0, 0.0)],
208    ///     [c64::new(44.0, 0.0), c64::new(51.0, 0.0)]
209    /// ];
210    ///
211    /// assert_eq!(expected, out);
212    /// ```
213    ///
214    pub fn execute_add_unchecked(&self, lhs: MatRef<C>, rhs: MatRef<C>, out: MatMut<C>) {
215        let m = lhs.nrows();
216        let n = rhs.ncols();
217        let k = lhs.ncols();
218        let out_col_stride = out.col_stride();
219
220        unsafe {
221            self.plan.execute_unchecked(
222                m,
223                n,
224                k,
225                out.as_ptr_mut() as _,
226                1,
227                out_col_stride,
228                lhs.as_ptr() as _,
229                1,
230                lhs.col_stride(),
231                rhs.as_ptr() as _,
232                1,
233                rhs.col_stride(),
234                C::one(),
235                C::one(),
236                false,
237                false,
238            );
239        }
240    }
241
242    #[inline(always)]
243    #[allow(clippy::too_many_arguments)]
244    /// Perform the additive matrix multiplication given by the plan without checking bounds.
245    ///
246    /// # Safety
247    ///
248    /// The multiplication defined here must be valid. The pointers must point
249    /// to adequately sized and proper buffers of memory that describe matrices
250    /// with the dimensions and strides given.
251    pub unsafe fn execute_add_raw_unchecked(
252        &self,
253        lhs: *const C,
254        rhs: *const C,
255        out: *mut C,
256        dst_rs: isize,
257        dst_cs: isize,
258        lhs_rs: isize,
259        lhs_cs: isize,
260        rhs_rs: isize,
261        rhs_cs: isize,
262    ) {
263        unsafe {
264            self.plan.execute_unchecked(
265                self.m,
266                self.n,
267                self.k,
268                out,
269                dst_rs,
270                dst_cs,
271                lhs,
272                lhs_rs,
273                lhs_cs,
274                rhs,
275                rhs_rs,
276                rhs_cs,
277                C::one(),
278                C::one(),
279                false,
280                false,
281            );
282        }
283    }
284}
285
286/// Performs matrix-matrix multiplication. (`alpha = 0`, `beta = 1`)
287///
288/// # Arguments
289///
290/// * `lhs`: The left-hand side matrix to multiply.
291/// * `rhs`: The right-hand side matrix to multiply.
292/// * `out`: The output matrix where the result will be stored.
293///
294/// # Safety
295///
296/// * The matrices must be column-major.
297/// * The dimensions of `out` must be `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
298///
299/// # Examples
300/// ```
301/// use qudit_core::accel::matmul_unchecked;
302/// use faer::{mat, Mat};
303/// use qudit_core::c64;
304///
305/// let mut out = Mat::<c64>::zeros(2, 2);
306///
307/// let lhs = mat![
308///     [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
309///     [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
310/// ];
311/// let rhs = mat![
312///     [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
313///     [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
314/// ];
315///
316/// matmul_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
317///
318/// let expected = mat![
319///     [c64::new(19.0, 0.0), c64::new(22.0, 0.0)],
320///     [c64::new(43.0, 0.0), c64::new(50.0, 0.0)]
321/// ];
322///
323/// assert_eq!(expected, out);
324/// ```
325///
326#[inline(always)]
327pub fn matmul_unchecked<C: ComplexScalar>(lhs: MatRef<C>, rhs: MatRef<C>, out: MatMut<C>) {
328    let m = lhs.nrows();
329    let n = rhs.ncols();
330    let k = lhs.ncols();
331
332    // After the runtime check of C, we explicitly transmute our inputs.
333    // This allows type-specific optimizations.
334    if is_same::<C, c32>() {
335        let plan = Plan::new_colmajor_lhs_and_dst_c32(m, n, k);
336        let out: MatMut<c32> = unsafe { std::mem::transmute(out) };
337        let rhs: MatRef<c32> = unsafe { std::mem::transmute(rhs) };
338        let lhs: MatRef<c32> = unsafe { std::mem::transmute(lhs) };
339        let out_col_stride = out.col_stride();
340
341        unsafe {
342            plan.execute_unchecked(
343                m,
344                n,
345                k,
346                out.as_ptr_mut() as _,
347                1,
348                out_col_stride,
349                lhs.as_ptr() as _,
350                1,
351                lhs.col_stride(),
352                rhs.as_ptr() as _,
353                1,
354                rhs.col_stride(),
355                c32::zero(),
356                c32::one(), // TODO: Figure if I can create custom kernels for one/zero alpha/beta
357                false,
358                false,
359            );
360        }
361    } else {
362        let plan = Plan::new_colmajor_lhs_and_dst_c64(m, n, k);
363        let out: MatMut<c64> = unsafe { std::mem::transmute(out) };
364        let rhs: MatRef<c64> = unsafe { std::mem::transmute(rhs) };
365        let lhs: MatRef<c64> = unsafe { std::mem::transmute(lhs) };
366        let out_col_stride = out.col_stride();
367
368        unsafe {
369            plan.execute_unchecked(
370                m,
371                n,
372                k,
373                out.as_ptr_mut() as _,
374                1,
375                out_col_stride,
376                lhs.as_ptr() as _,
377                1,
378                lhs.col_stride(),
379                rhs.as_ptr() as _,
380                1,
381                rhs.col_stride(),
382                c64::zero(),
383                c64::one(), // TODO: Figure if I can create custom kernels for one/zero alpha/beta
384                false,
385                false,
386            );
387        }
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::{c32, c64};
395    use faer::Mat;
396    use faer::mat;
397    use num_traits::Zero;
398
399    #[test]
400    fn test_matmul_unchecked() {
401        let m = 2;
402        let n = 2;
403        let k = 2;
404
405        let mut lhs = Mat::<c32>::zeros(m, k);
406        let mut rhs = Mat::<c32>::zeros(k, n);
407        let mut out = Mat::<c32>::zeros(m, n);
408
409        for i in 0..m {
410            for j in 0..k {
411                lhs[(i, j)] = c32::new((i + j) as f32, (i + j) as f32);
412            }
413        }
414
415        for i in 0..k {
416            for j in 0..n {
417                rhs[(i, j)] = c32::new((i + j) as f32, (i + j) as f32);
418            }
419        }
420
421        matmul_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
422
423        for i in 0..m {
424            for j in 0..n {
425                let mut sum = c32::zero();
426                for l in 0..k {
427                    sum += lhs[(i, l)] * rhs[(l, j)];
428                }
429                assert_eq!(out[(i, j)], sum);
430            }
431        }
432    }
433
434    #[test]
435    fn matmul_unchecked2() {
436        let mut out = Mat::<c64>::zeros(2, 2);
437
438        let lhs = mat![
439            [c64::new(1.0, 0.0), c64::new(2.0, 0.0)],
440            [c64::new(3.0, 0.0), c64::new(4.0, 0.0)]
441        ];
442        let rhs = mat![
443            [c64::new(5.0, 0.0), c64::new(6.0, 0.0)],
444            [c64::new(7.0, 0.0), c64::new(8.0, 0.0)]
445        ];
446
447        matmul_unchecked(lhs.as_ref(), rhs.as_ref(), out.as_mut());
448
449        let expected = mat![
450            [c64::new(19.0, 0.0), c64::new(22.0, 0.0)],
451            [c64::new(43.0, 0.0), c64::new(50.0, 0.0)]
452        ];
453
454        assert_eq!(out, expected);
455    }
456}