use super::Fit;
use crate::{Array1, error::RegressionResult, glm::Glm, model::Model, num::Float};
pub struct FitConfig<'a, M, F>
where
M: Glm,
F: Float,
{
pub(crate) model: &'a Model<M, F>,
pub options: FitOptions<F>,
}
impl<'a, M, F> FitConfig<'a, M, F>
where
M: Glm,
F: Float,
{
pub fn fit(self) -> RegressionResult<Fit<'a, M, F>, F> {
M::regression(self.model, self.options)
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.options.max_iter = max_iter;
self
}
pub fn tol(mut self, tol: F) -> Self {
self.options.tol = tol;
self
}
pub fn l2_reg(mut self, l2: F) -> Self {
self.options.l2 = l2;
self
}
pub fn l1_reg(mut self, l1: F) -> Self {
self.options.l1 = l1;
self
}
}
#[derive(Clone)]
pub struct FitOptions<F>
where
F: Float,
{
pub max_iter: usize,
pub tol: F,
pub l2: F,
pub l1: F,
pub init_guess: Option<Array1<F>>,
}
impl<F> Default for FitOptions<F>
where
F: Float,
{
fn default() -> Self {
Self {
max_iter: 128,
tol: F::epsilon(),
l2: F::zero(),
l1: F::zero(),
init_guess: None,
}
}
}