autd3_holo_gain/
backend.rs

1/*
2 * File: backend.rs
3 * Project: src
4 * Created Date: 28/05/2021
5 * Author: Shun Suzuki
6 * -----
7 * Last Modified: 30/05/2021
8 * Modified By: Shun Suzuki (suzuki@hapis.k.u-tokyo.ac.jp)
9 * -----
10 * Copyright (c) 2021 Hapis Lab. All rights reserved.
11 *
12 */
13
14use nalgebra::{Dynamic, Matrix, VecStorage, U1};
15use std::ops::{AddAssign, Mul, MulAssign};
16
17pub type Complex = nalgebra::Complex<f64>;
18pub type MatrixXc = Matrix<Complex, Dynamic, Dynamic, VecStorage<Complex, Dynamic, Dynamic>>;
19pub type MatrixX = Matrix<f64, Dynamic, Dynamic, VecStorage<f64, Dynamic, Dynamic>>;
20pub type VectorXc = Matrix<Complex, Dynamic, U1, VecStorage<Complex, Dynamic, U1>>;
21pub type VectorX = Matrix<f64, Dynamic, U1, VecStorage<f64, Dynamic, U1>>;
22
23pub enum Transpose {
24    NoTrans = 111,
25    Trans = 112,
26    ConjTrans = 113,
27    ConjNoTrans = 114,
28}
29
30pub trait Backend {
31    fn hadamard_product(a: &MatrixXc, b: &MatrixXc, c: &mut MatrixXc);
32    fn real(a: &MatrixXc, b: &mut MatrixX);
33    fn imag(a: &VectorXc, b: &mut VectorX);
34    fn pseudo_inverse_svd(matrix: MatrixXc, alpha: f64, result: &mut MatrixXc);
35    fn max_eigen_vector(matrix: MatrixXc) -> VectorXc;
36    fn matrix_add(alpha: f64, a: &MatrixX, beta: f64, b: &mut MatrixX);
37    fn matrix_mul(
38        trans_a: Transpose,
39        trans_b: Transpose,
40        alpha: Complex,
41        a: &MatrixXc,
42        b: &MatrixXc,
43        beta: Complex,
44        c: &mut MatrixXc,
45    );
46    fn matrix_mul_vec(
47        trans_a: Transpose,
48        alpha: Complex,
49        a: &MatrixXc,
50        b: &VectorXc,
51        beta: Complex,
52        c: &mut VectorXc,
53    );
54    fn vector_add(alpha: f64, a: &VectorX, b: &mut VectorX);
55    fn solve_ch(a: MatrixXc, b: &mut VectorXc) -> bool;
56    fn solve_g(a: MatrixX, b: &mut VectorX) -> bool;
57    fn dot(a: &VectorX, b: &VectorX) -> f64;
58    fn dot_c(a: &VectorXc, b: &VectorXc) -> Complex;
59    fn max_coefficient(a: &VectorX) -> f64;
60    fn max_coefficient_c(a: &VectorXc) -> f64;
61    fn concat_row(a: MatrixXc, b: &MatrixXc) -> MatrixXc;
62    fn concat_col(a: MatrixXc, b: &MatrixXc) -> MatrixXc;
63}
64
65pub struct NalgebraBackend {}
66
67impl Backend for NalgebraBackend {
68    fn hadamard_product(a: &MatrixXc, b: &MatrixXc, c: &mut MatrixXc) {
69        *c = a.component_mul(b);
70    }
71
72    fn real(a: &MatrixXc, b: &mut MatrixX) {
73        *b = a.map(|x| x.re);
74    }
75
76    fn imag(a: &VectorXc, b: &mut VectorX) {
77        *b = a.map(|x| x.im);
78    }
79
80    fn pseudo_inverse_svd(matrix: MatrixXc, alpha: f64, result: &mut MatrixXc) {
81        let svd = matrix.svd(true, true);
82        let s_inv = MatrixXc::from_diagonal(
83            &svd.singular_values
84                .map(|s| Complex::new(s / (s * s + alpha * alpha), 0.)),
85        );
86        *result = match (&svd.v_t, &svd.u) {
87            (Some(v_t), Some(u)) => v_t.adjoint() * s_inv * u.adjoint(),
88            _ => unreachable!(),
89        };
90    }
91
92    fn max_eigen_vector(matrix: MatrixXc) -> VectorXc {
93        let eig = nalgebra::SymmetricEigen::new(matrix);
94        eig.eigenvectors.column(eig.eigenvalues.imax()).into()
95    }
96
97    fn matrix_add(alpha: f64, a: &MatrixX, beta: f64, b: &mut MatrixX) {
98        b.mul_assign(beta);
99        b.add_assign(a.mul(alpha));
100    }
101
102    fn matrix_mul(
103        trans_a: Transpose,
104        trans_b: Transpose,
105        alpha: Complex,
106        a: &MatrixXc,
107        b: &MatrixXc,
108        beta: Complex,
109        c: &mut MatrixXc,
110    ) {
111        c.mul_assign(beta);
112        match (trans_a, trans_b) {
113            (Transpose::NoTrans, Transpose::NoTrans) => c.add_assign(a.mul(b).mul(alpha)),
114            (Transpose::NoTrans, Transpose::Trans) => c.add_assign(a.mul(b.transpose()).mul(alpha)),
115            (Transpose::NoTrans, Transpose::ConjTrans) => {
116                c.add_assign(a.mul(b.adjoint()).mul(alpha))
117            }
118            (Transpose::NoTrans, Transpose::ConjNoTrans) => {
119                c.add_assign(a.mul(b.conjugate()).mul(alpha))
120            }
121            (Transpose::Trans, Transpose::NoTrans) => c.add_assign(a.transpose().mul(b).mul(alpha)),
122            (Transpose::Trans, Transpose::Trans) => {
123                c.add_assign(a.transpose().mul(b.transpose()).mul(alpha))
124            }
125            (Transpose::Trans, Transpose::ConjTrans) => {
126                c.add_assign(a.transpose().mul(b.adjoint()).mul(alpha))
127            }
128            (Transpose::Trans, Transpose::ConjNoTrans) => {
129                c.add_assign(a.transpose().mul(b.conjugate()).mul(alpha))
130            }
131            (Transpose::ConjTrans, Transpose::NoTrans) => {
132                c.add_assign(a.adjoint().mul(b).mul(alpha))
133            }
134            (Transpose::ConjTrans, Transpose::Trans) => {
135                c.add_assign(a.adjoint().mul(b.transpose()).mul(alpha))
136            }
137            (Transpose::ConjTrans, Transpose::ConjTrans) => {
138                c.add_assign(a.adjoint().mul(b.adjoint()).mul(alpha))
139            }
140            (Transpose::ConjTrans, Transpose::ConjNoTrans) => {
141                c.add_assign(a.adjoint().mul(b.conjugate()).mul(alpha))
142            }
143            (Transpose::ConjNoTrans, Transpose::NoTrans) => {
144                c.add_assign(a.conjugate().mul(b).mul(alpha))
145            }
146            (Transpose::ConjNoTrans, Transpose::Trans) => {
147                c.add_assign(a.conjugate().mul(b.transpose()).mul(alpha))
148            }
149            (Transpose::ConjNoTrans, Transpose::ConjTrans) => {
150                c.add_assign(a.conjugate().mul(b.adjoint()).mul(alpha))
151            }
152            (Transpose::ConjNoTrans, Transpose::ConjNoTrans) => {
153                c.add_assign(a.conjugate().mul(b.conjugate()).mul(alpha))
154            }
155        };
156    }
157
158    fn matrix_mul_vec(
159        trans_a: Transpose,
160        alpha: Complex,
161        a: &MatrixXc,
162        b: &VectorXc,
163        beta: Complex,
164        c: &mut VectorXc,
165    ) {
166        c.mul_assign(beta);
167        match trans_a {
168            Transpose::NoTrans => c.add_assign(a.mul(b).mul(alpha)),
169            Transpose::Trans => c.add_assign(a.transpose().mul(b).mul(alpha)),
170            Transpose::ConjTrans => c.add_assign(a.adjoint().mul(b).mul(alpha)),
171            Transpose::ConjNoTrans => c.add_assign(a.conjugate().mul(b).mul(alpha)),
172        };
173    }
174
175    fn vector_add(alpha: f64, a: &VectorX, b: &mut VectorX) {
176        b.add_assign(a.mul(alpha));
177    }
178
179    fn solve_ch(a: MatrixXc, b: &mut VectorXc) -> bool {
180        a.qr().solve_mut(b)
181    }
182
183    fn solve_g(a: MatrixX, b: &mut VectorX) -> bool {
184        a.qr().solve_mut(b)
185    }
186
187    fn dot(a: &VectorX, b: &VectorX) -> f64 {
188        a.dot(b)
189    }
190
191    fn dot_c(a: &VectorXc, b: &VectorXc) -> Complex {
192        a.dot(b)
193    }
194
195    fn max_coefficient(a: &VectorX) -> f64 {
196        a.camax()
197    }
198
199    fn max_coefficient_c(a: &VectorXc) -> f64 {
200        a.camax()
201    }
202
203    fn concat_row(a: MatrixXc, b: &MatrixXc) -> MatrixXc {
204        let arows = a.nrows();
205        let acols = a.ncols();
206        let mut new_mat = a.resize(arows + b.nrows(), acols, Default::default());
207        new_mat
208            .slice_mut((arows, 0), (b.nrows(), b.ncols()))
209            .copy_from(b);
210
211        new_mat
212    }
213
214    fn concat_col(a: MatrixXc, b: &MatrixXc) -> MatrixXc {
215        let arows = a.nrows();
216        let acols = a.ncols();
217        let mut new_mat = a.resize(arows, acols + b.ncols(), Default::default());
218        new_mat
219            .slice_mut((0, acols), (b.nrows(), b.ncols()))
220            .copy_from(b);
221
222        new_mat
223    }
224}