linfa/param_guard.rs
1use std::error::Error;
2
3use crate::{
4 prelude::Records,
5 traits::{Fit, FitWith, Transformer},
6};
7
8/// A set of hyperparameters whose values have not been checked for validity. A reference to the
9/// checked hyperparameters can only be obtained after checking has completed. If the
10/// `Transformer`, `Fit`, or `FitWith` traits have been implemented on the checked
11/// hyperparameters, they will also be implemented on the unchecked hyperparameters with the
12/// checking step done automatically.
13///
14/// The hyperparameter validation done in `check_ref()` and `check()` should be identical.
15pub trait ParamGuard {
16 /// The checked hyperparameters
17 type Checked;
18 /// Error type resulting from failed hyperparameter checking
19 type Error: Error;
20
21 /// Checks the hyperparameters and returns a reference to the checked hyperparameters if
22 /// successful
23 fn check_ref(&self) -> Result<&Self::Checked, Self::Error>;
24
25 /// Checks the hyperparameters and returns the checked hyperparameters if successful
26 fn check(self) -> Result<Self::Checked, Self::Error>;
27
28 /// Calls `check()` and unwraps the result
29 fn check_unwrap(self) -> Self::Checked
30 where
31 Self: Sized,
32 {
33 self.check().unwrap()
34 }
35}
36
37/// Implement this trait to opt into a blanket `Transformer` impl that wraps the output of the
38/// unchecked `transform` call in a `Result`. If the unchecked `transform` call returns a `Result`,
39/// the blanket impl will return double `Result`s, so this trait should be avoided in that case.
40pub trait TransformGuard: ParamGuard {}
41
42/// Performs the checking step and calls `transform` on the checked hyperparameters. Returns error
43/// if checking was unsuccessful.
44impl<R: Records, T, P: TransformGuard> Transformer<R, Result<T, P::Error>> for P
45where
46 P::Checked: Transformer<R, T>,
47{
48 fn transform(&self, x: R) -> Result<T, P::Error> {
49 self.check_ref().map(|p| p.transform(x))
50 }
51}
52
53/// Performs checking step and calls `fit` on the checked hyperparameters. If checking failed, the
54/// checking error is converted to the original error type of `Fit` and returned.
55impl<R: Records, T, E, P: ParamGuard> Fit<R, T, E> for P
56where
57 P::Checked: Fit<R, T, E>,
58 E: Error + From<crate::error::Error> + From<P::Error>,
59{
60 type Object = <<P as ParamGuard>::Checked as Fit<R, T, E>>::Object;
61
62 fn fit(&self, dataset: &crate::DatasetBase<R, T>) -> Result<Self::Object, E> {
63 let checked = self.check_ref()?;
64 checked.fit(dataset)
65 }
66}
67
68/// Performs checking step and calls `fit_with` on the checked hyperparameters. If checking failed,
69/// the checking error is converted to the original error type of `FitWith` and returned.
70impl<'a, R: Records, T, E, P: ParamGuard> FitWith<'a, R, T, E> for P
71where
72 P::Checked: FitWith<'a, R, T, E>,
73 E: Error + From<crate::error::Error> + From<P::Error>,
74{
75 type ObjectIn = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectIn;
76 type ObjectOut = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectOut;
77
78 fn fit_with(
79 &self,
80 model: Self::ObjectIn,
81 dataset: &'a crate::DatasetBase<R, T>,
82 ) -> Result<Self::ObjectOut, E> {
83 let checked = self.check_ref()?;
84 checked.fit_with(model, dataset)
85 }
86}