qudit_core/accel/kron.rs
1//! Functions to efficiently perform the Kronecker product and fused Kronecker-add operations.
2
3use std::ops::AddAssign;
4use std::ops::Mul;
5
6use faer::reborrow::ReborrowMut;
7use faer_traits::ComplexField;
8
9use super::cartesian_match;
10use crate::ComplexScalar;
11use faer::MatMut;
12use faer::MatRef;
13
14// TODO: Add proper documentation to raw methods and add higher level
15// functions that call them with the cartesian_match for loop unrolling.
16/// Perform a kroneckor product between two matrix buffers.
17///
18/// # Safety
19///
20/// Caller must ensure that pointers point to properly addressable memory
21/// that describe matrices with dimensions and strides given.
22pub unsafe fn kron_kernel_raw<C: Mul<Output = C> + Copy>(
23 dst: *mut C,
24 dst_rs: isize,
25 dst_cs: isize,
26 lhs: *const C,
27 lhs_nrows: usize,
28 lhs_ncols: usize,
29 lhs_rs: isize,
30 lhs_cs: isize,
31 rhs: *const C,
32 rhs_nrows: usize,
33 rhs_ncols: usize,
34 rhs_rs: isize,
35 rhs_cs: isize,
36) {
37 unsafe {
38 for lhs_j in 0..lhs_ncols {
39 for lhs_i in 0..lhs_nrows {
40 let lhs_val = *lhs.offset(lhs_i as isize * lhs_rs + lhs_j as isize * lhs_cs);
41
42 let dst_major_row = lhs_i * rhs_nrows;
43 let dst_major_col = lhs_j * rhs_ncols;
44
45 for rhs_j in 0..rhs_ncols {
46 for rhs_i in 0..rhs_nrows {
47 let rhs_val =
48 *rhs.offset(rhs_i as isize * rhs_rs + rhs_j as isize * rhs_cs);
49
50 let dst_row = dst_major_row + rhs_i;
51 let dst_col = dst_major_col + rhs_j;
52
53 let dst_offset = dst_row as isize * dst_rs + dst_col as isize * dst_cs;
54
55 *dst.offset(dst_offset) = lhs_val * rhs_val;
56 }
57 }
58 }
59 }
60 }
61}
62
63/// Perform a kroneckor product between two matrix buffers and add the result to the output.
64///
65/// # Safety
66///
67/// Caller must ensure that pointers point to properly addressable memory
68/// that describe matrices with dimensions and strides given.
69pub unsafe fn kron_kernel_add_raw<C: Mul<Output = C> + Copy + AddAssign>(
70 dst: *mut C,
71 dst_rs: isize,
72 dst_cs: isize,
73 lhs: *const C,
74 lhs_nrows: usize,
75 lhs_ncols: usize,
76 lhs_rs: isize,
77 lhs_cs: isize,
78 rhs: *const C,
79 rhs_nrows: usize,
80 rhs_ncols: usize,
81 rhs_rs: isize,
82 rhs_cs: isize,
83) {
84 unsafe {
85 for lhs_j in 0..lhs_ncols {
86 for lhs_i in 0..lhs_nrows {
87 let lhs_val = *lhs.offset(lhs_i as isize * lhs_rs + lhs_j as isize * lhs_cs);
88
89 let dst_major_row = lhs_i * rhs_nrows;
90 let dst_major_col = lhs_j * rhs_ncols;
91
92 for rhs_j in 0..rhs_ncols {
93 for rhs_i in 0..rhs_nrows {
94 let rhs_val =
95 *rhs.offset(rhs_i as isize * rhs_rs + rhs_j as isize * rhs_cs);
96
97 let dst_row = dst_major_row + rhs_i;
98 let dst_col = dst_major_col + rhs_j;
99
100 let dst_offset = dst_row as isize * dst_rs + dst_col as isize * dst_cs;
101
102 *dst.offset(dst_offset) += lhs_val * rhs_val;
103 }
104 }
105 }
106 }
107 }
108}
109
110/// The inner kernel that performs the Kronecker product of two matrices
111/// without checking assumptions.
112///
113/// # Safety
114///
115/// * The dimensions of `dst` must be at least `lhs_rows * rhs_rows` by `lhs_cols * rhs_cols`.
116///
117/// # See also
118///
119/// * [`kron`] for a safe version of this function.
120///
121unsafe fn kron_kernel<C: ComplexField>(
122 mut dst: MatMut<C>,
123 lhs: MatRef<C>,
124 rhs: MatRef<C>,
125 lhs_rows: usize,
126 lhs_cols: usize,
127 rhs_rows: usize,
128 rhs_cols: usize,
129) {
130 unsafe {
131 for lhs_j in 0..lhs_cols {
132 for lhs_i in 0..lhs_rows {
133 let lhs_val = lhs.get_unchecked(lhs_i, lhs_j);
134
135 for rhs_j in 0..rhs_cols {
136 for rhs_i in 0..rhs_rows {
137 let rhs_val = rhs.get_unchecked(rhs_i, rhs_j);
138
139 *(dst.rb_mut().get_mut_unchecked(
140 lhs_i * rhs_rows + rhs_i,
141 lhs_j * rhs_cols + rhs_j,
142 )) = lhs_val.mul_by_ref(rhs_val);
143 }
144 }
145 }
146 }
147 }
148}
149
150/// Performs the Kronecker product of two matrices and adds this to the destination
151/// without checking assumptions.
152///
153/// More efficient that performing the Kronecker product followed by addition;
154/// we only look up each element of `dst` once, rather than twice.
155///
156/// # Safety
157///
158/// * The dimensions of `dst` must be at least `lhs_rows * rhs_rows` by `lhs_cols * rhs_cols`.
159///
160/// # See also
161///
162/// * [`kron_add`] for a safe version of this function.
163///
164unsafe fn kron_kernel_add<C: ComplexScalar>(
165 mut dst: MatMut<C>,
166 lhs: MatRef<C>,
167 rhs: MatRef<C>,
168 lhs_rows: usize,
169 lhs_cols: usize,
170 rhs_rows: usize,
171 rhs_cols: usize,
172) {
173 unsafe {
174 for lhs_j in 0..lhs_cols {
175 for lhs_i in 0..lhs_rows {
176 let lhs_val = lhs.get_unchecked(lhs_i, lhs_j);
177
178 for rhs_j in 0..rhs_cols {
179 for rhs_i in 0..rhs_rows {
180 let rhs_val = rhs.get_unchecked(rhs_i, rhs_j);
181
182 // Notice that each element of `dst` is only looked up once throughout the loops.
183 *(dst.rb_mut().get_mut_unchecked(
184 lhs_i * rhs_rows + rhs_i,
185 lhs_j * rhs_cols + rhs_j,
186 )) += lhs_val.mul_by_ref(rhs_val);
187 }
188 }
189 }
190 }
191 }
192}
193
194/// Performs the Kronecker product of two matrices without checking assumptions.
195///
196/// # Safety
197///
198/// * The dimensions of `dst` must be at least `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
199///
200/// # See also
201///
202/// * [`kron`] for a safe version of this function.
203///
204pub unsafe fn kron_unchecked<C: ComplexField>(dst: MatMut<C>, lhs: MatRef<C>, rhs: MatRef<C>) {
205 unsafe {
206 let lhs_rows = lhs.nrows();
207 let lhs_cols = lhs.ncols();
208 let rhs_rows = rhs.nrows();
209 let rhs_cols = rhs.ncols();
210
211 cartesian_match!(
212 { kron_kernel(dst, lhs, rhs, lhs_rows, lhs_cols, rhs_rows, rhs_cols) },
213 (lhs_rows, (lhs_cols, (rhs_rows, (rhs_cols, ())))),
214 (
215 (2, 3, 4, _),
216 ((2, 3, 4, _), ((2, 3, 4, _), ((2, 3, 4, _), ())))
217 )
218 );
219 }
220}
221
222/// Performs the Kronecker product of two square matrices without checking assumptions.
223///
224/// # Safety
225///
226/// * The dimensions of `dst` must be at least `lhs.nrows() * rhs.nrows()` by `lhs.ncols() * rhs.ncols()`.
227/// * The matrices must be square.
228///
229/// # See also
230///
231/// * [`kron`] for a safe version of this function.
232///
233pub unsafe fn kron_sq_unchecked<C: ComplexField>(dst: MatMut<C>, lhs: MatRef<C>, rhs: MatRef<C>) {
234 unsafe {
235 let lhs_dim = lhs.nrows();
236 let rhs_dim = rhs.nrows();
237
238 cartesian_match!(
239 { kron_kernel(dst, lhs, rhs, lhs_dim, lhs_dim, rhs_dim, rhs_dim) },
240 (lhs_dim, (rhs_dim, ())),
241 (
242 (2, 3, 4, 6, 8, 9, 16, 27, 32, 64, 81, _),
243 ((2, 3, 4, 6, 8, 9, 16, 27, 32, 64, 81, _), ())
244 )
245 );
246 }
247}
248
249/// Kronecker product of two matrices.
250///
251/// The Kronecker product of two matrices `A` and `B` is a block matrix
252/// `C` with the following structure:
253///
254/// ```text
255/// C = [ a00 * B, a01 * B, ..., a0n * B ]
256/// [ a10 * B, a11 * B, ..., a1n * B ]
257/// [ ... , ... , ..., ... ]
258/// [ am0 * B, am1 * B, ..., amn * B ]
259/// ```
260/// where `a_ij` is the element at position `(i, j)` of `A`.
261///
262/// # Panics
263///
264/// * If `dst` does not have the correct dimensions. The dimensions
265/// of `dst` must be `nrows(A) * nrows(B)` by `ncols(A) * ncols(B)`.
266///
267/// # Example
268/// ```
269/// use faer::mat;
270/// use faer::Mat;
271/// use qudit_core::accel::kron;
272///
273/// let a = mat![
274/// [1.0, 2.0],
275/// [3.0, 4.0],
276/// ];
277/// let b = mat![
278/// [0.0, 5.0],
279/// [6.0, 7.0],
280/// ];
281/// let c = mat![
282/// [0.0 , 5.0 , 0.0 , 10.0],
283/// [6.0 , 7.0 , 12.0, 14.0],
284/// [0.0 , 15.0, 0.0 , 20.0],
285/// [18.0, 21.0, 24.0, 28.0],
286/// ];
287///
288/// let mut dst = Mat::new();
289/// dst.resize_with(4, 4, |_, _| 0f64);
290///
291/// kron(a.as_ref(), b.as_ref(), dst.as_mut());
292///
293/// assert_eq!(dst, c);
294/// ```
295///
296pub fn kron<C: ComplexField>(lhs: MatRef<C>, rhs: MatRef<C>, dst: MatMut<C>) {
297 let mut lhs = lhs;
298 let mut rhs = rhs;
299 let mut dst = dst;
300
301 // Ensures that `dst` is in column-major order.
302 if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
303 dst = dst.transpose_mut();
304 lhs = lhs.transpose();
305 rhs = rhs.transpose();
306 }
307
308 // Checks that the dimensions of `dst` matches the expected dimensions of the Kronecker product of
309 // `lhs` and `rhs`. Also checks that no overflow occurs during the multiplication.
310 assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
311 assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
312
313 // Uses a specialized kernel for square matrices if both `lhs` and `rhs` are square.
314 if lhs.nrows() == lhs.ncols() && rhs.nrows() == rhs.ncols() {
315 // Safety: The dimensions have been checked.
316 unsafe { kron_sq_unchecked(dst, lhs, rhs) }
317 } else {
318 // Safety: The dimensions have been checked.
319 unsafe { kron_unchecked(dst, lhs, rhs) }
320 }
321}
322
323/// Computes the Kronecker product of two matrices and adds the result to a destination matrix.
324///
325/// For `A` ∈ M(R_a, C_a), `B` ∈ M(R_b, C_b), `C` ∈ M(R_a * R_b, C_a * C_b), this function mutates `C`
326/// such C_{i * R_b + k , j * C_b + l} -> C_{i * R_b + k , j * C_b + l} + A_{i, j} * B_{k, l}.
327///
328/// # Arguments
329///
330/// * `lhs` - The left hand-side matrix for the kronecker product. `A` in the description above.
331/// * `rhs` - The right hand-side matrix for the kronecker product. `B` in the description above.
332/// * `dst` - The matrix to be summed (mutated) by the kronercker product of `lhs` and `rhs`.
333/// `C` in the description above.
334///
335/// # Panics
336///
337/// * If `dst.nrows()` doesn't match `lhs.nrows()` times `rhs.nrows()`
338/// * If `dst.ncols()` doesn't match `lhs.ncols()` times `rhs.ncols()`
339/// * If an overflow occurs when calculating the expected dimensions.
340///
341/// # Example
342/// ```
343/// use faer::{mat, Mat};
344/// use qudit_core::accel::kron_add;
345/// use qudit_core::c64;
346///
347/// let mut dst = Mat::<c64>::zeros(4, 4);
348///
349/// let lhs = Mat::<c64>::from_fn(2, 2, |i, j| -> c64 {c64::new((2*i+1+j) as f64, 0.0)});
350/// let rhs = Mat::<c64>::from_fn(2, 2, |i, j| -> c64 {c64::new((2*i+5+j) as f64, 0.0)});
351///
352/// kron_add(lhs.as_ref(), rhs.as_ref(), dst.as_mut());
353///
354/// let expected_data = [
355/// [c64::new(5.0, 0.0), c64::new(6.0, 0.0), c64::new(10.0, 0.0), c64::new(12.0, 0.0)],
356/// [c64::new(7.0, 0.0), c64::new(8.0, 0.0), c64::new(14.0, 0.0), c64::new(16.0, 0.0)],
357/// [c64::new(15.0, 0.0), c64::new(18.0, 0.0), c64::new(20.0, 0.0), c64::new(24.0, 0.0)],
358/// [c64::new(21.0, 0.0), c64::new(24.0, 0.0), c64::new(28.0, 0.0), c64::new(32.0, 0.0)]
359/// ];
360/// let expected = Mat::from_fn(4, 4, |i, j| -> c64 {expected_data[i][j]});
361/// assert_eq!(dst, expected);
362/// ```
363///
364pub fn kron_add<C: ComplexScalar>(lhs: MatRef<C>, rhs: MatRef<C>, dst: MatMut<C>) {
365 let mut lhs = lhs;
366 let mut rhs = rhs;
367 let mut dst = dst;
368
369 // Makes `dst` is in column-major order. To maintain the same computation, we transpose `lhs` and `rhs` as well.
370 // This is allowed because the transpose is distributive over the Kronecker product and addition. Notice we are
371 // transposing input views, so we need not re-transpose our matrices after mutating the underlying data of dst.
372 if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() {
373 dst = dst.transpose_mut();
374 lhs = lhs.transpose();
375 rhs = rhs.transpose();
376 }
377 // Makes sure the dimesion of the Kronecker product between lhs, rhs matches that of dst.
378 // Recall (F_r, F_c) = (D_r * E_r, D_c * E_c) where F = D (x) E.
379 // Also makes sure overflows do not occur during the multiplications.
380 assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows()));
381 assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols()));
382
383 // Performs the actual Kronecker product followed by sum.
384 unsafe {
385 kron_kernel_add(
386 dst,
387 lhs,
388 rhs,
389 lhs.nrows(),
390 lhs.ncols(),
391 rhs.nrows(),
392 rhs.ncols(),
393 );
394 }
395}
396
397#[cfg(test)]
398mod kron_tests {
399 use super::*;
400 use faer::Mat;
401 use faer::mat;
402
403 // #[test]
404 // fn kron_add_test() {
405 // let mut dst = complex_mat!([
406 // [1.0-8.0j, 2.0+67.0j, 3.0, 4.0],
407 // [5.0, 6.0, 7.0, 8.0],
408 // [9.0, 10.0, 11.0, 12.0],
409 // [13.0, 14.0, 15.0, 16.0]]);
410 // let lhs= complex_mat!([
411 // [1.0+9.0j, 2.0],
412 // [3.0, 4.0]
413 // ]);
414 // let rhs = complex_mat!([
415 // [5.0-8.0j, 6.0],
416 // [7.0, 8.0]
417 // ]);
418
419 // kron_add(lhs.as_ref(), rhs.as_ref(), dst.as_mut());
420
421 // let expected = complex_mat!([
422 // [78.0+29.0j, 8.0+121.0j, 13.0-16.0j, 16.0],
423 // [12.0+63.0j, 14.0+72.0j, 21.0, 24.0],
424 // [24.0-24.0j, 28.0, 31.0-32.0j, 36.0],
425 // [34.0, 38.0, 43.0, 48.0]
426 // ]);
427
428 // assert_eq!(dst, expected);
429 // }
430
431 #[test]
432 fn kron_test() {
433 let a = mat![[1.0, 2.0], [3.0, 4.0],];
434 let b = mat![[0.0, 5.0], [6.0, 7.0],];
435 let c = mat![
436 [0.0, 5.0, 0.0, 10.0],
437 [6.0, 7.0, 12.0, 14.0],
438 [0.0, 15.0, 0.0, 20.0],
439 [18.0, 21.0, 24.0, 28.0],
440 ];
441 let mut dst = Mat::new();
442 dst.resize_with(4, 4, |_, _| 0f64);
443 kron(a.as_ref(), b.as_ref(), dst.as_mut());
444 assert_eq!(dst, c);
445 }
446}