opensrdk_linear_algebra/matrix/kr/
mod.rs

1use serde::Deserialize;
2use serde::Serialize;
3
4use crate::matrix::ge::*;
5use crate::matrix::MatrixError;
6use crate::number::Number;
7
8#[derive(Clone, Debug, Default, PartialEq, Hash, Serialize, Deserialize)]
9pub struct KroneckerMatrices<T = f64>
10where
11    T: Number,
12{
13    matrices: Vec<Matrix<T>>,
14    rows: usize,
15    cols: usize,
16}
17
18impl<T> KroneckerMatrices<T>
19where
20    T: Number,
21{
22    /// The code below means that `a = a_1 ⊗ a_2`
23    /// ```
24    /// use opensrdk_linear_algebra::*;
25    ///
26    /// let a_1 = Matrix::<f64>::new(2, 2);
27    /// let a_2 = Matrix::<f64>::new(3, 4);
28    /// let a = KroneckerMatrices::new(vec![a_1, a_2]);
29    /// ```
30    pub fn new(matrices: Vec<Matrix<T>>) -> Self {
31        let (rows, cols) = matrices
32            .iter()
33            .fold((1usize, 1usize), |v, m| (v.0 * m.rows(), v.1 * m.cols()));
34        Self {
35            matrices,
36            rows,
37            cols,
38        }
39    }
40
41    pub fn matrices(&self) -> &[Matrix<T>] {
42        &self.matrices
43    }
44
45    pub fn rows(&self) -> usize {
46        self.rows
47    }
48
49    pub fn cols(&self) -> usize {
50        self.cols
51    }
52
53    pub fn eject(self) -> Vec<Matrix<T>> {
54        self.matrices
55    }
56
57    pub fn prod(&self) -> Matrix<T> {
58        let bigp = self.matrices.len();
59        let rows = self.matrices[0].rows();
60        let cols = self.matrices[0].cols();
61        let elems_row = (0..rows.pow(bigp as u32))
62            .into_iter()
63            .map(|j| {
64                (0..cols.pow(bigp as u32))
65                    .into_iter()
66                    .map(|i| {
67                        let elem_a = (0..bigp - 1)
68                            .into_iter()
69                            .map(|p| {
70                                let k = bigp - 1 - p;
71                                let row =
72                                    ((j - (j % rows.pow(k as u32))) / rows.pow(k as u32)) % rows;
73                                let col =
74                                    ((i - (i % cols.pow(k as u32))) / cols.pow(k as u32)) % cols;
75                                self.matrices[p][(col, row)]
76                            })
77                            .product::<T>();
78                        let elem_b = self.matrices[bigp - 1][(i % cols, j % rows)];
79                        elem_a * elem_b
80                    })
81                    .collect::<Vec<T>>()
82            })
83            .collect::<Vec<Vec<T>>>()
84            .concat();
85        let result = Matrix::from(rows.pow(bigp as u32), elems_row).unwrap();
86        result
87    }
88}
89
90impl KroneckerMatrices {
91    pub fn vec_mul(&self, v: Vec<f64>) -> Result<Vec<f64>, MatrixError> {
92        let n = v.len();
93
94        if self.cols != n {
95            return Err(MatrixError::DimensionMismatch);
96        }
97
98        let u = self.prod().dot(&v.col_mat());
99        Ok(u.vec())
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use crate::*;
106    #[test]
107    fn it_works() {
108        let a = mat![
109            1.0, 2.0;
110            3.0, 4.0
111        ];
112        let b = mat![
113            1.0, 2.0;
114            3.0, 4.0
115        ];
116        let ab = KroneckerMatrices::new(vec![a, b]);
117        let c = ab.prod();
118
119        println!("c {:#?}", c);
120
121        assert_eq!(c[(0, 0)], 1.0);
122        assert_eq!(c[(0, 3)], 4.0);
123        assert_eq!(c[(2, 1)], 6.0);
124
125        let ab1 = ab.vec_mul(vec![1.0; 4]).unwrap().col_mat();
126        println!("ab1 {:#?}", ab1);
127        let c1 = c.dot(&vec![1.0; 4].col_mat());
128        println!("c1 {:#?}", c1);
129
130        assert_eq!(ab1[(0, 0)], c1[(0, 0)]);
131        assert_eq!(ab1[(1, 0)], c1[(1, 0)]);
132    }
133}