egobox_moe/
types.rs

1use crate::gaussian_mixture::GaussianMixture;
2use crate::{FullGpSurrogate, GpSurrogate, GpSurrogateExt};
3use bitflags::bitflags;
4#[allow(unused_imports)]
5use egobox_gp::correlation_models::{
6    AbsoluteExponentialCorr, Matern32Corr, Matern52Corr, SquaredExponentialCorr,
7};
8#[allow(unused_imports)]
9use egobox_gp::mean_models::{ConstantMean, LinearMean, QuadraticMean};
10use linfa::Float;
11use std::fmt::Display;
12
13#[cfg(feature = "serializable")]
14use serde::{Deserialize, Serialize};
15
16/// Enumeration of recombination modes handled by the mixture
17#[derive(Clone, Copy, PartialEq, Eq, Debug)]
18#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
19pub enum Recombination<F: Float> {
20    /// prediction is taken from the expert with highest responsability
21    /// resulting in a model with discontinuities
22    Hard,
23    /// Prediction is a combination experts prediction wrt their responsabilities,
24    /// an optional heaviside factor might be used control steepness of the change between
25    /// experts regions.
26    Smooth(Option<F>),
27}
28
29impl<F: Float> Display for Recombination<F> {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        let recomb = match self {
32            Recombination::Hard => "Hard".to_string(),
33            Recombination::Smooth(Some(f)) => format!("Smooth({f})"),
34            Recombination::Smooth(None) => "Smooth".to_string(),
35        };
36        write!(f, "Mixture[{}]", &recomb)
37    }
38}
39
40bitflags! {
41    /// Flags to specify tested regression models during experts selection (see [`regression_spec()`](egobox_moe::GpMixtureParams::regression_spec)).
42    ///
43    /// Flags can be combine with bit-wise `or` operator to select two or more models.
44    /// ```ignore
45    /// let spec = RegressionSpec::CONSTANT | RegressionSpec::LINEAR;
46    /// ```
47    ///
48    /// See [bitflags::bitflags]
49    #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone, Copy)]
50    #[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
51    pub struct RegressionSpec: u8 {
52        /// Constant regression
53        const CONSTANT = 0x01;
54        /// Linear regression
55        const LINEAR = 0x02;
56        /// 2-degree polynomial regression
57        const QUADRATIC = 0x04;
58        /// All regression models available
59        const ALL = RegressionSpec::CONSTANT.bits()
60                    | RegressionSpec::LINEAR.bits()
61                    | RegressionSpec::QUADRATIC.bits();
62    }
63}
64
65bitflags! {
66    /// Flags to specify tested correlation models during experts selection (see [`correlation_spec()`](egobox_moe::GpMixtureParams::correlation_spec)).
67    ///
68    /// Flags can be combine with bit-wise `or` operator to select two or more models.
69    /// ```ignore
70    /// let spec = CorrelationSpec::MATERN32 | CorrelationSpec::Matern52;
71    /// ```
72    ///
73    /// See [bitflags::bitflags]
74    #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone, Copy)]
75    #[cfg_attr(feature = "serializable", derive(Serialize, Deserialize), serde(transparent))]
76    pub struct CorrelationSpec: u8 {
77        /// Squared exponential correlation model
78        const SQUAREDEXPONENTIAL = 0x01;
79        /// Absolute exponential correlation model
80        const ABSOLUTEEXPONENTIAL = 0x02;
81        /// Matern 3/2 correlation model
82        const MATERN32 = 0x04;
83        /// Matern 5/2 correlation model
84        const MATERN52 = 0x08;
85        /// All correlation models available
86        const ALL = CorrelationSpec::SQUAREDEXPONENTIAL.bits()
87                    | CorrelationSpec::ABSOLUTEEXPONENTIAL.bits()
88                    | CorrelationSpec::MATERN32.bits()
89                    | CorrelationSpec::MATERN52.bits();
90    }
91}
92
93/// A trait to represent clustered structure
94pub trait Clustered {
95    fn n_clusters(&self) -> usize;
96    fn recombination(&self) -> Recombination<f64>;
97
98    fn to_clustering(&self) -> Clustering;
99}
100
101/// A structure for clustering
102#[derive(Clone, Debug)]
103#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
104pub struct Clustering {
105    /// Recombination between the clusters
106    pub(crate) recombination: Recombination<f64>,
107    /// Clusters
108    pub(crate) gmx: GaussianMixture<f64>,
109}
110
111impl Clustering {
112    pub fn new(gmx: GaussianMixture<f64>, recombination: Recombination<f64>) -> Self {
113        Clustering { gmx, recombination }
114    }
115
116    pub fn recombination(&self) -> Recombination<f64> {
117        self.recombination
118    }
119    pub fn gmx(&self) -> &GaussianMixture<f64> {
120        &self.gmx
121    }
122}
123
124/// A trait for Mixture of GP surrogates with derivatives using clustering
125pub trait MixtureGpSurrogate: Clustered + GpSurrogate + GpSurrogateExt {
126    fn experts(&self) -> &Vec<Box<dyn FullGpSurrogate>>;
127}
128
129#[derive(Default, Debug)]
130/// An enumeration of Gpx available file format
131pub enum GpFileFormat {
132    /// Human readable format
133    #[default]
134    Json,
135    /// Binary format
136    Binary,
137}