1use std::fmt::Debug;
4use std::iter::repeat_n;
5
6use faer::{Col, ColRef, Mat, MatRef};
7use nuts_derive::Storable;
8use serde::{Deserialize, Serialize};
9
10use crate::transform::{DiagMassMatrix, Transformation};
11use crate::{Math, sampler_stats::SamplerStats};
12
13pub fn mat_all_finite(mat: &MatRef<f64>) -> bool {
14 let mut ok = true;
15 faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
16 ok
17}
18
19fn col_all_finite(mat: &ColRef<f64>) -> bool {
20 let mut ok = true;
21 faer::zip!(mat).for_each(|faer::unzip!(val)| ok &= val.is_finite());
22 ok
23}
24
25struct InnerMatrix<M: Math> {
30 vecs: M::EigVectors,
31 vals_sqrt: M::EigValues,
33 vals_sqrt_inv: M::EigValues,
35 logdet_contribution: f64,
38 num_eigenvalues: u64,
39}
40
41impl<M: Math> Debug for InnerMatrix<M> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("InnerMatrix")
44 .field("vecs", &"<eig vectors>")
45 .field("vals_sqrt", &"<sqrt eig values>")
46 .field("vals_sqrt_inv", &"<inv sqrt eig values>")
47 .field("logdet_contribution", &self.logdet_contribution)
48 .field("num_eigenvalues", &self.num_eigenvalues)
49 .finish()
50 }
51}
52
53impl<M: Math> InnerMatrix<M> {
54 fn new(math: &mut M, mut vals: Col<f64>, vecs: Mat<f64>) -> Self {
55 let logdet_contribution: f64 = vals.iter().map(|&v| -0.5 * v.ln()).sum();
57 let num_eigenvalues = vals.nrows() as u64;
58
59 let vecs = math.new_eig_vectors(
60 vecs.col_iter()
61 .map(|col| col.try_as_col_major().unwrap().as_slice()),
62 );
63
64 vals.iter_mut().for_each(|x| *x = x.sqrt());
66 let vals_sqrt = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
67
68 vals.iter_mut().for_each(|x| *x = x.recip());
70 let vals_sqrt_inv = math.new_eig_values(vals.try_as_col_major().unwrap().as_slice());
71
72 Self {
73 vecs,
74 vals_sqrt,
75 vals_sqrt_inv,
76 logdet_contribution,
77 num_eigenvalues,
78 }
79 }
80
81 fn logdet(&self) -> f64 {
82 self.logdet_contribution
83 }
84}
85
86pub struct LowRankMassMatrix<M: Math> {
103 diag: DiagMassMatrix<M>,
104 inner: Option<InnerMatrix<M>>,
105 settings: LowRankSettings,
106 logdet: f64,
107 id: i64,
109}
110
111impl<M: Math> Debug for LowRankMassMatrix<M> {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("LowRankMassMatrix")
114 .field("diag", &self.diag)
115 .field("inner", &self.inner)
116 .field("settings", &self.settings)
117 .field("id", &self.id)
118 .finish()
119 }
120}
121
122impl<M: Math> LowRankMassMatrix<M> {
123 pub fn new(math: &mut M, settings: LowRankSettings) -> Self {
124 Self {
125 diag: DiagMassMatrix::new(math, settings.store_mass_matrix),
126 settings,
127 logdet: 0f64,
128 inner: None,
129 id: -1,
130 }
131 }
132
133 pub fn update_from_grad(
135 &mut self,
136 math: &mut M,
137 pos: &M::Vector,
138 grad: &M::Vector,
139 fill_invalid: f64,
140 clamp: (f64, f64),
141 ) {
142 self.inner = None;
143 self.diag
144 .update_diag_grad(math, pos, grad, fill_invalid, clamp);
145 self.logdet = self.diag.logdet();
146 self.id += 1;
147 }
148
149 pub fn update(
156 &mut self,
157 math: &mut M,
158 stds: Col<f64>,
159 mean: Col<f64>,
160 vals: Col<f64>,
161 vecs: Mat<f64>,
162 ) {
163 if (!col_all_finite(&stds.as_ref())) | (!col_all_finite(&mean.as_ref())) {
164 return;
165 }
166 if (!col_all_finite(&vals.as_ref())) | (!mat_all_finite(&vecs.as_ref())) {
167 return;
168 }
169
170 let mut stds_array = math.new_array();
171 math.read_from_slice(&mut stds_array, stds.try_as_col_major().unwrap().as_slice());
172 let mut mean_array = math.new_array();
173 math.read_from_slice(&mut mean_array, mean.try_as_col_major().unwrap().as_slice());
174 self.diag.set_transform(math, &stds_array, &mean_array);
175
176 let inner = InnerMatrix::new(math, vals, vecs);
177 self.logdet = inner.logdet() + self.diag.logdet();
178 self.inner = Some(inner);
179 self.id += 1;
180 }
181}
182
183#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
184pub struct LowRankSettings {
185 pub store_mass_matrix: bool,
186 pub gamma: f64,
187 pub eigval_cutoff: f64,
188}
189
190impl Default for LowRankSettings {
191 fn default() -> Self {
192 Self {
193 store_mass_matrix: false,
194 gamma: 1e-5,
195 eigval_cutoff: 2f64,
196 }
197 }
198}
199
200#[derive(Debug, Storable)]
201pub struct MatrixStats {
202 #[storable(event = "transformation_update")]
205 pub transformation_update_id: Option<i64>,
206 #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
207 pub mass_matrix_eigvals: Option<Vec<f64>>,
208 #[storable(event = "transformation_update", dims("unconstrained_parameter"))]
209 pub mass_matrix_stds: Option<Vec<f64>>,
210 #[storable(event = "transformation_update")]
211 pub num_eigenvalues: Option<u64>,
212}
213
214impl<M: Math> SamplerStats<M> for LowRankMassMatrix<M> {
215 type Stats = MatrixStats;
216 type StatsOptions = i64;
217
218 fn extract_stats(&self, math: &mut M, last_id: Self::StatsOptions) -> Self::Stats {
219 if self.id != last_id {
220 let num_eigenvalues = Some(
221 self.inner
222 .as_ref()
223 .map(|inner| inner.num_eigenvalues)
224 .unwrap_or(0),
225 );
226 if self.settings.store_mass_matrix {
227 let stds = Some(math.box_array(self.diag.stds()));
228 let eigvals = self
229 .inner
230 .as_ref()
231 .map(|inner| math.eigs_as_array(&inner.vals_sqrt));
232 let mut eigvals = eigvals.map(|x| x.into_vec());
233 if let Some(ref mut eigvals) = eigvals {
234 eigvals.extend(repeat_n(
235 f64::NAN,
236 stds.as_ref().unwrap().len() - eigvals.len(),
237 ));
238 }
239 MatrixStats {
240 transformation_update_id: Some(self.id),
241 mass_matrix_eigvals: eigvals,
242 mass_matrix_stds: stds.map(|x| x.into_vec()),
243 num_eigenvalues,
244 }
245 } else {
246 MatrixStats {
247 transformation_update_id: Some(self.id),
248 mass_matrix_eigvals: None,
249 mass_matrix_stds: None,
250 num_eigenvalues,
251 }
252 }
253 } else {
254 MatrixStats {
255 transformation_update_id: None,
256 mass_matrix_eigvals: None,
257 mass_matrix_stds: None,
258 num_eigenvalues: None,
259 }
260 }
261 }
262}
263
264impl<M: Math> Transformation<M> for LowRankMassMatrix<M> {
265 fn init_from_untransformed_position(
266 &self,
267 math: &mut M,
268 untransformed_position: &M::Vector,
269 untransformed_gradient: &mut M::Vector,
270 transformed_position: &mut M::Vector,
271 transformed_gradient: &mut M::Vector,
272 ) -> Result<(f64, f64), M::LogpErr> {
273 let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
274 self.compute_transformed_position(math, untransformed_position, transformed_position);
275 self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
276 Ok((logp, self.logdet(math)))
277 }
278
279 fn init_from_transformed_position(
280 &self,
281 math: &mut M,
282 untransformed_position: &mut M::Vector,
283 untransformed_gradient: &mut M::Vector,
284 transformed_position: &M::Vector,
285 transformed_gradient: &mut M::Vector,
286 ) -> Result<(f64, f64), M::LogpErr> {
287 self.compute_untransformed_position(math, transformed_position, untransformed_position);
288 let logp = math.logp_array(untransformed_position, untransformed_gradient)?;
289 self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
290 Ok((logp, self.logdet(math)))
291 }
292
293 fn inv_transform_normalize(
294 &self,
295 math: &mut M,
296 untransformed_position: &M::Vector,
297 untransformed_gradient: &M::Vector,
298 transformed_position: &mut M::Vector,
299 transformed_gradient: &mut M::Vector,
300 ) -> Result<f64, M::LogpErr> {
301 self.compute_transformed_position(math, untransformed_position, transformed_position);
302 self.compute_transformed_gradient(math, untransformed_gradient, transformed_gradient);
303 Ok(self.logdet(math))
304 }
305
306 fn transformation_id(&self, _math: &mut M) -> i64 {
307 self.id
308 }
309
310 fn next_stats_options(&self, _math: &mut M, _current: i64) -> i64 {
311 self.id
312 }
313}
314
315impl<M: Math> LowRankMassMatrix<M> {
316 fn compute_transformed_position(
317 &self,
318 math: &mut M,
319 untransformed_position: &M::Vector,
320 transformed_position: &mut M::Vector,
321 ) {
322 math.axpy_out(
323 &self.diag.mean(),
324 &untransformed_position,
325 -1.0,
326 transformed_position,
327 );
328 math.array_mult_inplace(transformed_position, self.diag.inv_stds());
329
330 if let Some(inner) = &self.inner {
331 math.apply_lowrank_transform_inplace(
332 &inner.vecs,
333 &inner.vals_sqrt_inv,
334 transformed_position,
335 );
336 }
337 }
338
339 fn compute_untransformed_position(
340 &self,
341 math: &mut M,
342 transformed_position: &M::Vector,
343 untransformed_position: &mut M::Vector,
344 ) {
345 match &self.inner {
346 None => {
347 math.array_mult(
348 transformed_position,
349 &self.diag.stds(),
350 untransformed_position,
351 );
352 }
353 Some(inner) => {
354 math.apply_lowrank_transform(
355 &inner.vecs,
356 &inner.vals_sqrt,
357 transformed_position,
358 untransformed_position,
359 );
360 math.array_mult_inplace(untransformed_position, &self.diag.stds());
361 }
362 }
363 math.axpy(&self.diag.mean(), untransformed_position, 1.0);
364 }
365
366 fn compute_transformed_gradient(
367 &self,
368 math: &mut M,
369 untransformed_gradient: &M::Vector,
370 transformed_gradient: &mut M::Vector,
371 ) {
372 math.array_mult(
373 untransformed_gradient,
374 self.diag.stds(),
375 transformed_gradient,
376 );
377
378 if let Some(inner) = &self.inner {
379 math.apply_lowrank_transform_inplace(
380 &inner.vecs,
381 &inner.vals_sqrt,
382 transformed_gradient,
383 );
384 }
385 }
386
387 fn logdet(&self, _math: &mut M) -> f64 {
389 self.logdet
390 }
391}