use std::error::Error;
use crate::{
prelude::Records,
traits::{Fit, FitWith, Transformer},
};
pub trait ParamGuard {
type Checked;
type Error: Error;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error>;
fn check(self) -> Result<Self::Checked, Self::Error>;
fn check_unwrap(self) -> Self::Checked
where
Self: Sized,
{
self.check().unwrap()
}
}
pub trait TransformGuard: ParamGuard {}
impl<R: Records, T, P: TransformGuard> Transformer<R, Result<T, P::Error>> for P
where
P::Checked: Transformer<R, T>,
{
fn transform(&self, x: R) -> Result<T, P::Error> {
self.check_ref().map(|p| p.transform(x))
}
}
impl<R: Records, T, E, P: ParamGuard> Fit<R, T, E> for P
where
P::Checked: Fit<R, T, E>,
E: Error + From<crate::error::Error> + From<P::Error>,
{
type Object = <<P as ParamGuard>::Checked as Fit<R, T, E>>::Object;
fn fit(&self, dataset: &crate::DatasetBase<R, T>) -> Result<Self::Object, E> {
let checked = self.check_ref()?;
checked.fit(dataset)
}
}
impl<'a, R: Records, T, E, P: ParamGuard> FitWith<'a, R, T, E> for P
where
P::Checked: FitWith<'a, R, T, E>,
E: Error + From<crate::error::Error> + From<P::Error>,
{
type ObjectIn = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectIn;
type ObjectOut = <<P as ParamGuard>::Checked as FitWith<'a, R, T, E>>::ObjectOut;
fn fit_with(
&self,
model: Self::ObjectIn,
dataset: &'a crate::DatasetBase<R, T>,
) -> Result<Self::ObjectOut, E> {
let checked = self.check_ref()?;
checked.fit_with(model, dataset)
}
}