1use std::fmt::Debug;
4use std::iter::repeat_n;
5
6use faer::{Col, ColRef, Mat, MatRef};
7use nuts_derive::Storable;
8use serde::{Deserialize, Serialize};
9
10use crate::transform::{DiagMassMatrix, Transformation};
11use crate::{Math, sampler_stats::SamplerStats};
12
13pub fn mat_all_finite(mat: &MatRef<f64>) -> bool {
14 let mut ok = true;
15 faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
16 ok
17}
18
19fn col_all_finite(mat: &ColRef<f64>) -> bool {
20 let mut ok = true;
21 faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
22 ok
23}
24
25struct InnerMatrix<M: Math> {
30 vecs: M::EigVectors,
31 vals_sqrt: M::EigValues,
33 vals_sqrt_inv: M::EigValues,
35 logdet_contribution: f64,
38 mu: M::Vector,
39 num_eigenvalues: u64,
40}
41
42impl<M: Math> Debug for InnerMatrix<M> {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("InnerMatrix")
45 .field("vecs", &"<eig vectors>")
46 .field("vals_sqrt", &"<sqrt eig values>")
47 .field("vals_sqrt_inv", &"<inv sqrt eig values>")
48 .field("logdet_contribution", &self.logdet_contribution)
49 .field("num_eigenvalues", &self.num_eigenvalues)
50 .field("mu", &self.mu)
51 .finish()
52 }
53}
54
55impl<M: Math> InnerMatrix<M> {
56 fn new(math: &mut M, mut vals: Col<f64>, vecs: Mat<f64>, mu: Col<f64>) -> Self {
57 let logdet_contribution: f64 = vals.iter().map(|&v| -0.5 * v.ln()).sum();
59 let num_eigenvalues = vals.nrows() as u64;
60
61 let vecs = math.new_eig_vectors(
62 vecs.col_iter()
63 .map(|col| col.try_as_col_major().unwrap().as_slice()),
64 );
65
66 vals.iter_mut().for_each(|x| *x = x.sqrt());
68 let vals_sqrt = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
69
70 vals.iter_mut().for_each(|x| *x = x.recip());
72 let vals_sqrt_inv = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
73
74 let mu = {
75 let mut array = math.new_array();
76 math.read_from_slice(&mut array, mu.try_as_col_major().unwrap().as_slice());
77 array
78 };
79
80 Self {
81 vecs,
82 vals_sqrt,
83 vals_sqrt_inv,
84 logdet_contribution,
85 mu,
86 num_eigenvalues,
87 }
88 }
89
90 fn logdet(&self) -> f64 {
91 self.logdet_contribution
92 }
93}
94
95pub struct LowRankMassMatrix<M: Math> {
112 diag: DiagMassMatrix<M>,
113 inner: Option<InnerMatrix<M>>,
114 settings: LowRankSettings,
115 logdet: f64,
116 id: i64,
118}
119
120impl<M: Math> Debug for LowRankMassMatrix<M> {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("LowRankMassMatrix")
123 .field("diag", &self.diag)
124 .field("inner", &self.inner)
125 .field("settings", &self.settings)
126 .field("id", &self.id)
127 .finish()
128 }
129}
130
131impl<M: Math> LowRankMassMatrix<M> {
132 pub fn new(math: &mut M, settings: LowRankSettings) -> Self {
133 Self {
134 diag: DiagMassMatrix::new(math, settings.store_mass_matrix),
135 settings,
136 logdet: 0f64,
137 inner: None,
138 id: -1,
139 }
140 }
141
142 pub fn update_from_grad(
144 &mut self,
145 math: &mut M,
146 pos: &M::Vector,
147 grad: &M::Vector,
148 fill_invalid: f64,
149 clamp: (f64, f64),
150 ) {
151 self.inner = None;
152 self.diag
153 .update_diag_grad(math, pos, grad, fill_invalid, clamp);
154 self.logdet = self.diag.logdet();
155 self.id += 1;
156 }
157
158 pub fn update(
165 &mut self,
166 math: &mut M,
167 stds: Col<f64>,
168 mean: Col<f64>,
169 vals: Col<f64>,
170 vecs: Mat<f64>,
171 mean_low_rank: Col<f64>,
172 ) {
173 if (!col_all_finite(&stds.as_ref())) | (!col_all_finite(&mean.as_ref())) {
174 return;
175 }
176 if (!col_all_finite(&vals.as_ref())) | (!mat_all_finite(&vecs.as_ref())) {
177 return;
178 }
179
180 let mut stds_array = math.new_array();
181 math.read_from_slice(&mut stds_array, stds.try_as_col_major().unwrap().as_slice());
182 let mut mean_array = math.new_array();
183 math.read_from_slice(&mut mean_array, mean.try_as_col_major().unwrap().as_slice());
184 self.diag.set_transform(math, &stds_array, &mean_array);
185
186 let inner = InnerMatrix::new(math, vals, vecs, mean_low_rank);
187 self.logdet = inner.logdet() + self.diag.logdet();
188 self.inner = Some(inner);
189 self.id += 1;
190 }
191}
192
193#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
194pub struct LowRankSettings {
195 pub store_mass_matrix: bool,
196 pub gamma: f64,
197 pub eigval_cutoff: f64,
198}
199
200impl Default for LowRankSettings {
201 fn default() -> Self {
202 Self {
203 store_mass_matrix: false,
204 gamma: 1e-5,
205 eigval_cutoff: 2f64,
206 }
207 }
208}
209
210#[derive(Debug, Storable)]
211pub struct MatrixStats {
212 #[storable(event = "transformation_update")]
215 pub transformation_update_id: Option<i64>,
216 #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
217 pub mass_matrix_eigvals: Option<Vec<f64>>,
218 #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
219 pub mass_matrix_stds: Option<Vec<f64>>,
220 #[storable(event = "transformation_update")]
221 pub num_eigenvalues: Option<u64>,
222}
223
224impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
225 type Stats = MatrixStats;
226 type StatsOptions = i64;
227
228 fn extract_stats(&self, math: &mut M, last_id: Self::StatsOptions) -> Self::Stats {
229 if self.id != last_id {
230 let num_eigenvalues = Some(
231 self.inner
232 .as_ref()
233 .map(|inner| inner.num_eigenvalues)
234 .unwrap_or(0),
235 );
236 if self.settings.store_mass_matrix {
237 let stds = Some(math.box_array(self.diag.stds()));
238 let eigvals = self
239 .inner
240 .as_ref()
241 .map(|inner| math.eigs_as_array(&inner.vals_sqrt));
242 let mut eigvals = eigvals.map(|x| x.into_vec());
243 if let Some(ref mut eigvals) = eigvals {
244 eigvals.extend(repeat_n(
245 f64::NAN,
246 stds.as_ref().unwrap().len() - eigvals.len(),
247 ));
248 }
249 MatrixStats {
250 transformation_update_id: Some(self.id),
251 mass_matrix_eigvals: eigvals,
252 mass_matrix_stds: stds.map(|x| x.into_vec()),
253 num_eigenvalues,
254 }
255 } else {
256 MatrixStats {
257 transformation_update_id: Some(self.id),
258 mass_matrix_eigvals: None,
259 mass_matrix_stds: None,
260 num_eigenvalues,
261 }
262 }
263 } else {
264 MatrixStats {
265 transformation_update_id: None,
266 mass_matrix_eigvals: None,
267 mass_matrix_stds: None,
268 num_eigenvalues: None,
269 }
270 }
271 }
272}
273
274impl<M: Math> Transformation<M> for LowRankMassMatrix<M> {
275 fn init_from_untransformed_position(
276 &self,
277 math: &mut M,
278 untransformed_position: &M::Vector,
279 untransformed_gradient: &mut M::Vector,
280 transformed_position: &mut M::Vector,
281 transformed_gradient: &mut M::Vector,
282 ) -> Result<(f64, f64), M::LogpErr> {
283 let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
284 self.compute_transformed_position(math, untransformed_position, transformed_position);
285 self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
286 Ok((logp, self.logdet(math)))
287 }
288
289 fn init_from_transformed_position(
290 &self,
291 math: &mut M,
292 untransformed_position: &mut M::Vector,
293 untransformed_gradient: &mut M::Vector,
294 transformed_position: &M::Vector,
295 transformed_gradient: &mut M::Vector,
296 ) -> Result<(f64, f64), M::LogpErr> {
297 self.compute_untransformed_position(math, transformed_position, untransformed_position);
298 let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
299 self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
300 Ok((logp, self.logdet(math)))
301 }
302
303 fn inv_transform_normalize(
304 &self,
305 math: &mut M,
306 untransformed_position: &M::Vector,
307 untransformed_gradient: &M::Vector,
308 transformed_position: &mut M::Vector,
309 transformed_gradient: &mut M::Vector,
310 ) -> Result<f64, M::LogpErr> {
311 self.compute_transformed_position(math, untransformed_position, transformed_position);
312 self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
313 Ok(self.logdet(math))
314 }
315
316 fn transformation_id(&self, _math: &mut M) -> i64 {
317 self.id
318 }
319
320 fn next_stats_options(&self, _math: &mut M, _current: i64) -> i64 {
321 self.id
322 }
323}
324
325impl<M: Math> LowRankMassMatrix<M> {
326 fn compute_transformed_position(
327 &self,
328 math: &mut M,
329 untransformed_position: &M::Vector,
330 transformed_position: &mut M::Vector,
331 ) {
332 math.axpy_out(
333 &self.diag.mean(),
334 &untransformed_position,
335 -1.0,
336 transformed_position,
337 );
338 math.array_mult_inplace(transformed_position, self.diag.inv_stds());
339
340 if let Some(inner) = &self.inner {
341 math.axpy(&inner.mu, transformed_position, -1.0);
342 math.apply_lowrank_transform_inplace(
343 &inner.vecs,
344 &inner.vals_sqrt_inv,
345 transformed_position,
346 );
347 }
348 }
349
350 fn compute_untransformed_position(
351 &self,
352 math: &mut M,
353 transformed_position: &M::Vector,
354 untransformed_position: &mut M::Vector,
355 ) {
356 match &self.inner {
357 None => {
358 math.array_mult(
359 transformed_position,
360 &self.diag.stds(),
361 untransformed_position,
362 );
363 }
364 Some(inner) => {
365 math.apply_lowrank_transform(
366 &inner.vecs,
367 &inner.vals_sqrt,
368 transformed_position,
369 untransformed_position,
370 );
371
372 math.axpy(&inner.mu, untransformed_position, 1.0);
373 math.array_mult_inplace(untransformed_position, &self.diag.stds());
374 }
375 }
376 math.axpy(&self.diag.mean(), untransformed_position, 1.0);
377 }
378
379 fn compute_transformed_gradient(
380 &self,
381 math: &mut M,
382 untransformed_gradient: &M::Vector,
383 transformed_gradient: &mut M::Vector,
384 ) {
385 math.array_mult(
386 untransformed_gradient,
387 self.diag.stds(),
388 transformed_gradient,
389 );
390
391 if let Some(inner) = &self.inner {
392 math.apply_lowrank_transform_inplace(
393 &inner.vecs,
394 &inner.vals_sqrt,
395 transformed_gradient,
396 );
397 }
398 }
399
400 fn logdet(&self, _math: &mut M) -> f64 {
402 self.logdet
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use faer::{Col, Mat};
409
410 use crate::Math;
411 use crate::math::CpuMath;
412 use crate::math::test_logps::NormalLogp;
413
414 use super::{LowRankMassMatrix, LowRankSettings};
415
416 fn make_math(dim: usize) -> CpuMath<NormalLogp> {
417 CpuMath::new(NormalLogp::new(dim, 0.0))
418 }
419
420 fn assert_close(a: &[f64], b: &[f64], tol: f64) {
421 assert_eq!(a.len(), b.len());
422 for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
423 assert!(
424 (ai - bi).abs() <= tol,
425 "index {i}: {ai} vs {bi} (tol {tol})"
426 );
427 }
428 }
429
430 fn read_vec(math: &mut CpuMath<NormalLogp>, v: &Col<f64>) -> Vec<f64> {
431 let mut out = vec![0f64; math.dim()];
432 math.write_to_slice(v, &mut out);
433 out
434 }
435
436 #[test]
438 fn test_diagonal_round_trip() {
439 let mut math = make_math(3);
440 let stds = Col::from_fn(3, |i| [1.0f64, 2.0, 3.0][i]);
441 let mean = Col::from_fn(3, |i| [0.5f64, -1.0, 2.0][i]);
442 let vals = Col::zeros(0);
443 let vecs = Mat::zeros(3, 0);
444 let mu = Col::zeros(3);
445 let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
446 mass.update(&mut math, stds, mean, vals, vecs, mu);
447
448 let x_orig = [1.5f64, -0.3, 4.2];
449 let mut untransformed = math.new_array();
450 let mut transformed = math.new_array();
451 let mut recovered = math.new_array();
452 math.read_from_slice(&mut untransformed, &x_orig);
453
454 mass.compute_transformed_position(&mut math, &untransformed, &mut transformed);
455 mass.compute_untransformed_position(&mut math, &transformed, &mut recovered);
456
457 assert_close(&read_vec(&mut math, &recovered), &x_orig, 1e-12);
458 }
459
460 #[test]
462 fn test_diagonal_round_trip_reverse() {
463 let mut math = make_math(3);
464 let stds = Col::from_fn(3, |i| [1.0f64, 2.0, 3.0][i]);
465 let mean = Col::from_fn(3, |i| [0.5f64, -1.0, 2.0][i]);
466 let vals = Col::zeros(0);
467 let vecs = Mat::zeros(3, 0);
468 let mu = Col::zeros(3);
469 let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
470 mass.update(&mut math, stds, mean, vals, vecs, mu);
471
472 let z_orig = [0.7f64, -1.1, 0.3];
473 let mut transformed = math.new_array();
474 let mut untransformed = math.new_array();
475 let mut recovered = math.new_array();
476 math.read_from_slice(&mut transformed, &z_orig);
477
478 mass.compute_untransformed_position(&mut math, &transformed, &mut untransformed);
479 mass.compute_transformed_position(&mut math, &untransformed, &mut recovered);
480
481 assert_close(&read_vec(&mut math, &recovered), &z_orig, 1e-12);
482 }
483
484 #[test]
486 fn test_lowrank_round_trip() {
487 let mut math = make_math(3);
488 let stds = Col::full(3, 1.0f64);
490 let mean = Col::from_fn(3, |i| [1.0f64, -0.5, 0.0][i]);
491 let vals = faer::col![4.0f64];
492 let mut vecs = Mat::zeros(3, 1);
493 vecs[(0, 0)] = 1.0;
494 let mu = Col::from_fn(3, |i| [0.2f64, -0.1, 0.0][i]);
495 let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
496 mass.update(&mut math, stds, mean, vals, vecs, mu);
497
498 let x_orig = [2.0f64, 0.5, -1.3];
499 let mut untransformed = math.new_array();
500 let mut transformed = math.new_array();
501 let mut recovered = math.new_array();
502 math.read_from_slice(&mut untransformed, &x_orig);
503
504 mass.compute_transformed_position(&mut math, &untransformed, &mut transformed);
505 mass.compute_untransformed_position(&mut math, &transformed, &mut recovered);
506
507 assert_close(&read_vec(&mut math, &recovered), &x_orig, 1e-12);
508 }
509
510 #[test]
512 fn test_lowrank_round_trip_reverse() {
513 let mut math = make_math(3);
514 let stds = Col::full(3, 1.0f64);
515 let mean = Col::from_fn(3, |i| [1.0f64, -0.5, 0.0][i]);
516 let vals = faer::col![4.0f64];
517 let mut vecs = Mat::zeros(3, 1);
518 vecs[(0, 0)] = 1.0;
519 let mu = Col::from_fn(3, |i| [0.2f64, -0.1, 0.0][i]);
520 let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
521 mass.update(&mut math, stds, mean, vals, vecs, mu);
522
523 let z_orig = [1.0f64, -0.3, 0.8];
524 let mut transformed = math.new_array();
525 let mut untransformed = math.new_array();
526 let mut recovered = math.new_array();
527 math.read_from_slice(&mut transformed, &z_orig);
528
529 mass.compute_untransformed_position(&mut math, &transformed, &mut untransformed);
530 mass.compute_transformed_position(&mut math, &untransformed, &mut recovered);
531
532 assert_close(&read_vec(&mut math, &recovered), &z_orig, 1e-12);
533 }
534}