use std::{collections::HashMap, error::Error, fmt::Debug, mem::replace};
use faer::linalg::matmul::matmul;
use faer::{Accum, Col, Mat, Par};
use itertools::{Itertools, izip};
use nuts_storable::{HasDims, Storable, Value};
use rand::RngExt;
use thiserror::Error;
use crate::math::util::multiply_inplace;
use super::{
math::{LogpError, Math},
util::{
axpy, axpy_out, multiply, scalar_prods2, scalar_prods3, std_norm_flow, std_norm_grad_flow,
std_norm_grad_flow_inplace, vector_dot,
},
};
#[derive(Debug)]
pub struct CpuMath<F: CpuLogpFunc> {
logp_func: F,
arch: pulp::Arch,
lowrank_scratch: Col<f64>,
}
impl<F: CpuLogpFunc> CpuMath<F> {
pub fn new(logp_func: F) -> Self {
let arch = pulp::Arch::new();
Self {
logp_func,
arch,
lowrank_scratch: Col::zeros(0),
}
}
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum CpuMathError {
#[error("Error during array operation")]
ArrayError(),
#[error("Error during point expansion: {0}")]
ExpandError(String),
}
impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
fn dim_sizes(&self) -> HashMap<String, u64> {
self.logp_func.dim_sizes()
}
fn coords(&self) -> HashMap<String, nuts_storable::Value> {
self.logp_func.coords()
}
}
pub struct ExpandedVectorWrapper<F: CpuLogpFunc>(F::ExpandedVector);
impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
fn names(parent: &CpuMath<F>) -> Vec<&str> {
F::ExpandedVector::names(&parent.logp_func)
}
fn item_type(parent: &CpuMath<F>, item: &str) -> nuts_storable::ItemType {
F::ExpandedVector::item_type(&parent.logp_func, item)
}
fn dims<'a>(parent: &'a CpuMath<F>, item: &str) -> Vec<&'a str> {
F::ExpandedVector::dims(&parent.logp_func, item)
}
fn get_all<'a>(
&'a mut self,
parent: &'a CpuMath<F>,
) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
self.0.get_all(&parent.logp_func)
}
}
impl<F: CpuLogpFunc> Math for CpuMath<F> {
type Vector = Col<f64>;
type EigVectors = Mat<f64>;
type EigValues = Col<f64>;
type LogpErr = F::LogpError;
type Err = CpuMathError;
type FlowParameters = F::FlowParameters;
type ExpandedVector = ExpandedVectorWrapper<F>;
fn new_array(&mut self) -> Self::Vector {
Col::zeros(self.dim())
}
fn new_eig_vectors<'a>(
&'a mut self,
vals: impl ExactSizeIterator<Item = &'a [f64]>,
) -> Self::EigVectors {
let ndim = self.dim();
let nvecs = vals.len();
let mut vectors: Mat<f64> = Mat::zeros(ndim, nvecs);
vectors.col_iter_mut().zip_eq(vals).for_each(|(col, vals)| {
col.try_as_col_major_mut()
.expect("Array is not contiguous")
.as_slice_mut()
.copy_from_slice(vals)
});
vectors
}
fn new_eig_values(&mut self, vals: &[f64]) -> Self::EigValues {
let mut values: Col<f64> = Col::zeros(vals.len());
values
.try_as_col_major_mut()
.expect("Array is not contiguous")
.as_slice_mut()
.copy_from_slice(vals);
values
}
fn logp_array(
&mut self,
position: &Self::Vector,
gradient: &mut Self::Vector,
) -> Result<f64, Self::LogpErr> {
self.logp_func.logp(
position
.try_as_col_major()
.expect("Array is not contiguous")
.as_slice(),
gradient
.try_as_col_major_mut()
.expect("Array is not contiguous")
.as_slice_mut(),
)
}
fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpErr> {
self.logp_func.logp(position, gradient)
}
fn dim(&self) -> usize {
self.logp_func.dim()
}
fn expand_vector<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
array: &Self::Vector,
) -> Result<Self::ExpandedVector, Self::Err> {
Ok(ExpandedVectorWrapper(
self.logp_func.expand_vector(
rng,
array
.try_as_col_major()
.ok_or_else(|| {
CpuMathError::ExpandError("Internal vector was not col major".into())
})?
.as_slice(),
)?,
))
}
fn vector_coord(&self) -> Option<Value> {
self.logp_func.vector_coord()
}
fn init_position<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
position: &mut Self::Vector,
gradient: &mut Self::Vector,
) -> Result<f64, Self::LogpErr> {
let pos = position
.try_as_col_major_mut()
.expect("Array is not contiguous")
.as_slice_mut();
pos.iter_mut().for_each(|x| {
let val: f64 = rng.random();
*x = val * 2f64 - 1f64
});
self.logp_func.logp(
position
.try_as_col_major()
.expect("Array is not contiguous")
.as_slice(),
gradient
.try_as_col_major_mut()
.expect("Array is not contiguous")
.as_slice_mut(),
)
}
fn scalar_prods3(
&mut self,
positive1: &Self::Vector,
negative1: &Self::Vector,
positive2: &Self::Vector,
x: &Self::Vector,
y: &Self::Vector,
) -> (f64, f64) {
scalar_prods3(
self.arch,
positive1.try_as_col_major().unwrap().as_slice(),
negative1.try_as_col_major().unwrap().as_slice(),
positive2.try_as_col_major().unwrap().as_slice(),
x.try_as_col_major().unwrap().as_slice(),
y.try_as_col_major().unwrap().as_slice(),
)
}
fn scalar_prods2(
&mut self,
positive1: &Self::Vector,
positive2: &Self::Vector,
x: &Self::Vector,
y: &Self::Vector,
) -> (f64, f64) {
scalar_prods2(
self.arch,
positive1.try_as_col_major().unwrap().as_slice(),
positive2.try_as_col_major().unwrap().as_slice(),
x.try_as_col_major().unwrap().as_slice(),
y.try_as_col_major().unwrap().as_slice(),
)
}
fn sq_norm_sum(&mut self, x: &Self::Vector, y: &Self::Vector) -> f64 {
x.try_as_col_major()
.unwrap()
.as_slice()
.iter()
.zip(y.try_as_col_major().unwrap().as_slice())
.map(|(&x, &y)| (x + y) * (x + y))
.sum()
}
fn read_from_slice(&mut self, dest: &mut Self::Vector, source: &[f64]) {
dest.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.copy_from_slice(source);
}
fn write_to_slice(&mut self, source: &Self::Vector, dest: &mut [f64]) {
dest.copy_from_slice(source.try_as_col_major().unwrap().as_slice())
}
fn copy_into(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
dest.clone_from(array)
}
fn axpy_out(&mut self, x: &Self::Vector, y: &Self::Vector, a: f64, out: &mut Self::Vector) {
axpy_out(
self.arch,
x.try_as_col_major().unwrap().as_slice(),
y.try_as_col_major().unwrap().as_slice(),
a,
out.try_as_col_major_mut().unwrap().as_slice_mut(),
);
}
fn axpy(&mut self, x: &Self::Vector, y: &mut Self::Vector, a: f64) {
axpy(
self.arch,
x.try_as_col_major().unwrap().as_slice(),
y.try_as_col_major_mut().unwrap().as_slice_mut(),
a,
);
}
fn fill_array(&mut self, array: &mut Self::Vector, val: f64) {
faer::zip!(array).for_each(|faer::unzip!(pos)| *pos = val);
}
fn array_all_finite(&mut self, array: &Self::Vector) -> bool {
let mut ok = true;
faer::zip!(array).for_each(|faer::unzip!(val)| ok &= val.is_finite());
ok
}
fn array_all_finite_and_nonzero(&mut self, array: &Self::Vector) -> bool {
self.arch.dispatch(|| {
array
.try_as_col_major()
.unwrap()
.as_slice()
.iter()
.all(|&x| x.is_finite() & (x != 0f64))
})
}
fn array_sum_ln(&mut self, array: &Self::Vector) -> f64 {
let mut sum = 0f64;
faer::zip!(array).for_each(|faer::unzip!(val)| sum += val.ln());
sum
}
fn array_mult(
&mut self,
array1: &Self::Vector,
array2: &Self::Vector,
dest: &mut Self::Vector,
) {
multiply(
self.arch,
array1.try_as_col_major().unwrap().as_slice(),
array2.try_as_col_major().unwrap().as_slice(),
dest.try_as_col_major_mut().unwrap().as_slice_mut(),
)
}
fn array_mult_inplace(&mut self, array1: &mut Self::Vector, array2: &Self::Vector) {
multiply_inplace(
self.arch,
array1.try_as_col_major_mut().unwrap().as_slice_mut(),
array2.try_as_col_major().unwrap().as_slice(),
)
}
fn array_recip(&mut self, array: &Self::Vector, dest: &mut Self::Vector) {
faer::zip!(array, dest).for_each(|faer::unzip!(val, dest)| *dest = val.recip())
}
fn apply_lowrank_transform(
&mut self,
vecs: &Self::EigVectors,
vals: &Self::EigValues,
rhs: &Self::Vector,
dest: &mut Self::Vector,
) {
if vecs.ncols() == 0 {
self.copy_into(rhs, dest);
return;
}
let rank = vecs.ncols();
if self.lowrank_scratch.nrows() != rank {
self.lowrank_scratch.resize_with(rank, |_| 0.0);
}
matmul(
self.lowrank_scratch.as_mut(),
Accum::Replace,
vecs.transpose(),
rhs.as_ref(),
1.0,
Par::Seq,
);
self.lowrank_scratch
.iter_mut()
.zip(vals.iter())
.for_each(|(s, &v)| *s *= v - 1.0);
dest.copy_from(rhs);
matmul(
dest.as_mut(),
Accum::Add,
vecs.as_ref(),
self.lowrank_scratch.as_ref(),
1.0,
Par::Seq,
);
}
fn apply_lowrank_transform_inplace(
&mut self,
vecs: &Self::EigVectors,
vals: &Self::EigValues,
rhs_and_dest: &mut Self::Vector,
) {
if vecs.ncols() == 0 {
return;
}
let rank = vecs.ncols();
if self.lowrank_scratch.nrows() != rank {
self.lowrank_scratch.resize_with(rank, |_| 0.0);
}
matmul(
self.lowrank_scratch.as_mut(),
Accum::Replace,
vecs.transpose(),
rhs_and_dest.as_ref(),
1.0,
Par::Seq,
);
self.lowrank_scratch
.iter_mut()
.zip(vals.iter())
.for_each(|(s, &v)| *s *= v - 1.0);
matmul(
rhs_and_dest.as_mut(),
Accum::Add,
vecs.as_ref(),
self.lowrank_scratch.as_ref(),
1.0,
Par::Seq,
);
}
fn array_mult_eigs(
&mut self,
stds: &Self::Vector,
rhs: &Self::Vector,
dest: &mut Self::Vector,
vecs: &Self::EigVectors,
vals: &Self::EigValues,
) {
let rhs = stds.as_diagonal() * rhs;
let trafo = vecs.transpose() * (&rhs);
let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + rhs;
let scaled = stds.as_diagonal() * inner_prod;
let _ = replace(dest, scaled);
}
fn std_norm_flow(
&mut self,
pos: &Self::Vector,
pos_out: &mut Self::Vector,
vel: &mut Self::Vector,
epsilon: f64,
) {
std_norm_flow(
self.arch,
pos.try_as_col_major().unwrap().as_slice(),
pos_out.try_as_col_major_mut().unwrap().as_slice_mut(),
vel.try_as_col_major_mut().unwrap().as_slice_mut(),
epsilon,
);
}
fn std_norm_grad_flow(
&mut self,
pos: &Self::Vector,
grad: &Self::Vector,
vel: &Self::Vector,
vel_out: &mut Self::Vector,
epsilon: f64,
) {
std_norm_grad_flow(
self.arch,
pos.try_as_col_major().unwrap().as_slice(),
grad.try_as_col_major().unwrap().as_slice(),
vel.try_as_col_major().unwrap().as_slice(),
vel_out.try_as_col_major_mut().unwrap().as_slice_mut(),
epsilon,
);
}
fn std_norm_grad_flow_inplace(
&mut self,
pos: &Self::Vector,
grad: &Self::Vector,
vel: &mut Self::Vector,
epsilon: f64,
) {
std_norm_grad_flow_inplace(
self.arch,
pos.try_as_col_major().unwrap().as_slice(),
grad.try_as_col_major().unwrap().as_slice(),
vel.try_as_col_major_mut().unwrap().as_slice_mut(),
epsilon,
);
}
fn array_normalize(&mut self, v: &mut Self::Vector) {
let v = v.try_as_col_major_mut().unwrap().as_slice_mut();
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
let inv = 1.0 / norm;
for x in v.iter_mut() {
*x *= inv;
}
}
fn esh_momentum_update(
&mut self,
gradient: &Self::Vector,
momentum: &mut Self::Vector,
step_size: f64,
) -> f64 {
let gradient = gradient.try_as_col_major().unwrap().as_slice();
let momentum = momentum.try_as_col_major_mut().unwrap().as_slice_mut();
let n = gradient.len();
assert!(n >= 2, "ESH dynamics requires at least 2 dimensions");
let grad_norm: f64 = gradient.iter().map(|g| g * g).sum::<f64>().sqrt();
let inv_grad_norm = 1.0 / grad_norm;
let momentum_proj: f64 = momentum
.iter()
.zip(gradient.iter())
.map(|(p, g)| p * g * inv_grad_norm)
.sum();
let dims_m1 = (n - 1) as f64;
let delta = step_size * grad_norm / dims_m1;
let zeta = (-delta).exp();
let coeff_g = (1.0 - zeta) * (1.0 + zeta + momentum_proj * (1.0 - zeta));
let coeff_p = 2.0 * zeta;
for (p, g) in momentum.iter_mut().zip(gradient.iter()) {
*p = coeff_g * (g * inv_grad_norm) + coeff_p * *p;
}
let raw_norm: f64 = momentum.iter().map(|p| p * p).sum::<f64>().sqrt();
let inv = 1.0 / raw_norm;
for p in momentum.iter_mut() {
*p *= inv;
}
let arg = momentum_proj + (1.0 - momentum_proj) * zeta * zeta;
let kinetic_energy_change = (delta - std::f64::consts::LN_2 + arg.ln_1p()) * dims_m1;
kinetic_energy_change
}
fn array_vector_dot(&mut self, array1: &Self::Vector, array2: &Self::Vector) -> f64 {
vector_dot(
self.arch,
array1.try_as_col_major().unwrap().as_slice(),
array2.try_as_col_major().unwrap().as_slice(),
)
}
fn array_gaussian<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
dest: &mut Self::Vector,
stds: &Self::Vector,
) {
let dist = rand_distr::StandardNormal;
dest.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut()
.zip(stds.try_as_col_major().unwrap().as_slice().iter())
.for_each(|(p, &s)| {
let norm: f64 = rng.sample(dist);
*p = s * norm;
});
}
fn array_gaussian_eigs<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
dest: &mut Self::Vector,
scale: &Self::Vector,
vals: &Self::EigValues,
vecs: &Self::EigVectors,
) {
let mut draw: Col<f64> = Col::zeros(self.dim());
let dist = rand_distr::StandardNormal;
draw.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut()
.for_each(|p| {
*p = rng.sample(dist);
});
let trafo = vecs.transpose() * (&draw);
let inner_prod = vecs * (vals.as_diagonal() * (&trafo) - (&trafo)) + draw;
let scaled = scale.as_diagonal() * inner_prod;
let _ = replace(dest, scaled);
}
fn array_update_variance(
&mut self,
mean: &mut Self::Vector,
variance: &mut Self::Vector,
value: &Self::Vector,
diff_scale: f64, ) {
self.arch.dispatch(|| {
izip!(
mean.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
variance
.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
value.try_as_col_major().unwrap().as_slice()
)
.for_each(|(mean, var, x)| {
let diff = x - *mean;
*mean += diff * diff_scale;
*var += diff * diff;
});
})
}
fn array_update_var_inv_std_draw(
&mut self,
inv_std: &mut Self::Vector,
std: &mut Self::Vector,
draw_var: &Self::Vector,
scale: f64,
fill_invalid: Option<f64>,
clamp: (f64, f64),
) {
self.arch.dispatch(|| {
izip!(
std.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
inv_std
.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
draw_var.try_as_col_major().unwrap().as_slice().iter(),
)
.for_each(|(std_out, inv_std_out, &draw_var)| {
let draw_var = draw_var * scale;
if (!draw_var.is_finite()) | (draw_var == 0f64) {
if let Some(fill_val) = fill_invalid {
*std_out = fill_val.sqrt();
*inv_std_out = fill_val.recip().sqrt();
}
} else {
let val = draw_var.clamp(clamp.0, clamp.1);
*std_out = val.sqrt();
*inv_std_out = val.recip().sqrt();
}
});
});
}
fn array_update_var_inv_std_draw_grad(
&mut self,
inv_std: &mut Self::Vector,
std: &mut Self::Vector,
draw_var: &Self::Vector,
grad_var: &Self::Vector,
fill_invalid: Option<f64>,
clamp: (f64, f64),
) {
self.arch.dispatch(|| {
izip!(
std.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
inv_std
.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
draw_var.try_as_col_major().unwrap().as_slice().iter(),
grad_var.try_as_col_major().unwrap().as_slice().iter(),
)
.for_each(|(std_out, inv_std_out, &draw_var, &grad_var)| {
let val = (draw_var / grad_var).sqrt();
if (!val.is_finite()) | (val == 0f64) {
if let Some(fill_val) = fill_invalid {
*std_out = fill_val.sqrt();
*inv_std_out = fill_val.recip().sqrt();
}
} else {
let val = val.clamp(clamp.0, clamp.1);
*std_out = val.sqrt();
*inv_std_out = val.recip().sqrt();
}
});
});
}
fn array_update_var_inv_std_grad(
&mut self,
inv_std: &mut Self::Vector,
std: &mut Self::Vector,
gradient: &Self::Vector,
fill_invalid: f64,
clamp: (f64, f64),
) {
self.arch.dispatch(|| {
izip!(
std.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
inv_std
.try_as_col_major_mut()
.unwrap()
.as_slice_mut()
.iter_mut(),
gradient.try_as_col_major().unwrap().as_slice().iter(),
)
.for_each(|(std_out, inv_std_out, &grad_var)| {
let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
let val = if val.is_finite() { val } else { fill_invalid };
*std_out = val.sqrt();
*inv_std_out = val.recip().sqrt();
});
});
}
fn eigs_as_array(&mut self, source: &Self::EigValues) -> Box<[f64]> {
source
.try_as_col_major()
.unwrap()
.as_slice()
.to_vec()
.into()
}
fn inv_transform_normalize(
&mut self,
params: &Self::FlowParameters,
untransformed_position: &Self::Vector,
untransofrmed_gradient: &Self::Vector,
transformed_position: &mut Self::Vector,
transformed_gradient: &mut Self::Vector,
) -> Result<f64, Self::LogpErr> {
self.logp_func.inv_transform_normalize(
params,
untransformed_position
.try_as_col_major()
.unwrap()
.as_slice(),
untransofrmed_gradient
.try_as_col_major()
.unwrap()
.as_slice(),
transformed_position
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
transformed_gradient
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
)
}
fn init_from_untransformed_position(
&mut self,
params: &Self::FlowParameters,
untransformed_position: &Self::Vector,
untransformed_gradient: &mut Self::Vector,
transformed_position: &mut Self::Vector,
transformed_gradient: &mut Self::Vector,
) -> Result<(f64, f64), Self::LogpErr> {
self.logp_func.init_from_untransformed_position(
params,
untransformed_position
.try_as_col_major()
.unwrap()
.as_slice(),
untransformed_gradient
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
transformed_position
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
transformed_gradient
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
)
}
fn init_from_transformed_position(
&mut self,
params: &Self::FlowParameters,
untransformed_position: &mut Self::Vector,
untransformed_gradient: &mut Self::Vector,
transformed_position: &Self::Vector,
transformed_gradient: &mut Self::Vector,
) -> Result<(f64, f64), Self::LogpErr> {
self.logp_func.init_from_transformed_position(
params,
untransformed_position
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
untransformed_gradient
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
transformed_position.try_as_col_major().unwrap().as_slice(),
transformed_gradient
.try_as_col_major_mut()
.unwrap()
.as_slice_mut(),
)
}
fn update_transformation<'a, R: rand::Rng + ?Sized>(
&'a mut self,
rng: &mut R,
untransformed_positions: impl ExactSizeIterator<Item = &'a Self::Vector>,
untransformed_gradients: impl ExactSizeIterator<Item = &'a Self::Vector>,
untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
params: &'a mut Self::FlowParameters,
) -> Result<(), Self::LogpErr> {
self.logp_func.update_transformation(
rng,
untransformed_positions.map(|x| x.try_as_col_major().unwrap().as_slice()),
untransformed_gradients.map(|x| x.try_as_col_major().unwrap().as_slice()),
untransformed_logp,
params,
)
}
fn init_transformation<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
untransformed_position: &Self::Vector,
untransfogmed_gradient: &Self::Vector,
chain: u64,
) -> Result<Self::FlowParameters, Self::LogpErr> {
self.logp_func.init_transformation(
rng,
untransformed_position
.try_as_col_major()
.unwrap()
.as_slice(),
untransfogmed_gradient
.try_as_col_major()
.unwrap()
.as_slice(),
chain,
)
}
fn new_transformation<R: rand::Rng + ?Sized>(
&mut self,
rng: &mut R,
dim: usize,
chain: u64,
) -> Result<Self::FlowParameters, Self::LogpErr> {
self.logp_func.new_transformation(rng, dim, chain)
}
fn transformation_id(&self, params: &Self::FlowParameters) -> Result<i64, Self::LogpErr> {
self.logp_func.transformation_id(params)
}
}
pub trait CpuLogpFunc: HasDims {
type LogpError: Debug + Send + Sync + Error + LogpError + 'static;
type FlowParameters;
type ExpandedVector: Storable<Self>;
fn dim(&self) -> usize;
fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result<f64, Self::LogpError>;
fn expand_vector<R>(
&mut self,
rng: &mut R,
array: &[f64],
) -> Result<Self::ExpandedVector, CpuMathError>
where
R: rand::Rng + ?Sized;
fn vector_coord(&self) -> Option<Value> {
None
}
fn inv_transform_normalize(
&mut self,
_params: &Self::FlowParameters,
_untransformed_position: &[f64],
_untransformed_gradient: &[f64],
_transformed_position: &mut [f64],
_transformed_gradient: &mut [f64],
) -> Result<f64, Self::LogpError> {
unimplemented!()
}
fn init_from_untransformed_position(
&mut self,
_params: &Self::FlowParameters,
_untransformed_position: &[f64],
_untransformed_gradient: &mut [f64],
_transformed_position: &mut [f64],
_transformed_gradient: &mut [f64],
) -> Result<(f64, f64), Self::LogpError> {
unimplemented!()
}
fn init_from_transformed_position(
&mut self,
_params: &Self::FlowParameters,
_untransformed_position: &mut [f64],
_untransformed_gradient: &mut [f64],
_transformed_position: &[f64],
_transformed_gradient: &mut [f64],
) -> Result<(f64, f64), Self::LogpError> {
unimplemented!()
}
fn update_transformation<'a, R: rand::Rng + ?Sized>(
&'a mut self,
_rng: &mut R,
_untransformed_positions: impl ExactSizeIterator<Item = &'a [f64]>,
_untransformed_gradients: impl ExactSizeIterator<Item = &'a [f64]>,
_untransformed_logp: impl ExactSizeIterator<Item = &'a f64>,
_params: &'a mut Self::FlowParameters,
) -> Result<(), Self::LogpError> {
unimplemented!()
}
fn init_transformation<R: rand::Rng + ?Sized>(
&mut self,
_rng: &mut R,
_untransformed_position: &[f64],
_untransformed_gradient: &[f64],
_chain: u64,
) -> Result<Self::FlowParameters, Self::LogpError> {
unimplemented!()
}
fn new_transformation<R: rand::Rng + ?Sized>(
&mut self,
_rng: &mut R,
_dim: usize,
_chain: u64,
) -> Result<Self::FlowParameters, Self::LogpError> {
unimplemented!()
}
fn transformation_id(&self, _params: &Self::FlowParameters) -> Result<i64, Self::LogpError> {
unimplemented!()
}
}
impl<M: CpuLogpFunc + Clone> Clone for CpuMath<M> {
fn clone(&self) -> Self {
Self {
logp_func: self.logp_func.clone(),
arch: self.arch,
lowrank_scratch: Col::zeros(self.lowrank_scratch.nrows()),
}
}
}