1use nalgebra::{Dynamic, Matrix, VecStorage, U1};
15use std::ops::{AddAssign, Mul, MulAssign};
16
17pub type Complex = nalgebra::Complex<f64>;
18pub type MatrixXc = Matrix<Complex, Dynamic, Dynamic, VecStorage<Complex, Dynamic, Dynamic>>;
19pub type MatrixX = Matrix<f64, Dynamic, Dynamic, VecStorage<f64, Dynamic, Dynamic>>;
20pub type VectorXc = Matrix<Complex, Dynamic, U1, VecStorage<Complex, Dynamic, U1>>;
21pub type VectorX = Matrix<f64, Dynamic, U1, VecStorage<f64, Dynamic, U1>>;
22
23pub enum Transpose {
24 NoTrans = 111,
25 Trans = 112,
26 ConjTrans = 113,
27 ConjNoTrans = 114,
28}
29
30pub trait Backend {
31 fn hadamard_product(a: &MatrixXc, b: &MatrixXc, c: &mut MatrixXc);
32 fn real(a: &MatrixXc, b: &mut MatrixX);
33 fn imag(a: &VectorXc, b: &mut VectorX);
34 fn pseudo_inverse_svd(matrix: MatrixXc, alpha: f64, result: &mut MatrixXc);
35 fn max_eigen_vector(matrix: MatrixXc) -> VectorXc;
36 fn matrix_add(alpha: f64, a: &MatrixX, beta: f64, b: &mut MatrixX);
37 fn matrix_mul(
38 trans_a: Transpose,
39 trans_b: Transpose,
40 alpha: Complex,
41 a: &MatrixXc,
42 b: &MatrixXc,
43 beta: Complex,
44 c: &mut MatrixXc,
45 );
46 fn matrix_mul_vec(
47 trans_a: Transpose,
48 alpha: Complex,
49 a: &MatrixXc,
50 b: &VectorXc,
51 beta: Complex,
52 c: &mut VectorXc,
53 );
54 fn vector_add(alpha: f64, a: &VectorX, b: &mut VectorX);
55 fn solve_ch(a: MatrixXc, b: &mut VectorXc) -> bool;
56 fn solve_g(a: MatrixX, b: &mut VectorX) -> bool;
57 fn dot(a: &VectorX, b: &VectorX) -> f64;
58 fn dot_c(a: &VectorXc, b: &VectorXc) -> Complex;
59 fn max_coefficient(a: &VectorX) -> f64;
60 fn max_coefficient_c(a: &VectorXc) -> f64;
61 fn concat_row(a: MatrixXc, b: &MatrixXc) -> MatrixXc;
62 fn concat_col(a: MatrixXc, b: &MatrixXc) -> MatrixXc;
63}
64
65pub struct NalgebraBackend {}
66
67impl Backend for NalgebraBackend {
68 fn hadamard_product(a: &MatrixXc, b: &MatrixXc, c: &mut MatrixXc) {
69 *c = a.component_mul(b);
70 }
71
72 fn real(a: &MatrixXc, b: &mut MatrixX) {
73 *b = a.map(|x| x.re);
74 }
75
76 fn imag(a: &VectorXc, b: &mut VectorX) {
77 *b = a.map(|x| x.im);
78 }
79
80 fn pseudo_inverse_svd(matrix: MatrixXc, alpha: f64, result: &mut MatrixXc) {
81 let svd = matrix.svd(true, true);
82 let s_inv = MatrixXc::from_diagonal(
83 &svd.singular_values
84 .map(|s| Complex::new(s / (s * s + alpha * alpha), 0.)),
85 );
86 *result = match (&svd.v_t, &svd.u) {
87 (Some(v_t), Some(u)) => v_t.adjoint() * s_inv * u.adjoint(),
88 _ => unreachable!(),
89 };
90 }
91
92 fn max_eigen_vector(matrix: MatrixXc) -> VectorXc {
93 let eig = nalgebra::SymmetricEigen::new(matrix);
94 eig.eigenvectors.column(eig.eigenvalues.imax()).into()
95 }
96
97 fn matrix_add(alpha: f64, a: &MatrixX, beta: f64, b: &mut MatrixX) {
98 b.mul_assign(beta);
99 b.add_assign(a.mul(alpha));
100 }
101
102 fn matrix_mul(
103 trans_a: Transpose,
104 trans_b: Transpose,
105 alpha: Complex,
106 a: &MatrixXc,
107 b: &MatrixXc,
108 beta: Complex,
109 c: &mut MatrixXc,
110 ) {
111 c.mul_assign(beta);
112 match (trans_a, trans_b) {
113 (Transpose::NoTrans, Transpose::NoTrans) => c.add_assign(a.mul(b).mul(alpha)),
114 (Transpose::NoTrans, Transpose::Trans) => c.add_assign(a.mul(b.transpose()).mul(alpha)),
115 (Transpose::NoTrans, Transpose::ConjTrans) => {
116 c.add_assign(a.mul(b.adjoint()).mul(alpha))
117 }
118 (Transpose::NoTrans, Transpose::ConjNoTrans) => {
119 c.add_assign(a.mul(b.conjugate()).mul(alpha))
120 }
121 (Transpose::Trans, Transpose::NoTrans) => c.add_assign(a.transpose().mul(b).mul(alpha)),
122 (Transpose::Trans, Transpose::Trans) => {
123 c.add_assign(a.transpose().mul(b.transpose()).mul(alpha))
124 }
125 (Transpose::Trans, Transpose::ConjTrans) => {
126 c.add_assign(a.transpose().mul(b.adjoint()).mul(alpha))
127 }
128 (Transpose::Trans, Transpose::ConjNoTrans) => {
129 c.add_assign(a.transpose().mul(b.conjugate()).mul(alpha))
130 }
131 (Transpose::ConjTrans, Transpose::NoTrans) => {
132 c.add_assign(a.adjoint().mul(b).mul(alpha))
133 }
134 (Transpose::ConjTrans, Transpose::Trans) => {
135 c.add_assign(a.adjoint().mul(b.transpose()).mul(alpha))
136 }
137 (Transpose::ConjTrans, Transpose::ConjTrans) => {
138 c.add_assign(a.adjoint().mul(b.adjoint()).mul(alpha))
139 }
140 (Transpose::ConjTrans, Transpose::ConjNoTrans) => {
141 c.add_assign(a.adjoint().mul(b.conjugate()).mul(alpha))
142 }
143 (Transpose::ConjNoTrans, Transpose::NoTrans) => {
144 c.add_assign(a.conjugate().mul(b).mul(alpha))
145 }
146 (Transpose::ConjNoTrans, Transpose::Trans) => {
147 c.add_assign(a.conjugate().mul(b.transpose()).mul(alpha))
148 }
149 (Transpose::ConjNoTrans, Transpose::ConjTrans) => {
150 c.add_assign(a.conjugate().mul(b.adjoint()).mul(alpha))
151 }
152 (Transpose::ConjNoTrans, Transpose::ConjNoTrans) => {
153 c.add_assign(a.conjugate().mul(b.conjugate()).mul(alpha))
154 }
155 };
156 }
157
158 fn matrix_mul_vec(
159 trans_a: Transpose,
160 alpha: Complex,
161 a: &MatrixXc,
162 b: &VectorXc,
163 beta: Complex,
164 c: &mut VectorXc,
165 ) {
166 c.mul_assign(beta);
167 match trans_a {
168 Transpose::NoTrans => c.add_assign(a.mul(b).mul(alpha)),
169 Transpose::Trans => c.add_assign(a.transpose().mul(b).mul(alpha)),
170 Transpose::ConjTrans => c.add_assign(a.adjoint().mul(b).mul(alpha)),
171 Transpose::ConjNoTrans => c.add_assign(a.conjugate().mul(b).mul(alpha)),
172 };
173 }
174
175 fn vector_add(alpha: f64, a: &VectorX, b: &mut VectorX) {
176 b.add_assign(a.mul(alpha));
177 }
178
179 fn solve_ch(a: MatrixXc, b: &mut VectorXc) -> bool {
180 a.qr().solve_mut(b)
181 }
182
183 fn solve_g(a: MatrixX, b: &mut VectorX) -> bool {
184 a.qr().solve_mut(b)
185 }
186
187 fn dot(a: &VectorX, b: &VectorX) -> f64 {
188 a.dot(b)
189 }
190
191 fn dot_c(a: &VectorXc, b: &VectorXc) -> Complex {
192 a.dot(b)
193 }
194
195 fn max_coefficient(a: &VectorX) -> f64 {
196 a.camax()
197 }
198
199 fn max_coefficient_c(a: &VectorXc) -> f64 {
200 a.camax()
201 }
202
203 fn concat_row(a: MatrixXc, b: &MatrixXc) -> MatrixXc {
204 let arows = a.nrows();
205 let acols = a.ncols();
206 let mut new_mat = a.resize(arows + b.nrows(), acols, Default::default());
207 new_mat
208 .slice_mut((arows, 0), (b.nrows(), b.ncols()))
209 .copy_from(b);
210
211 new_mat
212 }
213
214 fn concat_col(a: MatrixXc, b: &MatrixXc) -> MatrixXc {
215 let arows = a.nrows();
216 let acols = a.ncols();
217 let mut new_mat = a.resize(arows, acols + b.ncols(), Default::default());
218 new_mat
219 .slice_mut((0, acols), (b.nrows(), b.ncols()))
220 .copy_from(b);
221
222 new_mat
223 }
224}