1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
use std::error::Error;
use crate::{
prelude::Records,
traits::{Fit, FitWith, Transformer},
};
/// A set of hyperparameters whose values have not been checked for validity. A reference to the
/// checked hyperparameters can only be obtained after checking has completed. If the
/// `Transformer`, `Fit`, or `FitWith` traits have been implemented on the checked
/// hyperparameters, they will also be implemented on the unchecked hyperparameters with the
/// checking step done automatically.
///
/// The hyperparameter validation done in `check_ref()` and `check()` should be identical.
pub trait ParamGuard {
/// The checked hyperparameters
type Checked;
/// Error type resulting from failed hyperparameter checking
type Error: Error;
/// Checks the hyperparameters and returns a reference to the checked hyperparameters if
/// successful
fn check_ref(&self) -> Result<&Self::Checked, Self::Error>;
/// Checks the hyperparameters and returns the checked hyperparameters if successful
fn check(self) -> Result<Self::Checked, Self::Error>;
/// Calls `check()` and unwraps the result
fn check_unwrap(self) -> Self::Checked
where
Self: Sized,
{
self.check().unwrap()
}
}
/// Implement this trait to opt into a blanket `Transformer` impl that wraps the output of the
/// unchecked `transform` call in a `Result`. If the unchecked `transform` call returns a `Result`,
/// the blanket impl will return double `Result`s, so this trait should be avoided in that case.
pub trait TransformGuard: ParamGuard {}
/// Performs the checking step and calls `transform` on the checked hyperparameters. Returns error
/// if checking was unsuccessful.
impl<R: Records, T, P: TransformGuard> Transformer<R, Result<T, P::Error>> for P
where
P::Checked: Transformer<R, T>,
{
fn transform(&self, x: R) -> Result<T, P::Error> {
self.check_ref().map(|p| p.transform(x))
}
}
/// Performs checking step and calls `fit` on the checked hyperparameters. If checking failed, the
/// checking error is converted to the original error type of `Fit` and returned.
impl<R: Records, T, E, P: ParamGuard> Fit<R, T, E> for P
where
P::Checked: Fit<R, T, E>,
E: Error + From<crate::error::Error> + From<P::Error>,
{
type Object = <<P as ParamGuard>::Checked as Fit<R, T, E>>::Object;
fn fit(&self, dataset: &crate::DatasetBase<R, T>) -> Result<Self::Object, E> {
let checked = self.check_ref()?;
checked.fit(dataset)
}
}
/// Performs checking step and calls `fit_with` on the checked hyperparameters. If checking failed,
/// the checking error is converted to the original error type of `FitWith` and returned.
impl<'a, R: Records, T, E, P: ParamGuard> FitWith<'a, R, T, E> for P
where
P::Checked: FitWith<'a, R, T, E>,
E: Error + From<crate::error::Error> + From<P::Error>,
{
type ObjectIn = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectIn;
type ObjectOut = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectOut;
fn fit_with(
&self,
model: Self::ObjectIn,
dataset: &'a crate::DatasetBase<R, T>,
) -> Result<Self::ObjectOut, E> {
let checked = self.check_ref()?;
checked.fit_with(model, dataset)
}
}