#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ElasticNet {
alpha: f32,
l1_ratio: f32,
coefficients: Option<Vector<f32>>,
intercept: f32,
fit_intercept: bool,
max_iter: usize,
tol: f32,
}
impl ElasticNet {
#[must_use]
pub fn new(alpha: f32, l1_ratio: f32) -> Self {
Self {
alpha,
l1_ratio: l1_ratio.clamp(0.0, 1.0),
coefficients: None,
intercept: 0.0,
fit_intercept: true,
max_iter: 1000,
tol: 1e-4,
}
}
#[must_use]
pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
self.fit_intercept = fit_intercept;
self
}
#[must_use]
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
#[must_use]
pub fn with_tol(mut self, tol: f32) -> Self {
self.tol = tol;
self
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
#[must_use]
pub fn l1_ratio(&self) -> f32 {
self.l1_ratio
}
#[must_use]
pub fn coefficients(&self) -> &Vector<f32> {
self.coefficients
.as_ref()
.expect("Model not fitted. Call fit() first.")
}
#[must_use]
pub fn intercept(&self) -> f32 {
self.intercept
}
#[must_use]
pub fn is_fitted(&self) -> bool {
self.coefficients.is_some()
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
let model =
bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
Ok(model)
}
pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
use crate::serialization::safetensors;
use std::collections::BTreeMap;
let coefficients = self
.coefficients
.as_ref()
.ok_or("Cannot save unfitted model. Call fit() first.")?;
let mut tensors = BTreeMap::new();
let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
let coef_shape = vec![coefficients.len()];
tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
let intercept_data = vec![self.intercept];
let intercept_shape = vec![1];
tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
let alpha_data = vec![self.alpha];
let alpha_shape = vec![1];
tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
let l1_ratio_data = vec![self.l1_ratio];
let l1_ratio_shape = vec![1];
tensors.insert("l1_ratio".to_string(), (l1_ratio_data, l1_ratio_shape));
let max_iter_data = vec![self.max_iter as f32];
let max_iter_shape = vec![1];
tensors.insert("max_iter".to_string(), (max_iter_data, max_iter_shape));
let tol_data = vec![self.tol];
let tol_shape = vec![1];
tensors.insert("tol".to_string(), (tol_data, tol_shape));
safetensors::save_safetensors(path, &tensors)?;
Ok(())
}
pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
use crate::serialization::safetensors;
let (metadata, raw_data) = safetensors::load_safetensors(path)?;
let coef_meta = metadata
.get("coefficients")
.ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
let intercept_meta = metadata
.get("intercept")
.ok_or("Missing 'intercept' tensor in SafeTensors file")?;
let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
let alpha_meta = metadata
.get("alpha")
.ok_or("Missing 'alpha' tensor in SafeTensors file")?;
let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
let l1_ratio_meta = metadata
.get("l1_ratio")
.ok_or("Missing 'l1_ratio' tensor in SafeTensors file")?;
let l1_ratio_data = safetensors::extract_tensor(&raw_data, l1_ratio_meta)?;
let max_iter_meta = metadata
.get("max_iter")
.ok_or("Missing 'max_iter' tensor in SafeTensors file")?;
let max_iter_data = safetensors::extract_tensor(&raw_data, max_iter_meta)?;
let tol_meta = metadata
.get("tol")
.ok_or("Missing 'tol' tensor in SafeTensors file")?;
let tol_data = safetensors::extract_tensor(&raw_data, tol_meta)?;
if intercept_data.len() != 1 {
return Err(format!(
"Expected intercept tensor to have 1 element, got {}",
intercept_data.len()
));
}
if alpha_data.len() != 1 {
return Err(format!(
"Expected alpha tensor to have 1 element, got {}",
alpha_data.len()
));
}
if l1_ratio_data.len() != 1 {
return Err(format!(
"Expected l1_ratio tensor to have 1 element, got {}",
l1_ratio_data.len()
));
}
if max_iter_data.len() != 1 {
return Err(format!(
"Expected max_iter tensor to have 1 element, got {}",
max_iter_data.len()
));
}
if tol_data.len() != 1 {
return Err(format!(
"Expected tol tensor to have 1 element, got {}",
tol_data.len()
));
}
Ok(Self {
alpha: alpha_data[0],
l1_ratio: l1_ratio_data[0],
coefficients: Some(Vector::from_vec(coef_data)),
intercept: intercept_data[0],
fit_intercept: true, max_iter: max_iter_data[0] as usize,
tol: tol_data[0],
})
}
}