mod adapt;
mod diagonal;
mod external;
mod low_rank;
mod transformation;
pub use adapt::DiagAdaptExpSettings;
pub(crate) use adapt::DiagAdaptStrategy;
pub use adapt::LowRankMassMatrixStrategy;
pub(crate) use adapt::MassMatrixAdaptStrategy;
pub(crate) use diagonal::DiagMassMatrix;
pub use external::ExternalTransformation;
pub(crate) use low_rank::LowRankMassMatrix;
pub use low_rank::LowRankSettings;
pub use transformation::Transformation;
#[cfg(test)]
mod tests {
use std::{collections::HashMap, error::Error, fmt::Display};
use faer::{Col, Mat};
use nuts_storable::{HasDims, Storable};
use crate::{
Math,
math::{CpuLogpFunc, CpuMath, CpuMathError, LogpError},
transform::{DiagMassMatrix, LowRankMassMatrix, LowRankSettings, Transformation},
};
struct MvNormal {
precision: Mat<f64>,
dim: usize,
}
impl MvNormal {
fn new(precision: Mat<f64>) -> Self {
let dim = precision.nrows();
Self { precision, dim }
}
}
#[derive(Debug)]
struct NeverError;
impl Display for NeverError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "never")
}
}
impl Error for NeverError {}
impl LogpError for NeverError {
fn is_recoverable(&self) -> bool {
false
}
}
struct EmptyExpanded;
impl Storable<MvNormal> for EmptyExpanded {
fn names(_: &MvNormal) -> Vec<&str> {
vec![]
}
fn item_type(_: &MvNormal, _: &str) -> nuts_storable::ItemType {
unimplemented!()
}
fn dims<'a>(_: &'a MvNormal, _: &str) -> Vec<&'a str> {
vec![]
}
fn get_all<'a>(
&'a mut self,
_: &'a MvNormal,
) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
vec![]
}
}
impl HasDims for MvNormal {
fn dim_sizes(&self) -> HashMap<String, u64> {
let mut m = HashMap::new();
m.insert("unconstrained_parameter".into(), self.dim as u64);
m
}
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
HashMap::new()
}
}
impl CpuLogpFunc for MvNormal {
type LogpError = NeverError;
type FlowParameters = ();
type ExpandedVector = EmptyExpanded;
fn dim(&self) -> usize {
self.dim
}
fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, NeverError> {
let x = Col::from_fn(self.dim, |i| position[i]);
let px = &self.precision * &x;
let logp = -0.5 * (&x).transpose() * &px;
for i in 0..self.dim {
gradient[i] = -px[i];
}
let log_det_p: f64 = (0..self.dim).map(|i| self.precision[(i, i)].ln()).sum();
let norm = -0.5 * (self.dim as f64 * std::f64::consts::TAU.ln() - log_det_p);
Ok(logp + norm)
}
fn expand_vector<R: rand::Rng + ?Sized>(
&mut self,
_rng: &mut R,
_array: &[f64],
) -> Result<EmptyExpanded, CpuMathError> {
Ok(EmptyExpanded)
}
}
fn make_math(precision: Mat<f64>) -> CpuMath<MvNormal> {
CpuMath::new(MvNormal::new(precision))
}
fn read_vec(math: &mut CpuMath<MvNormal>, v: &Col<f64>) -> Vec<f64> {
let mut out = vec![0f64; math.dim()];
math.write_to_slice(v, &mut out);
out
}
fn assert_close(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(a.len(), b.len(), "length mismatch");
for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
assert!(
(ai - bi).abs() <= tol,
"index {i}: {ai} vs {bi} (tol {tol})"
);
}
}
fn standard_normal_logp(z: &[f64]) -> f64 {
let d = z.len() as f64;
-0.5 * (d * std::f64::consts::TAU.ln() + z.iter().map(|v| v * v).sum::<f64>())
}
#[test]
fn test_diag_transform_position_and_gradient() {
let sigma2 = [1f64, 4., 9.];
let mut precision = Mat::zeros(3, 3);
for i in 0..3 {
precision[(i, i)] = 1.0 / sigma2[i];
}
let mut math = make_math(precision);
let mut mass = DiagMassMatrix::new(&mut math, false);
let mut draw_var = math.new_array();
let mut grad_var = math.new_array();
let mut draw_mean = math.new_array();
let mut grad_mean = math.new_array();
math.read_from_slice(&mut draw_var, &sigma2);
math.read_from_slice(&mut grad_var, &sigma2.map(|v| 1.0 / v));
math.fill_array(&mut draw_mean, 0f64);
math.fill_array(&mut grad_mean, 0f64);
mass.update_diag_draw_grad(
&mut math,
&draw_mean,
&grad_mean,
&draw_var,
&grad_var,
None,
(1e-20, 1e20),
);
let x = [1f64, 2., 3.];
let mut untransformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_pos = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x);
let (logp, logdet) = mass
.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let z = read_vec(&mut math, &transformed_pos);
assert_close(&z, &[1.0, 1.0, 1.0], 1e-12);
let beta = read_vec(&mut math, &transformed_grad);
assert_close(&beta, &[-1.0, -1.0, -1.0], 1e-12);
let expected_logdet: f64 = sigma2.iter().map(|&s| -(0.5 * s.ln())).sum();
assert!(
(logdet - expected_logdet).abs() < 1e-12,
"logdet: {logdet} vs {expected_logdet}"
);
let logp_adapted = logp - logdet;
let expected_logp_adapted = standard_normal_logp(&z);
assert!(
(logp_adapted - expected_logp_adapted).abs() < 1e-12,
"adapted logp: {logp_adapted} vs {expected_logp_adapted}, {logp}, {logdet}"
);
}
#[test]
fn test_diag_round_trip() {
let sigma2 = [2f64, 0.5, 3.];
let mut precision = Mat::zeros(3, 3);
for i in 0..3 {
precision[(i, i)] = 1.0 / sigma2[i];
}
let mut math = make_math(precision);
let mut mass = DiagMassMatrix::new(&mut math, false);
let mut draw_var = math.new_array();
let mut grad_var = math.new_array();
let mut draw_mean = math.new_array();
let mut grad_mean = math.new_array();
math.read_from_slice(&mut draw_var, &sigma2);
math.read_from_slice(&mut grad_var, &sigma2.map(|v| 1.0 / v));
math.fill_array(&mut draw_mean, 0f64);
math.fill_array(&mut grad_mean, 0f64);
mass.update_diag_draw_grad(
&mut math,
&draw_mean,
&grad_mean,
&draw_var,
&grad_var,
None,
(1e-20, 1e20),
);
let x_orig = [1.5f64, -0.3, 2.1];
let mut untransformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_pos = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x_orig);
let (logp_fwd, logdet_fwd) = mass
.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let mut recovered_pos = math.new_array();
let mut recovered_grad = math.new_array();
let mut recovered_transformed_grad = math.new_array();
let (logp_inv, logdet_inv) = mass
.init_from_transformed_position(
&mut math,
&mut recovered_pos,
&mut recovered_grad,
&transformed_pos,
&mut recovered_transformed_grad,
)
.unwrap();
let x_recovered = read_vec(&mut math, &recovered_pos);
assert_close(&x_recovered, &x_orig, 1e-12);
assert!((logp_fwd - logp_inv).abs() < 1e-12, "logp mismatch");
assert!((logdet_fwd - logdet_inv).abs() < 1e-12, "logdet mismatch");
}
#[test]
fn test_diag_nonzero_mean() {
let sigma2 = [4f64, 1., 9.];
let mu = [3f64, -1., 2.];
let mut precision = Mat::zeros(3, 3);
for i in 0..3 {
precision[(i, i)] = 1.0 / sigma2[i];
}
let mut math = make_math(precision);
let mut mass = DiagMassMatrix::new(&mut math, false);
let mut draw_var = math.new_array();
let mut grad_var = math.new_array();
let mut draw_mean = math.new_array();
let mut grad_mean = math.new_array();
math.read_from_slice(&mut draw_var, &sigma2);
math.read_from_slice(&mut grad_var, &sigma2.map(|v| 1.0 / v));
math.read_from_slice(&mut draw_mean, &mu);
math.fill_array(&mut grad_mean, 0f64);
mass.update_diag_draw_grad(
&mut math,
&draw_mean,
&grad_mean,
&draw_var,
&grad_var,
None,
(1e-20, 1e20),
);
let x: Vec<f64> = mu
.iter()
.zip(sigma2.iter())
.map(|(&m, &s)| m + s.sqrt())
.collect();
let mut untransformed_pos = math.new_array();
let mut transformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x);
mass.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let z = read_vec(&mut math, &transformed_pos);
assert_close(&z, &[1.0, 1.0, 1.0], 1e-12);
}
#[test]
fn test_lowrank_transform_position_and_gradient() {
let sigma2 = [1f64, 4., 9.];
let mut precision = Mat::zeros(3, 3);
for i in 0..3 {
precision[(i, i)] = 1.0 / sigma2[i];
}
let mut math = make_math(precision);
let stds = Col::from_fn(3, |i| sigma2[i].sqrt());
let mean = Col::zeros(3);
let vals = Col::zeros(0);
let vecs = Mat::zeros(3, 0);
let settings = LowRankSettings::default();
let mut mass = LowRankMassMatrix::new(&mut math, settings);
mass.update(&mut math, stds, mean, vals, vecs);
let x = [1f64, 2., 3.];
let mut untransformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_pos = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x);
let (logp, logdet) = mass
.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let z = read_vec(&mut math, &transformed_pos);
assert_close(&z, &[1.0, 1.0, 1.0], 1e-12);
let beta = read_vec(&mut math, &transformed_grad);
assert_close(&beta, &[-1.0, -1.0, -1.0], 1e-12);
let expected_logdet: f64 = sigma2.iter().map(|&s| -(0.5 * s.ln())).sum();
assert!(
(logdet - expected_logdet).abs() < 1e-12,
"logdet: {logdet} vs {expected_logdet}"
);
let logp_adapted = logp - logdet;
let expected_logp_adapted = standard_normal_logp(&z);
assert!(
(logp_adapted - expected_logp_adapted).abs() < 1e-12,
"adapted logp: {logp_adapted} vs {expected_logp_adapted}"
);
}
#[test]
fn test_lowrank_round_trip() {
let sigma2 = [2f64, 0.5, 3.];
let mut precision = Mat::zeros(3, 3);
for i in 0..3 {
precision[(i, i)] = 1.0 / sigma2[i];
}
let mut math = make_math(precision);
let stds = Col::from_fn(3, |i| sigma2[i].sqrt());
let mean = Col::zeros(3);
let vals = Col::zeros(0);
let vecs = Mat::zeros(3, 0);
let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
mass.update(&mut math, stds, mean, vals, vecs);
let x_orig = [0.7f64, -1.2, 3.3];
let mut untransformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_pos = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x_orig);
let (logp_fwd, logdet_fwd) = mass
.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let mut recovered_pos = math.new_array();
let mut recovered_grad = math.new_array();
let mut recovered_transformed_grad = math.new_array();
let (logp_inv, logdet_inv) = mass
.init_from_transformed_position(
&mut math,
&mut recovered_pos,
&mut recovered_grad,
&transformed_pos,
&mut recovered_transformed_grad,
)
.unwrap();
let x_recovered = read_vec(&mut math, &recovered_pos);
assert_close(&x_recovered, &x_orig, 1e-12);
assert!((logp_fwd - logp_inv).abs() < 1e-12, "logp mismatch");
assert!((logdet_fwd - logdet_inv).abs() < 1e-12, "logdet mismatch");
}
#[test]
fn test_lowrank_with_rank1_correction() {
let mut precision = Mat::zeros(3, 3);
precision[(0, 0)] = 0.25;
precision[(1, 1)] = 1.0;
precision[(2, 2)] = 1.0;
let mut math = make_math(precision);
let stds = Col::full(3, 1.0f64);
let mean = Col::zeros(3);
let vals = faer::col![4.0f64];
let mut vecs = Mat::zeros(3, 1);
vecs[(0, 0)] = 1.0;
let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
mass.update(&mut math, stds, mean, vals, vecs);
let x = [2f64, 1., 1.];
let mut untransformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_pos = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x);
let (logp, logdet) = mass
.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let z = read_vec(&mut math, &transformed_pos);
assert_close(&z, &[1.0, 1.0, 1.0], 1e-12);
let beta = read_vec(&mut math, &transformed_grad);
assert_close(&beta, &[-1.0, -1.0, -1.0], 1e-12);
let expected_logdet = -0.5f64 * 4f64.ln();
assert!(
(logdet - expected_logdet).abs() < 1e-12,
"logdet: {logdet} vs {expected_logdet}"
);
let logp_adapted = logp - logdet;
let expected_logp_adapted = standard_normal_logp(&z);
assert!(
(logp_adapted - expected_logp_adapted).abs() < 1e-12,
"adapted logp: {logp_adapted} vs {expected_logp_adapted}"
);
let mut recovered_pos = math.new_array();
let mut recovered_grad = math.new_array();
let mut recovered_tgrad = math.new_array();
mass.init_from_transformed_position(
&mut math,
&mut recovered_pos,
&mut recovered_grad,
&transformed_pos,
&mut recovered_tgrad,
)
.unwrap();
let x_rec = read_vec(&mut math, &recovered_pos);
assert_close(&x_rec, &x, 1e-12);
}
#[test]
fn test_lowrank_nonzero_mean() {
let sigma2 = [4f64, 1., 9.];
let mu = [2f64, -1., 3.];
let mut precision = Mat::zeros(3, 3);
for i in 0..3 {
precision[(i, i)] = 1.0 / sigma2[i];
}
let mut math = make_math(precision);
let stds = Col::from_fn(3, |i| sigma2[i].sqrt());
let mean = Col::from_fn(3, |i| mu[i]);
let vals = Col::zeros(0);
let vecs = Mat::zeros(3, 0);
let mut mass = LowRankMassMatrix::new(&mut math, LowRankSettings::default());
mass.update(&mut math, stds, mean, vals, vecs);
let x: Vec<f64> = mu
.iter()
.zip(sigma2.iter())
.map(|(&m, &s)| m + s.sqrt())
.collect();
let mut untransformed_pos = math.new_array();
let mut transformed_pos = math.new_array();
let mut untransformed_grad = math.new_array();
let mut transformed_grad = math.new_array();
math.read_from_slice(&mut untransformed_pos, &x);
mass.init_from_untransformed_position(
&mut math,
&untransformed_pos,
&mut untransformed_grad,
&mut transformed_pos,
&mut transformed_grad,
)
.unwrap();
let z = read_vec(&mut math, &transformed_pos);
assert_close(&z, &[1.0, 1.0, 1.0], 1e-12);
let mut recovered_pos = math.new_array();
let mut recovered_grad = math.new_array();
let mut recovered_tgrad = math.new_array();
mass.init_from_transformed_position(
&mut math,
&mut recovered_pos,
&mut recovered_grad,
&transformed_pos,
&mut recovered_tgrad,
)
.unwrap();
let x_rec = read_vec(&mut math, &recovered_pos);
assert_close(&x_rec, &x, 1e-12);
}
}