use super::features::{compute_mean, polynomial_features};
use crate::error::{Error, Result};
use crate::linalg::Matrix;
use crate::regularized::elastic_net::{elastic_net_fit, ElasticNetFit, ElasticNetOptions};
use crate::regularized::lasso::{lasso_fit, LassoFit, LassoFitOptions};
use crate::regularized::ridge::{ridge_fit, RidgeFit, RidgeFitOptions};
fn build_poly_matrix(
x: &[f64],
degree: usize,
center: bool,
x_mean: f64,
) -> Result<Matrix> {
let n = x.len();
let cols = degree + 1;
let higher_features = polynomial_features(x, degree, center, x_mean)?;
let mut data = vec![0.0f64; n * cols];
for row in 0..n {
data[row * cols] = 1.0;
let xi_c = if center { x[row] - x_mean } else { x[row] };
data[row * cols + 1] = xi_c;
for (col_idx, feature) in higher_features.iter().enumerate() {
data[row * cols + col_idx + 2] = feature[row];
}
}
Ok(Matrix::new(n, cols, data))
}
pub fn polynomial_ridge(
y: &[f64],
x: &[f64],
degree: usize,
lambda: f64,
center: bool,
standardize: bool,
) -> Result<RidgeFit> {
if degree < 1 {
return Err(Error::InvalidInput(
"Polynomial degree must be at least 1".into(),
));
}
if y.len() != x.len() {
return Err(Error::DimensionMismatch(format!(
"Length of y ({}) must match length of x ({})",
y.len(),
x.len()
)));
}
let x_mean = if center { compute_mean(x) } else { 0.0 };
let x_matrix = build_poly_matrix(x, degree, center, x_mean)?;
let options = RidgeFitOptions {
lambda,
intercept: true,
standardize,
..Default::default()
};
ridge_fit(&x_matrix, y, &options)
}
pub fn polynomial_lasso(
y: &[f64],
x: &[f64],
degree: usize,
lambda: f64,
center: bool,
standardize: bool,
) -> Result<LassoFit> {
if degree < 1 {
return Err(Error::InvalidInput(
"Polynomial degree must be at least 1".into(),
));
}
if y.len() != x.len() {
return Err(Error::DimensionMismatch(format!(
"Length of y ({}) must match length of x ({})",
y.len(),
x.len()
)));
}
let x_mean = if center { compute_mean(x) } else { 0.0 };
let x_matrix = build_poly_matrix(x, degree, center, x_mean)?;
let options = LassoFitOptions {
lambda,
intercept: true,
standardize,
..Default::default()
};
lasso_fit(&x_matrix, y, &options)
}
pub fn polynomial_elastic_net(
y: &[f64],
x: &[f64],
degree: usize,
lambda: f64,
alpha: f64,
center: bool,
standardize: bool,
) -> Result<ElasticNetFit> {
if degree < 1 {
return Err(Error::InvalidInput(
"Polynomial degree must be at least 1".into(),
));
}
if y.len() != x.len() {
return Err(Error::DimensionMismatch(format!(
"Length of y ({}) must match length of x ({})",
y.len(),
x.len()
)));
}
let x_mean = if center { compute_mean(x) } else { 0.0 };
let x_matrix = build_poly_matrix(x, degree, center, x_mean)?;
let options = ElasticNetOptions {
lambda,
alpha,
intercept: true,
standardize,
..Default::default()
};
elastic_net_fit(&x_matrix, y, &options)
}