1use crate::correlation_models::CorrelationModel;
2use crate::errors::{GpError, Result};
3use crate::mean_models::ConstantMean;
4use crate::parameters::GpValidParams;
5use crate::{ThetaTuning, GP_COBYLA_MIN_EVAL};
6use linfa::{Float, ParamGuard};
7use ndarray::{array, Array1, Array2};
8#[cfg(feature = "serializable")]
9use serde::{Deserialize, Serialize};
10
11#[derive(Clone, Debug, PartialEq, Eq)]
13#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
14pub enum ParamTuning<F: Float> {
15 Fixed(F),
17 Optimized { init: F, bounds: (F, F) },
19}
20impl<F: Float> Default for ParamTuning<F> {
21 fn default() -> ParamTuning<F> {
22 Self::Optimized {
23 init: F::cast(1e-2),
24 bounds: (F::cast(100.0) * F::epsilon(), F::cast(1e10)),
25 }
26 }
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
31#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
32#[non_exhaustive]
33pub enum Inducings<F: Float> {
34 Randomized(usize),
36 Located(Array2<F>),
38}
39impl<F: Float> Default for Inducings<F> {
40 fn default() -> Inducings<F> {
41 Self::Randomized(10)
42 }
43}
44
45#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
47#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
48pub enum SparseMethod {
49 #[default]
50 Fitc,
52 Vfe,
54}
55
56#[derive(Clone, Debug, PartialEq, Eq)]
58#[cfg_attr(
59 feature = "serializable",
60 derive(Serialize, Deserialize),
61 serde(bound(
62 serialize = "F: Serialize, Corr: Serialize",
63 deserialize = "F: Deserialize<'de>, Corr: Deserialize<'de>"
64 ))
65)]
66pub struct SgpValidParams<F: Float, Corr: CorrelationModel<F>> {
67 gp_params: GpValidParams<F, ConstantMean, Corr>,
69 noise: ParamTuning<F>,
71 z: Inducings<F>,
73 method: SparseMethod,
75 seed: Option<u64>,
77}
78
79impl<F: Float, Corr: CorrelationModel<F>> Default for SgpValidParams<F, Corr> {
80 fn default() -> SgpValidParams<F, Corr> {
81 SgpValidParams {
82 gp_params: GpValidParams::default(),
83 noise: ParamTuning::default(),
84 z: Inducings::default(),
85 method: SparseMethod::default(),
86 seed: None,
87 }
88 }
89}
90
91impl<F: Float, Corr: CorrelationModel<F>> SgpValidParams<F, Corr> {
92 pub fn corr(&self) -> &Corr {
94 &self.gp_params.corr
95 }
96
97 pub fn kpls_dim(&self) -> Option<&usize> {
99 self.gp_params.kpls_dim.as_ref()
100 }
101
102 pub fn theta_tuning(&self) -> &ThetaTuning<F> {
104 &self.gp_params.theta_tuning
105 }
106
107 pub fn n_start(&self) -> usize {
109 self.gp_params.n_start
110 }
111
112 pub fn max_eval(&self) -> usize {
114 self.gp_params.max_eval
115 }
116
117 pub fn nugget(&self) -> F {
119 self.gp_params.nugget
120 }
121
122 pub fn method(&self) -> SparseMethod {
124 self.method
125 }
126
127 pub fn inducings(&self) -> &Inducings<F> {
129 &self.z
130 }
131
132 pub fn noise_variance(&self) -> &ParamTuning<F> {
134 &self.noise
135 }
136
137 pub fn seed(&self) -> Option<&u64> {
139 self.seed.as_ref()
140 }
141}
142
143#[derive(Clone, Debug)]
144pub struct SgpParams<F: Float, Corr: CorrelationModel<F>>(SgpValidParams<F, Corr>);
147
148impl<F: Float, Corr: CorrelationModel<F>> SgpParams<F, Corr> {
149 pub fn new(corr: Corr, inducings: Inducings<F>) -> SgpParams<F, Corr> {
151 Self(SgpValidParams {
152 gp_params: GpValidParams {
153 mean: ConstantMean::default(),
154 corr,
155 theta_tuning: ThetaTuning::Full {
156 init: array![F::cast(ThetaTuning::<F>::DEFAULT_INIT)],
157 bounds: array![(F::cast(ThetaTuning::<F>::DEFAULT_BOUNDS.0), F::cast(1e2),)],
158 }, ..Default::default()
160 },
161
162 z: inducings,
163 ..Default::default()
164 })
165 }
166
167 pub fn new_from_valid(params: &SgpValidParams<F, Corr>) -> Self {
168 Self(params.clone())
169 }
170
171 pub fn corr(mut self, corr: Corr) -> Self {
173 self.0.gp_params.corr = corr;
174 self
175 }
176
177 pub fn theta_init(mut self, theta_init: Array1<F>) -> Self {
182 self.0.gp_params.theta_tuning = match self.0.gp_params.theta_tuning {
183 ThetaTuning::Full { init: _, bounds } => ThetaTuning::Full {
184 init: theta_init,
185 bounds,
186 },
187 ThetaTuning::Partial {
188 init: _,
189 active: _,
190 bounds,
191 } => ThetaTuning::Full {
192 init: theta_init,
193 bounds,
194 },
195 ThetaTuning::Fixed(_) => ThetaTuning::Fixed(theta_init),
196 };
197 self
198 }
199
200 pub fn theta_bounds(mut self, theta_bounds: Array1<(F, F)>) -> Self {
204 self.0.gp_params.theta_tuning = match self.0.gp_params.theta_tuning {
205 ThetaTuning::Full { init, bounds: _ } => ThetaTuning::Full {
206 init,
207 bounds: theta_bounds,
208 },
209 ThetaTuning::Partial {
210 init,
211 active: _,
212 bounds: _,
213 } => ThetaTuning::Full {
214 init,
215 bounds: theta_bounds,
216 },
217 ThetaTuning::Fixed(f) => ThetaTuning::Fixed(f),
218 };
219 self
220 }
221
222 pub fn theta_tuning(mut self, theta_tuning: ThetaTuning<F>) -> Self {
224 self.0.gp_params.theta_tuning = theta_tuning;
225 self
226 }
227
228 pub fn kpls_dim(mut self, kpls_dim: Option<usize>) -> Self {
232 self.0.gp_params.kpls_dim = kpls_dim;
233 self
234 }
235
236 pub fn n_start(mut self, n_start: usize) -> Self {
238 self.0.gp_params.n_start = n_start;
239 self
240 }
241
242 pub fn max_eval(mut self, max_eval: usize) -> Self {
246 self.0.gp_params.max_eval = GP_COBYLA_MIN_EVAL.max(max_eval);
247 self
248 }
249
250 pub fn nugget(mut self, nugget: F) -> Self {
254 self.0.gp_params.nugget = nugget;
255 self
256 }
257
258 pub fn sparse_method(mut self, method: SparseMethod) -> Self {
260 self.0.method = method;
261 self
262 }
263
264 pub fn inducings(mut self, z: Array2<F>) -> Self {
266 self.0.z = Inducings::Located(z);
267 self
268 }
269
270 pub fn n_inducings(mut self, nz: usize) -> Self {
272 self.0.z = Inducings::Randomized(nz);
273 self
274 }
275
276 pub fn noise_variance(mut self, config: ParamTuning<F>) -> Self {
278 self.0.noise = config;
279 self
280 }
281
282 pub fn seed(mut self, seed: Option<u64>) -> Self {
284 self.0.seed = seed;
285 self
286 }
287}
288
289impl<F: Float, Corr: CorrelationModel<F>> From<SgpValidParams<F, Corr>> for SgpParams<F, Corr> {
290 fn from(valid: SgpValidParams<F, Corr>) -> Self {
291 SgpParams(valid.clone())
292 }
293}
294
295impl<F: Float, Corr: CorrelationModel<F>> ParamGuard for SgpParams<F, Corr> {
296 type Checked = SgpValidParams<F, Corr>;
297 type Error = GpError;
298
299 fn check_ref(&self) -> Result<&Self::Checked> {
300 if let Some(d) = self.0.gp_params.kpls_dim {
301 if d == 0 {
302 return Err(GpError::InvalidValueError(
303 "`kpls_dim` canot be 0!".to_string(),
304 ));
305 }
306 let theta = self.0.theta_tuning().init();
307 if theta.len() > 1 && d > theta.len() {
308 return Err(GpError::InvalidValueError(format!(
309 "Dimension reduction ({}) should be smaller than expected
310 training input size infered from given initial theta length ({})",
311 d,
312 theta.len()
313 )));
314 };
315 }
316 Ok(&self.0)
317 }
318
319 fn check(self) -> Result<Self::Checked> {
320 self.check_ref()?;
321 Ok(self.0)
322 }
323}