1use linfa::dataset::Dataset;
10use linfa::{
11 Float, ParamGuard,
12 traits::{Fit, Predict, PredictInplace},
13};
14use ndarray::{Array1, Array2};
15
16use crate::{
17 GaussianProcess, GpError, GpParams, SgpParams, SparseGaussianProcess, correlation_models,
18 mean_models,
19};
20
21pub trait PredictScore<F, ER, P, O>
23where
24 F: Float,
25 ER: std::error::Error + From<linfa::error::Error>,
26 P: Fit<Array2<F>, Array1<F>, ER, Object = O> + ParamGuard,
27 O: PredictInplace<Array2<F>, Array1<F>>,
28{
29 fn training_data(&self) -> &(Array2<F>, Array1<F>);
31
32 fn params(&self) -> P;
34
35 fn q2_score(&self, kfold: usize) -> F {
37 let (xt, yt) = self.training_data();
38 let dataset = Dataset::new(xt.to_owned(), yt.to_owned());
39 let yt_mean = yt.mean().unwrap();
40 let mut press = F::zero();
42 let mut tss = F::zero();
44 for (train, valid) in dataset.fold(kfold).into_iter() {
45 let params = self.params();
46 let model: O = params
47 .fit(&train)
48 .expect("cross-validation: sub model fitted");
49 let pred = model.predict(valid.records());
50 press += (valid.targets() - pred).mapv(|v| v * v).sum();
51 tss += (valid.targets() - yt_mean).mapv(|v| v * v).sum();
52 }
53 F::one() - press / tss
54 }
55
56 fn looq2_score(&self) -> F {
58 self.q2_score(self.training_data().0.nrows())
59 }
60}
61
62impl<F, Mean, Corr> PredictScore<F, GpError, GpParams<F, Mean, Corr>, Self>
63 for GaussianProcess<F, Mean, Corr>
64where
65 F: Float,
66 Mean: mean_models::RegressionModel<F>,
67 Corr: correlation_models::CorrelationModel<F>,
68{
69 fn training_data(&self) -> &(Array2<F>, Array1<F>) {
70 &self.training_data
71 }
72
73 fn params(&self) -> GpParams<F, Mean, Corr> {
74 GpParams::from(self.params.clone())
75 }
76}
77
78impl<F, Corr> PredictScore<F, GpError, SgpParams<F, Corr>, Self> for SparseGaussianProcess<F, Corr>
79where
80 F: Float,
81 Corr: correlation_models::CorrelationModel<F>,
82{
83 fn training_data(&self) -> &(Array2<F>, Array1<F>) {
84 &self.training_data
85 }
86
87 fn params(&self) -> SgpParams<F, Corr> {
88 SgpParams::from(self.params.clone())
89 }
90}
91
92#[cfg(test)]
93mod test {
94 use super::*;
95 use crate::{Inducings, SparseKriging};
96 use approx::assert_abs_diff_eq;
97 use egobox_doe::{Lhs, SamplingMethod};
98 use ndarray::{Array, Array1, ArrayBase, Axis, Data, Ix2, Zip, array};
99 use ndarray_rand::RandomExt;
100 use ndarray_rand::rand::SeedableRng;
101 use ndarray_rand::rand_distr::{Normal, Uniform};
102 use rand_xoshiro::Xoshiro256Plus;
103
104 fn griewank(x: &Array2<f64>) -> Array1<f64> {
105 let dim = x.ncols();
106 let d = Array1::linspace(1., dim as f64, dim).mapv(|v| v.sqrt());
107 let mut y = Array1::zeros((x.nrows(),));
108 Zip::from(&mut y).and(x.rows()).for_each(|y, x| {
109 let s = x.mapv(|v| v * v).sum() / 4000.;
110 let p = (x.to_owned() / &d)
111 .mapv(|v| v.cos())
112 .fold(1., |acc, x| acc * x);
113 *y = s - p + 1.;
114 });
115 y
116 }
117
118 #[test]
119 fn test_q2_gp_griewank() {
120 let dims = [5]; let nts = [100]; let lim = array![[-600., 600.]];
123
124 (0..dims.len()).for_each(|i| {
125 let dim = dims[i];
126 let nt = nts[i];
127 let xlimits = lim.broadcast((dim, 2)).unwrap();
128
129 let rng = Xoshiro256Plus::seed_from_u64(42);
130 let xt = Lhs::new(&xlimits).with_rng(rng).sample(nt);
131 let yt = griewank(&xt);
132
133 let gp = GaussianProcess::<
134 f64,
135 mean_models::ConstantMean,
136 correlation_models::SquaredExponentialCorr,
137 >::params(
138 mean_models::ConstantMean::default(),
139 correlation_models::SquaredExponentialCorr::default(),
140 )
141 .kpls_dim(Some(3))
142 .fit(&Dataset::new(xt, yt))
143 .expect("GP fit error");
144
145 assert_abs_diff_eq!(gp.looq2_score(), 1., epsilon = 1e-2);
146 assert_abs_diff_eq!(gp.q2_score(10), 1., epsilon = 1e-2);
147 });
148 }
149
150 const PI: f64 = std::f64::consts::PI;
151
152 fn f_obj(x: &ArrayBase<impl Data<Elem = f64>, Ix2>) -> Array2<f64> {
153 x.mapv(|v| (3. * PI * v).sin() + 0.3 * (9. * PI * v).cos() + 0.5 * (7. * PI * v).sin())
154 }
155
156 fn make_test_data(
157 nt: usize,
158 eta2: f64,
159 rng: &mut Xoshiro256Plus,
160 ) -> (Array2<f64>, Array1<f64>) {
161 let normal = Normal::new(0., eta2.sqrt()).unwrap();
162 let gaussian_noise = Array::<f64, _>::random_using((nt, 1), normal, rng);
163 let xt = 2. * Array::<f64, _>::random_using((nt, 1), Uniform::new(0., 1.), rng) - 1.;
164 let yt = (f_obj(&xt) + gaussian_noise).remove_axis(Axis(1));
165 (xt, yt)
166 }
167
168 #[test]
169 fn test_q2_sgp() {
170 let mut rng = Xoshiro256Plus::seed_from_u64(42);
171 let nt = 200;
173 let eta2: f64 = 0.01;
175 let (xt, yt) = make_test_data(nt, eta2, &mut rng);
176 let n_inducings = 30;
177 let sgp = SparseKriging::params(Inducings::Randomized(n_inducings))
178 .seed(Some(42))
179 .fit(&Dataset::new(xt.clone(), yt.clone()))
180 .expect("GP fitted");
181
182 assert_abs_diff_eq!(sgp.looq2_score(), 1., epsilon = 2e-2);
183 assert_abs_diff_eq!(sgp.q2_score(10), 1., epsilon = 2e-2);
184 }
185}