use crate::error::FtrlError;
use linfa::{Float, ParamGuard};
use rand::Rng;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct FtrlParams<F: Float, R: Rng>(pub(crate) FtrlValidParams<F, R>);
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct FtrlValidParams<F: Float, R: Rng> {
pub(crate) alpha: F,
pub(crate) beta: F,
pub(crate) l1_ratio: F,
pub(crate) l2_ratio: F,
pub(crate) rng: R,
}
impl<F: Float, R: Rng> ParamGuard for FtrlParams<F, R> {
type Checked = FtrlValidParams<F, R>;
type Error = FtrlError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
Err(FtrlError::InvalidL1Ratio(self.0.l1_ratio.to_f32().unwrap()))
} else if !(F::zero()..=F::one()).contains(&self.0.l2_ratio) {
Err(FtrlError::InvalidL2Ratio(self.0.l2_ratio.to_f32().unwrap()))
} else if !&self.0.alpha.is_finite() || self.0.alpha.is_negative() {
Err(FtrlError::InvalidAlpha(self.0.alpha.to_f32().unwrap()))
} else if !&self.0.beta.is_finite() || self.0.beta.is_negative() {
Err(FtrlError::InvalidBeta(self.0.beta.to_f32().unwrap()))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}
impl<F: Float, R: Rng> FtrlValidParams<F, R> {
pub fn alpha(&self) -> F {
self.alpha
}
pub fn beta(&self) -> F {
self.beta
}
pub fn l1_ratio(&self) -> F {
self.l1_ratio
}
pub fn l2_ratio(&self) -> F {
self.l2_ratio
}
pub fn rng(&self) -> &R {
&self.rng
}
}
impl<F: Float, R: Rng> FtrlParams<F, R> {
pub fn new(alpha: F, beta: F, l1_ratio: F, l2_ratio: F, rng: R) -> Self {
Self(FtrlValidParams {
alpha,
beta,
l1_ratio,
l2_ratio,
rng,
})
}
pub fn default_with_rng(rng: R) -> Self {
Self(FtrlValidParams {
alpha: F::cast(0.005),
beta: F::cast(0.0),
l1_ratio: F::cast(0.5),
l2_ratio: F::cast(0.5),
rng,
})
}
pub fn alpha(mut self, alpha: F) -> Self {
self.0.alpha = alpha;
self
}
pub fn beta(mut self, beta: F) -> Self {
self.0.beta = beta;
self
}
pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
self.0.l1_ratio = l1_ratio;
self
}
pub fn l2_ratio(mut self, l2_ratio: F) -> Self {
self.0.l2_ratio = l2_ratio;
self
}
pub fn rng(mut self, rng: R) -> Self {
self.0.rng = rng;
self
}
}