linfa/
param_guard.rs

1use std::error::Error;
2
3use crate::{
4    prelude::Records,
5    traits::{Fit, FitWith, Transformer},
6};
7
8/// A set of hyperparameters whose values have not been checked for validity. A reference to the
9/// checked hyperparameters can only be obtained after checking has completed. If the
10/// `Transformer`, `Fit`, or `FitWith` traits have been implemented on the checked
11/// hyperparameters, they will also be implemented on the unchecked hyperparameters with the
12/// checking step done automatically.
13///
14/// The hyperparameter validation done in `check_ref()` and `check()` should be identical.
15pub trait ParamGuard {
16    /// The checked hyperparameters
17    type Checked;
18    /// Error type resulting from failed hyperparameter checking
19    type Error: Error;
20
21    /// Checks the hyperparameters and returns a reference to the checked hyperparameters if
22    /// successful
23    fn check_ref(&self) -> Result<&Self::Checked, Self::Error>;
24
25    /// Checks the hyperparameters and returns the checked hyperparameters if successful
26    fn check(self) -> Result<Self::Checked, Self::Error>;
27
28    /// Calls `check()` and unwraps the result
29    fn check_unwrap(self) -> Self::Checked
30    where
31        Self: Sized,
32    {
33        self.check().unwrap()
34    }
35}
36
37/// Implement this trait to opt into a blanket `Transformer` impl that wraps the output of the
38/// unchecked `transform` call in a `Result`. If the unchecked `transform` call returns a `Result`,
39/// the blanket impl will return double `Result`s, so this trait should be avoided in that case.
40pub trait TransformGuard: ParamGuard {}
41
42/// Performs the checking step and calls `transform` on the checked hyperparameters. Returns error
43/// if checking was unsuccessful.
44impl<R: Records, T, P: TransformGuard> Transformer<R, Result<T, P::Error>> for P
45where
46    P::Checked: Transformer<R, T>,
47{
48    fn transform(&self, x: R) -> Result<T, P::Error> {
49        self.check_ref().map(|p| p.transform(x))
50    }
51}
52
53/// Performs checking step and calls `fit` on the checked hyperparameters. If checking failed, the
54/// checking error is converted to the original error type of `Fit` and returned.
55impl<R: Records, T, E, P: ParamGuard> Fit<R, T, E> for P
56where
57    P::Checked: Fit<R, T, E>,
58    E: Error + From<crate::error::Error> + From<P::Error>,
59{
60    type Object = <<P as ParamGuard>::Checked as Fit<R, T, E>>::Object;
61
62    fn fit(&self, dataset: &crate::DatasetBase<R, T>) -> Result<Self::Object, E> {
63        let checked = self.check_ref()?;
64        checked.fit(dataset)
65    }
66}
67
68/// Performs checking step and calls `fit_with` on the checked hyperparameters. If checking failed,
69/// the checking error is converted to the original error type of `FitWith` and returned.
70impl<'a, R: Records, T, E, P: ParamGuard> FitWith<'a, R, T, E> for P
71where
72    P::Checked: FitWith<'a, R, T, E>,
73    E: Error + From<crate::error::Error> + From<P::Error>,
74{
75    type ObjectIn = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectIn;
76    type ObjectOut = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectOut;
77
78    fn fit_with(
79        &self,
80        model: Self::ObjectIn,
81        dataset: &'a crate::DatasetBase<R, T>,
82    ) -> Result<Self::ObjectOut, E> {
83        let checked = self.check_ref()?;
84        checked.fit_with(model, dataset)
85    }
86}