use crate::errors::YakfError;
use crate::linalg::allocator::Allocator;
use crate::linalg::{DefaultAllocator, DimName, OMatrix, OVector};
use num_traits::float::FloatCore;
pub trait SamplingMethod<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T> + Allocator<f64, T, T2> + Allocator<f64, T, T>,
{
fn weights_c(&self) -> &OVector<f64, T2>;
fn weights_m(&self) -> &OVector<f64, T2>;
fn bases(&self) -> Option<&OMatrix<f64, T, T2>>;
fn has_bases(&self) -> bool;
fn get_lamda_plus_n(&self) -> Option<f64>;
fn sampling_states(
&self,
p: &OMatrix<f64, T, T>,
state: &OVector<f64, T>,
) -> Result<OMatrix<f64, T, T2>, YakfError>;
}
#[derive(Debug)]
pub struct MinimalSkewSimplexSampling<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T> + Allocator<f64, T, T2> + Allocator<f64, T, T>,
{
pub weights: OVector<f64, T2>,
pub u_bases: Option<OMatrix<f64, T, T2>>,
}
impl<T, T2> SamplingMethod<T, T2> for MinimalSkewSimplexSampling<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T> + Allocator<f64, T, T2> + Allocator<f64, T, T>,
{
fn weights_c(&self) -> &OVector<f64, T2> {
&self.weights
}
fn weights_m(&self) -> &OVector<f64, T2> {
&self.weights
}
fn bases(&self) -> Option<&OMatrix<f64, T, T2>> {
self.u_bases.as_ref()
}
fn has_bases(&self) -> bool {
true
}
fn get_lamda_plus_n(&self) -> Option<f64> {
None
}
fn sampling_states(
&self,
p: &OMatrix<f64, T, T>,
state: &OVector<f64, T>,
) -> Result<OMatrix<f64, T, T2>, YakfError> {
match p.clone_owned().cholesky() {
Some(cholesky) => {
let cho = cholesky.unpack();
let mut samples: OMatrix<f64, T, T2> = OMatrix::<f64, T, T2>::zeros();
let bases = self.bases().unwrap();
for (i, mut col) in samples.column_iter_mut().enumerate() {
let u_i = bases.column(i);
let chi = state + &cho * u_i;
col.copy_from(&chi);
}
Ok(samples)
}
None => Err(YakfError::CholeskyErr),
}
}
}
impl<T, T2> MinimalSkewSimplexSampling<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T, T2> + Allocator<f64, T> + Allocator<f64, T, T>,
{
#[allow(dead_code)]
pub fn build(w0: f64) -> Result<Self, YakfError> {
let mut sampling = Self::empty()?;
sampling.set_weights(w0);
sampling.expand_bases(T::dim());
Ok(sampling)
}
#[allow(dead_code)]
fn empty() -> Result<Self, YakfError> {
if T2::dim() != T::dim() + 2 {
error!("Weights dimention should be set as (n + 2), where n is the state dimention.");
Err(YakfError::DimensionMismatchErr)
} else {
Ok(Self {
weights: OVector::<f64, T2>::zeros(),
u_bases: None,
})
}
}
fn set_weights(&mut self, w0: f64) {
let n = T::dim();
self.weights[0] = w0;
for i in 1..3 {
self.weights[i] = (1.0 - self.weights[0]) / 2_f64.powi(n as i32);
}
for i in 3..n + 2 {
self.weights[i] = 2_f64.powi(i as i32 - 2) * self.weights[1];
}
}
#[allow(dead_code)]
fn scale_weights(&mut self, _a: f64) {}
fn expand_bases(&mut self, n: usize) {
let mut cols: OMatrix<f64, T, T2> = OMatrix::<f64, T, T2>::zeros();
let col0 = OVector::<f64, T>::zeros();
let mut w_iter =
self.weights
.iter()
.enumerate()
.map(|(i, w)| if i == 0 { 0.0 } else { 1.0 / libm::sqrt(2.0 * w) });
let mut col1 = OVector::<f64, T>::zeros();
w_iter.next();
for k in 0..n {
col1[k] = -w_iter.next().unwrap(); }
for i in 0..T2::dim() {
if i == 0 {
cols.set_column(i, &col0);
} else if i == 1 {
cols.set_column(i, &col1);
} else {
let rev_idx = i - 2;
let mut new_col = cols.column(i - 1).clone_owned();
for k in 0..rev_idx {
new_col[k] = 0.0;
}
new_col[rev_idx] = -new_col[rev_idx];
cols.set_column(i, &new_col);
}
}
self.u_bases = Some(OMatrix::<f64, T, T2>::from(cols));
}
}
#[derive(Debug)]
pub struct SymmetricallyDistributedSampling<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T> + Allocator<f64, T, T2> + Allocator<f64, T, T>,
{
pub weights_c: OVector<f64, T2>,
pub weights_m: OVector<f64, T2>,
pub u_bases: OMatrix<f64, T, T2>,
pub k: f64,
}
impl<T, T2> SamplingMethod<T, T2> for SymmetricallyDistributedSampling<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T> + Allocator<f64, T, T2> + Allocator<f64, T, T>,
{
fn weights_c(&self) -> &OVector<f64, T2> {
&self.weights_c
}
fn weights_m(&self) -> &OVector<f64, T2> {
&self.weights_m
}
fn bases(&self) -> Option<&OMatrix<f64, T, T2>> {
None
}
fn has_bases(&self) -> bool {
false
}
fn get_lamda_plus_n(&self) -> Option<f64> {
Some(self.k)
}
fn sampling_states(
&self,
p: &OMatrix<f64, T, T>,
state: &OVector<f64, T>,
) -> Result<OMatrix<f64, T, T2>, YakfError> {
match p.clone().cholesky() {
Some(cholesky) => {
let cho = cholesky.unpack();
let mut samples: OMatrix<f64, T, T2> = OMatrix::<f64, T, T2>::zeros();
let sqrt_lamda_plus_n = libm::sqrt(self.get_lamda_plus_n().unwrap());
for (i, mut col) in samples.column_iter_mut().enumerate() {
if i == 0 {
let chi = state;
col.copy_from(&chi);
} else if i <= T::dim() {
let chi = state + sqrt_lamda_plus_n * &cho.column(i);
col.copy_from(&chi);
} else {
let chi = state - sqrt_lamda_plus_n * &cho.column(i);
col.copy_from(&chi);
};
}
Ok(samples)
}
None => Err(YakfError::CholeskyErr),
}
}
}
impl<T, T2> SymmetricallyDistributedSampling<T, T2>
where
T: DimName,
T2: DimName,
DefaultAllocator:
Allocator<f64, T2> + Allocator<f64, T, T2> + Allocator<f64, T> + Allocator<f64, T, T>,
{
#[allow(dead_code)]
pub fn build(a: f64, b: Option<f64>, k: Option<f64>) -> Result<Self, YakfError> {
let b = match b {
Some(v) => v,
None => 2_f64,
};
let k = match k {
Some(v) => v,
None => 0_f64,
};
let mut sampling = Self::empty()?;
sampling.set_weights(a, b, k);
Ok(sampling)
}
#[allow(dead_code)]
fn empty() -> Result<Self, YakfError> {
if T2::dim() != 2 * T::dim() + 1 {
error!("Weights dimention should be set as (n + 2), where n is the state dimention.");
Err(YakfError::DimensionMismatchErr)
} else {
Ok(Self {
weights_c: OVector::<f64, T2>::zeros(),
weights_m: OVector::<f64, T2>::zeros(),
u_bases: OMatrix::<f64, T, T2>::zeros(),
k: 0_f64,
})
}
}
fn set_weights(&mut self, a: f64, b: f64, k: f64) -> Option<f64> {
let n = T::dim();
let lamda = (a as f64).powi(2) * (n as f64 + k) - n as f64;
let lamda_plus_n = lamda + n as f64;
self.weights_m[0] = lamda / lamda_plus_n;
self.weights_c[0] = self.weights_m[0] + (1.0 - (a as f64).powi(2) + b);
for i in 1..2 * n + 1 {
self.weights_m[i] = 0.5 / lamda_plus_n;
self.weights_c[i] = self.weights_c[i];
}
self.k = lamda_plus_n;
Some(lamda_plus_n)
}
}