#![doc = include_str!("../README.md")]
mod algorithm;
mod error;
mod hyperparams;
use crate::hyperparams::FtrlValidParams;
pub use algorithm::Result;
pub use error::FtrlError;
pub use hyperparams::FtrlParams;
use linfa::Float;
use ndarray::Array1;
use ndarray_rand::RandomExt;
use rand::{distributions::Uniform, Rng};
use rand_xoshiro::{rand_core::SeedableRng, Xoshiro256Plus};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
#[derive(Debug, Clone)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct Ftrl<F: Float> {
alpha: F,
beta: F,
l1_ratio: F,
l2_ratio: F,
z: Array1<F>,
n: Array1<F>,
}
impl<F: Float> Ftrl<F> {
pub fn params() -> FtrlParams<F, Xoshiro256Plus> {
FtrlParams::default_with_rng(Xoshiro256Plus::seed_from_u64(42))
}
pub fn params_with_rng<R: Rng>(rng: R) -> FtrlParams<F, R> {
FtrlParams::default_with_rng(rng)
}
pub fn new<R: Rng + Clone>(params: FtrlValidParams<F, R>, nfeatures: usize) -> Ftrl<F> {
let mut rng = params.rng.clone();
Self {
alpha: params.alpha,
beta: params.beta,
l1_ratio: params.l1_ratio,
l2_ratio: params.l2_ratio,
n: Array1::zeros(nfeatures),
z: Array1::random_using(nfeatures, Uniform::new(F::zero(), F::one()), &mut rng),
}
}
}