1use crate::correlation_models::CorrelationModel;
2use crate::errors::{GpError, Result};
3use crate::mean_models::ConstantMean;
4use crate::parameters::GpValidParams;
5use crate::{GP_COBYLA_MIN_EVAL, ThetaTuning};
6use linfa::{Float, ParamGuard};
7use ndarray::{Array1, Array2, array};
8#[cfg(feature = "serializable")]
9use serde::{Deserialize, Serialize};
10
11#[derive(Clone, Debug, PartialEq, Eq)]
13#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
14pub enum ParamTuning<F: Float> {
15 Fixed(F),
17 Optimized {
19 init: F,
21 bounds: (F, F),
23 },
24}
25impl<F: Float> Default for ParamTuning<F> {
26 fn default() -> ParamTuning<F> {
27 Self::Optimized {
28 init: F::cast(1e-2),
29 bounds: (F::cast(100.0) * F::epsilon(), F::cast(1e10)),
30 }
31 }
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
36#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
37#[non_exhaustive]
38pub enum Inducings<F: Float> {
39 Randomized(usize),
41 Located(Array2<F>),
43}
44impl<F: Float> Default for Inducings<F> {
45 fn default() -> Inducings<F> {
46 Self::Randomized(10)
47 }
48}
49
50#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
52#[cfg_attr(feature = "serializable", derive(Serialize, Deserialize))]
53pub enum SparseMethod {
54 #[default]
55 Fitc,
57 Vfe,
59}
60
61#[derive(Clone, Debug, PartialEq, Eq)]
63#[cfg_attr(
64 feature = "serializable",
65 derive(Serialize, Deserialize),
66 serde(bound(
67 serialize = "F: Serialize, Corr: Serialize",
68 deserialize = "F: Deserialize<'de>, Corr: Deserialize<'de>"
69 ))
70)]
71pub struct SgpValidParams<F: Float, Corr: CorrelationModel<F>> {
72 gp_params: GpValidParams<F, ConstantMean, Corr>,
74 noise: ParamTuning<F>,
76 z: Inducings<F>,
78 method: SparseMethod,
80 seed: Option<u64>,
82}
83
84impl<F: Float, Corr: CorrelationModel<F>> Default for SgpValidParams<F, Corr> {
85 fn default() -> SgpValidParams<F, Corr> {
86 SgpValidParams {
87 gp_params: GpValidParams::default(),
88 noise: ParamTuning::default(),
89 z: Inducings::default(),
90 method: SparseMethod::default(),
91 seed: None,
92 }
93 }
94}
95
96impl<F: Float, Corr: CorrelationModel<F>> SgpValidParams<F, Corr> {
97 pub fn corr(&self) -> &Corr {
99 &self.gp_params.corr
100 }
101
102 pub fn kpls_dim(&self) -> Option<&usize> {
104 self.gp_params.kpls_dim.as_ref()
105 }
106
107 pub fn theta_tuning(&self) -> &ThetaTuning<F> {
109 &self.gp_params.theta_tuning
110 }
111
112 pub fn n_start(&self) -> usize {
114 self.gp_params.n_start
115 }
116
117 pub fn max_eval(&self) -> usize {
119 self.gp_params.max_eval
120 }
121
122 pub fn nugget(&self) -> F {
124 self.gp_params.nugget
125 }
126
127 pub fn method(&self) -> SparseMethod {
129 self.method
130 }
131
132 pub fn inducings(&self) -> &Inducings<F> {
134 &self.z
135 }
136
137 pub fn noise_variance(&self) -> &ParamTuning<F> {
139 &self.noise
140 }
141
142 pub fn seed(&self) -> Option<&u64> {
144 self.seed.as_ref()
145 }
146}
147
148#[derive(Clone, Debug)]
149pub struct SgpParams<F: Float, Corr: CorrelationModel<F>>(SgpValidParams<F, Corr>);
152
153impl<F: Float, Corr: CorrelationModel<F>> SgpParams<F, Corr> {
154 pub fn new(corr: Corr, inducings: Inducings<F>) -> SgpParams<F, Corr> {
156 Self(SgpValidParams {
157 gp_params: GpValidParams {
158 mean: ConstantMean::default(),
159 corr,
160 theta_tuning: ThetaTuning::Full {
161 init: array![F::cast(ThetaTuning::<F>::DEFAULT_INIT)],
162 bounds: array![(F::cast(ThetaTuning::<F>::DEFAULT_BOUNDS.0), F::cast(1e2),)],
163 }, ..Default::default()
165 },
166
167 z: inducings,
168 ..Default::default()
169 })
170 }
171
172 pub fn new_from_valid(params: &SgpValidParams<F, Corr>) -> Self {
174 Self(params.clone())
175 }
176
177 pub fn corr(mut self, corr: Corr) -> Self {
179 self.0.gp_params.corr = corr;
180 self
181 }
182
183 pub fn theta_init(mut self, theta_init: Array1<F>) -> Self {
188 self.0.gp_params.theta_tuning = match self.0.gp_params.theta_tuning {
189 ThetaTuning::Full { init: _, bounds } => ThetaTuning::Full {
190 init: theta_init,
191 bounds,
192 },
193 ThetaTuning::Partial {
194 init: _,
195 active: _,
196 bounds,
197 } => ThetaTuning::Full {
198 init: theta_init,
199 bounds,
200 },
201 ThetaTuning::Fixed(_) => ThetaTuning::Fixed(theta_init),
202 };
203 self
204 }
205
206 pub fn theta_bounds(mut self, theta_bounds: Array1<(F, F)>) -> Self {
210 self.0.gp_params.theta_tuning = match self.0.gp_params.theta_tuning {
211 ThetaTuning::Full { init, bounds: _ } => ThetaTuning::Full {
212 init,
213 bounds: theta_bounds,
214 },
215 ThetaTuning::Partial {
216 init,
217 active: _,
218 bounds: _,
219 } => ThetaTuning::Full {
220 init,
221 bounds: theta_bounds,
222 },
223 ThetaTuning::Fixed(f) => ThetaTuning::Fixed(f),
224 };
225 self
226 }
227
228 pub fn theta_tuning(mut self, theta_tuning: ThetaTuning<F>) -> Self {
230 self.0.gp_params.theta_tuning = theta_tuning;
231 self
232 }
233
234 pub fn kpls_dim(mut self, kpls_dim: Option<usize>) -> Self {
238 self.0.gp_params.kpls_dim = kpls_dim;
239 self
240 }
241
242 pub fn n_start(mut self, n_start: usize) -> Self {
244 self.0.gp_params.n_start = n_start;
245 self
246 }
247
248 pub fn max_eval(mut self, max_eval: usize) -> Self {
252 self.0.gp_params.max_eval = GP_COBYLA_MIN_EVAL.max(max_eval);
253 self
254 }
255
256 pub fn nugget(mut self, nugget: F) -> Self {
260 self.0.gp_params.nugget = nugget;
261 self
262 }
263
264 pub fn sparse_method(mut self, method: SparseMethod) -> Self {
266 self.0.method = method;
267 self
268 }
269
270 pub fn inducings(mut self, z: Array2<F>) -> Self {
272 self.0.z = Inducings::Located(z);
273 self
274 }
275
276 pub fn n_inducings(mut self, nz: usize) -> Self {
278 self.0.z = Inducings::Randomized(nz);
279 self
280 }
281
282 pub fn noise_variance(mut self, config: ParamTuning<F>) -> Self {
284 self.0.noise = config;
285 self
286 }
287
288 pub fn seed(mut self, seed: Option<u64>) -> Self {
290 self.0.seed = seed;
291 self
292 }
293}
294
295impl<F: Float, Corr: CorrelationModel<F>> From<SgpValidParams<F, Corr>> for SgpParams<F, Corr> {
296 fn from(valid: SgpValidParams<F, Corr>) -> Self {
297 SgpParams(valid.clone())
298 }
299}
300
301impl<F: Float, Corr: CorrelationModel<F>> ParamGuard for SgpParams<F, Corr> {
302 type Checked = SgpValidParams<F, Corr>;
303 type Error = GpError;
304
305 fn check_ref(&self) -> Result<&Self::Checked> {
306 if let Some(d) = self.0.gp_params.kpls_dim {
307 if d == 0 {
308 return Err(GpError::InvalidValueError(
309 "`kpls_dim` canot be 0!".to_string(),
310 ));
311 }
312 let theta = self.0.theta_tuning().init();
313 if theta.len() > 1 && d > theta.len() {
314 return Err(GpError::InvalidValueError(format!(
315 "Dimension reduction ({}) should be smaller than expected
316 training input size infered from given initial theta length ({})",
317 d,
318 theta.len()
319 )));
320 };
321 }
322 Ok(&self.0)
323 }
324
325 fn check(self) -> Result<Self::Checked> {
326 self.check_ref()?;
327 Ok(self.0)
328 }
329}