use super::na_action::NaAction;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SolverType {
#[default]
Qr,
Svd,
Cholesky,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LambdaScaling {
#[default]
Raw,
Glmnet,
}
#[derive(Debug, Clone)]
pub struct RegressionOptions {
pub with_intercept: bool,
pub compute_inference: bool,
pub confidence_level: f64,
pub lambda: f64,
pub alpha: f64,
pub lambda_scaling: LambdaScaling,
pub forgetting_factor: f64,
pub solver: SolverType,
pub max_iterations: usize,
pub tolerance: f64,
pub rank_tolerance: f64,
pub na_action: NaAction,
}
impl Default for RegressionOptions {
fn default() -> Self {
Self {
with_intercept: true,
compute_inference: true,
confidence_level: 0.95,
lambda: 0.0,
alpha: 0.0,
lambda_scaling: LambdaScaling::Raw,
forgetting_factor: 1.0,
solver: SolverType::Qr,
max_iterations: 1000,
tolerance: 1e-6,
rank_tolerance: 1e-10,
na_action: NaAction::Omit,
}
}
}
#[derive(Debug, Error)]
pub enum OptionsError {
#[error("lambda must be non-negative, got {0}")]
InvalidLambda(f64),
#[error("alpha must be in [0, 1], got {0}")]
InvalidAlpha(f64),
#[error("confidence_level must be in (0, 1), got {0}")]
InvalidConfidenceLevel(f64),
#[error("forgetting_factor must be in (0, 1], got {0}")]
InvalidForgettingFactor(f64),
#[error("tolerance must be positive, got {0}")]
InvalidTolerance(f64),
#[error("max_iterations must be at least 1, got {0}")]
InvalidMaxIterations(usize),
#[error("alpha > 0 requires lambda > 0 for elastic net")]
ElasticNetRequiresLambda,
}
impl RegressionOptions {
pub fn builder() -> RegressionOptionsBuilder {
RegressionOptionsBuilder::default()
}
pub fn ols() -> Self {
Self::default()
}
pub fn ridge(lambda: f64) -> Self {
Self {
lambda,
..Default::default()
}
}
pub fn lasso(lambda: f64) -> Self {
Self {
lambda,
alpha: 1.0,
..Default::default()
}
}
pub fn elastic_net(lambda: f64, alpha: f64) -> Self {
Self {
lambda,
alpha,
..Default::default()
}
}
pub fn rls(forgetting_factor: f64) -> Self {
Self {
forgetting_factor,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), OptionsError> {
if self.lambda < 0.0 {
return Err(OptionsError::InvalidLambda(self.lambda));
}
if !(0.0..=1.0).contains(&self.alpha) {
return Err(OptionsError::InvalidAlpha(self.alpha));
}
if self.confidence_level <= 0.0 || self.confidence_level >= 1.0 {
return Err(OptionsError::InvalidConfidenceLevel(self.confidence_level));
}
if self.forgetting_factor <= 0.0 || self.forgetting_factor > 1.0 {
return Err(OptionsError::InvalidForgettingFactor(
self.forgetting_factor,
));
}
if self.tolerance <= 0.0 {
return Err(OptionsError::InvalidTolerance(self.tolerance));
}
if self.max_iterations < 1 {
return Err(OptionsError::InvalidMaxIterations(self.max_iterations));
}
if self.alpha > 0.0 && self.lambda == 0.0 {
return Err(OptionsError::ElasticNetRequiresLambda);
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct RegressionOptionsBuilder {
options: RegressionOptions,
}
impl RegressionOptionsBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_intercept(mut self, include: bool) -> Self {
self.options.with_intercept = include;
self
}
pub fn compute_inference(mut self, compute: bool) -> Self {
self.options.compute_inference = compute;
self
}
pub fn confidence_level(mut self, level: f64) -> Self {
self.options.confidence_level = level;
self
}
pub fn lambda(mut self, lambda: f64) -> Self {
self.options.lambda = lambda;
self
}
pub fn alpha(mut self, alpha: f64) -> Self {
self.options.alpha = alpha;
self
}
pub fn lambda_scaling(mut self, scaling: LambdaScaling) -> Self {
self.options.lambda_scaling = scaling;
self
}
pub fn forgetting_factor(mut self, factor: f64) -> Self {
self.options.forgetting_factor = factor;
self
}
pub fn solver(mut self, solver: SolverType) -> Self {
self.options.solver = solver;
self
}
pub fn max_iterations(mut self, max_iter: usize) -> Self {
self.options.max_iterations = max_iter;
self
}
pub fn tolerance(mut self, tol: f64) -> Self {
self.options.tolerance = tol;
self
}
pub fn rank_tolerance(mut self, tol: f64) -> Self {
self.options.rank_tolerance = tol;
self
}
pub fn na_action(mut self, action: NaAction) -> Self {
self.options.na_action = action;
self
}
pub fn build(self) -> Result<RegressionOptions, OptionsError> {
self.options.validate()?;
Ok(self.options)
}
pub fn build_unchecked(self) -> RegressionOptions {
self.options
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_options() {
let opts = RegressionOptions::default();
assert!(opts.with_intercept);
assert!(opts.compute_inference);
assert!((opts.confidence_level - 0.95).abs() < 1e-10);
assert!((opts.lambda - 0.0).abs() < 1e-10);
}
#[test]
fn test_builder() {
let opts = RegressionOptions::builder()
.with_intercept(false)
.lambda(0.5)
.build()
.unwrap();
assert!(!opts.with_intercept);
assert!((opts.lambda - 0.5).abs() < 1e-10);
}
#[test]
fn test_validation_invalid_lambda() {
let result = RegressionOptions::builder().lambda(-1.0).build();
assert!(matches!(result, Err(OptionsError::InvalidLambda(_))));
}
#[test]
fn test_validation_invalid_alpha() {
let result = RegressionOptions::builder().alpha(1.5).build();
assert!(matches!(result, Err(OptionsError::InvalidAlpha(_))));
}
#[test]
fn test_validation_elastic_net_requires_lambda() {
let result = RegressionOptions::builder().alpha(0.5).lambda(0.0).build();
assert!(matches!(
result,
Err(OptionsError::ElasticNetRequiresLambda)
));
}
#[test]
fn test_factory_methods() {
let ols = RegressionOptions::ols();
assert!((ols.lambda - 0.0).abs() < 1e-10);
let ridge = RegressionOptions::ridge(0.5);
assert!((ridge.lambda - 0.5).abs() < 1e-10);
let lasso = RegressionOptions::lasso(0.5);
assert!((lasso.alpha - 1.0).abs() < 1e-10);
let elastic = RegressionOptions::elastic_net(0.5, 0.5);
assert!((elastic.lambda - 0.5).abs() < 1e-10);
assert!((elastic.alpha - 0.5).abs() < 1e-10);
}
#[test]
fn test_rls_factory() {
let rls = RegressionOptions::rls(0.99);
assert!((rls.forgetting_factor - 0.99).abs() < 1e-10);
}
#[test]
fn test_builder_new() {
let builder = RegressionOptionsBuilder::new();
let opts = builder.build_unchecked();
assert!(opts.with_intercept); }
#[test]
fn test_builder_solver() {
let opts = RegressionOptions::builder()
.solver(SolverType::Svd)
.build_unchecked();
assert_eq!(opts.solver, SolverType::Svd);
}
#[test]
fn test_builder_rank_tolerance() {
let opts = RegressionOptions::builder()
.rank_tolerance(1e-8)
.build_unchecked();
assert!((opts.rank_tolerance - 1e-8).abs() < 1e-14);
}
#[test]
fn test_validation_invalid_confidence_level_zero() {
let result = RegressionOptions::builder().confidence_level(0.0).build();
assert!(matches!(
result,
Err(OptionsError::InvalidConfidenceLevel(_))
));
}
#[test]
fn test_validation_invalid_confidence_level_one() {
let result = RegressionOptions::builder().confidence_level(1.0).build();
assert!(matches!(
result,
Err(OptionsError::InvalidConfidenceLevel(_))
));
}
#[test]
fn test_validation_invalid_forgetting_factor_zero() {
let result = RegressionOptions::builder().forgetting_factor(0.0).build();
assert!(matches!(
result,
Err(OptionsError::InvalidForgettingFactor(_))
));
}
#[test]
fn test_validation_invalid_forgetting_factor_over_one() {
let result = RegressionOptions::builder().forgetting_factor(1.5).build();
assert!(matches!(
result,
Err(OptionsError::InvalidForgettingFactor(_))
));
}
#[test]
fn test_validation_invalid_tolerance() {
let result = RegressionOptions::builder().tolerance(0.0).build();
assert!(matches!(result, Err(OptionsError::InvalidTolerance(_))));
}
#[test]
fn test_validation_invalid_max_iterations() {
let result = RegressionOptions::builder().max_iterations(0).build();
assert!(matches!(result, Err(OptionsError::InvalidMaxIterations(_))));
}
#[test]
fn test_solver_type_default() {
let solver = SolverType::default();
assert_eq!(solver, SolverType::Qr);
}
#[test]
fn test_lambda_scaling_default() {
let scaling = LambdaScaling::default();
assert_eq!(scaling, LambdaScaling::Raw);
}
}