qudit_core/accel/
kron.rs

1//! Functions to efficiently perform the Kronecker product and fused Kronecker-add operations.
2
3use std::ops::AddAssign;
4use std::ops::Mul;
5
6use faer::reborrow::ReborrowMut;
7use faer_traits::ComplexField;
8
9use super::cartesian_match;
10use crate::ComplexScalar;
11use faer::MatMut;
12use faer::MatRef;
13
14// TODO: Add proper documentation to raw methods and add higher level
15// functions that call them with the cartesian_match for loop unrolling.
16/// Perform a kroneckor product between two matrix buffers.
17///
18/// # Safety
19///
20/// Caller must ensure that pointers point to properly addressable memory
21/// that describe matrices with dimensions and strides given.
22pub unsafe fn kron_kernel_raw<C: Mul<Output = C> + Copy>(
23    dst: *mut C,
24    dst_rs: isize,
25    dst_cs: isize,
26    lhs: *const C,
27    lhs_nrows: usize,
28    lhs_ncols: usize,
29    lhs_rs: isize,
30    lhs_cs: isize,
31    rhs: *const C,
32    rhs_nrows: usize,
33    rhs_ncols: usize,
34    rhs_rs: isize,
35    rhs_cs: isize,
36) {
37    unsafe {
38        for lhs_j in 0..lhs_ncols {
39            for lhs_i in 0..lhs_nrows {
40                let lhs_val = *lhs.offset(lhs_i as isize * lhs_rs + lhs_j as isize * lhs_cs);
41
42                let dst_major_row = lhs_i * rhs_nrows;
43                let dst_major_col = lhs_j * rhs_ncols;
44
45                for rhs_j in 0..rhs_ncols {
46                    for rhs_i in 0..rhs_nrows {
47                        let rhs_val =
48                            *rhs.offset(rhs_i as isize * rhs_rs + rhs_j as isize * rhs_cs);
49
50                        let dst_row = dst_major_row + rhs_i;
51                        let dst_col = dst_major_col + rhs_j;
52
53                        let dst_offset = dst_row as isize * dst_rs + dst_col as isize * dst_cs;
54
55                        *dst.offset(dst_offset) = lhs_val * rhs_val;
56                    }
57                }
58            }
59        }
60    }
61}
62
63/// Perform a kroneckor product between two matrix buffers and add the result to the output.
64///
65/// # Safety
66///
67/// Caller must ensure that pointers point to properly addressable memory
68/// that describe matrices with dimensions and strides given.
69pub unsafe fn kron_kernel_add_raw<C: Mul<Output = C> + Copy + AddAssign>(
70    dst: *mut C,
71    dst_rs: isize,
72    dst_cs: isize,
73    lhs: *const C,
74    lhs_nrows: usize,
75    lhs_ncols: usize,
76    lhs_rs: isize,
77    lhs_cs: isize,
78    rhs: *const C,
79    rhs_nrows: usize,
80    rhs_ncols: usize,
81    rhs_rs: isize,
82    rhs_cs: isize,
83) {
84    unsafe {
85        for lhs_j in 0..lhs_ncols {
86            for lhs_i in 0..lhs_nrows {
87                let lhs_val = *lhs.offset(lhs_i as isize * lhs_rs + lhs_j as isize * lhs_cs);
88
89                let dst_major_row = lhs_i * rhs_nrows;
90                let dst_major_col = lhs_j * rhs_ncols;
91
92                for rhs_j in 0..rhs_ncols {
93                    for rhs_i in 0..rhs_nrows {
94                        let rhs_val =
95                            *rhs.offset(rhs_i as isize * rhs_rs + rhs_j as isize * rhs_cs);
96
97                        let dst_row = dst_major_row + rhs_i;
98                        let dst_col = dst_major_col + rhs_j;
99
100                        let dst_offset = dst_row as isize * dst_rs + dst_col as isize * dst_cs;
101
102                        *dst.offset(dst_offset) += lhs_val * rhs_val;
103                    }
104                }
105            }
106        }
107    }
108}
109
110/// The inner kernel that performs the Kronecker product of two matrices
111/// without checking assumptions.
112///
113/// # Safety
114///
115/// * The dimensions of `dst` must be at least `lhs_rows * rhs_rows` by `lhs_cols * rhs_cols`.
116///
117/// # See also
118///
119/// * [`kron`] for a safe version of this function.
120///
121unsafe fn kron_kernel<C: ComplexField>(
122    mut dst: MatMut<C>,
123    lhs: MatRef<C>,
124    rhs: MatRef<C>,
125    lhs_rows: usize,
126    lhs_cols: usize,
127    rhs_rows: usize,
128    rhs_cols: usize,
129) {
130    unsafe {
131        for lhs_j in 0..lhs_cols {
132            for lhs_i in 0..lhs_rows {
133                let lhs_val = lhs.get_unchecked(lhs_i, lhs_j);
134
135                for rhs_j in 0..rhs_cols {
136                    for rhs_i in 0..rhs_rows {
137                        let rhs_val = rhs.get_unchecked(rhs_i, rhs_j);
138
139                        *(dst.rb_mut().get_mut_unchecked(
140                            lhs_i * rhs_rows + rhs_i,
141                            lhs_j * rhs_cols + rhs_j,
142                        )) = lhs_val.mul_by_ref(rhs_val);
143                    }
144                }
145            }
146        }
147    }
148}
149
150/// Performs the Kronecker product of two matrices and adds this to the destination
151/// without checking assumptions.
152///
153/// More efficient that performing the Kronecker product followed by addition;
154/// we only look up each element of `dst` once, rather than twice.
155///
156/// # Safety
157///
158/// * The dimensions of `dst` must be at least `lhs_rows * rhs_rows` by `lhs_cols * rhs_cols`.
159///
160/// # See also
161///
162/// * [`kron_add`] for a safe version of this function.
163///
164unsafe fn kron_kernel_add<C: ComplexScalar>(
165    mut dst: MatMut<C>,
166    lhs: MatRef<C>,
167    rhs: MatRef<C>,
168    lhs_rows: usize,
169    lhs_cols: usize,
170    rhs_rows: usize,
171    rhs_cols: usize,
172) {
173    unsafe {
174        for lhs_j in 0..lhs_cols {
175            for lhs_i in 0..lhs_rows {
176                let lhs_val = lhs.get_unchecked(lhs_i, lhs_j);
177
178                for rhs_j in 0..rhs_cols {
179                    for rhs_i in 0..rhs_rows {
180                        let rhs_val = rhs.get_unchecked(rhs_i, rhs_j);
181
182                        // Notice that each element of `dst` is only looked up once throughout the loops.
183                        *(dst.rb_mut().get_mut_unchecked(
184                            lhs_i * rhs_rows + rhs_i,
185                            lhs_j * rhs_cols + rhs_j,
186                        )) += lhs_val.mul_by_ref(rhs_val);
187                    }
188                }
189            }
190        }
191    }
192}
193
194/// Performs the Kronecker product of two matrices without checking assumptions.
195///
196/// # Safety
197///
198/// * The dimensions of `dst` must be at least `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
199///
200/// # See also
201///
202/// * [`kron`] for a safe version of this function.
203///
204pub unsafe fn kron_unchecked<C: ComplexField>(dst: MatMut<C>, lhs: MatRef<C>, rhs: MatRef<C>) {
205    unsafe {
206        let lhs_rows = lhs.nrows();
207        let lhs_cols = lhs.ncols();
208        let rhs_rows = rhs.nrows();
209        let rhs_cols = rhs.ncols();
210
211        cartesian_match!(
212            { kron_kernel(dst, lhs, rhs, lhs_rows, lhs_cols, rhs_rows, rhs_cols) },
213            (lhs_rows, (lhs_cols, (rhs_rows, (rhs_cols, ())))),
214            (
215                (2, 3, 4, _),
216                ((2, 3, 4, _), ((2, 3, 4, _), ((2, 3, 4, _), ())))
217            )
218        );
219    }
220}
221
222/// Performs the Kronecker product of two square matrices without checking assumptions.
223///
224/// # Safety
225///
226/// * The dimensions of `dst` must be at least `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
227/// * The matrices must be square.
228///
229/// # See also
230///
231/// * [`kron`] for a safe version of this function.
232///
233pub unsafe fn kron_sq_unchecked<C: ComplexField>(dst: MatMut<C>, lhs: MatRef<C>, rhs: MatRef<C>) {
234    unsafe {
235        let lhs_dim = lhs.nrows();
236        let rhs_dim = rhs.nrows();
237
238        cartesian_match!(
239            { kron_kernel(dst, lhs, rhs, lhs_dim, lhs_dim, rhs_dim, rhs_dim) },
240            (lhs_dim, (rhs_dim, ())),
241            (
242                (2, 3, 4, 6, 8, 9, 16, 27, 32, 64, 81, _),
243                ((2, 3, 4, 6, 8, 9, 16, 27, 32, 64, 81, _), ())
244            )
245        );
246    }
247}
248
249/// Kronecker product of two matrices.
250///
251/// The Kronecker product of two matrices `A` and `B` is a block matrix
252/// `C` with the following structure:
253///
254/// ```text
255/// C = [ a00 * B, a01 * B, ..., a0n * B ]
256///     [ a10 * B, a11 * B, ..., a1n * B ]
257///     [ ...    , ...    , ..., ...     ]
258///     [ am0 * B, am1 * B, ..., amn * B ]
259/// ```
260/// where `a_ij` is the element at position `(i, j)` of `A`.
261///
262/// # Panics
263///
264/// * If `dst` does not have the correct dimensions. The dimensions
265///   of `dst` must be `nrows(A) * nrows(B)` by `ncols(A) * ncols(B)`.
266///
267/// # Example
268/// ```
269/// use faer::mat;
270/// use faer::Mat;
271/// use qudit_core::accel::kron;
272///
273/// let a = mat![
274///     [1.0, 2.0],
275///     [3.0, 4.0],
276/// ];
277/// let b = mat![
278///     [0.0, 5.0],
279///     [6.0, 7.0],
280/// ];
281/// let c = mat![
282///     [0.0 , 5.0 , 0.0 , 10.0],
283///     [6.0 , 7.0 , 12.0, 14.0],
284///     [0.0 , 15.0, 0.0 , 20.0],
285///     [18.0, 21.0, 24.0, 28.0],
286/// ];
287///
288/// let mut dst = Mat::new();
289/// dst.resize_with(4, 4, |_, _| 0f64);
290///
291/// kron(a.as_ref(), b.as_ref(), dst.as_mut());
292///
293/// assert_eq!(dst, c);
294/// ```
295///
296pub fn kron<C: ComplexField>(lhs: MatRef<C>, rhs: MatRef<C>, dst: MatMut<C>) {
297    let mut lhs = lhs;
298    let mut rhs = rhs;
299    let mut dst = dst;
300
301    // Ensures that `dst` is in column-major order.
302    if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
303        dst = dst.transpose_mut();
304        lhs = lhs.transpose();
305        rhs = rhs.transpose();
306    }
307
308    // Checks that the dimensions of `dst` matches the expected dimensions of the Kronecker product of
309    // `lhs` and `rhs`. Also checks that no overflow occurs during the multiplication.
310    assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
311    assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
312
313    // Uses a specialized kernel for square matrices if both `lhs` and `rhs` are square.
314    if lhs.nrows() == lhs.ncols() && rhs.nrows() == rhs.ncols() {
315        // Safety: The dimensions have been checked.
316        unsafe { kron_sq_unchecked(dst, lhs, rhs) }
317    } else {
318        // Safety: The dimensions have been checked.
319        unsafe { kron_unchecked(dst, lhs, rhs) }
320    }
321}
322
323/// Computes the Kronecker product of two matrices and adds the result to a destination matrix.
324///
325/// For `A` ∈ M(R_a, C_a), `B` ∈ M(R_b, C_b), `C` ∈ M(R_a * R_b, C_a * C_b), this function mutates `C`
326/// such C_{i * R_b + k , j * C_b + l} -> C_{i * R_b + k , j * C_b + l} + A_{i, j} * B_{k, l}.
327///
328/// # Arguments
329///
330/// * `lhs` -  The left hand-side matrix for the kronecker product. `A` in the description above.
331/// * `rhs` - The right hand-side matrix for the kronecker product. `B` in the description above.
332/// * `dst` - The matrix to be summed (mutated) by the kronercker product of `lhs` and `rhs`.
333///   `C` in the description above.
334///
335/// # Panics
336///
337/// * If `dst.nrows()` doesn't match `lhs.nrows()` times `rhs.nrows()`
338/// * If `dst.ncols()` doesn't match `lhs.ncols()` times `rhs.ncols()`
339/// * If an overflow occurs when calculating the expected dimensions.
340///
341/// # Example
342/// ```
343/// use faer::{mat, Mat};
344/// use qudit_core::accel::kron_add;
345/// use qudit_core::c64;
346///
347/// let mut dst = Mat::<c64>::zeros(4, 4);
348///
349/// let lhs = Mat::<c64>::from_fn(2, 2, |i, j| -> c64 {c64::new((2*i+1+j) as f64, 0.0)});
350/// let rhs = Mat::<c64>::from_fn(2, 2, |i, j| -> c64 {c64::new((2*i+5+j) as f64, 0.0)});
351///
352/// kron_add(lhs.as_ref(), rhs.as_ref(), dst.as_mut());
353///
354/// let expected_data = [
355///      [c64::new(5.0, 0.0), c64::new(6.0, 0.0), c64::new(10.0, 0.0), c64::new(12.0, 0.0)],
356///      [c64::new(7.0, 0.0), c64::new(8.0, 0.0), c64::new(14.0, 0.0), c64::new(16.0, 0.0)],
357///      [c64::new(15.0, 0.0), c64::new(18.0, 0.0), c64::new(20.0, 0.0), c64::new(24.0, 0.0)],
358///      [c64::new(21.0, 0.0), c64::new(24.0, 0.0), c64::new(28.0, 0.0), c64::new(32.0, 0.0)]
359/// ];
360/// let expected = Mat::from_fn(4, 4, |i, j| -> c64 {expected_data[i][j]});
361/// assert_eq!(dst, expected);
362/// ```
363///
364pub fn kron_add<C: ComplexScalar>(lhs: MatRef<C>, rhs: MatRef<C>, dst: MatMut<C>) {
365    let mut lhs = lhs;
366    let mut rhs = rhs;
367    let mut dst = dst;
368
369    // Makes `dst` is in column-major order. To maintain the same computation, we transpose `lhs` and `rhs` as well.
370    // This is allowed because the transpose is distributive over the Kronecker product and addition. Notice we are
371    // transposing input views, so we need not re-transpose our matrices after mutating the underlying data of dst.
372    if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
373        dst = dst.transpose_mut();
374        lhs = lhs.transpose();
375        rhs = rhs.transpose();
376    }
377    // Makes sure the dimesion of the Kronecker product between lhs, rhs matches that of dst.
378    // Recall (F_r, F_c) = (D_r * E_r, D_c * E_c) where F = D (x) E.
379    // Also makes sure overflows do not occur during the multiplications.
380    assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
381    assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
382
383    // Performs the actual Kronecker product followed by sum.
384    unsafe {
385        kron_kernel_add(
386            dst,
387            lhs,
388            rhs,
389            lhs.nrows(),
390            lhs.ncols(),
391            rhs.nrows(),
392            rhs.ncols(),
393        );
394    }
395}
396
397#[cfg(test)]
398mod kron_tests {
399    use super::*;
400    use faer::Mat;
401    use faer::mat;
402
403    // #[test]
404    // fn kron_add_test() {
405    //     let mut dst = complex_mat!([
406    //         [1.0-8.0j, 2.0+67.0j, 3.0, 4.0],
407    //         [5.0, 6.0, 7.0, 8.0],
408    //         [9.0, 10.0, 11.0, 12.0],
409    //         [13.0, 14.0, 15.0, 16.0]]);
410    //     let lhs= complex_mat!([
411    //         [1.0+9.0j, 2.0],
412    //         [3.0, 4.0]
413    //     ]);
414    //     let rhs = complex_mat!([
415    //         [5.0-8.0j, 6.0],
416    //         [7.0, 8.0]
417    //     ]);
418
419    //     kron_add(lhs.as_ref(), rhs.as_ref(), dst.as_mut());
420
421    //     let expected = complex_mat!([
422    //         [78.0+29.0j, 8.0+121.0j, 13.0-16.0j, 16.0],
423    //         [12.0+63.0j, 14.0+72.0j, 21.0, 24.0],
424    //         [24.0-24.0j, 28.0, 31.0-32.0j, 36.0],
425    //         [34.0, 38.0, 43.0, 48.0]
426    //     ]);
427
428    //     assert_eq!(dst, expected);
429    // }
430
431    #[test]
432    fn kron_test() {
433        let a = mat![[1.0, 2.0], [3.0, 4.0],];
434        let b = mat![[0.0, 5.0], [6.0, 7.0],];
435        let c = mat![
436            [0.0, 5.0, 0.0, 10.0],
437            [6.0, 7.0, 12.0, 14.0],
438            [0.0, 15.0, 0.0, 20.0],
439            [18.0, 21.0, 24.0, 28.0],
440        ];
441        let mut dst = Mat::new();
442        dst.resize_with(4, 4, |_, _| 0f64);
443        kron(a.as_ref(), b.as_ref(), dst.as_mut());
444        assert_eq!(dst, c);
445    }
446}