array_matrix/matrix/kronecker.rs
1use std::ops::Mul;
2
3use crate::{matrix_init, Matrix};
4
5pub trait KroneckerMul<Rhs>: Matrix
6where
7 Rhs: Matrix,
8 Self::Output: Matrix
9{
10 type Output;
11
12 /// Returns the kronecker product of the two matrices
13 ///
14 /// A ⊗ₖᵣₒₙ B
15 ///
16 /// # Arguments
17 ///
18 /// * `rhs` - A matrix of any size
19 ///
20 /// # Examples
21 ///
22 /// ```rust
23 /// let a = [
24 /// [1.0, 2.0],
25 /// [3.0, 4.0]
26 /// ];
27 /// let b = [
28 /// [1.0, 2.0],
29 /// [3.0, 4.0]
30 /// ];
31 /// let ab = [
32 /// [1.0, 2.0, 2.0, 4.0],
33 /// [3.0, 4.0, 6.0, 8.0],
34 /// [3.0, 6.0, 4.0, 8.0],
35 /// [9.0, 12.0, 12.0, 16.0]
36 /// ];
37 /// assert_eq!(a.kronecker(b), ab);
38 /// ```
39 fn kronecker_mul(self, rhs: Rhs) -> Self::Output;
40}
41
42impl<F, const L1: usize, const H1: usize, const L2: usize, const H2: usize>
43 KroneckerMul<[[F; L2]; H2]>
44for
45 [[F; L1]; H1]
46where
47 Self: Matrix,
48 [[F; L2]; H2]: Matrix,
49 F: Clone + Mul<F>,
50 [[<F as Mul<F>>::Output; L1*L2]; H1*H2]: Matrix
51{
52 type Output = [[<F as Mul<F>>::Output; L1*L2]; H1*H2];
53
54 fn kronecker_mul(self, rhs: [[F; L2]; H2]) -> Self::Output
55 {
56 matrix_init(|r, c| self[r/H1][c/L1].clone()*rhs[r%H2][c%L2].clone())
57 }
58}