1use super::gaussian_mixture::GaussianMixture;
2use crate::clustering::{find_best_number_of_clusters, sort_by_cluster};
3use crate::errors::MoeError;
4use crate::errors::Result;
5use crate::parameters::{GpMixtureParams, GpMixtureValidParams};
6use crate::{GpMetrics, IaeAlphaPlotData, types::*};
7use crate::{GpType, expertise_macros::*};
8use crate::{NbClusters, surrogates::*};
9
10use egobox_gp::{GaussianProcess, SparseGaussianProcess, correlation_models::*, mean_models::*};
11use linfa::dataset::Records;
12use linfa::traits::{Fit, Predict, PredictInplace};
13use linfa::{Dataset, DatasetBase, Float, ParamGuard};
14use linfa_clustering::GaussianMixtureModel;
15use log::{debug, info, trace};
16use paste::paste;
17use std::cmp::Ordering;
18use std::ops::Sub;
19
20#[cfg(not(feature = "blas"))]
21use linfa_linalg::norm::*;
22use ndarray::{
23 Array1, Array2, Array3, ArrayBase, ArrayView2, Axis, Data, Ix1, Ix2, Zip, concatenate, s,
24};
25
26#[cfg(feature = "blas")]
27use ndarray_linalg::Norm;
28use ndarray_rand::rand::Rng;
29use ndarray_stats::QuantileExt;
30
31#[cfg(feature = "serializable")]
32use serde::{Deserialize, Serialize};
33#[cfg(feature = "persistent")]
34use std::fs;
35#[cfg(feature = "persistent")]
36use std::io::Write;
37
38macro_rules! check_allowed {
39 ($spec:ident, $model_kind:ident, $model:ident, $list:ident) => {
40 paste! {
41 if $spec.contains([< $model_kind Spec>]::[< $model:upper >]) {
42 $list.push(stringify!($model));
43 }
44 }
45 };
46}
47
48impl<D: Data<Elem = f64>> Fit<ArrayBase<D, Ix2>, ArrayBase<D, Ix1>, MoeError>
49 for GpMixtureValidParams<f64>
50{
51 type Object = GpMixture;
52
53 fn fit(
61 &self,
62 dataset: &DatasetBase<ArrayBase<D, Ix2>, ArrayBase<D, Ix1>>,
63 ) -> Result<Self::Object> {
64 let x = dataset.records();
65 let y = dataset.targets();
66 self.train(x, y)
67 }
68}
69
70impl GpMixtureValidParams<f64> {
71 pub fn train(
73 &self,
74 xt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
75 yt: &ArrayBase<impl Data<Elem = f64>, Ix1>,
76 ) -> Result<GpMixture> {
77 trace!("Moe training...");
78 let nx = xt.ncols();
79 let data = concatenate(
80 Axis(1),
81 &[xt.view(), yt.to_owned().insert_axis(Axis(1)).view()],
82 )
83 .unwrap();
84
85 let (n_clusters, recomb) = match self.n_clusters() {
86 NbClusters::Auto { max } => {
87 let max_nb_clusters = max.unwrap_or(xt.nrows() / 10 + 1);
89 info!("Find best number of clusters up to {max_nb_clusters}");
90 find_best_number_of_clusters(
91 xt,
92 yt,
93 max_nb_clusters,
94 self.kpls_dim(),
95 self.regression_spec(),
96 self.correlation_spec(),
97 self.rng(),
98 )
99 }
100 NbClusters::Fixed { nb: nb_clusters } => (nb_clusters, self.recombination()),
101 };
102 if let NbClusters::Auto { max: _ } = self.n_clusters() {
103 info!("Automatic settings {n_clusters} {recomb:?}");
104 }
105
106 let training = if recomb == Recombination::Smooth(None) && self.n_clusters().is_multi() {
107 let (_, training_data) = extract_part(&data, 5);
110 training_data
111 } else {
112 data.to_owned()
113 };
114 let dataset = Dataset::from(training);
115
116 let gmx = if self.gmx().is_some() {
117 self.gmx().unwrap().clone()
118 } else {
119 trace!("GMM training...");
120 let gmm = GaussianMixtureModel::params(n_clusters)
121 .n_runs(20)
122 .with_rng(self.rng())
123 .fit(&dataset)?;
124
125 let weights = gmm.weights().to_owned();
127 let means = gmm.means().slice(s![.., ..nx]).to_owned();
128 let covariances = gmm.covariances().slice(s![.., ..nx, ..nx]).to_owned();
129 let factor = match recomb {
130 Recombination::Smooth(Some(f)) => f,
131 Recombination::Smooth(_) => 1.,
132 Recombination::Hard => 1.,
133 };
134 GaussianMixture::new(weights, means, covariances)?.heaviside_factor(factor)
135 };
136
137 trace!("Train on clusters...");
138 let clustering = Clustering::new(gmx, recomb);
139 self.train_on_clusters(&xt.view(), &yt.view(), &clustering)
140 }
141
142 pub fn train_on_clusters(
145 &self,
146 xt: &ArrayBase<impl Data<Elem = f64>, Ix2>,
147 yt: &ArrayBase<impl Data<Elem = f64>, Ix1>,
148 clustering: &Clustering,
149 ) -> Result<GpMixture> {
150 let gmx = clustering.gmx();
151 let recomb = clustering.recombination();
152 let nx = xt.ncols();
153 let data = concatenate(
154 Axis(1),
155 &[xt.view(), yt.to_owned().insert_axis(Axis(1)).view()],
156 )
157 .unwrap();
158
159 let dataset_clustering = gmx.predict(xt);
160 let clusters = sort_by_cluster(gmx.n_clusters(), &data, &dataset_clustering);
161
162 check_number_of_points(&clusters, xt.ncols(), self.regression_spec())?;
163
164 let mut experts = Vec::new();
166 let nb_clusters = clusters.len();
167 for (nc, cluster) in clusters.iter().enumerate() {
168 if nb_clusters > 1 && cluster.nrows() < 3 {
169 return Err(MoeError::ClusteringError(format!(
170 "Not enough points in cluster, requires at least 3, got {}",
171 cluster.nrows()
172 )));
173 }
174 debug!("nc={} theta_tuning={:?}", nc, self.theta_tunings());
175 let expert = self.find_best_expert(nc, nx, cluster)?;
176 experts.push(expert);
177 }
178
179 if recomb == Recombination::Smooth(None) && self.n_clusters().is_multi() {
180 let (test, _) = extract_part(&data, 5);
183 let xtest = test.slice(s![.., ..nx]).to_owned();
184 let ytest = test.slice(s![.., nx..]).to_owned().remove_axis(Axis(1));
185 let factor = self.optimize_heaviside_factor(&experts, gmx, &xtest, &ytest);
186 info!("Retrain mixture with optimized heaviside factor={factor}");
187
188 let moe = GpMixtureParams::from(self.clone())
189 .n_clusters(NbClusters::fixed(gmx.n_clusters()))
190 .recombination(Recombination::Smooth(Some(factor)))
191 .check()?
192 .train(xt, yt)?; Ok(moe)
195 } else {
196 Ok(GpMixture {
197 gp_type: self.gp_type().clone(),
198 recombination: recomb,
199 experts,
200 gmx: gmx.clone(),
201 training_data: (xt.to_owned(), yt.to_owned()),
202 params: self.clone(),
203 })
204 }
205 }
206
207 fn find_best_expert(
210 &self,
211 nc: usize,
212 nx: usize,
213 data: &ArrayBase<impl Data<Elem = f64>, Ix2>,
214 ) -> Result<Box<dyn FullGpSurrogate>> {
215 let xtrain = data.slice(s![.., ..nx]).to_owned();
216 let ytrain = data.slice(s![.., nx..]).to_owned();
217 let mut dataset = Dataset::from((xtrain.clone(), ytrain.clone().remove_axis(Axis(1))));
218 let regression_spec = self.regression_spec();
219 let mut allowed_means = vec![];
220 check_allowed!(regression_spec, Regression, Constant, allowed_means);
221 check_allowed!(regression_spec, Regression, Linear, allowed_means);
222 check_allowed!(regression_spec, Regression, Quadratic, allowed_means);
223 let correlation_spec = self.correlation_spec();
224 let mut allowed_corrs = vec![];
225 check_allowed!(
226 correlation_spec,
227 Correlation,
228 SquaredExponential,
229 allowed_corrs
230 );
231 check_allowed!(
232 correlation_spec,
233 Correlation,
234 AbsoluteExponential,
235 allowed_corrs
236 );
237 check_allowed!(correlation_spec, Correlation, Matern32, allowed_corrs);
238 check_allowed!(correlation_spec, Correlation, Matern52, allowed_corrs);
239
240 debug!("Find best expert");
241 let best = if allowed_means.len() == 1 && allowed_corrs.len() == 1 {
242 (format!("{}_{}", allowed_means[0], allowed_corrs[0]), None) } else {
244 let mut map_error = Vec::new();
245 compute_errors!(self, allowed_means, allowed_corrs, dataset, map_error);
246 let errs: Vec<f64> = map_error.iter().map(|(_, err)| *err).collect();
247 debug!("Accuracies {map_error:?}");
248 let argmin = errs
249 .iter()
250 .enumerate()
251 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
252 .map(|(index, _)| index)
253 .unwrap();
254 (map_error[argmin].0.clone(), Some(map_error[argmin].1))
255 };
256 debug!("after Find best expert");
257
258 let expert = match self.gp_type() {
259 GpType::FullGp => {
260 let best_expert_params: std::result::Result<Box<dyn GpSurrogateParams>, MoeError> =
261 match best.0.as_str() {
262 "Constant_SquaredExponential" => {
263 Ok(make_surrogate_params!(Constant, SquaredExponential))
264 }
265 "Constant_AbsoluteExponential" => {
266 Ok(make_surrogate_params!(Constant, AbsoluteExponential))
267 }
268 "Constant_Matern32" => Ok(make_surrogate_params!(Constant, Matern32)),
269 "Constant_Matern52" => Ok(make_surrogate_params!(Constant, Matern52)),
270 "Linear_SquaredExponential" => {
271 Ok(make_surrogate_params!(Linear, SquaredExponential))
272 }
273 "Linear_AbsoluteExponential" => {
274 Ok(make_surrogate_params!(Linear, AbsoluteExponential))
275 }
276 "Linear_Matern32" => Ok(make_surrogate_params!(Linear, Matern32)),
277 "Linear_Matern52" => Ok(make_surrogate_params!(Linear, Matern52)),
278 "Quadratic_SquaredExponential" => {
279 Ok(make_surrogate_params!(Quadratic, SquaredExponential))
280 }
281 "Quadratic_AbsoluteExponential" => {
282 Ok(make_surrogate_params!(Quadratic, AbsoluteExponential))
283 }
284 "Quadratic_Matern32" => Ok(make_surrogate_params!(Quadratic, Matern32)),
285 "Quadratic_Matern52" => Ok(make_surrogate_params!(Quadratic, Matern52)),
286 _ => {
287 return Err(MoeError::ExpertError(format!(
288 "Unknown expert {}",
289 best.0
290 )));
291 }
292 };
293 let mut expert_params = best_expert_params?;
294 expert_params.n_start(self.n_start());
295 expert_params.max_eval(self.max_eval());
296 expert_params.kpls_dim(self.kpls_dim());
297 if nc > 0 && self.theta_tunings().len() == 1 {
298 expert_params.theta_tuning(self.theta_tunings()[0].clone());
299 } else {
300 debug!("Training with theta_tuning = {:?}.", self.theta_tunings());
301 expert_params.theta_tuning(self.theta_tunings()[nc].clone());
302 }
303 debug!("Train best expert...");
304 expert_params.train(&xtrain.view(), &ytrain.view())
305 }
306 GpType::SparseGp {
307 inducings,
308 sparse_method,
309 ..
310 } => {
311 let inducings = inducings.to_owned();
312 let best_expert_params: std::result::Result<Box<dyn SgpSurrogateParams>, MoeError> =
313 match best.0.as_str() {
314 "Constant_SquaredExponential" => {
315 Ok(make_sgp_surrogate_params!(SquaredExponential, inducings))
316 }
317 "Constant_AbsoluteExponential" => {
318 Ok(make_sgp_surrogate_params!(AbsoluteExponential, inducings))
319 }
320 "Constant_Matern32" => Ok(make_sgp_surrogate_params!(Matern32, inducings)),
321 "Constant_Matern52" => Ok(make_sgp_surrogate_params!(Matern52, inducings)),
322 _ => {
323 return Err(MoeError::ExpertError(format!(
324 "Unknown expert {}",
325 best.0
326 )));
327 }
328 };
329 let mut expert_params = best_expert_params?;
330 let seed = self.rng().r#gen();
331 debug!("Theta tuning = {:?}", self.theta_tunings());
332 expert_params.sparse_method(*sparse_method);
333 expert_params.seed(seed);
334 expert_params.n_start(self.n_start());
335 expert_params.kpls_dim(self.kpls_dim());
336 expert_params.theta_tuning(self.theta_tunings()[0].clone());
337 debug!("Train best expert...");
338 expert_params.train(&xtrain.view(), &ytrain.view())
339 }
340 };
341
342 debug!("...after best expert training");
343 if let Some(v) = best.1 {
344 info!("Best expert {} accuracy={}", best.0, v);
345 }
346 expert
347 }
348
349 fn optimize_heaviside_factor(
354 &self,
355 experts: &[Box<dyn FullGpSurrogate>],
356 gmx: &GaussianMixture<f64>,
357 xtest: &ArrayBase<impl Data<Elem = f64>, Ix2>,
358 ytest: &ArrayBase<impl Data<Elem = f64>, Ix1>,
359 ) -> f64 {
360 if self.recombination() == Recombination::Hard || self.n_clusters().is_mono() {
361 1.
362 } else {
363 let scale_factors = Array1::linspace(0.1, 2.1, 20);
364 let errors = scale_factors.map(move |&factor| {
365 let gmx2 = gmx.clone();
366 let gmx2 = gmx2.heaviside_factor(factor);
367 let pred = predict_smooth(experts, &gmx2, xtest).unwrap();
368 pred.sub(ytest).mapv(|x| x * x).sum().sqrt() / xtest.mapv(|x| x * x).sum().sqrt()
369 });
370
371 let min_error_index = errors.argmin().unwrap();
372 if *errors.max().unwrap() < 1e-6 {
373 1.
374 } else {
375 scale_factors[min_error_index]
376 }
377 }
378 }
379}
380
381fn check_number_of_points<F>(
382 clusters: &[ArrayBase<impl Data<Elem = F>, Ix2>],
383 dim: usize,
384 regr: RegressionSpec,
385) -> Result<()> {
386 if clusters.len() > 1 {
387 let min_number_point = if regr.contains(RegressionSpec::QUADRATIC) {
388 (dim + 1) * (dim + 2) / 2
389 } else if regr.contains(RegressionSpec::LINEAR) {
390 dim + 1
391 } else {
392 1
393 };
394 for cluster in clusters {
395 if cluster.len() < min_number_point {
396 return Err(MoeError::ClusteringError(format!(
397 "Not enough points in training set. Need {} points, got {}",
398 min_number_point,
399 cluster.len()
400 )));
401 }
402 }
403 }
404 Ok(())
405}
406
407fn predict_smooth(
412 experts: &[Box<dyn FullGpSurrogate>],
413 gmx: &GaussianMixture<f64>,
414 points: &ArrayBase<impl Data<Elem = f64>, Ix2>,
415) -> Result<Array1<f64>> {
416 let probas = gmx.predict_probas(points);
417 let preds: Array1<f64> = experts
418 .iter()
419 .enumerate()
420 .map(|(i, gp)| gp.predict(&points.view()).unwrap() * probas.column(i))
421 .fold(Array1::zeros((points.nrows(),)), |acc, pred| acc + pred);
422 Ok(preds)
423}
424
425#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
430pub struct GpMixture {
431 recombination: Recombination<f64>,
433 experts: Vec<Box<dyn FullGpSurrogate>>,
435 gmx: GaussianMixture<f64>,
437 gp_type: GpType<f64>,
439 training_data: (Array2<f64>, Array1<f64>),
441 params: GpMixtureValidParams<f64>,
443}
444
445impl std::fmt::Display for GpMixture {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 let recomb = match self.recombination() {
448 Recombination::Hard => "Hard".to_string(),
449 Recombination::Smooth(Some(f)) => format!("Smooth({f})"),
450 Recombination::Smooth(_) => "Smooth".to_string(),
451 };
452 let experts = self
453 .experts
454 .iter()
455 .map(|expert| expert.to_string())
456 .reduce(|acc, s| acc + ", " + &s)
457 .unwrap();
458 write!(f, "Mixture[{}]({})", &recomb, &experts)
459 }
460}
461
462impl Clustered for GpMixture {
463 fn n_clusters(&self) -> usize {
465 self.gmx.n_clusters()
466 }
467
468 fn recombination(&self) -> Recombination<f64> {
470 self.recombination()
471 }
472
473 fn to_clustering(&self) -> Clustering {
475 Clustering {
476 recombination: self.recombination(),
477 gmx: self.gmx.clone(),
478 }
479 }
480}
481
482#[cfg_attr(feature = "serializable", typetag::serde)]
483impl GpSurrogate for GpMixture {
484 fn dims(&self) -> (usize, usize) {
485 self.experts[0].dims()
486 }
487
488 fn predict(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
489 match self.recombination {
490 Recombination::Hard => self.predict_hard(x),
491 Recombination::Smooth(_) => self.predict_smooth(x),
492 }
493 }
494
495 fn predict_var(&self, x: &ArrayView2<f64>) -> Result<Array1<f64>> {
496 match self.recombination {
497 Recombination::Hard => self.predict_var_hard(x),
498 Recombination::Smooth(_) => self.predict_var_smooth(x),
499 }
500 }
501
502 fn predict_valvar(&self, x: &ArrayView2<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
503 match self.recombination {
504 Recombination::Hard => self.predict_valvar_hard(x),
505 Recombination::Smooth(_) => self.predict_valvar_smooth(x),
506 }
507 }
508
509 #[cfg(feature = "persistent")]
511 fn save(&self, path: &str, format: GpFileFormat) -> Result<()> {
512 let mut file = fs::File::create(path).unwrap();
513
514 let bytes = match format {
515 GpFileFormat::Json => serde_json::to_vec(self).map_err(MoeError::SaveJsonError)?,
516 GpFileFormat::Binary => {
517 bincode::serde::encode_to_vec(self, bincode::config::standard())
518 .map_err(MoeError::SaveBinaryError)?
519 }
520 };
521 file.write_all(&bytes)?;
522
523 Ok(())
524 }
525}
526
527#[cfg_attr(feature = "serializable", typetag::serde)]
528impl GpSurrogateExt for GpMixture {
529 fn predict_gradients(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
530 match self.recombination {
531 Recombination::Hard => self.predict_gradients_hard(x),
532 Recombination::Smooth(_) => self.predict_gradients_smooth(x),
533 }
534 }
535
536 fn predict_var_gradients(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>> {
537 match self.recombination {
538 Recombination::Hard => self.predict_var_gradients_hard(x),
539 Recombination::Smooth(_) => self.predict_var_gradients_smooth(x),
540 }
541 }
542
543 fn predict_valvar_gradients(&self, x: &ArrayView2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
544 match self.recombination {
545 Recombination::Hard => self.predict_valvar_gradients_hard(x),
546 Recombination::Smooth(_) => self.predict_valvar_gradients_smooth(x),
547 }
548 }
549
550 fn sample(&self, x: &ArrayView2<f64>, n_traj: usize) -> Result<Array2<f64>> {
551 if self.n_clusters() != 1 {
552 return Err(MoeError::SampleError(format!(
553 "Can not sample when several clusters {}",
554 self.n_clusters()
555 )));
556 }
557 self.sample_expert(0, x, n_traj)
558 }
559}
560
561impl GpMetrics<MoeError, GpMixtureParams<f64>, Self> for GpMixture {
562 fn training_data(&self) -> &(Array2<f64>, Array1<f64>) {
563 &self.training_data
564 }
565
566 fn params(&self) -> GpMixtureParams<f64> {
567 GpMixtureParams::<f64>::from(self.params.clone())
568 }
569}
570
571#[cfg_attr(feature = "serializable", typetag::serde)]
572impl GpQualityAssurance for GpMixture {
573 fn training_data(&self) -> &(Array2<f64>, Array1<f64>) {
574 (self as &dyn GpMetrics<_, _, _>).training_data()
575 }
576
577 fn q2_k(&self, kfold: usize) -> f64 {
578 (self as &dyn GpMetrics<_, _, _>).q2_k_score(kfold)
579 }
580 fn q2(&self) -> f64 {
581 (self as &dyn GpMetrics<_, _, _>).q2_score()
582 }
583
584 fn pva_k(&self, kfold: usize) -> f64 {
585 (self as &dyn GpMetrics<_, _, _>).pva_k_score(kfold)
586 }
587 fn pva(&self) -> f64 {
588 (self as &dyn GpMetrics<_, _, _>).pva_score()
589 }
590
591 fn iae_alpha_k(&self, kfold: usize) -> f64 {
592 (self as &dyn GpMetrics<_, _, _>).iae_alpha_k_score(kfold, None)
593 }
594 fn iae_alpha_k_score_with_plot(&self, kfold: usize, plot_data: &mut IaeAlphaPlotData) -> f64 {
595 (self as &dyn GpMetrics<_, _, _>).iae_alpha_k_score(kfold, Some(plot_data))
596 }
597 fn iae_alpha(&self) -> f64 {
598 (self as &dyn GpMetrics<_, _, _>).iae_alpha_score(None)
599 }
600}
601
602#[cfg_attr(feature = "serializable", typetag::serde)]
603impl MixtureGpSurrogate for GpMixture {
604 fn experts(&self) -> &Vec<Box<dyn FullGpSurrogate>> {
606 &self.experts
607 }
608}
609
610impl GpMixture {
611 pub fn params() -> GpMixtureParams<f64> {
613 GpMixtureParams::new()
614 }
615
616 pub fn gp_type(&self) -> &GpType<f64> {
618 &self.gp_type
619 }
620
621 pub fn recombination(&self) -> Recombination<f64> {
623 self.recombination
624 }
625
626 pub fn gmx(&self) -> &GaussianMixture<f64> {
628 &self.gmx
629 }
630
631 pub fn set_recombination(mut self, recombination: Recombination<f64>) -> Self {
633 self.recombination = match recombination {
634 Recombination::Hard => recombination,
635 Recombination::Smooth(Some(_)) => recombination,
636 Recombination::Smooth(_) => Recombination::Smooth(Some(1.)),
637 };
638 self
639 }
640
641 pub fn set_gmx(
643 mut self,
644 weights: Array1<f64>,
645 means: Array2<f64>,
646 covariances: Array3<f64>,
647 ) -> Self {
648 self.gmx = GaussianMixture::new(weights, means, covariances).unwrap();
649 self
650 }
651
652 pub fn set_experts(mut self, experts: Vec<Box<dyn FullGpSurrogate>>) -> Self {
654 self.experts = experts;
655 self
656 }
657
658 pub fn predict_smooth(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
663 predict_smooth(&self.experts, &self.gmx, x)
664 }
665
666 pub fn predict_var_smooth(
671 &self,
672 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
673 ) -> Result<Array1<f64>> {
674 let probas = self.gmx.predict_probas(x);
675 let preds: Array1<f64> = self
676 .experts
677 .iter()
678 .enumerate()
679 .map(|(i, gp)| {
680 let p = probas.column(i);
681 gp.predict_var(&x.view()).unwrap() * p * p
682 })
683 .fold(Array1::zeros(x.nrows()), |acc, var| acc + var);
684 Ok(preds)
685 }
686
687 pub fn predict_gradients_smooth(
692 &self,
693 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
694 ) -> Result<Array2<f64>> {
695 let probas = self.gmx.predict_probas(x);
696 let probas_drv = self.gmx.predict_probas_derivatives(x);
697 let mut drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
698
699 Zip::from(drv.rows_mut())
700 .and(x.rows())
701 .and(probas.rows())
702 .and(probas_drv.outer_iter())
703 .for_each(|mut y, x, p, pprime| {
704 let x = x.insert_axis(Axis(0));
705 let preds: Array1<f64> = self
706 .experts
707 .iter()
708 .map(|gp| gp.predict(&x).unwrap()[0])
709 .collect();
710 let drvs: Vec<Array1<f64>> = self
711 .experts
712 .iter()
713 .map(|gp| gp.predict_gradients(&x).unwrap().row(0).to_owned())
714 .collect();
715
716 let preds = preds.insert_axis(Axis(1));
717 let mut preds_drv = Array2::zeros((self.experts.len(), x.len()));
718 Zip::indexed(preds_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&drvs[i]));
719
720 let mut term1 = Array2::zeros((self.experts.len(), x.len()));
721 Zip::from(term1.rows_mut())
722 .and(&p)
723 .and(preds_drv.rows())
724 .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p))));
725 let term1 = term1.sum_axis(Axis(0));
726
727 let term2 = pprime.to_owned() * preds;
728 let term2 = term2.sum_axis(Axis(0));
729
730 y.assign(&(term1 + term2));
731 });
732 Ok(drv)
733 }
734
735 pub fn predict_var_gradients_smooth(
740 &self,
741 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
742 ) -> Result<Array2<f64>> {
743 let probas = self.gmx.predict_probas(x);
744 let probas_drv = self.gmx.predict_probas_derivatives(x);
745
746 let mut drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
747
748 Zip::from(drv.rows_mut())
749 .and(x.rows())
750 .and(probas.rows())
751 .and(probas_drv.outer_iter())
752 .for_each(|mut y, xi, p, pprime| {
753 let xii = xi.insert_axis(Axis(0));
754 let preds: Array1<f64> = self
755 .experts
756 .iter()
757 .map(|gp| gp.predict_var(&xii).unwrap()[0])
758 .collect();
759 let drvs: Vec<Array1<f64>> = self
760 .experts
761 .iter()
762 .map(|gp| gp.predict_var_gradients(&xii).unwrap().row(0).to_owned())
763 .collect();
764
765 let preds = preds.insert_axis(Axis(1));
766 let mut preds_drv = Array2::zeros((self.experts.len(), xi.len()));
767 Zip::indexed(preds_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&drvs[i]));
768
769 let mut term1 = Array2::zeros((self.experts.len(), xi.len()));
770 Zip::from(term1.rows_mut())
771 .and(&p)
772 .and(preds_drv.rows())
773 .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p * p))));
774 let term1 = term1.sum_axis(Axis(0));
775
776 let term2 = (p.to_owned() * pprime * preds).mapv(|v| 2. * v);
777 let term2 = term2.sum_axis(Axis(0));
778
779 y.assign(&(term1 + term2));
780 });
781
782 Ok(drv)
783 }
784
785 pub fn predict_valvar_smooth(
790 &self,
791 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
792 ) -> Result<(Array1<f64>, Array1<f64>)> {
793 let probas = self.gmx.predict_probas(x);
794 let valvar: (Array1<f64>, Array1<f64>) = self
795 .experts
796 .iter()
797 .enumerate()
798 .map(|(i, gp)| {
799 let p = probas.column(i);
800 let (pred, var) = gp.predict_valvar(&x.view()).unwrap();
801 (pred * p, var * p * p)
802 })
803 .fold(
804 (Array1::zeros((x.nrows(),)), Array1::zeros((x.nrows(),))),
805 |acc, (pred, var)| (acc.0 + pred, acc.1 + var),
806 );
807
808 Ok(valvar)
809 }
810
811 fn predict_valvar_gradients_smooth(
812 &self,
813 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
814 ) -> Result<(Array2<f64>, Array2<f64>)> {
815 let probas = self.gmx.predict_probas(x);
816 let probas_drv = self.gmx.predict_probas_derivatives(x);
817
818 let mut val_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
819 let mut var_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
820
821 Zip::from(val_drv.rows_mut())
822 .and(var_drv.rows_mut())
823 .and(x.rows())
824 .and(probas.rows())
825 .and(probas_drv.outer_iter())
826 .for_each(|mut val_y, mut var_y, xi, p, pprime| {
827 let xii = xi.insert_axis(Axis(0));
828 let (preds, vars): (Vec<f64>, Vec<f64>) = self
829 .experts
830 .iter()
831 .map(|gp| {
832 let (pred, var) = gp.predict_valvar(&xii).unwrap();
833 (pred[0], var[0])
834 })
835 .unzip();
836 let preds: Array2<f64> = Array1::from(preds).insert_axis(Axis(1));
837 let vars: Array2<f64> = Array1::from(vars).insert_axis(Axis(1));
838 let (drvs, var_drvs): (Vec<Array1<f64>>, Vec<Array1<f64>>) = self
839 .experts
840 .iter()
841 .map(|gp| {
842 let (predg, varg) = gp.predict_valvar_gradients(&xii).unwrap();
843 (predg.row(0).to_owned(), varg.row(0).to_owned())
844 })
845 .unzip();
846
847 let mut preds_drv = Array2::zeros((self.experts.len(), xi.len()));
848 let mut vars_drv = Array2::zeros((self.experts.len(), xi.len()));
849 Zip::indexed(preds_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&drvs[i]));
850 Zip::indexed(vars_drv.rows_mut()).for_each(|i, mut jc| jc.assign(&var_drvs[i]));
851
852 let mut val_term1 = Array2::zeros((self.experts.len(), xi.len()));
853 Zip::from(val_term1.rows_mut())
854 .and(&p)
855 .and(preds_drv.rows())
856 .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p))));
857 let val_term1 = val_term1.sum_axis(Axis(0));
858 let val_term2 = pprime.to_owned() * preds;
859 let val_term2 = val_term2.sum_axis(Axis(0));
860 val_y.assign(&(val_term1 + val_term2));
861
862 let mut var_term1 = Array2::zeros((self.experts.len(), xi.len()));
863 Zip::from(var_term1.rows_mut())
864 .and(&p)
865 .and(vars_drv.rows())
866 .for_each(|mut t, p, der| t.assign(&(der.to_owned().mapv(|v| v * p * p))));
867 let var_term1 = var_term1.sum_axis(Axis(0));
868 let var_term2 = (p.to_owned() * pprime * vars).mapv(|v| 2. * v);
869 let var_term2 = var_term2.sum_axis(Axis(0));
870 var_y.assign(&(var_term1 + var_term2));
871 });
872 Ok((val_drv, var_drv))
873 }
874
875 pub fn predict_hard(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
880 let clustering = self.gmx.predict(x);
881 trace!("Clustering {clustering:?}");
882 let mut preds = Array1::zeros((x.nrows(),));
883 Zip::from(&mut preds)
884 .and(x.rows())
885 .and(&clustering)
886 .for_each(|y, x, &c| *y = self.experts[c].predict(&x.insert_axis(Axis(0))).unwrap()[0]);
887 Ok(preds)
888 }
889
890 pub fn predict_var_hard(
895 &self,
896 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
897 ) -> Result<Array1<f64>> {
898 let clustering = self.gmx.predict(x);
899 trace!("Clustering {clustering:?}");
900 let mut variances = Array1::zeros(x.nrows());
901 Zip::from(&mut variances)
902 .and(x.rows())
903 .and(&clustering)
904 .for_each(|y, x, &c| {
905 *y = self.experts[c]
906 .predict_var(&x.insert_axis(Axis(0)))
907 .unwrap()[0];
908 });
909 Ok(variances)
910 }
911
912 pub fn predict_valvar_hard(
916 &self,
917 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
918 ) -> Result<(Array1<f64>, Array1<f64>)> {
919 let clustering = self.gmx.predict(x);
920 trace!("Clustering {clustering:?}");
921 let mut preds = Array1::zeros((x.nrows(),));
922 let mut variances = Array1::zeros(x.nrows());
923 Zip::from(&mut preds)
924 .and(&mut variances)
925 .and(x.rows())
926 .and(&clustering)
927 .for_each(|y, v, x, &c| {
928 let (pred, var) = self.experts[c]
929 .predict_valvar(&x.insert_axis(Axis(0)))
930 .unwrap();
931 *y = pred[0];
932 *v = var[0];
933 });
934 Ok((preds, variances))
935 }
936
937 pub fn predict_gradients_hard(
943 &self,
944 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
945 ) -> Result<Array2<f64>> {
946 let mut drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
947 let clustering = self.gmx.predict(x);
948 Zip::from(drv.rows_mut())
949 .and(x.rows())
950 .and(&clustering)
951 .for_each(|mut drv_i, xi, &c| {
952 let x = xi.to_owned().insert_axis(Axis(0));
953 let x_drv: ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> =
954 self.experts[c].predict_gradients(&x.view()).unwrap();
955 drv_i.assign(&x_drv.row(0))
956 });
957 Ok(drv)
958 }
959
960 pub fn predict_var_gradients_hard(
966 &self,
967 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
968 ) -> Result<Array2<f64>> {
969 let mut vardrv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
970 let clustering = self.gmx.predict(x);
971 Zip::from(vardrv.rows_mut())
972 .and(x.rows())
973 .and(&clustering)
974 .for_each(|mut vardrv_i, xi, &c| {
975 let x = xi.to_owned().insert_axis(Axis(0));
976 let x_vardrv: ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>> =
977 self.experts[c].predict_var_gradients(&x.view()).unwrap();
978 vardrv_i.assign(&x_vardrv.row(0))
979 });
980 Ok(vardrv)
981 }
982
983 pub fn predict_valvar_gradients_hard(
987 &self,
988 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
989 ) -> Result<(Array2<f64>, Array2<f64>)> {
990 let mut val_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
991 let mut var_drv = Array2::<f64>::zeros((x.nrows(), x.ncols()));
992 let clustering = self.gmx.predict(x);
993 Zip::from(val_drv.rows_mut())
994 .and(var_drv.rows_mut())
995 .and(x.rows())
996 .and(&clustering)
997 .for_each(|mut val_y, mut var_y, xi, &c| {
998 let x = xi.to_owned().insert_axis(Axis(0));
999 let (x_val_drv, x_var_drv) =
1000 self.experts[c].predict_valvar_gradients(&x.view()).unwrap();
1001 val_y.assign(&x_val_drv.row(0));
1002 var_y.assign(&x_var_drv.row(0));
1003 });
1004 Ok((val_drv, var_drv))
1005 }
1006
1007 pub fn sample_expert(
1011 &self,
1012 ith: usize,
1013 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1014 n_traj: usize,
1015 ) -> Result<Array2<f64>> {
1016 self.experts[ith].sample(&x.view(), n_traj)
1017 }
1018
1019 pub fn predict(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
1021 <GpMixture as GpSurrogate>::predict(self, &x.view())
1022 }
1023
1024 pub fn predict_var(&self, x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Result<Array1<f64>> {
1026 <GpMixture as GpSurrogate>::predict_var(self, &x.view())
1027 }
1028
1029 pub fn predict_valvar(
1031 &self,
1032 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1033 ) -> Result<(Array1<f64>, Array1<f64>)> {
1034 <GpMixture as GpSurrogate>::predict_valvar(self, &x.view())
1035 }
1036
1037 pub fn predict_gradients(
1039 &self,
1040 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1041 ) -> Result<Array2<f64>> {
1042 <GpMixture as GpSurrogateExt>::predict_gradients(self, &x.view())
1043 }
1044
1045 pub fn predict_var_gradients(
1047 &self,
1048 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1049 ) -> Result<Array2<f64>> {
1050 <GpMixture as GpSurrogateExt>::predict_var_gradients(self, &x.view())
1051 }
1052
1053 pub fn predict_valvar_gradients(
1055 &self,
1056 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1057 ) -> Result<(Array2<f64>, Array2<f64>)> {
1058 <GpMixture as GpSurrogateExt>::predict_valvar_gradients(self, &x.view())
1059 }
1060
1061 pub fn sample(
1065 &self,
1066 x: &ArrayBase<impl Data<Elem = f64>, Ix2>,
1067 n_traj: usize,
1068 ) -> Result<Array2<f64>> {
1069 <GpMixture as GpSurrogateExt>::sample(self, &x.view(), n_traj)
1070 }
1071
1072 #[cfg(feature = "persistent")]
1096 pub fn load(path: &str, format: GpFileFormat) -> Result<Box<GpMixture>> {
1097 let data = fs::read(path)?;
1098 let moe = match format {
1099 GpFileFormat::Json => serde_json::from_slice(&data)?,
1100 GpFileFormat::Binary => {
1101 bincode::serde::decode_from_slice(&data, bincode::config::standard())
1102 .map(|(surrogate, _)| surrogate)?
1103 }
1104 };
1105 Ok(Box::new(moe))
1106 }
1107}
1108
1109fn extract_part<F: Float>(
1112 data: &ArrayBase<impl Data<Elem = F>, Ix2>,
1113 quantile: usize,
1114) -> (Array2<F>, Array2<F>) {
1115 let nsamples = data.nrows();
1116 let indices = Array1::range(0., nsamples as f32, quantile as f32).mapv(|v| v as usize);
1117 let data_test = data.select(Axis(0), indices.as_slice().unwrap());
1118 let indices2: Vec<usize> = (0..nsamples).filter(|i| i % quantile != 0).collect();
1119 let data_train = data.select(Axis(0), &indices2);
1120 (data_test, data_train)
1121}
1122
1123impl<D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array1<f64>> for GpMixture {
1124 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<f64>) {
1125 assert_eq!(
1126 x.nrows(),
1127 y.len(),
1128 "The number of data points must match the number of output targets."
1129 );
1130
1131 let values = self.predict(x).expect("MoE prediction");
1132 *y = values;
1133 }
1134
1135 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<f64> {
1136 Array1::zeros(x.nrows())
1137 }
1138}
1139
1140#[allow(dead_code)]
1142pub struct MoeVariancePredictor<'a>(&'a GpMixture);
1143impl<D: Data<Elem = f64>> PredictInplace<ArrayBase<D, Ix2>, Array1<f64>>
1144 for MoeVariancePredictor<'_>
1145{
1146 fn predict_inplace(&self, x: &ArrayBase<D, Ix2>, y: &mut Array1<f64>) {
1147 assert_eq!(
1148 x.nrows(),
1149 y.len(),
1150 "The number of data points must match the number of output targets."
1151 );
1152
1153 let values = self.0.predict_var(x).expect("MoE variances prediction");
1154 *y = values;
1155 }
1156
1157 fn default_target(&self, x: &ArrayBase<D, Ix2>) -> Array1<f64> {
1158 Array1::zeros(x.nrows())
1159 }
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165 use approx::assert_abs_diff_eq;
1166 use argmin_testfunctions::rosenbrock;
1167 use egobox_doe::{Lhs, SamplingMethod};
1168 use ndarray::{Array, Array2, Zip, array};
1169 use ndarray_npy::write_npy;
1170 use ndarray_rand::RandomExt;
1171 use ndarray_rand::rand::SeedableRng;
1172 use ndarray_rand::rand_distr::Uniform;
1173 use rand_xoshiro::Xoshiro256Plus;
1174
1175 fn f_test_1d(x: &Array2<f64>) -> Array1<f64> {
1176 let mut y = Array1::zeros(x.len());
1177 let x = Array::from_iter(x.iter().cloned());
1178 Zip::from(&mut y).and(&x).for_each(|yi, xi| {
1179 if *xi < 0.4 {
1180 *yi = xi * xi;
1181 } else if (0.4..0.8).contains(xi) {
1182 *yi = 3. * xi + 1.;
1183 } else {
1184 *yi = f64::sin(10. * xi);
1185 }
1186 });
1187 y
1188 }
1189
1190 fn df_test_1d(x: &Array2<f64>) -> Array2<f64> {
1191 let mut y = Array2::zeros(x.dim());
1192 Zip::from(y.rows_mut())
1193 .and(x.rows())
1194 .for_each(|mut yi, xi| {
1195 if xi[0] < 0.4 {
1196 yi[0] = 2. * xi[0];
1197 } else if (0.4..0.8).contains(&xi[0]) {
1198 yi[0] = 3.;
1199 } else {
1200 yi[0] = 10. * f64::cos(10. * xi[0]);
1201 }
1202 });
1203 y
1204 }
1205
1206 #[test]
1207 fn test_moe_hard() {
1208 let mut rng = Xoshiro256Plus::seed_from_u64(0);
1209 let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
1210 let yt = f_test_1d(&xt.to_owned());
1211 let moe = GpMixture::params()
1212 .n_clusters(NbClusters::fixed(3))
1213 .regression_spec(RegressionSpec::CONSTANT)
1214 .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1215 .recombination(Recombination::Hard)
1216 .with_rng(rng)
1217 .fit(&Dataset::new(xt, yt))
1218 .expect("MOE fitted");
1219 let x = Array1::linspace(0., 1., 30).insert_axis(Axis(1));
1220 let preds = moe.predict(&x).expect("MOE prediction");
1221 let dpreds = moe.predict_gradients(&x).expect("MOE drv prediction");
1222 println!("dpred = {dpreds}");
1223 let test_dir = "target/tests";
1224 std::fs::create_dir_all(test_dir).ok();
1225 write_npy(format!("{test_dir}/x_hard.npy"), &x).expect("x saved");
1226 write_npy(format!("{test_dir}/preds_hard.npy"), &preds).expect("preds saved");
1227 write_npy(format!("{test_dir}/dpreds_hard.npy"), &dpreds).expect("dpreds saved");
1228 assert_abs_diff_eq!(
1229 0.39 * 0.39,
1230 moe.predict(&array![[0.39]]).unwrap()[0],
1231 epsilon = 1e-4
1232 );
1233 assert_abs_diff_eq!(
1234 f64::sin(10. * 0.82),
1235 moe.predict(&array![[0.82]]).unwrap()[0],
1236 epsilon = 1e-4
1237 );
1238 println!("LOOQ2 = {}", moe.q2_score());
1239 }
1240
1241 #[test]
1242 fn test_moe_smooth() {
1243 let test_dir = "target/tests";
1244 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1245 let xt = Array2::random_using((60, 1), Uniform::new(0., 1.), &mut rng);
1246 let yt = f_test_1d(&xt);
1247 let ds = Dataset::new(xt.to_owned(), yt.to_owned());
1248 let moe = GpMixture::params()
1249 .n_clusters(NbClusters::fixed(3))
1250 .recombination(Recombination::Smooth(Some(0.5)))
1251 .with_rng(rng.clone())
1252 .fit(&ds)
1253 .expect("MOE fitted");
1254 let x = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
1255 let preds = moe.predict(&x).expect("MOE prediction");
1256 write_npy(format!("{test_dir}/xt.npy"), &xt).expect("x saved");
1257 write_npy(format!("{test_dir}/yt.npy"), &yt).expect("preds saved");
1258 write_npy(format!("{test_dir}/x_smooth.npy"), &x).expect("x saved");
1259 write_npy(format!("{test_dir}/preds_smooth.npy"), &preds).expect("preds saved");
1260
1261 println!("Smooth moe {moe}");
1263 assert_abs_diff_eq!(
1264 0.2623, moe.predict(&array![[0.37]]).unwrap()[0],
1266 epsilon = 1e-3
1267 );
1268
1269 let moe = GpMixture::params()
1271 .n_clusters(NbClusters::fixed(3))
1272 .recombination(Recombination::Smooth(None))
1273 .with_rng(rng.clone())
1274 .fit(&ds)
1275 .expect("MOE fitted");
1276 println!("Smooth moe {moe}");
1277
1278 std::fs::create_dir_all(test_dir).ok();
1279 let x = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
1280 let preds = moe.predict(&x).expect("MOE prediction");
1281 write_npy(format!("{test_dir}/x_smooth2.npy"), &x).expect("x saved");
1282 write_npy(format!("{test_dir}/preds_smooth2.npy"), &preds).expect("preds saved");
1283 assert_abs_diff_eq!(
1284 0.37 * 0.37, moe.predict(&array![[0.37]]).unwrap()[0],
1286 epsilon = 1e-3
1287 );
1288 }
1289
1290 #[test]
1291 fn test_moe_auto() {
1292 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1293 let xt = Array2::random_using((60, 1), Uniform::new(0., 1.), &mut rng);
1294 let yt = f_test_1d(&xt);
1295 let ds = Dataset::new(xt, yt.to_owned());
1296 let moe = GpMixture::params()
1297 .n_clusters(NbClusters::auto())
1298 .with_rng(rng.clone())
1299 .fit(&ds)
1300 .expect("MOE fitted");
1301 println!(
1302 "Moe auto: nb clusters={}, recomb={:?}",
1303 moe.n_clusters(),
1304 moe.recombination()
1305 );
1306 assert_abs_diff_eq!(
1307 0.37 * 0.37, moe.predict(&array![[0.37]]).unwrap()[0],
1309 epsilon = 1e-3
1310 );
1311 }
1312
1313 #[test]
1314 fn test_moe_variances_smooth() {
1315 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1316 let xt = Array2::random_using((100, 1), Uniform::new(0., 1.), &mut rng);
1317 let yt = f_test_1d(&xt);
1318 let moe = GpMixture::params()
1319 .n_clusters(NbClusters::fixed(3))
1320 .recombination(Recombination::Smooth(None))
1321 .regression_spec(RegressionSpec::CONSTANT)
1322 .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1323 .with_rng(rng.clone())
1324 .fit(&Dataset::new(xt, yt))
1325 .expect("MOE fitted");
1326 let x = Array1::linspace(0., 1., 20).insert_axis(Axis(1));
1328 let variances = moe.predict_var(&x).expect("MOE variances prediction");
1329 assert_abs_diff_eq!(*variances.max().unwrap(), 0., epsilon = 1e-10);
1330 }
1331
1332 fn xsinx(x: &[f64]) -> f64 {
1333 (x[0] - 3.5) * f64::sin((x[0] - 3.5) / std::f64::consts::PI)
1334 }
1335
1336 #[test]
1337 fn test_find_best_expert() {
1338 let mut rng = Xoshiro256Plus::seed_from_u64(0);
1339 let xt = Array2::random_using((10, 1), Uniform::new(0., 1.), &mut rng);
1340 let yt = xt.mapv(|x| xsinx(&[x]));
1341 let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap();
1342 let moe = GpMixture::params().with_rng(rng).check_unwrap();
1343 let best_expert = &moe.find_best_expert(0, 1, &data).unwrap();
1344 println!("Best expert {best_expert}");
1345 }
1346
1347 #[test]
1348 fn test_find_best_heaviside_factor() {
1349 let mut rng = Xoshiro256Plus::seed_from_u64(0);
1350 let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
1351 let yt = f_test_1d(&xt);
1352 let _moe = GpMixture::params()
1353 .n_clusters(NbClusters::fixed(3))
1354 .with_rng(rng)
1355 .fit(&Dataset::new(xt, yt))
1356 .expect("MOE fitted");
1357 }
1358
1359 #[cfg(feature = "persistent")]
1360 #[test]
1361 fn test_save_load_moe() {
1362 let test_dir = "target/tests";
1363 std::fs::create_dir_all(test_dir).ok();
1364
1365 let mut rng = Xoshiro256Plus::seed_from_u64(0);
1366 let xt = Array2::random_using((50, 1), Uniform::new(0., 1.), &mut rng);
1367 let yt = f_test_1d(&xt);
1368 let ds = Dataset::new(xt, yt);
1369 let moe = GpMixture::params()
1370 .n_clusters(NbClusters::fixed(3))
1371 .with_rng(rng)
1372 .fit(&ds)
1373 .expect("MOE fitted");
1374 let xtest = array![[0.6]];
1375 let y_expected = moe.predict(&xtest).unwrap();
1376 let filename = format!("{test_dir}/saved_moe.json");
1377 moe.save(&filename, GpFileFormat::Json).expect("MoE saving");
1378 let new_moe = GpMixture::load(&filename, GpFileFormat::Json).expect("MoE loading");
1379 assert_abs_diff_eq!(y_expected, new_moe.predict(&xtest).unwrap(), epsilon = 1e-6);
1380 }
1381
1382 #[test]
1383 fn test_moe_drv_smooth() {
1384 let rng = Xoshiro256Plus::seed_from_u64(0);
1385 let xt = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
1390 let yt = f_test_1d(&xt);
1391
1392 let moe = GpMixture::params()
1393 .n_clusters(NbClusters::fixed(3))
1394 .regression_spec(RegressionSpec::CONSTANT)
1395 .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1396 .recombination(Recombination::Smooth(Some(0.5)))
1397 .with_rng(rng)
1398 .fit(&Dataset::new(xt, yt))
1399 .expect("MOE fitted");
1400 let x = Array1::linspace(0., 1., 50).insert_axis(Axis(1));
1401 let preds = moe.predict(&x).expect("MOE prediction");
1402 let dpreds = moe.predict_gradients(&x).expect("MOE drv prediction");
1403
1404 let test_dir = "target/tests";
1405 std::fs::create_dir_all(test_dir).ok();
1406 write_npy(format!("{test_dir}/x_moe_smooth.npy"), &x).expect("x saved");
1407 write_npy(format!("{test_dir}/preds_moe_smooth.npy"), &preds).expect("preds saved");
1408 write_npy(format!("{test_dir}/dpreds_moe_smooth.npy"), &dpreds).expect("dpreds saved");
1409
1410 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1411 for _ in 0..100 {
1412 let x1: f64 = rng.gen_range(0.1..0.9);
1413 let h = 1e-8;
1414 let xtest = array![[x1]];
1415
1416 let x = array![[x1], [x1 + h], [x1 - h]];
1417 let preds = moe.predict(&x).unwrap();
1418 let fdiff = (preds[1] - preds[2]) / (2. * h);
1419
1420 let drv = moe.predict_gradients(&xtest).unwrap();
1421 let df = df_test_1d(&xtest);
1422
1423 let err = if drv[[0, 0]] < 1e-2 {
1427 (drv[[0, 0]] - fdiff).abs()
1428 } else {
1429 (drv[[0, 0]] - fdiff).abs() / drv[[0, 0]] };
1431 println!(
1432 "Test predicted derivatives at {xtest}: drv {drv}, true df {df}, fdiff {fdiff}"
1433 );
1434 println!("preds(x, x+h, x-h)={preds}");
1435 assert_abs_diff_eq!(err, 0.0, epsilon = 1e-1);
1436 }
1437 }
1438
1439 fn norm1(x: &Array2<f64>) -> Array2<f64> {
1440 x.mapv(|v| v.abs())
1441 .sum_axis(Axis(1))
1442 .insert_axis(Axis(1))
1443 .to_owned()
1444 }
1445
1446 fn rosenb(x: &Array2<f64>) -> Array2<f64> {
1447 let mut y: Array2<f64> = Array2::zeros((x.nrows(), 1));
1448 Zip::from(y.rows_mut())
1449 .and(x.rows())
1450 .par_for_each(|mut yi, xi| yi.assign(&array![rosenbrock(&xi.to_vec())]));
1451 y
1452 }
1453
1454 #[allow(clippy::excessive_precision)]
1455 fn test_variance_derivatives(f: fn(&Array2<f64>) -> Array2<f64>) {
1456 let rng = Xoshiro256Plus::seed_from_u64(0);
1457 let xt = egobox_doe::FullFactorial::new(&array![[-1., 1.], [-1., 1.]]).sample(100);
1458 let yt = f(&xt);
1459
1460 let moe = GpMixture::params()
1461 .n_clusters(NbClusters::fixed(2))
1462 .regression_spec(RegressionSpec::CONSTANT)
1463 .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1464 .recombination(Recombination::Smooth(Some(1.)))
1465 .with_rng(rng)
1466 .fit(&Dataset::new(xt, yt.remove_axis(Axis(1))))
1467 .expect("MOE fitted");
1468
1469 for _ in 0..20 {
1470 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1471 let x = Array::random_using((2,), Uniform::new(0., 1.), &mut rng);
1472 let xa: f64 = x[0];
1473 let xb: f64 = x[1];
1474 let e = 1e-4;
1475
1476 println!("Test derivatives at [{xa}, {xb}]");
1477
1478 let x = array![
1479 [xa, xb],
1480 [xa + e, xb],
1481 [xa - e, xb],
1482 [xa, xb + e],
1483 [xa, xb - e]
1484 ];
1485 let y_pred = moe.predict(&x).unwrap();
1486 let y_deriv = moe.predict_gradients(&x).unwrap();
1487
1488 let diff_g = (y_pred[1] - y_pred[2]) / (2. * e);
1489 let diff_d = (y_pred[3] - y_pred[4]) / (2. * e);
1490
1491 assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
1492 assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
1493
1494 let y_pred = moe.predict_var(&x).unwrap();
1495 let y_deriv = moe.predict_var_gradients(&x).unwrap();
1496
1497 let diff_g = (y_pred[1] - y_pred[2]) / (2. * e);
1498 let diff_d = (y_pred[3] - y_pred[4]) / (2. * e);
1499
1500 assert_rel_or_abs_error(y_deriv[[0, 0]], diff_g);
1501 assert_rel_or_abs_error(y_deriv[[0, 1]], diff_d);
1502 }
1503 }
1504
1505 #[test]
1507 fn test_valvar_predictions() {
1508 let rng = Xoshiro256Plus::seed_from_u64(0);
1509 let xt = egobox_doe::FullFactorial::new(&array![[-1., 1.], [-1., 1.]]).sample(100);
1510 let yt = rosenb(&xt).remove_axis(Axis(1));
1511
1512 for corr in [
1513 CorrelationSpec::SQUAREDEXPONENTIAL,
1514 CorrelationSpec::MATERN32,
1515 CorrelationSpec::MATERN52,
1516 ] {
1517 println!("Test valvar derivatives with correlation {corr:?}");
1518 for recomb in [
1519 Recombination::Hard,
1520 Recombination::Smooth(Some(0.5)),
1521 Recombination::Smooth(None),
1522 ] {
1523 println!("Testing valvar derivatives with recomb={recomb:?}");
1524
1525 let moe = GpMixture::params()
1526 .n_clusters(NbClusters::fixed(2))
1527 .regression_spec(RegressionSpec::CONSTANT)
1528 .correlation_spec(corr)
1529 .recombination(recomb)
1530 .with_rng(rng.clone())
1531 .fit(&Dataset::new(xt.to_owned(), yt.to_owned()))
1532 .expect("MOE fitted");
1533
1534 for _ in 0..10 {
1535 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1536 let x = Array::random_using((2,), Uniform::new(0., 1.), &mut rng);
1537 let xa: f64 = x[0];
1538 let xb: f64 = x[1];
1539 let e = 1e-4;
1540
1541 let x = array![
1542 [xa, xb],
1543 [xa + e, xb],
1544 [xa - e, xb],
1545 [xa, xb + e],
1546 [xa, xb - e]
1547 ];
1548 let (y_pred, v_pred) = moe.predict_valvar(&x).unwrap();
1549 let (y_deriv, v_deriv) = moe.predict_valvar_gradients(&x).unwrap();
1550
1551 let pred = moe.predict(&x).unwrap();
1552 let var = moe.predict_var(&x).unwrap();
1553 assert_abs_diff_eq!(y_pred, pred, epsilon = 1e-12);
1554 assert_abs_diff_eq!(v_pred, var, epsilon = 1e-12);
1555
1556 let deriv = moe.predict_gradients(&x).unwrap();
1557 let vardrv = moe.predict_var_gradients(&x).unwrap();
1558 assert_abs_diff_eq!(y_deriv, deriv, epsilon = 1e-12);
1559 assert_abs_diff_eq!(v_deriv, vardrv, epsilon = 1e-12);
1560 }
1561 }
1562 }
1563 }
1564
1565 fn assert_rel_or_abs_error(y_deriv: f64, fdiff: f64) {
1566 println!("analytic deriv = {y_deriv}, fdiff = {fdiff}");
1567 if fdiff.abs() < 1e-2 {
1568 assert_abs_diff_eq!(y_deriv, 0.0, epsilon = 1e-1); } else {
1570 let drv_rel_error1 = (y_deriv - fdiff).abs() / fdiff; assert_abs_diff_eq!(drv_rel_error1, 0.0, epsilon = 1e-1);
1572 }
1573 }
1574
1575 #[test]
1576 fn test_moe_var_deriv_norm1() {
1577 test_variance_derivatives(norm1);
1578 }
1579 #[test]
1580 fn test_moe_var_deriv_rosenb() {
1581 test_variance_derivatives(rosenb);
1582 }
1583
1584 #[test]
1585 fn test_moe_display() {
1586 let rng = Xoshiro256Plus::seed_from_u64(0);
1587 let xt = Lhs::new(&array![[0., 1.]])
1588 .with_rng(rng.clone())
1589 .sample(100);
1590 let yt = f_test_1d(&xt);
1591
1592 let moe = GpMixture::params()
1593 .n_clusters(NbClusters::fixed(3))
1594 .regression_spec(RegressionSpec::CONSTANT)
1595 .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1596 .recombination(Recombination::Hard)
1597 .with_rng(rng)
1598 .fit(&Dataset::new(xt, yt))
1599 .expect("MOE fitted");
1600 println!("Display moe: {moe}");
1603 }
1604
1605 fn griewank(x: &Array2<f64>) -> Array1<f64> {
1606 let dim = x.ncols();
1607 let d = Array1::linspace(1., dim as f64, dim).mapv(|v| v.sqrt());
1608 let mut y = Array1::zeros((x.nrows(),));
1609 Zip::from(&mut y).and(x.rows()).for_each(|y, x| {
1610 let s = x.mapv(|v| v * v).sum() / 4000.;
1611 let p = (x.to_owned() / &d)
1612 .mapv(|v| v.cos())
1613 .fold(1., |acc, x| acc * x);
1614 *y = s - p + 1.;
1615 });
1616 y
1617 }
1618
1619 #[test]
1620 fn test_kpls_griewank() {
1621 let dims = [100];
1622 let nts = [100];
1623 let lim = array![[-600., 600.]];
1624
1625 let test_dir = "target/tests";
1626 std::fs::create_dir_all(test_dir).ok();
1627
1628 (0..1).for_each(|i| {
1629 let dim = dims[i];
1630 let nt = nts[i];
1631 let xlimits = lim.broadcast((dim, 2)).unwrap();
1632
1633 let prefix = "griewank";
1634 let xfilename = format!("{test_dir}/{prefix}_xt_{nt}x{dim}.npy");
1635 let yfilename = format!("{test_dir}/{prefix}_yt_{nt}x1.npy");
1636
1637 let rng = Xoshiro256Plus::seed_from_u64(42);
1638 let xt = Lhs::new(&xlimits).with_rng(rng).sample(nt);
1639 write_npy(xfilename, &xt).expect("cannot save xt");
1640 let yt = griewank(&xt);
1641 write_npy(yfilename, &yt).expect("cannot save yt");
1642
1643 let gp = GpMixture::params()
1644 .n_clusters(NbClusters::default())
1645 .regression_spec(RegressionSpec::CONSTANT)
1646 .correlation_spec(CorrelationSpec::SQUAREDEXPONENTIAL)
1647 .kpls_dim(Some(3))
1648 .fit(&Dataset::new(xt, yt))
1649 .expect("GP fit error");
1650
1651 let rng = Xoshiro256Plus::seed_from_u64(0);
1656 let xtest = Lhs::new(&xlimits).with_rng(rng).sample(100);
1657 let ytest = gp.predict(&xtest).expect("prediction error");
1658 let ytrue = griewank(&xtest);
1659
1660 let nrmse = (ytrue.to_owned() - &ytest).norm_l2() / ytrue.norm_l2();
1661 println!(
1662 "diff={} ytrue={} nrsme={}",
1663 (ytrue.to_owned() - &ytest).norm_l2(),
1664 ytrue.norm_l2(),
1665 nrmse
1666 );
1667 assert_abs_diff_eq!(nrmse, 0., epsilon = 1e-2);
1668 });
1669 }
1670
1671 fn sphere(x: &Array2<f64>) -> Array1<f64> {
1672 (x * x)
1673 .sum_axis(Axis(1))
1674 .into_shape_with_order((x.nrows(),))
1675 .expect("Cannot reshape sphere output")
1676 }
1677
1678 #[test]
1679 fn test_moe_smooth_vs_hard_one_cluster() {
1680 let mut rng = Xoshiro256Plus::seed_from_u64(42);
1681 let xt = Array2::random_using((50, 2), Uniform::new(0., 1.), &mut rng);
1682 let yt = sphere(&xt);
1683 let ds = Dataset::new(xt, yt.to_owned());
1684
1685 let moe_hard = GpMixture::params()
1687 .n_clusters(NbClusters::fixed(1))
1688 .recombination(Recombination::Hard)
1689 .with_rng(rng.clone())
1690 .fit(&ds)
1691 .expect("MOE hard fitted");
1692
1693 let moe_smooth = GpMixture::params()
1695 .n_clusters(NbClusters::fixed(1))
1696 .recombination(Recombination::Smooth(Some(1.0)))
1697 .with_rng(rng)
1698 .fit(&ds)
1699 .expect("MOE smooth fitted");
1700
1701 let mut rng = Xoshiro256Plus::seed_from_u64(43);
1703 let x = Array2::random_using((1, 2), Uniform::new(0., 1.), &mut rng);
1704 let preds_hard = moe_hard.predict(&x).expect("MOE hard prediction");
1705 let preds_smooth = moe_smooth.predict(&x).expect("MOE smooth prediction");
1706 println!("predict hard = {preds_hard} smooth = {preds_smooth}");
1707 assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1708
1709 let preds_hard = moe_hard.predict_var(&x).expect("MOE hard prediction");
1711 let preds_smooth = moe_smooth.predict_var(&x).expect("MOE smooth prediction");
1712 assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1713
1714 println!("Check pred gradients at x = {x}");
1716 let preds_smooth = moe_smooth
1717 .predict_gradients(&x)
1718 .expect("MOE smooth prediction");
1719 println!("smooth gradients = {preds_smooth}");
1720 let preds_hard = moe_hard.predict_gradients(&x).expect("MOE hard prediction");
1721 assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1722
1723 let preds_hard = moe_hard
1725 .predict_var_gradients(&x)
1726 .expect("MOE hard prediction");
1727 let preds_smooth = moe_smooth
1728 .predict_var_gradients(&x)
1729 .expect("MOE smooth prediction");
1730 assert_abs_diff_eq!(preds_hard, preds_smooth, epsilon = 1e-5);
1731 }
1732}