egobox_moe/lib.rs
1//! This library implements Mixture of Experts method using [GP models](egobox_gp).
2//!
3//! MoE method aims at increasing the accuracy of a function approximation by replacing
4//! a single global model by a weighted sum of local gp regression models (experts).
5//! It is based on a partition of the problem domain into several subdomains
6//! via clustering algorithms followed by a local expert training on each subdomain.
7//!
8//! The recombination between the GP models can be either:
9//! * `hard`: one GP model is being responsible to provide the predicted value
10//! at the given point. GP selection is done by taking the largest probability of the
11//! given point being part of the cluster corresponding to the expert GP.
12//! In hard mode, transition between models leads to discontinuity.
13//! * `smooth`: all GPs models are taken and their predicted values at a given point are
14//! weighted regarding their responsability (probability of the given point being part
15//! of the cluster corresponding to the expert GP). In this case the MoE model is continuous.
16//! The smoothness is automatically adjusted using a factor, the heaviside factor,
17//! which can also be set manually.
18//!
19//! # Implementation
20//!
21//! * Clusters are defined by clustering the training data with
22//! [linfa-clustering](https://docs.rs/linfa-clustering/latest/linfa_clustering/)
23//! gaussian mixture model.
24//! * This library is a port of the
25//! [SMT MoE method](https://smt.readthedocs.io/en/latest/_src_docs/applications/moe.html)
26//! using egobox GP models as experts.
27//! * It leverages on the egobox GP PLS reduction feature to handle high dimensional problems.
28//! * MoE trained model can be save to disk and reloaded. See
29//!
30//! # Features
31//!
32//! ## serializable
33//!
34//! The `serializable` feature enables serialization based on [serde crate](https://serde.rs/).
35//!
36//! ## persistent
37//!
38//! The `persistent` feature enables `save()`/`load()` methods for a MoE model
39//! to/from a json file using the [serde and serde_json crates](https://serde.rs/).
40//!
41//! # Example
42//!
43//! ```no_run
44//! use ndarray::{Array2, Array1, Zip, Axis};
45//! use egobox_moe::{GpMixture, Recombination, NbClusters};
46//! use ndarray_rand::{RandomExt, rand::SeedableRng, rand_distr::Uniform};
47//! use rand_xoshiro::Xoshiro256Plus;
48//! use linfa::{traits::Fit, ParamGuard, Dataset};
49//!
50//! // One-dimensional test function with 3 modes
51//! fn f3modes(x: &Array1<f64>) -> Array1<f64> {
52//! let mut y = Array1::zeros(x.len());
53//! Zip::from(&mut y).and(x).for_each(|yi, &xi| {
54//! if xi < 0.4 {
55//! *yi = xi * xi;
56//! } else if (0.4..0.8).contains(&xi) {
57//! *yi = 3. * xi + 1.;
58//! } else {
59//! *yi = f64::sin(10. * xi);
60//! }
61//! });
62//! y
63//! }
64//!
65//! // Training data
66//! let mut rng = Xoshiro256Plus::from_entropy();
67//! let xt = Array1::random_using((50, ), Uniform::new(0., 1.), &mut rng);
68//! let yt = f3modes(&xt);
69//! let ds = Dataset::new(xt.insert_axis(Axis(1)), yt);
70//!
71//! // Predictions
72//! let observations = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
73//! let predictions = GpMixture::params()
74//! .n_clusters(NbClusters::fixed(3))
75//! .recombination(Recombination::Hard)
76//! .fit(&ds)
77//! .expect("MoE model training")
78//! .predict(&observations)
79//! .expect("MoE predictions");
80//! ```
81//!
82//! # Reference
83//!
84//! Bettebghor, Dimitri, et al. [Surrogate modeling approximation using a mixture of
85//! experts based on EM joint estimation](https://hal.archives-ouvertes.fr/hal-01852300/document)
86//! Structural and multidisciplinary optimization 43.2 (2011): 243-259.
87//!
88#![warn(missing_docs)]
89#![warn(rustdoc::broken_intra_doc_links)]
90mod clustering;
91mod errors;
92mod expertise_macros;
93mod gaussian_mixture;
94mod surrogates;
95mod types;
96
97mod algorithm;
98mod metrics;
99mod parameters;
100
101pub use clustering::*;
102pub use errors::*;
103pub use gaussian_mixture::*;
104pub use metrics::*;
105pub use surrogates::*;
106pub use types::*;
107
108pub use algorithm::*;
109pub use parameters::*;