use std::fmt;
use crate::learner::StreamingLearner;
#[non_exhaustive]
pub enum Regularization {
None,
Ridge(f64),
Lasso(f64),
ElasticNet {
l1: f64,
l2: f64,
},
}
impl Clone for Regularization {
fn clone(&self) -> Self {
match self {
Self::None => Self::None,
Self::Ridge(lambda) => Self::Ridge(*lambda),
Self::Lasso(lambda) => Self::Lasso(*lambda),
Self::ElasticNet { l1, l2 } => Self::ElasticNet { l1: *l1, l2: *l2 },
}
}
}
impl fmt::Debug for Regularization {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "Regularization::None"),
Self::Ridge(lambda) => write!(f, "Regularization::Ridge({})", lambda),
Self::Lasso(lambda) => write!(f, "Regularization::Lasso({})", lambda),
Self::ElasticNet { l1, l2 } => {
write!(f, "Regularization::ElasticNet {{ l1: {}, l2: {} }}", l1, l2)
}
}
}
}
pub struct StreamingLinearModel {
weights: Vec<f64>,
bias: f64,
learning_rate: f64,
regularization: Regularization,
n_features: Option<usize>,
samples_seen: u64,
}
impl StreamingLinearModel {
#[inline]
pub fn new(learning_rate: f64) -> Self {
Self {
weights: Vec::new(),
bias: 0.0,
learning_rate,
regularization: Regularization::None,
n_features: None,
samples_seen: 0,
}
}
#[inline]
pub fn ridge(learning_rate: f64, lambda: f64) -> Self {
Self {
weights: Vec::new(),
bias: 0.0,
learning_rate,
regularization: Regularization::Ridge(lambda),
n_features: None,
samples_seen: 0,
}
}
#[inline]
pub fn lasso(learning_rate: f64, lambda: f64) -> Self {
Self {
weights: Vec::new(),
bias: 0.0,
learning_rate,
regularization: Regularization::Lasso(lambda),
n_features: None,
samples_seen: 0,
}
}
#[inline]
pub fn elastic_net(learning_rate: f64, l1: f64, l2: f64) -> Self {
Self {
weights: Vec::new(),
bias: 0.0,
learning_rate,
regularization: Regularization::ElasticNet { l1, l2 },
n_features: None,
samples_seen: 0,
}
}
#[inline]
pub fn weights(&self) -> &[f64] {
&self.weights
}
#[inline]
pub fn bias(&self) -> f64 {
self.bias
}
#[inline]
fn lazy_init(&mut self, n_features: usize) {
if self.n_features.is_none() {
self.n_features = Some(n_features);
self.weights = vec![0.0; n_features];
}
}
#[inline]
fn apply_l1_proximal(&mut self, threshold: f64) {
for w in &mut self.weights {
let abs_w = w.abs();
if abs_w <= threshold {
*w = 0.0;
} else {
*w = w.signum() * (abs_w - threshold);
}
}
}
#[inline]
fn apply_l2_decay(&mut self, factor: f64) {
let scale = 1.0 - factor;
for w in &mut self.weights {
*w *= scale;
}
}
}
impl StreamingLearner for StreamingLinearModel {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
self.lazy_init(features.len());
let pred: f64 = self
.weights
.iter()
.zip(features.iter())
.map(|(w, x)| w * x)
.sum::<f64>()
+ self.bias;
let error = pred - target;
let lr = self.learning_rate;
match &self.regularization {
Regularization::Ridge(lambda) => self.apply_l2_decay(lr * lambda),
Regularization::ElasticNet { l2, .. } => self.apply_l2_decay(lr * l2),
_ => {}
}
let grad_scale = 2.0 * error * weight * lr;
for (w, x) in self.weights.iter_mut().zip(features.iter()) {
*w -= grad_scale * x;
}
self.bias -= grad_scale;
match &self.regularization {
Regularization::Lasso(lambda) => self.apply_l1_proximal(lr * lambda),
Regularization::ElasticNet { l1, .. } => self.apply_l1_proximal(lr * l1),
_ => {}
}
self.samples_seen += 1;
}
#[inline]
fn predict(&self, features: &[f64]) -> f64 {
self.weights
.iter()
.zip(features.iter())
.map(|(w, x)| w * x)
.sum::<f64>()
+ self.bias
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
self.weights.clear();
self.bias = 0.0;
self.n_features = None;
self.samples_seen = 0;
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as crate::learner::Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
}
impl crate::learner::Tunable for StreamingLinearModel {
fn diagnostics_array(&self) -> [f64; 5] {
[
0.0,
0.0,
0.0,
self.weights.len() as f64,
1.0 / (1.0 + self.samples_seen as f64),
]
}
fn adjust_config(&mut self, lr_multiplier: f64, _lambda_delta: f64) {
self.learning_rate = (self.learning_rate * lr_multiplier).max(1e-12);
}
}
impl Clone for StreamingLinearModel {
fn clone(&self) -> Self {
Self {
weights: self.weights.clone(),
bias: self.bias,
learning_rate: self.learning_rate,
regularization: self.regularization.clone(),
n_features: self.n_features,
samples_seen: self.samples_seen,
}
}
}
impl fmt::Debug for StreamingLinearModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamingLinearModel")
.field("n_features", &self.n_features)
.field("bias", &self.bias)
.field("learning_rate", &self.learning_rate)
.field("regularization", &self.regularization)
.field("samples_seen", &self.samples_seen)
.field("n_weights", &self.weights.len())
.finish()
}
}
impl crate::automl::DiagnosticSource for StreamingLinearModel {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
Some(crate::automl::ConfigDiagnostics {
effective_dof: self.weights.len() as f64,
uncertainty: 1.0 / (1.0 + self.samples_seen as f64),
..Default::default()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_creation() {
let model = StreamingLinearModel::new(0.01);
assert_eq!(model.n_samples_seen(), 0);
assert_eq!(model.bias(), 0.0);
assert!(model.weights().is_empty());
}
#[test]
fn test_lazy_init() {
let mut model = StreamingLinearModel::new(0.01);
assert!(model.n_features.is_none());
assert!(model.weights().is_empty());
model.train(&[1.0, 2.0, 3.0], 5.0);
assert_eq!(model.n_features, Some(3));
assert_eq!(model.weights().len(), 3);
assert_eq!(model.n_samples_seen(), 1);
}
#[test]
fn test_no_regularization() {
let mut model = StreamingLinearModel::new(0.001);
for i in 0..5000 {
let x1 = (i as f64 * 0.37).sin();
let x2 = (i as f64 * 0.73).cos();
let y = 2.0 * x1 + 3.0 * x2 + 1.0;
model.train(&[x1, x2], y);
}
let w = model.weights();
assert!(
(w[0] - 2.0).abs() < 0.3,
"w[0] should be near 2.0, got {}",
w[0],
);
assert!(
(w[1] - 3.0).abs() < 0.3,
"w[1] should be near 3.0, got {}",
w[1],
);
assert!(
(model.bias() - 1.0).abs() < 0.3,
"bias should be near 1.0, got {}",
model.bias(),
);
}
#[test]
fn test_ridge() {
let mut plain = StreamingLinearModel::new(0.001);
let mut ridge = StreamingLinearModel::ridge(0.001, 0.1);
for i in 0..2000 {
let x1 = (i as f64 * 0.37).sin();
let x2 = (i as f64 * 0.73).cos();
let y = 2.0 * x1 + 3.0 * x2 + 1.0;
plain.train(&[x1, x2], y);
ridge.train(&[x1, x2], y);
}
let plain_l2: f64 = plain.weights().iter().map(|w| w * w).sum();
let ridge_l2: f64 = ridge.weights().iter().map(|w| w * w).sum();
assert!(
ridge_l2 < plain_l2,
"Ridge L2 norm ({}) should be less than plain ({})",
ridge_l2,
plain_l2,
);
}
#[test]
fn test_lasso() {
let mut model = StreamingLinearModel::lasso(0.001, 0.01);
for i in 0..5000 {
let x1 = (i as f64 * 0.31).sin();
let x2 = (i as f64 * 0.97).cos() * 0.01; let x3 = (i as f64 * 0.53).sin() * 0.01; let y = 3.0 * x1;
model.train(&[x1, x2, x3], y);
}
let w = model.weights();
assert!(
w[1].abs() < 0.05,
"w[1] should be near zero (sparse), got {}",
w[1],
);
assert!(
w[2].abs() < 0.05,
"w[2] should be near zero (sparse), got {}",
w[2],
);
assert!(w[0].abs() > 0.5, "w[0] should be non-trivial, got {}", w[0],);
}
#[test]
fn test_elastic_net() {
let mut ridge_only = StreamingLinearModel::ridge(0.001, 0.05);
let mut elastic = StreamingLinearModel::elastic_net(0.001, 0.025, 0.025);
for i in 0..3000 {
let x1 = (i as f64 * 0.37).sin();
let x2 = (i as f64 * 0.73).cos();
let y = 2.0 * x1 + 3.0 * x2;
ridge_only.train(&[x1, x2], y);
elastic.train(&[x1, x2], y);
}
let ew = elastic.weights();
assert!(ew[0].is_finite() && ew[0].abs() > 0.01);
assert!(ew[1].is_finite() && ew[1].abs() > 0.01);
let ridge_l2: f64 = ridge_only.weights().iter().map(|w| w * w).sum();
let elastic_l2: f64 = ew.iter().map(|w| w * w).sum();
assert!(
(ridge_l2 - elastic_l2).abs() > 1e-6,
"ElasticNet and Ridge should produce different weight norms",
);
}
#[test]
fn test_predict_batch() {
let mut model = StreamingLinearModel::new(0.01);
for i in 0..100 {
let x = i as f64 * 0.1;
model.train(&[x, x * 2.0], x * 3.0);
}
let rows: Vec<&[f64]> = vec![&[1.0, 2.0], &[3.0, 6.0], &[0.5, 1.0]];
let batch = model.predict_batch(&rows);
assert_eq!(batch.len(), rows.len());
for (i, row) in rows.iter().enumerate() {
let individual = model.predict(row);
assert!(
(batch[i] - individual).abs() < 1e-12,
"batch[{}]={} != individual={}",
i,
batch[i],
individual,
);
}
}
#[test]
fn test_reset() {
let mut model = StreamingLinearModel::new(0.01);
for i in 0..50 {
model.train(&[i as f64, (i as f64) * 2.0], i as f64);
}
assert_eq!(model.n_samples_seen(), 50);
assert!(!model.weights().is_empty());
model.reset();
assert_eq!(model.n_samples_seen(), 0);
assert!(model.weights().is_empty());
assert_eq!(model.bias(), 0.0);
let pred = model.predict(&[1.0, 2.0]);
assert!(
pred.abs() < 1e-12,
"prediction after reset should be zero, got {}",
pred,
);
}
#[test]
fn test_weighted_samples() {
let mut model_w1 = StreamingLinearModel::new(0.1);
let mut model_w2 = StreamingLinearModel::new(0.1);
let features = [1.0, 0.0];
let target = 10.0;
model_w1.train_one(&features, target, 1.0);
model_w2.train_one(&features, target, 2.0);
let shift_w1 = model_w1.bias().abs();
let shift_w2 = model_w2.bias().abs();
assert!(
shift_w2 > shift_w1,
"weight=2 shift ({}) should exceed weight=1 shift ({})",
shift_w2,
shift_w1,
);
assert!(
(shift_w2 - 2.0 * shift_w1).abs() < 1e-12,
"weight=2 shift should be exactly 2x weight=1: {} vs {}",
shift_w2,
shift_w1,
);
}
#[test]
fn test_trait_object() {
let model = StreamingLinearModel::ridge(0.01, 0.001);
let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
boxed.train(&[1.0, 2.0], 5.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[1.0, 2.0]);
assert!(pred.is_finite());
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
}