#![doc = include_str!("../README.md")]
use linfa::Float;
use ndarray::{Array1, Array2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
mod algorithm;
mod error;
mod hyperparams;
pub use error::{ElasticNetError, Result};
pub use hyperparams::{
ElasticNetParams, ElasticNetParamsBase, ElasticNetValidParams, ElasticNetValidParamsBase,
MultiTaskElasticNetParams, MultiTaskElasticNetValidParams,
};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone)]
pub struct ElasticNet<F> {
hyperplane: Array1<F>,
intercept: F,
duality_gap: F,
n_steps: u32,
variance: Result<Array1<F>>,
}
impl<F: Float> ElasticNet<F> {
pub fn params() -> ElasticNetParams<F> {
ElasticNetParams::new()
}
pub fn ridge() -> ElasticNetParams<F> {
ElasticNetParams::new().l1_ratio(F::zero())
}
pub fn lasso() -> ElasticNetParams<F> {
ElasticNetParams::new().l1_ratio(F::one())
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone)]
pub struct MultiTaskElasticNet<F> {
hyperplane: Array2<F>,
intercept: Array1<F>,
duality_gap: F,
n_steps: u32,
variance: Result<Array1<F>>,
}
impl<F: Float> MultiTaskElasticNet<F> {
pub fn params() -> MultiTaskElasticNetParams<F> {
MultiTaskElasticNetParams::new()
}
pub fn ridge() -> MultiTaskElasticNetParams<F> {
MultiTaskElasticNetParams::new().l1_ratio(F::zero())
}
pub fn lasso() -> MultiTaskElasticNetParams<F> {
MultiTaskElasticNetParams::new().l1_ratio(F::one())
}
}