#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use linfa::{Float, ParamGuard};
use crate::error::ElasticNetError;
use super::Result;
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ElasticNetValidParamsBase<F, const MULTI_TASK: bool> {
penalty: F,
l1_ratio: F,
with_intercept: bool,
max_iterations: u32,
tolerance: F,
}
pub type ElasticNetValidParams<F> = ElasticNetValidParamsBase<F, false>;
pub type MultiTaskElasticNetValidParams<F> = ElasticNetValidParamsBase<F, true>;
impl<F: Float, const MULTI_TASK: bool> ElasticNetValidParamsBase<F, MULTI_TASK> {
pub fn penalty(&self) -> F {
self.penalty
}
pub fn l1_ratio(&self) -> F {
self.l1_ratio
}
pub fn with_intercept(&self) -> bool {
self.with_intercept
}
pub fn max_iterations(&self) -> u32 {
self.max_iterations
}
pub fn tolerance(&self) -> F {
self.tolerance
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ElasticNetParamsBase<F, const MULTI_TASK: bool>(
ElasticNetValidParamsBase<F, MULTI_TASK>,
);
pub type ElasticNetParams<F> = ElasticNetParamsBase<F, false>;
pub type MultiTaskElasticNetParams<F> = ElasticNetParamsBase<F, true>;
impl<F: Float, const MULTI_TASK: bool> Default for ElasticNetParamsBase<F, MULTI_TASK> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float, const MULTI_TASK: bool> ElasticNetParamsBase<F, MULTI_TASK> {
pub fn new() -> ElasticNetParamsBase<F, MULTI_TASK> {
Self(ElasticNetValidParamsBase {
penalty: F::one(),
l1_ratio: F::cast(0.5),
with_intercept: true,
max_iterations: 1000,
tolerance: F::cast(1e-4),
})
}
pub fn penalty(mut self, penalty: F) -> Self {
self.0.penalty = penalty;
self
}
pub fn l1_ratio(mut self, l1_ratio: F) -> Self {
self.0.l1_ratio = l1_ratio;
self
}
pub fn with_intercept(mut self, with_intercept: bool) -> Self {
self.0.with_intercept = with_intercept;
self
}
pub fn tolerance(mut self, tolerance: F) -> Self {
self.0.tolerance = tolerance;
self
}
pub fn max_iterations(mut self, max_iterations: u32) -> Self {
self.0.max_iterations = max_iterations;
self
}
}
impl<F: Float, const MULTI_TASK: bool> ParamGuard for ElasticNetParamsBase<F, MULTI_TASK> {
type Checked = ElasticNetValidParamsBase<F, MULTI_TASK>;
type Error = ElasticNetError;
fn check_ref(&self) -> Result<&Self::Checked> {
if self.0.penalty.is_negative() {
Err(ElasticNetError::InvalidPenalty(
self.0.penalty.to_f32().unwrap(),
))
} else if !(F::zero()..=F::one()).contains(&self.0.l1_ratio) {
Err(ElasticNetError::InvalidL1Ratio(
self.0.l1_ratio.to_f32().unwrap(),
))
} else if self.0.tolerance.is_negative() {
Err(ElasticNetError::InvalidTolerance(
self.0.tolerance.to_f32().unwrap(),
))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked> {
self.check_ref()?;
Ok(self.0)
}
}