use crate::correlation_models::{CorrelationModel, SquaredExponentialCorr};
use crate::errors::{GpError, Result};
use crate::mean_models::ConstantMean;
use crate::parameters::GpValidParams;
use crate::{ParamTuning, ThetaTuning};
use linfa::{Float, ParamGuard};
use ndarray::Array2;
#[cfg(feature = "serializable")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ParamEstimation<F: Float> {
Fixed(F),
Estimated { initial_guess: F, bounds: (F, F) },
}
impl<F: Float> Default for ParamEstimation<F> {
fn default() -> ParamEstimation<F> {
Self::Estimated {
initial_guess: F::cast(1e-2),
bounds: (F::cast(100.0) * F::epsilon(), F::cast(1e10)),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
#[non_exhaustive]
pub enum Inducings<F: Float> {
Randomized(usize),
Located(Array2<F>),
}
impl<F: Float> Default for Inducings<F> {
fn default() -> Inducings<F> {
Self::Randomized(10)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
pub enum SparseMethod {
#[default]
Fitc,
Vfe,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SgpValidParams<F: Float, Corr: CorrelationModel<F>> {
gp_params: GpValidParams<F, ConstantMean, Corr>,
noise: ParamEstimation<F>,
z: Inducings<F>,
method: SparseMethod,
seed: Option<u64>,
}
impl<F: Float> Default for SgpValidParams<F, SquaredExponentialCorr> {
fn default() -> SgpValidParams<F, SquaredExponentialCorr> {
SgpValidParams {
gp_params: GpValidParams::default(),
noise: ParamEstimation::default(),
z: Inducings::default(),
method: SparseMethod::default(),
seed: None,
}
}
}
impl<F: Float, Corr: CorrelationModel<F>> SgpValidParams<F, Corr> {
pub fn corr(&self) -> &Corr {
&self.gp_params.corr
}
pub fn kpls_dim(&self) -> Option<&usize> {
self.gp_params.kpls_dim.as_ref()
}
pub fn theta_tuning(&self) -> &ThetaTuning<F> {
&self.gp_params.theta_tuning
}
pub fn n_start(&self) -> usize {
self.gp_params.n_start
}
pub fn nugget(&self) -> F {
self.gp_params.nugget
}
pub fn method(&self) -> SparseMethod {
self.method
}
pub fn inducings(&self) -> &Inducings<F> {
&self.z
}
pub fn noise_variance(&self) -> &ParamEstimation<F> {
&self.noise
}
pub fn seed(&self) -> Option<&u64> {
self.seed.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct SgpParams<F: Float, Corr: CorrelationModel<F>>(SgpValidParams<F, Corr>);
impl<F: Float, Corr: CorrelationModel<F>> SgpParams<F, Corr> {
pub fn new(corr: Corr, inducings: Inducings<F>) -> SgpParams<F, Corr> {
Self(SgpValidParams {
gp_params: GpValidParams {
mean: ConstantMean::default(),
corr,
theta_tuning: ThetaTuning::default(),
kpls_dim: None,
n_start: 10,
nugget: F::cast(1000.0) * F::epsilon(),
},
noise: ParamEstimation::default(),
z: inducings,
method: SparseMethod::default(),
seed: None,
})
}
pub fn corr(mut self, corr: Corr) -> Self {
self.0.gp_params.corr = corr;
self
}
pub fn theta_init(mut self, theta_init: Vec<F>) -> Self {
self.0.gp_params.theta_tuning = ParamTuning {
init: theta_init,
..(self.0.gp_params.theta_tuning().clone()).into()
}
.try_into()
.unwrap();
self
}
pub fn theta_bounds(mut self, theta_bounds: Vec<(F, F)>) -> Self {
self.0.gp_params.theta_tuning = ParamTuning {
bounds: theta_bounds,
..(self.0.gp_params.theta_tuning()).clone().into()
}
.try_into()
.unwrap();
self
}
pub fn theta_tuning(mut self, theta_tuning: ThetaTuning<F>) -> Self {
self.0.gp_params.theta_tuning = theta_tuning;
self
}
pub fn kpls_dim(mut self, kpls_dim: Option<usize>) -> Self {
self.0.gp_params.kpls_dim = kpls_dim;
self
}
pub fn n_start(mut self, n_start: usize) -> Self {
self.0.gp_params.n_start = n_start;
self
}
pub fn nugget(mut self, nugget: F) -> Self {
self.0.gp_params.nugget = nugget;
self
}
pub fn sparse_method(mut self, method: SparseMethod) -> Self {
self.0.method = method;
self
}
pub fn inducings(mut self, z: Array2<F>) -> Self {
self.0.z = Inducings::Located(z);
self
}
pub fn n_inducings(mut self, nz: usize) -> Self {
self.0.z = Inducings::Randomized(nz);
self
}
pub fn noise_variance(mut self, config: ParamEstimation<F>) -> Self {
self.0.noise = config;
self
}
pub fn seed(mut self, seed: Option<u64>) -> Self {
self.0.seed = seed;
self
}
}
impl<F: Float, Corr: CorrelationModel<F>> ParamGuard for SgpParams<F, Corr> {
type Checked = SgpValidParams<F, Corr>;
type Error = GpError;
fn check_ref(&self) -> Result<&Self::Checked> {
if let Some(d) = self.0.gp_params.kpls_dim {
if d == 0 {
return Err(GpError::InvalidValueError(
"`kpls_dim` canot be 0!".to_string(),
));
}
let theta = self.0.theta_tuning().theta0();
if theta.len() > 1 && d > theta.len() {
return Err(GpError::InvalidValueError(format!(
"Dimension reduction ({}) should be smaller than expected
training input size infered from given initial theta length ({})",
d,
theta.len()
)));
};
}
Ok(&self.0)
}
fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}