opensrdk_linear_algebra/matrix/kr/
mod.rs1use serde::Deserialize;
2use serde::Serialize;
3
4use crate::matrix::ge::*;
5use crate::matrix::MatrixError;
6use crate::number::Number;
7
8#[derive(Clone, Debug, Default, PartialEq, Hash, Serialize, Deserialize)]
9pub struct KroneckerMatrices<T = f64>
10where
11 T: Number,
12{
13 matrices: Vec<Matrix<T>>,
14 rows: usize,
15 cols: usize,
16}
17
18impl<T> KroneckerMatrices<T>
19where
20 T: Number,
21{
22 pub fn new(matrices: Vec<Matrix<T>>) -> Self {
31 let (rows, cols) = matrices
32 .iter()
33 .fold((1usize, 1usize), |v, m| (v.0 * m.rows(), v.1 * m.cols()));
34 Self {
35 matrices,
36 rows,
37 cols,
38 }
39 }
40
41 pub fn matrices(&self) -> &[Matrix<T>] {
42 &self.matrices
43 }
44
45 pub fn rows(&self) -> usize {
46 self.rows
47 }
48
49 pub fn cols(&self) -> usize {
50 self.cols
51 }
52
53 pub fn eject(self) -> Vec<Matrix<T>> {
54 self.matrices
55 }
56
57 pub fn prod(&self) -> Matrix<T> {
58 let bigp = self.matrices.len();
59 let rows = self.matrices[0].rows();
60 let cols = self.matrices[0].cols();
61 let elems_row = (0..rows.pow(bigp as u32))
62 .into_iter()
63 .map(|j| {
64 (0..cols.pow(bigp as u32))
65 .into_iter()
66 .map(|i| {
67 let elem_a = (0..bigp - 1)
68 .into_iter()
69 .map(|p| {
70 let k = bigp - 1 - p;
71 let row =
72 ((j - (j % rows.pow(k as u32))) / rows.pow(k as u32)) % rows;
73 let col =
74 ((i - (i % cols.pow(k as u32))) / cols.pow(k as u32)) % cols;
75 self.matrices[p][(col, row)]
76 })
77 .product::<T>();
78 let elem_b = self.matrices[bigp - 1][(i % cols, j % rows)];
79 elem_a * elem_b
80 })
81 .collect::<Vec<T>>()
82 })
83 .collect::<Vec<Vec<T>>>()
84 .concat();
85 let result = Matrix::from(rows.pow(bigp as u32), elems_row).unwrap();
86 result
87 }
88}
89
90impl KroneckerMatrices {
91 pub fn vec_mul(&self, v: Vec<f64>) -> Result<Vec<f64>, MatrixError> {
92 let n = v.len();
93
94 if self.cols != n {
95 return Err(MatrixError::DimensionMismatch);
96 }
97
98 let u = self.prod().dot(&v.col_mat());
99 Ok(u.vec())
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use crate::*;
106 #[test]
107 fn it_works() {
108 let a = mat![
109 1.0, 2.0;
110 3.0, 4.0
111 ];
112 let b = mat![
113 1.0, 2.0;
114 3.0, 4.0
115 ];
116 let ab = KroneckerMatrices::new(vec![a, b]);
117 let c = ab.prod();
118
119 println!("c {:#?}", c);
120
121 assert_eq!(c[(0, 0)], 1.0);
122 assert_eq!(c[(0, 3)], 4.0);
123 assert_eq!(c[(2, 1)], 6.0);
124
125 let ab1 = ab.vec_mul(vec![1.0; 4]).unwrap().col_mat();
126 println!("ab1 {:#?}", ab1);
127 let c1 = c.dot(&vec![1.0; 4].col_mat());
128 println!("c1 {:#?}", c1);
129
130 assert_eq!(ab1[(0, 0)], c1[(0, 0)]);
131 assert_eq!(ab1[(1, 0)], c1[(1, 0)]);
132 }
133}