Skip to main content

egobox_moe/
algorithm.rs

1use super::gaussian_mixture::GaussianMixture;
2use crate::clustering::{find_best_number_of_clusters, sort_by_cluster};
3use crate::errors::MoeError;
4use crate::errors::Result;
5use crate::parameters::{GpMixtureParams, GpMixtureValidParams};
6use crate::{GpMetrics, IaeAlphaPlotData, types::*};
7use crate::{GpType, expertise_macros::*};
8use crate::{NbClusters, surrogates::*};
9
10use egobox_gp::{GaussianProcess, SparseGaussianProcess, correlation_models::*, mean_models::*};
11use linfa::dataset::Records;
12use linfa::traits::{Fit, Predict, PredictInplace};
13use linfa::{Dataset, DatasetBase, Float, ParamGuard};
14use linfa_clustering::GaussianMixtureModel;
15use log::{debug, info, trace};
16use paste::paste;
17use std::cmp::Ordering;
18use std::ops::Sub;
19
20#[cfg(not(feature = "blas"))]
21use linfa_linalg::norm::*;
22use ndarray::{
23    Array1, Array2, Array3, ArrayBase, ArrayView2, Axis, Data, Ix1, Ix2, Zip, concatenate, s,
24};
25
26#[cfg(feature = "blas")]
27use ndarray_linalg::Norm;
28use ndarray_rand::rand::Rng;
29use ndarray_stats::QuantileExt;
30
31#[cfg(feature = "serializable")]
32use serde::{Deserialize, Serialize};
33#[cfg(feature = "persistent")]
34use std::fs;
35#[cfg(feature = "persistent")]
36use std::io::Write;
37
38macro_rules! check_allowed {
39    ($spec:ident, $model_kind:ident, $model:ident, $list:ident) => {
40        paste! {
41            if $spec.contains([< $model_kind Spec>]::[< $model:upper >]) {
42                $list.push(stringify!($model));
43            }
44        }
45    };
46}
47
48impl<D: Data<Elem = f64>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix1>, MoeError>
49    for GpMixtureValidParams<f64>
50{
51    type Object = GpMixture;
52
53    /// Fit Moe parameters using maximum likelihood
54    ///
55    /// # Errors
56    ///
57    /// * [MoeError::ClusteringError]: if there is not enough points regarding the clusters,
58    /// * [MoeError::GpError]: if gaussian process fitting fails
59    ///
60    fn fit(
61        &self,
62        dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix1>>,
63    ) -> Result<Self::Object> {
64        let x = dataset.records();
65        let y = dataset.targets();
66        self.train(x, y)
67    }
68}
69
70impl GpMixtureValidParams<f64> {
71    /// Train a Mixture of Experts model on the given training data (xt, yt)
72    pub fn train(
73        &self,
74        xt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
75        yt: &ArrayBase<impl Data<Elem = f64>, Ix1>,
76    ) -> Result<GpMixture> {
77        trace!("Moe training...");
78        let nx = xt.ncols();
79        let data = concatenate(
80            Axis(1),
81            &[xt.view(), yt.to_owned().insert_axis(Axis(1)).view()],
82        )
83        .unwrap();
84
85        let (n_clusters, recomb) = match self.n_clusters() {
86            NbClusters::Auto { max } => {
87                // automatic mode
88                let max_nb_clusters = max.unwrap_or(xt.nrows() / 10 + 1);
89                info!("Find best number of clusters up to {max_nb_clusters}");
90                find_best_number_of_clusters(
91                    xt,
92                    yt,
93                    max_nb_clusters,
94                    self.kpls_dim(),
95                    self.regression_spec(),
96                    self.correlation_spec(),
97                    self.rng(),
98                )
99            }
100            NbClusters::Fixed { nb: nb_clusters } => (nb_clusters, self.recombination()),
101        };
102        if let NbClusters::Auto { max: _ } = self.n_clusters() {
103            info!("Automatic settings {n_clusters} {recomb:?}");
104        }
105
106        let training = if recomb == Recombination::Smooth(None) && self.n_clusters().is_multi() {
107            // Extract 5% of data for validation to find best heaviside factor
108            // TODO: Better use cross-validation... but performances impact?
109            let (_, training_data) = extract_part(&data, 5);
110            training_data
111        } else {
112            data.to_owned()
113        };
114        let dataset = Dataset::from(training);
115
116        let gmx = if self.gmx().is_some() {
117            self.gmx().unwrap().clone()
118        } else {
119            trace!("GMM training...");
120            let gmm = GaussianMixtureModel::params(n_clusters)
121                .n_runs(20)
122                .with_rng(self.rng())
123                .fit(&dataset)?;
124
125            // GMX for prediction
126            let weights = gmm.weights().to_owned();
127            let means = gmm.means().slice(s![.., ..nx]).to_owned();
128            let covariances = gmm.covariances().slice(s![.., ..nx, ..nx]).to_owned();
129            let factor = match recomb {
130                Recombination::Smooth(Some(f)) => f,
131                Recombination::Smooth(_) => 1.,
132                Recombination::Hard => 1.,
133            };
134            GaussianMixture::new(weights, means, covariances)?.heaviside_factor(factor)
135        };
136
137        trace!("Train on clusters...");
138        let clustering = Clustering::new(gmx, recomb);
139        self.train_on_clusters(&xt.view(), &yt.view(), &clustering)
140    }
141
142    /// Using the current state of the clustering, select and train the experts
143    /// Returns the fitted mixture of experts model
144    pub fn train_on_clusters(
145        &self,
146        xt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
147        yt: &ArrayBase<impl Data<Elem = f64>, Ix1>,
148        clustering: &Clustering,
149    ) -> Result<GpMixture> {
150        let gmx = clustering.gmx();
151        let recomb = clustering.recombination();
152        let nx = xt.ncols();
153        let data = concatenate(
154            Axis(1),
155            &[xt.view(), yt.to_owned().insert_axis(Axis(1)).view()],
156        )
157        .unwrap();
158
159        let dataset_clustering = gmx.predict(xt);
160        let clusters = sort_by_cluster(gmx.n_clusters(), &data, &dataset_clustering);
161
162        check_number_of_points(&clusters, xt.ncols(), self.regression_spec())?;
163
164        // Fit GPs on clustered data
165        let mut experts = Vec::new();
166        let nb_clusters = clusters.len();
167        for (nc, cluster) in clusters.iter().enumerate() {
168            if nb_clusters > 1 && cluster.nrows() < 3 {
169                return Err(MoeError::ClusteringError(format!(
170                    "Not enough points in cluster, requires at least 3, got {}",
171                    cluster.nrows()
172                )));
173            }
174            debug!("nc={} theta_tuning={:?}", nc, self.theta_tunings());
175            let expert = self.find_best_expert(nc, nx, cluster)?;
176            experts.push(expert);
177        }
178
179        if recomb == Recombination::Smooth(None) && self.n_clusters().is_multi() {
180            // Extract 5% of data for validation to find best heaviside factor
181            // TODO: Better use cross-validation... but performances impact?
182            let (test, _) = extract_part(&data, 5);
183            let xtest = test.slice(s![.., ..nx]).to_owned();
184            let ytest = test.slice(s![.., nx..]).to_owned().remove_axis(Axis(1));
185            let factor = self.optimize_heaviside_factor(&experts, gmx, &xtest, &ytest);
186            info!("Retrain mixture with optimized heaviside factor={factor}");
187
188            let moe = GpMixtureParams::from(self.clone())
189                .n_clusters(NbClusters::fixed(gmx.n_clusters()))
190                .recombination(Recombination::Smooth(Some(factor)))
191                .check()?
192                .train(xt, yt)?; // needs to train the gaussian mixture on all data (xt, yt) as it was
193            // previously trained on data excluding test data (see train method)
194            Ok(moe)
195        } else {
196            Ok(GpMixture {
197                gp_type: self.gp_type().clone(),
198                recombination: recomb,
199                experts,
200                gmx: gmx.clone(),
201                training_data: (xt.to_owned(), yt.to_owned()),
202                params: self.clone(),
203            })
204        }
205    }
206
207    /// Select the surrogate which gives the smallest prediction error on the given data
208    /// The error is computed using cross-validation
209    fn find_best_expert(
210        &self,
211        nc: usize,
212        nx: usize,
213        data: &ArrayBase<impl Data<Elem = f64>, Ix2>,
214    ) -> Result<Box<dyn FullGpSurrogate>> {
215        let xtrain = data.slice(s![.., ..nx]).to_owned();
216        let ytrain = data.slice(s![.., nx..]).to_owned();
217        let mut dataset = Dataset::from((xtrain.clone(), ytrain.clone().remove_axis(Axis(1))));
218        let regression_spec = self.regression_spec();
219        let mut allowed_means = vec![];
220        check_allowed!(regression_spec, Regression, Constant, allowed_means);
221        check_allowed!(regression_spec, Regression, Linear, allowed_means);
222        check_allowed!(regression_spec, Regression, Quadratic, allowed_means);
223        let correlation_spec = self.correlation_spec();
224        let mut allowed_corrs = vec![];
225        check_allowed!(
226            correlation_spec,
227            Correlation,
228            SquaredExponential,
229            allowed_corrs
230        );
231        check_allowed!(
232            correlation_spec,
233            Correlation,
234            AbsoluteExponential,
235            allowed_corrs
236        );
237        check_allowed!(correlation_spec, Correlation, Matern32, allowed_corrs);
238        check_allowed!(correlation_spec, Correlation, Matern52, allowed_corrs);
239
240        debug!("Find best expert");
241        let best = if allowed_means.len() == 1 && allowed_corrs.len() == 1 {
242            (format!("{}_{}", allowed_means[0], allowed_corrs[0]), None) // shortcut
243        } else {
244            let mut map_error = Vec::new();
245            compute_errors!(self, allowed_means, allowed_corrs, dataset, map_error);
246            let errs: Vec<f64> = map_error.iter().map(|(_, err)| *err).collect();
247            debug!("Accuracies {map_error:?}");
248            let argmin = errs
249                .iter()
250                .enumerate()
251                .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
252                .map(|(index, _)| index)
253                .unwrap();
254            (map_error[argmin].0.clone(), Some(map_error[argmin].1))
255        };
256        debug!("after Find best expert");
257
258        let expert = match self.gp_type() {
259            GpType::FullGp => {
260                let best_expert_params: std::result::Result<Box<dyn GpSurrogateParams>, MoeError> =
261                    match best.0.as_str() {
262                        "Constant_SquaredExponential" => {
263                            Ok(make_surrogate_params!(Constant, SquaredExponential))
264                        }
265                        "Constant_AbsoluteExponential" => {
266                            Ok(make_surrogate_params!(Constant, AbsoluteExponential))
267                        }
268                        "Constant_Matern32" => Ok(make_surrogate_params!(Constant, Matern32)),
269                        "Constant_Matern52" => Ok(make_surrogate_params!(Constant, Matern52)),
270                        "Linear_SquaredExponential" => {
271                            Ok(make_surrogate_params!(Linear, SquaredExponential))
272                        }
273                        "Linear_AbsoluteExponential" => {
274                            Ok(make_surrogate_params!(Linear, AbsoluteExponential))
275                        }
276                        "Linear_Matern32" => Ok(make_surrogate_params!(Linear, Matern32)),
277                        "Linear_Matern52" => Ok(make_surrogate_params!(Linear, Matern52)),
278                        "Quadratic_SquaredExponential" => {
279                            Ok(make_surrogate_params!(Quadratic, SquaredExponential))
280                        }
281                        "Quadratic_AbsoluteExponential" => {
282                            Ok(make_surrogate_params!(Quadratic, AbsoluteExponential))
283                        }
284                        "Quadratic_Matern32" => Ok(make_surrogate_params!(Quadratic, Matern32)),
285                        "Quadratic_Matern52" => Ok(make_surrogate_params!(Quadratic, Matern52)),
286                        _ => {
287                            return Err(MoeError::ExpertError(format!(
288                                "Unknown expert {}",
289                                best.0
290                            )));
291                        }
292                    };
293                let mut expert_params = best_expert_params?;
294                expert_params.n_start(self.n_start());
295                expert_params.max_eval(self.max_eval());
296                expert_params.kpls_dim(self.kpls_dim());
297                if nc > 0 && self.theta_tunings().len() == 1 {
298                    expert_params.theta_tuning(self.theta_tunings()[0].clone());
299                } else {
300                    debug!("Training with theta_tuning = {:?}.", self.theta_tunings());
301                    expert_params.theta_tuning(self.theta_tunings()[nc].clone());
302                }
303                debug!("Train best expert...");
304                expert_params.train(&xtrain.view(), &ytrain.view())
305            }
306            GpType::SparseGp {
307                inducings,
308                sparse_method,
309                ..
310            } => {
311                let inducings = inducings.to_owned();
312                let best_expert_params: std::result::Result<Box<dyn SgpSurrogateParams>, MoeError> =
313                    match best.0.as_str() {
314                        "Constant_SquaredExponential" => {
315                            Ok(make_sgp_surrogate_params!(SquaredExponential, inducings))
316                        }
317                        "Constant_AbsoluteExponential" => {
318                            Ok(make_sgp_surrogate_params!(AbsoluteExponential, inducings))
319                        }
320                        "Constant_Matern32" => Ok(make_sgp_surrogate_params!(Matern32, inducings)),
321                        "Constant_Matern52" => Ok(make_sgp_surrogate_params!(Matern52, inducings)),
322                        _ => {
323                            return Err(MoeError::ExpertError(format!(
324                                "Unknown expert {}",
325                                best.0
326                            )));
327                        }
328                    };
329                let mut expert_params = best_expert_params?;
330                let seed = self.rng().r#gen();
331                debug!("Theta tuning = {:?}", self.theta_tunings());
332                expert_params.sparse_method(*sparse_method);
333                expert_params.seed(seed);
334                expert_params.n_start(self.n_start());
335                expert_params.kpls_dim(self.kpls_dim());
336                expert_params.theta_tuning(self.theta_tunings()[0].clone());
337                debug!("Train best expert...");
338                expert_params.train(&xtrain.view(), &ytrain.view())
339            }
340        };
341
342        debug!("...after best expert training");
343        if let Some(v) = best.1 {
344            info!("Best expert {} accuracy={}", best.0, v);
345        }
346        expert
347    }
348
349    /// Take the best heaviside factor from 0.1 to 2.1 (step 0.1).
350    /// Mixture (`gmx` and experts`) is already trained only the continuous recombination is changed
351    /// and the factor giving the smallest prediction error on the given test data  
352    /// Used only in case of smooth recombination
353    fn optimize_heaviside_factor(
354        &self,
355        experts: &[Box<dyn FullGpSurrogate>],
356        gmx: &GaussianMixture<f64>,
357        xtest: &ArrayBase<impl Data<Elem = f64>, Ix2>,
358        ytest: &ArrayBase<impl Data<Elem = f64>, Ix1>,
359    ) -> f64 {
360        if self.recombination() == Recombination::Hard || self.n_clusters().is_mono() {
361            1.
362        } else {
363            let scale_factors = Array1::linspace(0.1, 2.1, 20);
364            let errors = scale_factors.map(move |&factor| {
365                let gmx2 = gmx.clone();
366                let gmx2 = gmx2.heaviside_factor(factor);
367                let pred = predict_smooth(experts, &gmx2, xtest).unwrap();
368                pred.sub(ytest).mapv(|x| x * x).sum().sqrt() / xtest.mapv(|x| x * x).sum().sqrt()
369            });
370
371            let min_error_index = errors.argmin().unwrap();
372            if *errors.max().unwrap() < 1e-6 {
373                1.
374            } else {
375                scale_factors[min_error_index]
376            }
377        }
378    }
379}
380
381fn check_number_of_points<F>(
382    clusters: &[ArrayBase<impl Data<Elem = F>, Ix2>],
383    dim: usize,
384    regr: RegressionSpec,
385) -> Result<()> {
386    if clusters.len() > 1 {
387        let min_number_point = if regr.contains(RegressionSpec::QUADRATIC) {
388            (dim + 1) * (dim + 2) / 2
389        } else if regr.contains(RegressionSpec::LINEAR) {
390            dim + 1
391        } else {
392            1
393        };
394        for cluster in clusters {
395            if cluster.len() < min_number_point {
396                return Err(MoeError::ClusteringError(format!(
397                    "Not enough points in training set. Need {} points, got {}",
398                    min_number_point,
399                    cluster.len()
400                )));
401            }
402        }
403    }
404    Ok(())
405}
406
407/// Predict outputs at given points with `experts` and gaussian mixture `gmx`.
408/// `gmx` is used to get the probability of x to belongs to one cluster
409/// or another (ie responsabilities). Those responsabilities are used to combine
410/// output values predict by each cluster experts.
411fn predict_smooth(
412    experts: &[Box<dyn FullGpSurrogate>],
413    gmx: &GaussianMixture<f64>,
414    points: &ArrayBase<impl Data<Elem = f64>, Ix2>,
415) -> Result<Array1<f64>> {
416    let probas = gmx.predict_probas(points);
417    let preds: Array1<f64> = experts
418        .iter()
419        .enumerate()
420        .map(|(i, gp)| gp.predict(&points.view()).unwrap() * probas.column(i))
421        .fold(Array1::zeros((points.nrows(),)), |acc, pred| acc + pred);
422    Ok(preds)
423}
424
425/// Mixture of gaussian process experts
426/// Implementation note: the structure is not generic over 'F: Float' to be able to
427/// implement use serde easily as deserialization of generic impls is not supported yet
428/// See <https://github.com/dtolnay/typetag/issues/1>
429#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
430pub struct GpMixture {
431    /// The mode of recombination to get the output prediction from experts prediction
432    recombination: Recombination<f64>,
433    /// The list of the best experts trained on each cluster
434    experts: Vec<Box<dyn FullGpSurrogate>>,
435    /// The gaussian mixture allowing to predict cluster responsabilities for a given point
436    gmx: GaussianMixture<f64>,
437    /// Gp type
438    gp_type: GpType<f64>,
439    /// Training inputs
440    training_data: (Array2<f64>, Array1<f64>),
441    /// Params used to fit this model
442    params: GpMixtureValidParams<f64>,
443}
444
445impl std::fmt::Display for GpMixture {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        let recomb = match self.recombination() {
448            Recombination::Hard => "Hard".to_string(),
449            Recombination::Smooth(Some(f)) => format!("Smooth({f})"),
450            Recombination::Smooth(_) => "Smooth".to_string(),
451        };
452        let experts = self
453            .experts
454            .iter()
455            .map(|expert| expert.to_string())
456            .reduce(|acc, s| acc + ", " + &s)
457            .unwrap();
458        write!(f, "Mixture[{}]({})", &recomb, &experts)
459    }
460}
461
462impl Clustered for GpMixture {
463    /// Number of clusters
464    fn n_clusters(&self) -> usize {
465        self.gmx.n_clusters()
466    }
467
468    /// Clustering Recombination
469    fn recombination(&self) -> Recombination<f64> {
470        self.recombination()
471    }
472
473    /// Convert to clustering
474    fn to_clustering(&self) -> Clustering {
475        Clustering {
476            recombination: self.recombination(),
477            gmx: self.gmx.clone(),
478        }
479    }
480}
481
482#[cfg_attr(feature = "serializable", typetag::serde)]
483impl GpSurrogate for GpMixture {
484    fn dims(&self) -> (usize, usize) {
485        self.experts[0].dims()
486    }
487
488    fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
489        match self.recombination {
490            Recombination::Hard => self.predict_hard(x),
491            Recombination::Smooth(_) => self.predict_smooth(x),
492        }
493    }
494
495    fn predict_var(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
496        match self.recombination {
497            Recombination::Hard => self.predict_var_hard(x),
498            Recombination::Smooth(_) => self.predict_var_smooth(x),
499        }
500    }
501
502    fn predict_valvar(&self, x: &ArrayView2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
503        match self.recombination {
504            Recombination::Hard => self.predict_valvar_hard(x),
505            Recombination::Smooth(_) => self.predict_valvar_smooth(x),
506        }
507    }
508
509    /// Save Moe model in given file.
510    #[cfg(feature = "persistent")]
511    fn save(&self, path: &str, format: GpFileFormat) -> Result<()> {
512        let mut file = fs::File::create(path).unwrap();
513
514        let bytes = match format {
515            GpFileFormat::Json => serde_json::to_vec(self).map_err(MoeError::SaveJsonError)?,
516            GpFileFormat::Binary => {
517                bincode::serde::encode_to_vec(self, bincode::config::standard())
518                    .map_err(MoeError::SaveBinaryError)?
519            }
520        };
521        file.write_all(&bytes)?;
522
523        Ok(())
524    }
525}
526
527#[cfg_attr(feature = "serializable", typetag::serde)]
528impl GpSurrogateExt for GpMixture {
529    fn predict_gradients(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
530        match self.recombination {
531            Recombination::Hard => self.predict_gradients_hard(x),
532            Recombination::Smooth(_) => self.predict_gradients_smooth(x),
533        }
534    }
535
536    fn predict_var_gradients(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
537        match self.recombination {
538            Recombination::Hard => self.predict_var_gradients_hard(x),
539            Recombination::Smooth(_) => self.predict_var_gradients_smooth(x),
540        }
541    }
542
543    fn predict_valvar_gradients(&self, x: &ArrayView2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
544        match self.recombination {
545            Recombination::Hard => self.predict_valvar_gradients_hard(x),
546            Recombination::Smooth(_) => self.predict_valvar_gradients_smooth(x),
547        }
548    }
549
550    fn sample(&self, x: &ArrayView2<f64>, n_traj: usize) -> Result<Array2<f64>> {
551        if self.n_clusters() != 1 {
552            return Err(MoeError::SampleError(format!(
553                "Can not sample when several clusters {}",
554                self.n_clusters()
555            )));
556        }
557        self.sample_expert(0, x, n_traj)
558    }
559}
560
561impl GpMetrics<MoeError, GpMixtureParams<f64>, Self> for GpMixture {
562    fn training_data(&self) -> &(Array2<f64>, Array1<f64>) {
563        &self.training_data
564    }
565
566    fn params(&self) -> GpMixtureParams<f64> {
567        GpMixtureParams::<f64>::from(self.params.clone())
568    }
569}
570
571#[cfg_attr(feature = "serializable", typetag::serde)]
572impl GpQualityAssurance for GpMixture {
573    fn training_data(&self) -> &(Array2<f64>, Array1<f64>) {
574        (self as &dyn GpMetrics<_, _, _>).training_data()
575    }
576
577    fn q2_k(&self, kfold: usize) -> f64 {
578        (self as &dyn GpMetrics<_, _, _>).q2_k_score(kfold)
579    }
580    fn q2(&self) -> f64 {
581        (self as &dyn GpMetrics<_, _, _>).q2_score()
582    }
583
584    fn pva_k(&self, kfold: usize) -> f64 {
585        (self as &dyn GpMetrics<_, _, _>).pva_k_score(kfold)
586    }
587    fn pva(&self) -> f64 {
588        (self as &dyn GpMetrics<_, _, _>).pva_score()
589    }
590
591    fn iae_alpha_k(&self, kfold: usize) -> f64 {
592        (self as &dyn GpMetrics<_, _, _>).iae_alpha_k_score(kfold, None)
593    }
594    fn iae_alpha_k_score_with_plot(&self, kfold: usize, plot_data: &mut IaeAlphaPlotData) -> f64 {
595        (self as &dyn GpMetrics<_, _, _>).iae_alpha_k_score(kfold, Some(plot_data))
596    }
597    fn iae_alpha(&self) -> f64 {
598        (self as &dyn GpMetrics<_, _, _>).iae_alpha_score(None)
599    }
600}
601
602#[cfg_attr(feature = "serializable", typetag::serde)]
603impl MixtureGpSurrogate for GpMixture {
604    /// Selected experts in the mixture
605    fn experts(&self) -> &Vec<Box<dyn FullGpSurrogate>> {
606        &self.experts
607    }
608}
609
610impl GpMixture {
611    /// Constructor of mixture of experts parameters
612    pub fn params() -> GpMixtureParams<f64> {
613        GpMixtureParams::new()
614    }
615
616    /// Retrieve output dimensions from
617    pub fn gp_type(&self) -> &GpType<f64> {
618        &self.gp_type
619    }
620
621    /// Recombination mode
622    pub fn recombination(&self) -> Recombination<f64> {
623        self.recombination
624    }
625
626    /// Gaussian mixture
627    pub fn gmx(&self) -> &GaussianMixture<f64> {
628        &self.gmx
629    }
630
631    /// Sets recombination mode
632    pub fn set_recombination(mut self, recombination: Recombination<f64>) -> Self {
633        self.recombination = match recombination {
634            Recombination::Hard => recombination,
635            Recombination::Smooth(Some(_)) => recombination,
636            Recombination::Smooth(_) => Recombination::Smooth(Some(1.)),
637        };
638        self
639    }
640
641    /// Set the gaussian mixture to use given weights, means and covariances
642    pub fn set_gmx(
643        mut self,
644        weights: Array1<f64>,
645        means: Array2<f64>,
646        covariances: Array3<f64>,
647    ) -> Self {
648        self.gmx = GaussianMixture::new(weights, means, covariances).unwrap();
649        self
650    }
651
652    /// Set the model experts to use in the mixture
653    pub fn set_experts(mut self, experts: Vec<Box<dyn FullGpSurrogate>>) -> Self {
654        self.experts = experts;
655        self
656    }
657
658    /// Predict outputs at a set of points `x` specified as (n, nx) matrix.
659    /// Gaussian Mixture is used to get the probability of the point to belongs to one cluster
660    /// or another (ie responsabilities).     
661    /// The smooth recombination of each cluster expert responsabilty is used to get the result.
662    pub fn predict_smooth(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
663        predict_smooth(&self.experts, &self.gmx, x)
664    }
665
666    /// Predict variances at a set of points `x` specified as (n, nx) matrix.
667    /// Gaussian Mixture is used to get the probability of the point to belongs to one cluster
668    /// or another (ie responsabilities).
669    /// The smooth recombination of each cluster expert responsabilty is used to get the result.
670    pub fn predict_var_smooth(
671        &self,
672        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
673    ) -> Result<Array1<f64>> {
674        let probas = self.gmx.predict_probas(x);
675        let preds: Array1<f64> = self
676            .experts
677            .iter()
678            .enumerate()
679            .map(|(i, gp)| {
680                let p = probas.column(i);
681                gp.predict_var(&x.view()).unwrap() * p * p
682            })
683            .fold(Array1::zeros(x.nrows()), |acc, var| acc + var);
684        Ok(preds)
685    }
686
687    /// Predict derivatives of the output at a set of points `x` specified as (n, nx) matrix.
688    /// Return derivatives as a (n, nx) matrix where the ith row contain the partial derivatives of
689    /// of the output wrt the nx components of `x` valued at the ith x point.
690    /// The smooth recombination of each cluster expert responsability is used to get the result.
691    pub fn predict_gradients_smooth(
692        &self,
693        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
694    ) -> Result<Array2<f64>> {
695        let probas = self.gmx.predict_probas(x);
696        let probas_drv = self.gmx.predict_probas_derivatives(x);
697        let mut drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
698
699        Zip::from(drv.rows_mut())
700            .and(x.rows())
701            .and(probas.rows())
702            .and(probas_drv.outer_iter())
703            .for_each(|mut y, x, p, pprime| {
704                let x = x.insert_axis(Axis(0));
705                let preds: Array1<f64> = self
706                    .experts
707                    .iter()
708                    .map(|gp| gp.predict(&x).unwrap()[0])
709                    .collect();
710                let drvs: Vec<Array1<f64>> = self
711                    .experts
712                    .iter()
713                    .map(|gp| gp.predict_gradients(&x).unwrap().row(0).to_owned())
714                    .collect();
715
716                let preds = preds.insert_axis(Axis(1));
717                let mut preds_drv = Array2::zeros((self.experts.len(), x.len()));
718                Zip::indexed(preds_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&drvs[i]));
719
720                let mut term1 = Array2::zeros((self.experts.len(), x.len()));
721                Zip::from(term1.rows_mut())
722                    .and(&p)
723                    .and(preds_drv.rows())
724                    .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p))));
725                let term1 = term1.sum_axis(Axis(0));
726
727                let term2 = pprime.to_owned() * preds;
728                let term2 = term2.sum_axis(Axis(0));
729
730                y.assign(&(term1 + term2));
731            });
732        Ok(drv)
733    }
734
735    /// Predict derivatives of the variance at a set of points `x` specified as (n, nx) matrix.
736    /// Return derivatives as a (n, nx) matrix where the ith row contain the partial derivatives of
737    /// of the vairance wrt the nx components of `x` valued at the ith x point.
738    /// The smooth recombination of each cluster expert responsability is used to get the result.
739    pub fn predict_var_gradients_smooth(
740        &self,
741        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
742    ) -> Result<Array2<f64>> {
743        let probas = self.gmx.predict_probas(x);
744        let probas_drv = self.gmx.predict_probas_derivatives(x);
745
746        let mut drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
747
748        Zip::from(drv.rows_mut())
749            .and(x.rows())
750            .and(probas.rows())
751            .and(probas_drv.outer_iter())
752            .for_each(|mut y, xi, p, pprime| {
753                let xii = xi.insert_axis(Axis(0));
754                let preds: Array1<f64> = self
755                    .experts
756                    .iter()
757                    .map(|gp| gp.predict_var(&xii).unwrap()[0])
758                    .collect();
759                let drvs: Vec<Array1<f64>> = self
760                    .experts
761                    .iter()
762                    .map(|gp| gp.predict_var_gradients(&xii).unwrap().row(0).to_owned())
763                    .collect();
764
765                let preds = preds.insert_axis(Axis(1));
766                let mut preds_drv = Array2::zeros((self.experts.len(), xi.len()));
767                Zip::indexed(preds_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&drvs[i]));
768
769                let mut term1 = Array2::zeros((self.experts.len(), xi.len()));
770                Zip::from(term1.rows_mut())
771                    .and(&p)
772                    .and(preds_drv.rows())
773                    .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p * p))));
774                let term1 = term1.sum_axis(Axis(0));
775
776                let term2 = (p.to_owned() * pprime * preds).mapv(|v| 2. * v);
777                let term2 = term2.sum_axis(Axis(0));
778
779                y.assign(&(term1 + term2));
780            });
781
782        Ok(drv)
783    }
784
785    /// Predict outputs and variances at a set of points `x` specified as (n, nx) matrix.
786    /// Gaussian Mixture is used to get the probability of the point to belongs to one cluster
787    /// or another (ie responsabilities).
788    /// The smooth recombination of each cluster expert responsabilty is used to get the result.
789    pub fn predict_valvar_smooth(
790        &self,
791        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
792    ) -> Result<(Array1<f64>, Array1<f64>)> {
793        let probas = self.gmx.predict_probas(x);
794        let valvar: (Array1<f64>, Array1<f64>) = self
795            .experts
796            .iter()
797            .enumerate()
798            .map(|(i, gp)| {
799                let p = probas.column(i);
800                let (pred, var) = gp.predict_valvar(&x.view()).unwrap();
801                (pred * p, var * p * p)
802            })
803            .fold(
804                (Array1::zeros((x.nrows(),)), Array1::zeros((x.nrows(),))),
805                |acc, (pred, var)| (acc.0 + pred, acc.1 + var),
806            );
807
808        Ok(valvar)
809    }
810
811    fn predict_valvar_gradients_smooth(
812        &self,
813        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
814    ) -> Result<(Array2<f64>, Array2<f64>)> {
815        let probas = self.gmx.predict_probas(x);
816        let probas_drv = self.gmx.predict_probas_derivatives(x);
817
818        let mut val_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
819        let mut var_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
820
821        Zip::from(val_drv.rows_mut())
822            .and(var_drv.rows_mut())
823            .and(x.rows())
824            .and(probas.rows())
825            .and(probas_drv.outer_iter())
826            .for_each(|mut val_y, mut var_y, xi, p, pprime| {
827                let xii = xi.insert_axis(Axis(0));
828                let (preds, vars): (Vec<f64>, Vec<f64>) = self
829                    .experts
830                    .iter()
831                    .map(|gp| {
832                        let (pred, var) = gp.predict_valvar(&xii).unwrap();
833                        (pred[0], var[0])
834                    })
835                    .unzip();
836                let preds: Array2<f64> = Array1::from(preds).insert_axis(Axis(1));
837                let vars: Array2<f64> = Array1::from(vars).insert_axis(Axis(1));
838                let (drvs, var_drvs): (Vec<Array1<f64>>, Vec<Array1<f64>>) = self
839                    .experts
840                    .iter()
841                    .map(|gp| {
842                        let (predg, varg) = gp.predict_valvar_gradients(&xii).unwrap();
843                        (predg.row(0).to_owned(), varg.row(0).to_owned())
844                    })
845                    .unzip();
846
847                let mut preds_drv = Array2::zeros((self.experts.len(), xi.len()));
848                let mut vars_drv = Array2::zeros((self.experts.len(), xi.len()));
849                Zip::indexed(preds_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&drvs[i]));
850                Zip::indexed(vars_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&var_drvs[i]));
851
852                let mut val_term1 = Array2::zeros((self.experts.len(), xi.len()));
853                Zip::from(val_term1.rows_mut())
854                    .and(&p)
855                    .and(preds_drv.rows())
856                    .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p))));
857                let val_term1 = val_term1.sum_axis(Axis(0));
858                let val_term2 = pprime.to_owned() * preds;
859                let val_term2 = val_term2.sum_axis(Axis(0));
860                val_y.assign(&(val_term1 + val_term2));
861
862                let mut var_term1 = Array2::zeros((self.experts.len(), xi.len()));
863                Zip::from(var_term1.rows_mut())
864                    .and(&p)
865                    .and(vars_drv.rows())
866                    .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p * p))));
867                let var_term1 = var_term1.sum_axis(Axis(0));
868                let var_term2 = (p.to_owned() * pprime * vars).mapv(|v| 2. * v);
869                let var_term2 = var_term2.sum_axis(Axis(0));
870                var_y.assign(&(var_term1 + var_term2));
871            });
872        Ok((val_drv, var_drv))
873    }
874
875    /// Predict outputs at a set of points `x` specified as (n, nx) matrix.
876    /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
877    /// Then the expert of the cluster is used to predict the output value.
878    /// Returns the ouputs as a (n, 1) column vector
879    pub fn predict_hard(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
880        let clustering = self.gmx.predict(x);
881        trace!("Clustering {clustering:?}");
882        let mut preds = Array1::zeros((x.nrows(),));
883        Zip::from(&mut preds)
884            .and(x.rows())
885            .and(&clustering)
886            .for_each(|y, x, &c| *y = self.experts[c].predict(&x.insert_axis(Axis(0))).unwrap()[0]);
887        Ok(preds)
888    }
889
890    /// Predict variance at a set of points `x` specified as (n, nx) matrix.
891    /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
892    /// The expert of the cluster is used to predict variance value.
893    /// Returns the variances as a (n,) vector
894    pub fn predict_var_hard(
895        &self,
896        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
897    ) -> Result<Array1<f64>> {
898        let clustering = self.gmx.predict(x);
899        trace!("Clustering {clustering:?}");
900        let mut variances = Array1::zeros(x.nrows());
901        Zip::from(&mut variances)
902            .and(x.rows())
903            .and(&clustering)
904            .for_each(|y, x, &c| {
905                *y = self.experts[c]
906                    .predict_var(&x.insert_axis(Axis(0)))
907                    .unwrap()[0];
908            });
909        Ok(variances)
910    }
911
912    /// Predict outputs and variances at a set of points `x` specified as (n, nx) matrix.
913    /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
914    /// The expert of the cluster is used to predict variance value.
915    pub fn predict_valvar_hard(
916        &self,
917        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
918    ) -> Result<(Array1<f64>, Array1<f64>)> {
919        let clustering = self.gmx.predict(x);
920        trace!("Clustering {clustering:?}");
921        let mut preds = Array1::zeros((x.nrows(),));
922        let mut variances = Array1::zeros(x.nrows());
923        Zip::from(&mut preds)
924            .and(&mut variances)
925            .and(x.rows())
926            .and(&clustering)
927            .for_each(|y, v, x, &c| {
928                let (pred, var) = self.experts[c]
929                    .predict_valvar(&x.insert_axis(Axis(0)))
930                    .unwrap();
931                *y = pred[0];
932                *v = var[0];
933            });
934        Ok((preds, variances))
935    }
936
937    /// Predict derivatives of the output at a set of points `x` specified as (n, nx) matrix.
938    /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
939    /// The expert of the cluster is used to predict variance value.
940    /// Returns derivatives as a (n, nx) matrix where the ith row contain the partial derivatives of
941    /// of the output wrt the nx components of `x` valued at the ith x point.
942    pub fn predict_gradients_hard(
943        &self,
944        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
945    ) -> Result<Array2<f64>> {
946        let mut drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
947        let clustering = self.gmx.predict(x);
948        Zip::from(drv.rows_mut())
949            .and(x.rows())
950            .and(&clustering)
951            .for_each(|mut drv_i, xi, &c| {
952                let x = xi.to_owned().insert_axis(Axis(0));
953                let x_drv: ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> =
954                    self.experts[c].predict_gradients(&x.view()).unwrap();
955                drv_i.assign(&x_drv.row(0))
956            });
957        Ok(drv)
958    }
959
960    /// Predict derivatives of the variances at a set of points `x` specified as (n, nx) matrix.
961    /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
962    /// The expert of the cluster is used to predict variance value.
963    /// Returns derivatives as a (n, nx) matrix where the ith row contain the partial derivatives of
964    /// of the output wrt the nx components of `x` valued at the ith x point.
965    pub fn predict_var_gradients_hard(
966        &self,
967        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
968    ) -> Result<Array2<f64>> {
969        let mut vardrv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
970        let clustering = self.gmx.predict(x);
971        Zip::from(vardrv.rows_mut())
972            .and(x.rows())
973            .and(&clustering)
974            .for_each(|mut vardrv_i, xi, &c| {
975                let x = xi.to_owned().insert_axis(Axis(0));
976                let x_vardrv: ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> =
977                    self.experts[c].predict_var_gradients(&x.view()).unwrap();
978                vardrv_i.assign(&x_vardrv.row(0))
979            });
980        Ok(vardrv)
981    }
982
983    /// Predict derivatives of the outputs and variances at a set of points `x` specified as (n, nx) matrix.
984    /// Gaussian Mixture is used to get the cluster where the point belongs (highest responsability)
985    /// The expert of the cluster is used to predict variance value.
986    pub fn predict_valvar_gradients_hard(
987        &self,
988        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
989    ) -> Result<(Array2<f64>, Array2<f64>)> {
990        let mut val_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
991        let mut var_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
992        let clustering = self.gmx.predict(x);
993        Zip::from(val_drv.rows_mut())
994            .and(var_drv.rows_mut())
995            .and(x.rows())
996            .and(&clustering)
997            .for_each(|mut val_y, mut var_y, xi, &c| {
998                let x = xi.to_owned().insert_axis(Axis(0));
999                let (x_val_drv, x_var_drv) =
1000                    self.experts[c].predict_valvar_gradients(&x.view()).unwrap();
1001                val_y.assign(&x_val_drv.row(0));
1002                var_y.assign(&x_var_drv.row(0));
1003            });
1004        Ok((val_drv, var_drv))
1005    }
1006
1007    /// Sample `n_traj` trajectories at a set of points `x` specified as (n, nx) matrix.
1008    /// using the expert `ith` of the mixture.
1009    /// Returns the samples as a (n, n_traj) matrix where the ith row
1010    pub fn sample_expert(
1011        &self,
1012        ith: usize,
1013        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1014        n_traj: usize,
1015    ) -> Result<Array2<f64>> {
1016        self.experts[ith].sample(&x.view(), n_traj)
1017    }
1018
1019    /// Predict outputs at a set of points `x` specified as (n, nx) matrix.
1020    pub fn predict(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
1021        <GpMixture as GpSurrogate>::predict(self, &x.view())
1022    }
1023
1024    /// Predict variances at a set of points `x` specified as (n, nx) matrix.
1025    pub fn predict_var(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
1026        <GpMixture as GpSurrogate>::predict_var(self, &x.view())
1027    }
1028
1029    /// Predict outputs and variances at a set of points `x` specified as (n, nx) matrix.
1030    pub fn predict_valvar(
1031        &self,
1032        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1033    ) -> Result<(Array1<f64>, Array1<f64>)> {
1034        <GpMixture as GpSurrogate>::predict_valvar(self, &x.view())
1035    }
1036
1037    /// Predict derivatives of the output at a set of points `x` specified as (n, nx) matrix.
1038    pub fn predict_gradients(
1039        &self,
1040        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1041    ) -> Result<Array2<f64>> {
1042        <GpMixture as GpSurrogateExt>::predict_gradients(self, &x.view())
1043    }
1044
1045    /// Predict derivatives of the variance at a set of points `x` specified as (n, nx) matrix.
1046    pub fn predict_var_gradients(
1047        &self,
1048        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1049    ) -> Result<Array2<f64>> {
1050        <GpMixture as GpSurrogateExt>::predict_var_gradients(self, &x.view())
1051    }
1052
1053    /// Predict derivatives of the outputs and variances at a set of points `x` specified as (n, nx) matrix.
1054    pub fn predict_valvar_gradients(
1055        &self,
1056        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1057    ) -> Result<(Array2<f64>, Array2<f64>)> {
1058        <GpMixture as GpSurrogateExt>::predict_valvar_gradients(self, &x.view())
1059    }
1060
1061    /// Sample `n_traj` trajectories at a set of points `x` specified as (n, nx) matrix.
1062    /// Returns the samples as a (n, n_traj) matrix where the ith row
1063    /// contain the samples of the output at the ith point.
1064    pub fn sample(
1065        &self,
1066        x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1067        n_traj: usize,
1068    ) -> Result<Array2<f64>> {
1069        <GpMixture as GpSurrogateExt>::sample(self, &x.view(), n_traj)
1070    }
1071
1072    // pub fn cv_quality(&self) -> f64 {
1073    //     let dataset = Dataset::new(self.xtrain.to_owned(), self.ytrain.to_owned());
1074    //     let mut error = 0.;
1075    //     for (train, valid) in dataset.fold(self.xtrain.nrows()).into_iter() {
1076    //         if let Ok(mixture) = GpMixtureParams::default()
1077    //             .kpls_dim(self.kpls_dim)
1078    //             .gmx(
1079    //                 self.gmx.weights().to_owned(),
1080    //                 self.gmx.means().to_owned(),
1081    //                 self.gmx.covariances().to_owned(),
1082    //             )
1083    //             .fit(&train)
1084    //         {
1085    //             let pred = mixture.predict(valid.records()).unwrap();
1086    //             error += (valid.targets() - pred).norm_l2();
1087    //         } else {
1088    //             error += f64::INFINITY;
1089    //         }
1090    //     }
1091    //     error / self.ytrain.std(1.)
1092    // }
1093
1094    /// Load Moe from the given file.
1095    #[cfg(feature = "persistent")]
1096    pub fn load(path: &str, format: GpFileFormat) -> Result<Box<GpMixture>> {
1097        let data = fs::read(path)?;
1098        let moe = match format {
1099            GpFileFormat::Json => serde_json::from_slice(&data)?,
1100            GpFileFormat::Binary => {
1101                bincode::serde::decode_from_slice(&data, bincode::config::standard())
1102                    .map(|(surrogate, _)| surrogate)?
1103            }
1104        };
1105        Ok(Box::new(moe))
1106    }
1107}
1108
1109/// Take one out of `quantile` in a set of data rows
1110/// Returns the selected part and the remaining data.
1111fn extract_part<F: Float>(
1112    data: &ArrayBase<impl Data<Elem = F>, Ix2>,
1113    quantile: usize,
1114) -> (Array2<F>, Array2<F>) {
1115    let nsamples = data.nrows();
1116    let indices = Array1::range(0., nsamples as f32, quantile as f32).mapv(|v| v as usize);
1117    let data_test = data.select(Axis(0), indices.as_slice().unwrap());
1118    let indices2: Vec<usize> = (0..nsamples).filter(|i| i % quantile != 0).collect();
1119    let data_train = data.select(Axis(0), &indices2);
1120    (data_test, data_train)
1121}
1122
1123impl<D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array1<f64>> for GpMixture {
1124    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<f64>) {
1125        assert_eq!(
1126            x.nrows(),
1127            y.len(),
1128            "The number of data points must match the number of output targets."
1129        );
1130
1131        let values = self.predict(x).expect("MoE prediction");
1132        *y = values;
1133    }
1134
1135    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<f64> {
1136        Array1::zeros(x.nrows())
1137    }
1138}
1139
1140/// Adaptator to implement `linfa::Predict` for variance prediction
1141#[allow(dead_code)]
1142pub struct MoeVariancePredictor<'a>(&'a GpMixture);
1143impl<D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array1<f64>>
1144    for MoeVariancePredictor<'_>
1145{
1146    fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<f64>) {
1147        assert_eq!(
1148            x.nrows(),
1149            y.len(),
1150            "The number of data points must match the number of output targets."
1151        );
1152
1153        let values = self.0.predict_var(x).expect("MoE variances prediction");
1154        *y = values;
1155    }
1156
1157    fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<f64> {
1158        Array1::zeros(x.nrows())
1159    }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165    use approx::assert_abs_diff_eq;
1166    use argmin_testfunctions::rosenbrock;
1167    use egobox_doe::{Lhs, SamplingMethod};
1168    use ndarray::{Array, Array2, Zip, array};
1169    use ndarray_npy::write_npy;
1170    use ndarray_rand::RandomExt;
1171    use ndarray_rand::rand::SeedableRng;
1172    use ndarray_rand::rand_distr::Uniform;
1173    use rand_xoshiro::Xoshiro256Plus;
1174
1175    fn f_test_1d(x: &Array2<f64>) -> Array1<f64> {
1176        let mut y = Array1::zeros(x.len());
1177        let x = Array::from_iter(x.iter().cloned());
1178        Zip::from(&mut y).and(&x).for_each(|yi, xi| {
1179            if *xi < 0.4 {
1180                *yi = xi * xi;
1181            } else if (0.4..0.8).contains(xi) {
1182                *yi = 3. * xi + 1.;
1183            } else {
1184                *yi = f64::sin(10. * xi);
1185            }
1186        });
1187        y
1188    }
1189
1190    fn df_test_1d(x: &Array2<f64>) -> Array2<f64> {
1191        let mut y = Array2::zeros(x.dim());
1192        Zip::from(y.rows_mut())
1193            .and(x.rows())
1194            .for_each(|mut yi, xi| {
1195                if xi[0] < 0.4 {
1196                    yi[0] = 2. * xi[0];
1197                } else if (0.4..0.8).contains(&xi[0]) {
1198                    yi[0] = 3.;
1199                } else {
1200                    yi[0] = 10. * f64::cos(10. * xi[0]);
1201                }
1202            });
1203        y
1204    }
1205
1206    #[test]
1207    fn test_moe_hard() {
1208        let mut rng = Xoshiro256Plus::seed_from_u64(0);
1209        let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
1210        let yt = f_test_1d(&xt.to_owned());
1211        let moe = GpMixture::params()
1212            .n_clusters(NbClusters::fixed(3))
1213            .regression_spec(RegressionSpec::CONSTANT)
1214            .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1215            .recombination(Recombination::Hard)
1216            .with_rng(rng)
1217            .fit(&Dataset::new(xt, yt))
1218            .expect("MOE fitted");
1219        let x = Array1::linspace(0., 1., 30).insert_axis(Axis(1));
1220        let preds = moe.predict(&x).expect("MOE prediction");
1221        let dpreds = moe.predict_gradients(&x).expect("MOE drv prediction");
1222        println!("dpred = {dpreds}");
1223        let test_dir = "target/tests";
1224        std::fs::create_dir_all(test_dir).ok();
1225        write_npy(format!("{test_dir}/x_hard.npy"), &x).expect("x saved");
1226        write_npy(format!("{test_dir}/preds_hard.npy"), &preds).expect("preds saved");
1227        write_npy(format!("{test_dir}/dpreds_hard.npy"), &dpreds).expect("dpreds saved");
1228        assert_abs_diff_eq!(
1229            0.39 * 0.39,
1230            moe.predict(&array![[0.39]]).unwrap()[0],
1231            epsilon = 1e-4
1232        );
1233        assert_abs_diff_eq!(
1234            f64::sin(10. * 0.82),
1235            moe.predict(&array![[0.82]]).unwrap()[0],
1236            epsilon = 1e-4
1237        );
1238        println!("LOOQ2 = {}", moe.q2_score());
1239    }
1240
1241    #[test]
1242    fn test_moe_smooth() {
1243        let test_dir = "target/tests";
1244        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1245        let xt = Array2::random_using((60, 1), Uniform::new(0., 1.), &mut rng);
1246        let yt = f_test_1d(&xt);
1247        let ds = Dataset::new(xt.to_owned(), yt.to_owned());
1248        let moe = GpMixture::params()
1249            .n_clusters(NbClusters::fixed(3))
1250            .recombination(Recombination::Smooth(Some(0.5)))
1251            .with_rng(rng.clone())
1252            .fit(&ds)
1253            .expect("MOE fitted");
1254        let x = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
1255        let preds = moe.predict(&x).expect("MOE prediction");
1256        write_npy(format!("{test_dir}/xt.npy"), &xt).expect("x saved");
1257        write_npy(format!("{test_dir}/yt.npy"), &yt).expect("preds saved");
1258        write_npy(format!("{test_dir}/x_smooth.npy"), &x).expect("x saved");
1259        write_npy(format!("{test_dir}/preds_smooth.npy"), &preds).expect("preds saved");
1260
1261        // Predict with smooth 0.5 which is not good
1262        println!("Smooth moe {moe}");
1263        assert_abs_diff_eq!(
1264            0.2623, // test we are not good as the true value = 0.37*0.37 = 0.1369
1265            moe.predict(&array![[0.37]]).unwrap()[0],
1266            epsilon = 1e-3
1267        );
1268
1269        // Predict with smooth adjusted automatically which is better
1270        let moe = GpMixture::params()
1271            .n_clusters(NbClusters::fixed(3))
1272            .recombination(Recombination::Smooth(None))
1273            .with_rng(rng.clone())
1274            .fit(&ds)
1275            .expect("MOE fitted");
1276        println!("Smooth moe {moe}");
1277
1278        std::fs::create_dir_all(test_dir).ok();
1279        let x = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
1280        let preds = moe.predict(&x).expect("MOE prediction");
1281        write_npy(format!("{test_dir}/x_smooth2.npy"), &x).expect("x saved");
1282        write_npy(format!("{test_dir}/preds_smooth2.npy"), &preds).expect("preds saved");
1283        assert_abs_diff_eq!(
1284            0.37 * 0.37, // true value of the function
1285            moe.predict(&array![[0.37]]).unwrap()[0],
1286            epsilon = 1e-3
1287        );
1288    }
1289
1290    #[test]
1291    fn test_moe_auto() {
1292        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1293        let xt = Array2::random_using((60, 1), Uniform::new(0., 1.), &mut rng);
1294        let yt = f_test_1d(&xt);
1295        let ds = Dataset::new(xt, yt.to_owned());
1296        let moe = GpMixture::params()
1297            .n_clusters(NbClusters::auto())
1298            .with_rng(rng.clone())
1299            .fit(&ds)
1300            .expect("MOE fitted");
1301        println!(
1302            "Moe auto: nb clusters={}, recomb={:?}",
1303            moe.n_clusters(),
1304            moe.recombination()
1305        );
1306        assert_abs_diff_eq!(
1307            0.37 * 0.37, // true value of the function
1308            moe.predict(&array![[0.37]]).unwrap()[0],
1309            epsilon = 1e-3
1310        );
1311    }
1312
1313    #[test]
1314    fn test_moe_variances_smooth() {
1315        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1316        let xt = Array2::random_using((100, 1), Uniform::new(0., 1.), &mut rng);
1317        let yt = f_test_1d(&xt);
1318        let moe = GpMixture::params()
1319            .n_clusters(NbClusters::fixed(3))
1320            .recombination(Recombination::Smooth(None))
1321            .regression_spec(RegressionSpec::CONSTANT)
1322            .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1323            .with_rng(rng.clone())
1324            .fit(&Dataset::new(xt, yt))
1325            .expect("MOE fitted");
1326        // Smoke test: prediction is pretty good hence variance is very low
1327        let x = Array1::linspace(0., 1., 20).insert_axis(Axis(1));
1328        let variances = moe.predict_var(&x).expect("MOE variances prediction");
1329        assert_abs_diff_eq!(*variances.max().unwrap(), 0., epsilon = 1e-10);
1330    }
1331
1332    fn xsinx(x: &[f64]) -> f64 {
1333        (x[0] - 3.5) * f64::sin((x[0] - 3.5) / std::f64::consts::PI)
1334    }
1335
1336    #[test]
1337    fn test_find_best_expert() {
1338        let mut rng = Xoshiro256Plus::seed_from_u64(0);
1339        let xt = Array2::random_using((10, 1), Uniform::new(0., 1.), &mut rng);
1340        let yt = xt.mapv(|x| xsinx(&[x]));
1341        let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1342        let moe = GpMixture::params().with_rng(rng).check_unwrap();
1343        let best_expert = &moe.find_best_expert(0, 1, &data).unwrap();
1344        println!("Best expert {best_expert}");
1345    }
1346
1347    #[test]
1348    fn test_find_best_heaviside_factor() {
1349        let mut rng = Xoshiro256Plus::seed_from_u64(0);
1350        let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
1351        let yt = f_test_1d(&xt);
1352        let _moe = GpMixture::params()
1353            .n_clusters(NbClusters::fixed(3))
1354            .with_rng(rng)
1355            .fit(&Dataset::new(xt, yt))
1356            .expect("MOE fitted");
1357    }
1358
1359    #[cfg(feature = "persistent")]
1360    #[test]
1361    fn test_save_load_moe() {
1362        let test_dir = "target/tests";
1363        std::fs::create_dir_all(test_dir).ok();
1364
1365        let mut rng = Xoshiro256Plus::seed_from_u64(0);
1366        let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
1367        let yt = f_test_1d(&xt);
1368        let ds = Dataset::new(xt, yt);
1369        let moe = GpMixture::params()
1370            .n_clusters(NbClusters::fixed(3))
1371            .with_rng(rng)
1372            .fit(&ds)
1373            .expect("MOE fitted");
1374        let xtest = array![[0.6]];
1375        let y_expected = moe.predict(&xtest).unwrap();
1376        let filename = format!("{test_dir}/saved_moe.json");
1377        moe.save(&filename, GpFileFormat::Json).expect("MoE saving");
1378        let new_moe = GpMixture::load(&filename, GpFileFormat::Json).expect("MoE loading");
1379        assert_abs_diff_eq!(y_expected, new_moe.predict(&xtest).unwrap(), epsilon = 1e-6);
1380    }
1381
1382    #[test]
1383    fn test_moe_drv_smooth() {
1384        let rng = Xoshiro256Plus::seed_from_u64(0);
1385        // Use regular evenly spaced data to avoid numerical issue
1386        // and getting a smooth surrogate modeling
1387        // Otherwise with Lhs and bad luck this test fails from time to time
1388        // when surrogate modeling happens to be wrong
1389        let xt = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
1390        let yt = f_test_1d(&xt);
1391
1392        let moe = GpMixture::params()
1393            .n_clusters(NbClusters::fixed(3))
1394            .regression_spec(RegressionSpec::CONSTANT)
1395            .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1396            .recombination(Recombination::Smooth(Some(0.5)))
1397            .with_rng(rng)
1398            .fit(&Dataset::new(xt, yt))
1399            .expect("MOE fitted");
1400        let x = Array1::linspace(0., 1., 50).insert_axis(Axis(1));
1401        let preds = moe.predict(&x).expect("MOE prediction");
1402        let dpreds = moe.predict_gradients(&x).expect("MOE drv prediction");
1403
1404        let test_dir = "target/tests";
1405        std::fs::create_dir_all(test_dir).ok();
1406        write_npy(format!("{test_dir}/x_moe_smooth.npy"), &x).expect("x saved");
1407        write_npy(format!("{test_dir}/preds_moe_smooth.npy"), &preds).expect("preds saved");
1408        write_npy(format!("{test_dir}/dpreds_moe_smooth.npy"), &dpreds).expect("dpreds saved");
1409
1410        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1411        for _ in 0..100 {
1412            let x1: f64 = rng.gen_range(0.1..0.9);
1413            let h = 1e-8;
1414            let xtest = array![[x1]];
1415
1416            let x = array![[x1], [x1 + h], [x1 - h]];
1417            let preds = moe.predict(&x).unwrap();
1418            let fdiff = (preds[1] - preds[2]) / (2. * h);
1419
1420            let drv = moe.predict_gradients(&xtest).unwrap();
1421            let df = df_test_1d(&xtest);
1422
1423            // Check only computed derivatives against fdiff of computed prediction
1424            // and fdiff can be wrong wrt to true derivatives due to bad surrogate modeling
1425            // specially at discontinuities hence no check against true derivative here
1426            let err = if drv[[0, 0]] < 1e-2 {
1427                (drv[[0, 0]] - fdiff).abs()
1428            } else {
1429                (drv[[0, 0]] - fdiff).abs() / drv[[0, 0]] // check relative error
1430            };
1431            println!(
1432                "Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1433            );
1434            println!("preds(x, x+h, x-h)={preds}");
1435            assert_abs_diff_eq!(err, 0.0, epsilon = 1e-1);
1436        }
1437    }
1438
1439    fn norm1(x: &Array2<f64>) -> Array2<f64> {
1440        x.mapv(|v| v.abs())
1441            .sum_axis(Axis(1))
1442            .insert_axis(Axis(1))
1443            .to_owned()
1444    }
1445
1446    fn rosenb(x: &Array2<f64>) -> Array2<f64> {
1447        let mut y: Array2<f64> = Array2::zeros((x.nrows(), 1));
1448        Zip::from(y.rows_mut())
1449            .and(x.rows())
1450            .par_for_each(|mut yi, xi| yi.assign(&array![rosenbrock(&xi.to_vec())]));
1451        y
1452    }
1453
1454    #[allow(clippy::excessive_precision)]
1455    fn test_variance_derivatives(f: fn(&Array2<f64>) -> Array2<f64>) {
1456        let rng = Xoshiro256Plus::seed_from_u64(0);
1457        let xt = egobox_doe::FullFactorial::new(&array![[-1., 1.], [-1., 1.]]).sample(100);
1458        let yt = f(&xt);
1459
1460        let moe = GpMixture::params()
1461            .n_clusters(NbClusters::fixed(2))
1462            .regression_spec(RegressionSpec::CONSTANT)
1463            .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1464            .recombination(Recombination::Smooth(Some(1.)))
1465            .with_rng(rng)
1466            .fit(&Dataset::new(xt, yt.remove_axis(Axis(1))))
1467            .expect("MOE fitted");
1468
1469        for _ in 0..20 {
1470            let mut rng = Xoshiro256Plus::seed_from_u64(42);
1471            let x = Array::random_using((2,), Uniform::new(0., 1.), &mut rng);
1472            let xa: f64 = x[0];
1473            let xb: f64 = x[1];
1474            let e = 1e-4;
1475
1476            println!("Test derivatives at [{xa}, {xb}]");
1477
1478            let x = array![
1479                [xa, xb],
1480                [xa + e, xb],
1481                [xa - e, xb],
1482                [xa, xb + e],
1483                [xa, xb - e]
1484            ];
1485            let y_pred = moe.predict(&x).unwrap();
1486            let y_deriv = moe.predict_gradients(&x).unwrap();
1487
1488            let diff_g = (y_pred[1] - y_pred[2]) / (2. * e);
1489            let diff_d = (y_pred[3] - y_pred[4]) / (2. * e);
1490
1491            assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
1492            assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
1493
1494            let y_pred = moe.predict_var(&x).unwrap();
1495            let y_deriv = moe.predict_var_gradients(&x).unwrap();
1496
1497            let diff_g = (y_pred[1] - y_pred[2]) / (2. * e);
1498            let diff_d = (y_pred[3] - y_pred[4]) / (2. * e);
1499
1500            assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
1501            assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
1502        }
1503    }
1504
1505    /// Test prediction valvar derivatives against derivatives and variance derivatives
1506    #[test]
1507    fn test_valvar_predictions() {
1508        let rng = Xoshiro256Plus::seed_from_u64(0);
1509        let xt = egobox_doe::FullFactorial::new(&array![[-1., 1.], [-1., 1.]]).sample(100);
1510        let yt = rosenb(&xt).remove_axis(Axis(1));
1511
1512        for corr in [
1513            CorrelationSpec::SQUAREDEXPONENTIAL,
1514            CorrelationSpec::MATERN32,
1515            CorrelationSpec::MATERN52,
1516        ] {
1517            println!("Test valvar derivatives with correlation {corr:?}");
1518            for recomb in [
1519                Recombination::Hard,
1520                Recombination::Smooth(Some(0.5)),
1521                Recombination::Smooth(None),
1522            ] {
1523                println!("Testing valvar derivatives with recomb={recomb:?}");
1524
1525                let moe = GpMixture::params()
1526                    .n_clusters(NbClusters::fixed(2))
1527                    .regression_spec(RegressionSpec::CONSTANT)
1528                    .correlation_spec(corr)
1529                    .recombination(recomb)
1530                    .with_rng(rng.clone())
1531                    .fit(&Dataset::new(xt.to_owned(), yt.to_owned()))
1532                    .expect("MOE fitted");
1533
1534                for _ in 0..10 {
1535                    let mut rng = Xoshiro256Plus::seed_from_u64(42);
1536                    let x = Array::random_using((2,), Uniform::new(0., 1.), &mut rng);
1537                    let xa: f64 = x[0];
1538                    let xb: f64 = x[1];
1539                    let e = 1e-4;
1540
1541                    let x = array![
1542                        [xa, xb],
1543                        [xa + e, xb],
1544                        [xa - e, xb],
1545                        [xa, xb + e],
1546                        [xa, xb - e]
1547                    ];
1548                    let (y_pred, v_pred) = moe.predict_valvar(&x).unwrap();
1549                    let (y_deriv, v_deriv) = moe.predict_valvar_gradients(&x).unwrap();
1550
1551                    let pred = moe.predict(&x).unwrap();
1552                    let var = moe.predict_var(&x).unwrap();
1553                    assert_abs_diff_eq!(y_pred, pred, epsilon = 1e-12);
1554                    assert_abs_diff_eq!(v_pred, var, epsilon = 1e-12);
1555
1556                    let deriv = moe.predict_gradients(&x).unwrap();
1557                    let vardrv = moe.predict_var_gradients(&x).unwrap();
1558                    assert_abs_diff_eq!(y_deriv, deriv, epsilon = 1e-12);
1559                    assert_abs_diff_eq!(v_deriv, vardrv, epsilon = 1e-12);
1560                }
1561            }
1562        }
1563    }
1564
1565    fn assert_rel_or_abs_error(y_deriv: f64, fdiff: f64) {
1566        println!("analytic deriv = {y_deriv}, fdiff = {fdiff}");
1567        if fdiff.abs() < 1e-2 {
1568            assert_abs_diff_eq!(y_deriv, 0.0, epsilon = 1e-1); // check absolute when close to zero
1569        } else {
1570            let drv_rel_error1 = (y_deriv - fdiff).abs() / fdiff; // check relative
1571            assert_abs_diff_eq!(drv_rel_error1, 0.0, epsilon = 1e-1);
1572        }
1573    }
1574
1575    #[test]
1576    fn test_moe_var_deriv_norm1() {
1577        test_variance_derivatives(norm1);
1578    }
1579    #[test]
1580    fn test_moe_var_deriv_rosenb() {
1581        test_variance_derivatives(rosenb);
1582    }
1583
1584    #[test]
1585    fn test_moe_display() {
1586        let rng = Xoshiro256Plus::seed_from_u64(0);
1587        let xt = Lhs::new(&array![[0., 1.]])
1588            .with_rng(rng.clone())
1589            .sample(100);
1590        let yt = f_test_1d(&xt);
1591
1592        let moe = GpMixture::params()
1593            .n_clusters(NbClusters::fixed(3))
1594            .regression_spec(RegressionSpec::CONSTANT)
1595            .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1596            .recombination(Recombination::Hard)
1597            .with_rng(rng)
1598            .fit(&Dataset::new(xt, yt))
1599            .expect("MOE fitted");
1600        // Values may vary depending on the platforms and linalg backends
1601        // assert_eq!("Mixture[Hard](Constant_SquaredExponentialGP(mean=ConstantMean, corr=SquaredExponential, theta=[0.03871601282054056], variance=[0.276011431746834], likelihood=454.17113736397033), Constant_SquaredExponentialGP(mean=ConstantMean, corr=SquaredExponential, theta=[0.07903503494417609], variance=[0.0077182164672893756], likelihood=436.39615700140183), Constant_SquaredExponentialGP(mean=ConstantMean, corr=SquaredExponential, theta=[0.050821466014058826], variance=[0.32824998062969973], likelihood=193.19339252734846))", moe.to_string());
1602        println!("Display moe: {moe}");
1603    }
1604
1605    fn griewank(x: &Array2<f64>) -> Array1<f64> {
1606        let dim = x.ncols();
1607        let d = Array1::linspace(1., dim as f64, dim).mapv(|v| v.sqrt());
1608        let mut y = Array1::zeros((x.nrows(),));
1609        Zip::from(&mut y).and(x.rows()).for_each(|y, x| {
1610            let s = x.mapv(|v| v * v).sum() / 4000.;
1611            let p = (x.to_owned() / &d)
1612                .mapv(|v| v.cos())
1613                .fold(1., |acc, x| acc * x);
1614            *y = s - p + 1.;
1615        });
1616        y
1617    }
1618
1619    #[test]
1620    fn test_kpls_griewank() {
1621        let dims = [100];
1622        let nts = [100];
1623        let lim = array![[-600., 600.]];
1624
1625        let test_dir = "target/tests";
1626        std::fs::create_dir_all(test_dir).ok();
1627
1628        (0..1).for_each(|i| {
1629            let dim = dims[i];
1630            let nt = nts[i];
1631            let xlimits = lim.broadcast((dim, 2)).unwrap();
1632
1633            let prefix = "griewank";
1634            let xfilename = format!("{test_dir}/{prefix}_xt_{nt}x{dim}.npy");
1635            let yfilename = format!("{test_dir}/{prefix}_yt_{nt}x1.npy");
1636
1637            let rng = Xoshiro256Plus::seed_from_u64(42);
1638            let xt = Lhs::new(&xlimits).with_rng(rng).sample(nt);
1639            write_npy(xfilename, &xt).expect("cannot save xt");
1640            let yt = griewank(&xt);
1641            write_npy(yfilename, &yt).expect("cannot save yt");
1642
1643            let gp = GpMixture::params()
1644                .n_clusters(NbClusters::default())
1645                .regression_spec(RegressionSpec::CONSTANT)
1646                .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1647                .kpls_dim(Some(3))
1648                .fit(&Dataset::new(xt, yt))
1649                .expect("GP fit error");
1650
1651            // To see file size : 100D => json ~ 1.2Mo, bin ~ 0.6Mo
1652            // gp.save("griewank.json", GpFileFormat::Json).unwrap();
1653            // gp.save("griewank.bin", GpFileFormat::Binary).unwrap();
1654
1655            let rng = Xoshiro256Plus::seed_from_u64(0);
1656            let xtest = Lhs::new(&xlimits).with_rng(rng).sample(100);
1657            let ytest = gp.predict(&xtest).expect("prediction error");
1658            let ytrue = griewank(&xtest);
1659
1660            let nrmse = (ytrue.to_owned() - &ytest).norm_l2() / ytrue.norm_l2();
1661            println!(
1662                "diff={}  ytrue={} nrsme={}",
1663                (ytrue.to_owned() - &ytest).norm_l2(),
1664                ytrue.norm_l2(),
1665                nrmse
1666            );
1667            assert_abs_diff_eq!(nrmse, 0., epsilon = 1e-2);
1668        });
1669    }
1670
1671    fn sphere(x: &Array2<f64>) -> Array1<f64> {
1672        (x * x)
1673            .sum_axis(Axis(1))
1674            .into_shape_with_order((x.nrows(),))
1675            .expect("Cannot reshape sphere output")
1676    }
1677
1678    #[test]
1679    fn test_moe_smooth_vs_hard_one_cluster() {
1680        let mut rng = Xoshiro256Plus::seed_from_u64(42);
1681        let xt = Array2::random_using((50, 2), Uniform::new(0., 1.), &mut rng);
1682        let yt = sphere(&xt);
1683        let ds = Dataset::new(xt, yt.to_owned());
1684
1685        // Fit hard
1686        let moe_hard = GpMixture::params()
1687            .n_clusters(NbClusters::fixed(1))
1688            .recombination(Recombination::Hard)
1689            .with_rng(rng.clone())
1690            .fit(&ds)
1691            .expect("MOE hard fitted");
1692
1693        // Fit smooth
1694        let moe_smooth = GpMixture::params()
1695            .n_clusters(NbClusters::fixed(1))
1696            .recombination(Recombination::Smooth(Some(1.0)))
1697            .with_rng(rng)
1698            .fit(&ds)
1699            .expect("MOE smooth fitted");
1700
1701        // Predict
1702        let mut rng = Xoshiro256Plus::seed_from_u64(43);
1703        let x = Array2::random_using((1, 2), Uniform::new(0., 1.), &mut rng);
1704        let preds_hard = moe_hard.predict(&x).expect("MOE hard prediction");
1705        let preds_smooth = moe_smooth.predict(&x).expect("MOE smooth prediction");
1706        println!("predict hard = {preds_hard} smooth = {preds_smooth}");
1707        assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1708
1709        // Predict var
1710        let preds_hard = moe_hard.predict_var(&x).expect("MOE hard prediction");
1711        let preds_smooth = moe_smooth.predict_var(&x).expect("MOE smooth prediction");
1712        assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1713
1714        // Predict gradients
1715        println!("Check pred gradients at x = {x}");
1716        let preds_smooth = moe_smooth
1717            .predict_gradients(&x)
1718            .expect("MOE smooth prediction");
1719        println!("smooth gradients = {preds_smooth}");
1720        let preds_hard = moe_hard.predict_gradients(&x).expect("MOE hard prediction");
1721        assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1722
1723        // Predict var gradients
1724        let preds_hard = moe_hard
1725            .predict_var_gradients(&x)
1726            .expect("MOE hard prediction");
1727        let preds_smooth = moe_smooth
1728            .predict_var_gradients(&x)
1729            .expect("MOE smooth prediction");
1730        assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1731    }
1732}