egobox_moe/
lib.rs

1//! This library implements Mixture of Experts method using [GP models](egobox_gp).
2//!
3//! MoE method aims at increasing the accuracy of a function approximation by replacing
4//! a single global model by a weighted sum of local gp regression models (experts).
5//! It is based on a partition of the problem domain into several subdomains
6//! via clustering algorithms followed by a local expert training on each subdomain.
7//!
8//! The recombination between the GP models can be either:
9//! * `hard`: one GP model is being responsible to provide the predicted value
10//!   at the given point. GP selection is done by taking the largest probability of the
11//!   given point being part of the cluster corresponding to the expert GP.
12//!   In hard mode, transition between models leads to discontinuity.
13//! * `smooth`: all GPs models are taken and their predicted values at a given point are
14//!   weighted regarding their responsability (probability of the given point being part
15//!   of the cluster corresponding to the expert GP). In this case the MoE model is continuous.
16//!   The smoothness is automatically adjusted using a factor, the heaviside factor,
17//!   which can also be set manually.
18//!
19//! # Implementation
20//!
21//! * Clusters are defined by clustering the training data with
22//!   [linfa-clustering](https://docs.rs/linfa-clustering/latest/linfa_clustering/)
23//!   gaussian mixture model.
24//! * This library is a port of the
25//!   [SMT MoE method](https://smt.readthedocs.io/en/latest/_src_docs/applications/moe.html)
26//!   using egobox GP models as experts.
27//! * It leverages on the egobox GP PLS reduction feature to handle high dimensional problems.
28//! * MoE trained model can be save to disk and reloaded. See
29//!  
30//! # Features
31//!
32//! ## serializable
33//!
34//! The `serializable` feature enables serialization based on [serde crate](https://serde.rs/).
35//!
36//! ## persistent
37//!
38//! The `persistent` feature enables `save()`/`load()` methods for a MoE model
39//! to/from a json file using the [serde and serde_json crates](https://serde.rs/).
40//!
41//! # Example
42//!
43//! ```no_run
44//! use ndarray::{Array2, Array1, Zip, Axis};
45//! use egobox_moe::{GpMixture, Recombination, NbClusters};
46//! use ndarray_rand::{RandomExt, rand::SeedableRng, rand_distr::Uniform};
47//! use rand_xoshiro::Xoshiro256Plus;
48//! use linfa::{traits::Fit, ParamGuard, Dataset};
49//!
50//! // One-dimensional test function with 3 modes
51//! fn f3modes(x: &Array1<f64>) -> Array1<f64> {
52//!     let mut y = Array1::zeros(x.len());
53//!     Zip::from(&mut y).and(x).for_each(|yi, &xi| {
54//!         if xi < 0.4 {
55//!             *yi = xi * xi;
56//!         } else if (0.4..0.8).contains(&xi) {
57//!             *yi = 3. * xi + 1.;
58//!         } else {
59//!             *yi = f64::sin(10. * xi);
60//!         }
61//!     });
62//!     y
63//! }
64//!
65//! // Training data
66//! let mut rng = Xoshiro256Plus::from_entropy();
67//! let xt = Array1::random_using((50, ), Uniform::new(0., 1.), &mut rng);
68//! let yt = f3modes(&xt);
69//! let ds = Dataset::new(xt.insert_axis(Axis(1)), yt);
70//!
71//! // Predictions
72//! let observations = Array1::linspace(0., 1., 100).insert_axis(Axis(1));
73//! let predictions = GpMixture::params()
74//!                     .n_clusters(NbClusters::fixed(3))
75//!                     .recombination(Recombination::Hard)
76//!                     .fit(&ds)
77//!                     .expect("MoE model training")
78//!                     .predict(&observations)
79//!                     .expect("MoE predictions");
80//! ```
81//!
82//! # Reference
83//!
84//! Bettebghor, Dimitri, et al. [Surrogate modeling approximation using a mixture of
85//! experts based on EM joint estimation](https://hal.archives-ouvertes.fr/hal-01852300/document)
86//! Structural and multidisciplinary optimization 43.2 (2011): 243-259.
87//!
88#![warn(missing_docs)]
89#![warn(rustdoc::broken_intra_doc_links)]
90mod clustering;
91mod errors;
92mod expertise_macros;
93mod gaussian_mixture;
94mod surrogates;
95mod types;
96
97mod algorithm;
98mod metrics;
99mod parameters;
100
101pub use clustering::*;
102pub use errors::*;
103pub use gaussian_mixture::*;
104pub use metrics::*;
105pub use surrogates::*;
106pub use types::*;
107
108pub use algorithm::*;
109pub use parameters::*;