use crate::{glm::link::Link, LinearError, TweedieRegressor};
use linfa::{Float, ParamGuard};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct TweedieRegressorValidParams<F> {
alpha: F,
fit_intercept: bool,
power: F,
link: Option<Link>,
max_iter: usize,
tol: F,
}
impl<F: Float> TweedieRegressorValidParams<F> {
pub fn alpha(&self) -> F {
self.alpha
}
pub fn fit_intercept(&self) -> bool {
self.fit_intercept
}
pub fn power(&self) -> F {
self.power
}
pub fn link(&self) -> Link {
match self.link {
Some(x) => x,
None if self.power <= F::zero() => Link::Identity,
None => Link::Log,
}
}
pub fn max_iter(&self) -> usize {
self.max_iter
}
pub fn tol(&self) -> F {
self.tol
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TweedieRegressorParams<F>(TweedieRegressorValidParams<F>);
impl<F: Float> Default for TweedieRegressorParams<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> TweedieRegressor<F> {
pub fn params() -> TweedieRegressorParams<F> {
TweedieRegressorParams::new()
}
}
impl<F: Float> TweedieRegressorParams<F> {
pub fn new() -> Self {
Self(TweedieRegressorValidParams {
alpha: F::one(),
fit_intercept: true,
power: F::one(),
link: None,
max_iter: 100,
tol: F::cast(1e-4),
})
}
pub fn alpha(mut self, alpha: F) -> Self {
self.0.alpha = alpha;
self
}
pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
self.0.fit_intercept = fit_intercept;
self
}
pub fn power(mut self, power: F) -> Self {
self.0.power = power;
self
}
pub fn link(mut self, link: Link) -> Self {
self.0.link = Some(link);
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.0.max_iter = max_iter;
self
}
pub fn tol(mut self, tol: F) -> Self {
self.0.tol = tol;
self
}
}
impl<F: Float> ParamGuard for TweedieRegressorParams<F> {
type Checked = TweedieRegressorValidParams<F>;
type Error = LinearError<F>;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if self.0.alpha.is_sign_negative() {
Err(LinearError::InvalidPenalty(self.0.alpha))
} else if self.0.power > F::zero() && self.0.power < F::one() {
Err(LinearError::InvalidTweediePower(self.0.power))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}