1use crate::gaussian_mixture::GaussianMixture;
2use crate::{FullGpSurrogate, GpSurrogate, GpSurrogateExt};
3use bitflags::bitflags;
4#[allow(unused_imports)]
5use egobox_gp::correlation_models::{
6 AbsoluteExponentialCorr, Matern32Corr, Matern52Corr, SquaredExponentialCorr,
7};
8#[allow(unused_imports)]
9use egobox_gp::mean_models::{ConstantMean, LinearMean, QuadraticMean};
10use linfa::Float;
11use std::fmt::Display;
12
13#[cfg(feature = "serializable")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
19pub enum Recombination<F: Float> {
20 Hard,
23 Smooth(Option<F>),
27}
28
29impl<F: Float> Display for Recombination<F> {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 let recomb = match self {
32 Recombination::Hard => "Hard".to_string(),
33 Recombination::Smooth(Some(f)) => format!("Smooth({f})"),
34 Recombination::Smooth(None) => "Smooth".to_string(),
35 };
36 write!(f, "Mixture[{}]", &recomb)
37 }
38}
39
40bitflags! {
41 #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone, Copy)]
50 #[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
51 pub struct RegressionSpec: u8 {
52 const CONSTANT = 0x01;
54 const LINEAR = 0x02;
56 const QUADRATIC = 0x04;
58 const ALL = RegressionSpec::CONSTANT.bits()
60 | RegressionSpec::LINEAR.bits()
61 | RegressionSpec::QUADRATIC.bits();
62 }
63}
64
65bitflags! {
66 #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone, Copy)]
75 #[cfg_attr(feature = "serializable", derive(Serialize, Deserialize), serde(transparent))]
76 pub struct CorrelationSpec: u8 {
77 const SQUAREDEXPONENTIAL = 0x01;
79 const ABSOLUTEEXPONENTIAL = 0x02;
81 const MATERN32 = 0x04;
83 const MATERN52 = 0x08;
85 const ALL = CorrelationSpec::SQUAREDEXPONENTIAL.bits()
87 | CorrelationSpec::ABSOLUTEEXPONENTIAL.bits()
88 | CorrelationSpec::MATERN32.bits()
89 | CorrelationSpec::MATERN52.bits();
90 }
91}
92
93pub trait Clustered {
95 fn n_clusters(&self) -> usize;
96 fn recombination(&self) -> Recombination<f64>;
97
98 fn to_clustering(&self) -> Clustering;
99}
100
101#[derive(Clone, Debug)]
103#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
104pub struct Clustering {
105 pub(crate) recombination: Recombination<f64>,
107 pub(crate) gmx: GaussianMixture<f64>,
109}
110
111impl Clustering {
112 pub fn new(gmx: GaussianMixture<f64>, recombination: Recombination<f64>) -> Self {
113 Clustering { gmx, recombination }
114 }
115
116 pub fn recombination(&self) -> Recombination<f64> {
117 self.recombination
118 }
119 pub fn gmx(&self) -> &GaussianMixture<f64> {
120 &self.gmx
121 }
122}
123
124pub trait MixtureGpSurrogate: Clustered + GpSurrogate + GpSurrogateExt {
126 fn experts(&self) -> &Vec<Box<dyn FullGpSurrogate>>;
127}
128
129#[derive(Default, Debug)]
130pub enum GpFileFormat {
132 #[default]
134 Json,
135 Binary,
137}