1use crate::utils::{eigenvalues_2x2, index_of_largest, swap_columns, swap_rows};
9use crate::Decomposition;
10
11pub trait ModCholeskySE99<L, E, P>
22where
23 Self: Sized,
24{
25 fn mod_cholesky_se99(&self) -> Decomposition<L, E, P> {
27 panic!("Not implemented!")
28 }
29}
30
31impl ModCholeskySE99<ndarray::Array2<f64>, ndarray::Array1<f64>, ndarray::Array1<usize>>
32 for ndarray::Array2<f64>
33{
34 fn mod_cholesky_se99(
36 &self,
37 ) -> Decomposition<ndarray::Array2<f64>, ndarray::Array1<f64>, ndarray::Array1<usize>> {
38 assert!(self.is_square());
39 use ndarray::s;
40
41 let n = self.raw_dim()[0];
42
43 let mut l = self.clone();
44 let mut e = ndarray::Array1::zeros(n);
45 let mut p = ndarray::Array::from_iter(0..n);
46
47 let tau = std::f64::EPSILON.cbrt();
49 let tau_bar = std::f64::EPSILON.cbrt();
50 let mu = 0.1_f64;
51
52 let mut phaseone = true;
53
54 let gamma = l
55 .diag()
56 .fold(0.0, |acc, x| if x.abs() > acc { x.abs() } else { acc });
57
58 let mut j = 0;
59
60 while j < n && phaseone {
62 let aii_max =
63 l.diag().slice(s![j..]).fold(
64 std::f64::NEG_INFINITY,
65 |acc, &x| if x > acc { x } else { acc },
66 );
67 let aii_min =
68 l.diag()
69 .slice(s![j..])
70 .fold(std::f64::INFINITY, |acc, &x| if x < acc { x } else { acc });
71 if aii_max < tau_bar * gamma || aii_min < -mu * aii_max {
72 phaseone = false;
73 break;
74 } else {
75 let max_idx = index_of_largest(&l.diag().slice(s![j..]));
77 if max_idx != 0 {
78 swap_rows(&mut l, j, j + max_idx);
79 swap_columns(&mut l, j, j + max_idx);
80 p.swap(j, j + max_idx);
81 }
82 let tmp = ((j + 1)..n).fold(std::f64::INFINITY, |acc, i| {
83 let nv = l[(i, i)] - l[(i, j)].powi(2) / l[(j, j)];
84 if nv < acc {
85 nv
86 } else {
87 acc
88 }
89 });
90 if tmp < -mu * gamma {
91 phaseone = false;
93 break;
94 } else {
95 l[(j, j)] = l[(j, j)].sqrt();
97 for i in (j + 1)..n {
98 l[(i, j)] /= l[(j, j)];
99 l[(j, i)] /= l[(j, j)];
100 for k in (j + 1)..=i {
101 l[(i, k)] -= l[(i, j)] * l[(k, j)];
102 l[(k, i)] = l[(i, k)];
104 }
105 }
106 j += 1;
107 }
108 }
109 }
110
111 let mut delta_prev = 0.0;
113
114 if !phaseone && j == n - 1 {
116 e[j] = -l[(j, j)] + (tau_bar * gamma).max(tau * (-l[(j, j)]) / (1.0 - tau));
117 l[(j, j)] += e[j];
118 l[(j, j)] = l[(j, j)].sqrt();
119 }
120
121 if !phaseone && j < n - 1 {
122 let k = j;
123
124 let mut g = ndarray::Array::zeros(n);
126 for i in k..n {
127 g[i] = l[(i, i)]
128 - l.slice(s![i, k..i]).map(|x| x.abs()).sum()
129 - l.slice(s![(i + 1).., i]).map(|x| x.abs()).sum();
130 }
131
132 for j in k..(n - 2) {
134 let max_idx = index_of_largest(&g.slice(s![j..]));
136 if max_idx != 0 {
137 swap_rows(&mut l, j, j + max_idx);
138 swap_columns(&mut l, j, j + max_idx);
139 p.swap(j, j + max_idx);
140 g.swap(j, j + max_idx);
141 }
142
143 let normj = l.slice(s![(j + 1).., j]).map(|x| x.abs()).sum();
145 e[j] = 0.0f64
146 .max(delta_prev)
147 .max(-l[(j, j)] + normj.max(tau_bar * gamma));
148 if e[j] > 0.0 {
149 l[(j, j)] += e[j];
150 delta_prev = e[j];
151 }
152
153 if (l[(j, j)] - normj).abs() > 1.0 * std::f64::EPSILON {
155 let tmp = 1.0 - normj / l[(j, j)];
156 for i in (j + 1)..n {
157 g[i] += l[(i, j)].abs() * tmp;
158 }
159 }
160
161 l[(j, j)] = l[(j, j)].sqrt();
163 for i in (j + 1)..n {
164 l[(i, j)] /= l[(j, j)];
165 l[(j, i)] /= l[(j, j)];
166 for k in (j + 1)..=i {
167 l[(i, k)] -= l[(i, j)] * l[(k, j)];
168 l[(k, i)] = l[(i, k)];
170 }
171 }
172 }
173
174 let (lhi, llo) = eigenvalues_2x2(&l.slice(s![(n - 2).., (n - 2)..]));
176 e[n - 2] = 0.0f64
177 .max(-llo + (tau_bar * gamma).max(tau * (lhi - llo) / (1.0 - tau)))
178 .max(delta_prev);
179 e[n - 1] = e[n - 2];
180 if e[n - 2] > 0.0 {
181 l[(n - 2, n - 2)] += e[n - 2];
182 l[(n - 1, n - 1)] += e[n - 1];
183 }
184 l[(n - 2, n - 2)] = l[(n - 2, n - 2)].sqrt();
185 l[(n - 1, n - 2)] /= l[(n - 2, n - 2)];
186 l[(n - 1, n - 1)] = (l[(n - 1, n - 1)] - l[(n - 1, n - 2)].powi(2)).sqrt();
187 }
188
189 for i in 0..(n - 1) {
191 l.slice_mut(s![i, (i + 1)..]).fill(0.0);
192 }
193
194 let ec = e.clone();
196 for i in 0..n {
197 e[p[i]] = ec[i];
198 }
199
200 Decomposition::new(l, e, p)
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::utils::*;
208
209 #[test]
210 fn test_modchol_se99_3x3() {
211 let a: ndarray::Array2<f64> =
212 ndarray::arr2(&[[1.0, 1.0, 2.0], [1.0, 1.0, 3.0], [2.0, 3.0, 1.0]]);
213 let res_l_up: ndarray::Array2<f64> = ndarray::arr2(&[
214 [1.732050807568877, 0.5773502691896257, 1.154700538379251],
215 [0.0, 1.698920954907997, 1.37342077428181],
216 [0.0, 0.0, 0.006912871809428971],
217 ]);
218 let res = res_l_up.t().dot(&res_l_up);
219 let decomp = a.mod_cholesky_se99();
220 let l = decomp.l;
221 let e = diag_mat_from_arr(decomp.e.as_slice().unwrap());
222 let p = index_to_permutation_mat(decomp.p.as_slice().unwrap());
223 let paptpept = p.dot(&a.dot(&p.t())) + p.dot(&e.dot(&p.t()));
224 assert!(paptpept.abs_diff_eq(&l.dot(&l.t()), 1e-12));
232 assert!(l.dot(&(l.t())).abs_diff_eq(&res, 1e-12));
233 }
234
235 #[test]
236 fn test_modchol_se99_4x4() {
237 let a: ndarray::Array2<f64> = ndarray::arr2(&[
238 [1890.3, -1705.6, -315.8, 3000.3],
239 [-1705.6, 1538.3, 284.9, -2706.6],
240 [-315.8, 284.9, 52.5, -501.2],
241 [3000.3, -2706.6, -501.2, 4760.8],
242 ]);
243 let res_l_up: ndarray::Array2<f64> = ndarray::arr2(&[
244 [
245 68.99855070941707,
246 -7.263920688867382,
247 -39.22691088684848,
248 43.48352203273905,
249 ],
250 [
251 0.0,
252 0.3194133212151726,
253 -0.1288911532532789,
254 0.1905221679618937,
255 ],
256 [0.0, 0.0, 0.4447317171993393, 0.3345847412304742],
257 [0.0, 0.0, 0.0, 0.001713817545399892],
258 ]);
259 let res = res_l_up.t().dot(&res_l_up);
260
261 let decomp = a.mod_cholesky_se99();
262 let l = decomp.l;
263 let e = diag_mat_from_arr(decomp.e.as_slice().unwrap());
264 let p = index_to_permutation_mat(decomp.p.as_slice().unwrap());
265 let paptpept = p.dot(&a.dot(&p.t())) + p.dot(&e.dot(&p.t()));
266 assert!(paptpept.abs_diff_eq(&l.dot(&l.t()), 1e-12));
274 assert!(l.dot(&(l.t())).abs_diff_eq(&res, 1e-1));
277 }
278
279 #[test]
280 fn test_modchol_se99_6x6() {
281 let a: ndarray::Array2<f64> = ndarray::arr2(&[
282 [14.8253, -6.4243, 7.8746, -1.2498, 10.2733, 10.2733],
283 [-6.4243, 15.1024, -1.1155, -0.2761, -8.2117, -8.2117],
284 [7.8746, -1.1155, 51.8519, -23.3482, 12.5902, 12.5902],
285 [-1.2498, -0.2761, -23.3482, 22.7967, -9.8958, -9.8958],
286 [10.2733, -8.2117, 12.5902, -9.8958, 21.0656, 21.0656],
287 [10.2733, -8.2117, 12.5902, -9.8958, 21.0656, 21.0656],
288 ]);
289 let res_l_up: ndarray::Array2<f64> = ndarray::arr2(&[
290 [
291 7.200826341469429,
292 1.748438221248757,
293 -0.1549127762706699,
294 -3.242433422611255,
295 1.093568935922023,
296 1.748438221248757,
297 ],
298 [
299 0.0,
300 4.243649819020943,
301 -1.871229936413708,
302 -0.9959835646917692,
303 1.970299772942301,
304 4.243649819020943,
305 ],
306 [
307 0.0,
308 0.0,
309 3.402484468269805,
310 -0.7765233465239986,
311 -0.7547450415137518,
312 0.0,
313 ],
314 [0.0, 0.0, 0.0, 3.269304777945995, 1.123276587271259, 0.0],
315 [0.0, 0.0, 0.0, 0.0, 2.813527220044002, 0.0],
316 [0.0, 0.0, 0.0, 0.0, 0.0, 4.360427593036232e-05],
317 ]);
318 let res = res_l_up.t().dot(&res_l_up);
319
320 let decomp = a.mod_cholesky_se99();
321 let l = decomp.l;
322 let e = diag_mat_from_arr(decomp.e.as_slice().unwrap());
323 let p = index_to_permutation_mat(decomp.p.as_slice().unwrap());
324 let paptpept = p.dot(&a.dot(&p.t())) + p.dot(&e.dot(&p.t()));
325 assert!(paptpept.abs_diff_eq(&l.dot(&l.t()), 1e-12));
333 assert!(l.dot(&(l.t())).abs_diff_eq(&res, 1e-3));
336 }
337}