use std::fmt;
pub trait OnlineLearner {
fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError>;
fn predict(&self, features: &[f64]) -> Result<f64, OnlineError>;
fn n_updates(&self) -> usize;
fn weights(&self) -> &[f64];
}
#[derive(Debug, Clone)]
pub struct OnlineUpdateResult {
pub loss: f64,
pub weight_delta_norm: f64,
pub was_mistake: bool,
}
#[derive(Debug)]
pub enum OnlineError {
DimensionMismatch { expected: usize, got: usize },
InvalidHyperparameter(String),
NotFitted,
}
impl fmt::Display for OnlineError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OnlineError::DimensionMismatch { expected, got } => write!(
f,
"dimension mismatch: expected {expected} features, got {got}"
),
OnlineError::InvalidHyperparameter(msg) => {
write!(f, "invalid hyperparameter: {msg}")
}
OnlineError::NotFitted => write!(f, "model has not been fitted yet"),
}
}
}
impl std::error::Error for OnlineError {}
#[derive(Debug, Clone, Default)]
pub struct OnlineStats {
pub n_updates: usize,
pub n_mistakes: usize,
pub cumulative_loss: f64,
pub mean_loss: f64,
pub last_weight_norm: f64,
}
impl OnlineStats {
pub fn mistake_rate(&self) -> f64 {
if self.n_updates == 0 {
0.0
} else {
self.n_mistakes as f64 / self.n_updates as f64
}
}
pub fn update(&mut self, result: &OnlineUpdateResult) {
self.n_updates += 1;
if result.was_mistake {
self.n_mistakes += 1;
}
self.cumulative_loss += result.loss;
self.mean_loss = self.cumulative_loss / self.n_updates as f64;
}
}
#[inline]
fn l2_norm_sq(v: &[f64]) -> f64 {
v.iter().map(|x| x * x).sum()
}
#[inline]
fn l2_norm(v: &[f64]) -> f64 {
l2_norm_sq(v).sqrt()
}
#[inline]
fn dot(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
#[inline]
fn sign(x: f64) -> f64 {
if x > 0.0 {
1.0
} else if x < 0.0 {
-1.0
} else {
0.0
}
}
#[derive(Debug, Clone)]
pub struct Perceptron {
weights: Vec<f64>,
bias: f64,
n_updates: usize,
stats: OnlineStats,
learning_rate: f64,
}
impl Perceptron {
pub fn new(n_features: usize) -> Self {
Self {
weights: vec![0.0; n_features],
bias: 0.0,
n_updates: 0,
stats: OnlineStats::default(),
learning_rate: 1.0,
}
}
pub fn with_learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
pub fn bias(&self) -> f64 {
self.bias
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
fn score(&self, features: &[f64]) -> f64 {
dot(&self.weights, features) + self.bias
}
}
impl OnlineLearner for Perceptron {
fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
let score = self.score(features);
let predicted_sign = sign(score);
let true_sign = sign(label);
let margin = true_sign * score;
let loss = if margin <= 0.0 { -margin } else { 0.0 };
let was_mistake = predicted_sign != true_sign;
let mut delta_sq = 0.0_f64;
if was_mistake {
let eta_y = self.learning_rate * true_sign;
for (w, x) in self.weights.iter_mut().zip(features.iter()) {
let delta = eta_y * x;
delta_sq += delta * delta;
*w += delta;
}
let bias_delta = self.learning_rate * true_sign;
delta_sq += bias_delta * bias_delta;
self.bias += bias_delta;
}
self.n_updates += 1;
let weight_delta_norm = delta_sq.sqrt();
let result = OnlineUpdateResult {
loss,
weight_delta_norm,
was_mistake,
};
self.stats.update(&result);
self.stats.last_weight_norm = l2_norm(&self.weights);
Ok(result)
}
fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
Ok(sign(self.score(features)))
}
fn n_updates(&self) -> usize {
self.n_updates
}
fn weights(&self) -> &[f64] {
&self.weights
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PAVariant {
PA,
PAI,
PAII,
}
#[derive(Debug, Clone)]
pub struct PassiveAggressive {
weights: Vec<f64>,
bias: f64,
n_updates: usize,
stats: OnlineStats,
aggressiveness: f64,
variant: PAVariant,
}
impl PassiveAggressive {
pub fn new(n_features: usize, variant: PAVariant) -> Self {
Self {
weights: vec![0.0; n_features],
bias: 0.0,
n_updates: 0,
stats: OnlineStats::default(),
aggressiveness: 1.0,
variant,
}
}
pub fn with_aggressiveness(mut self, c: f64) -> Result<Self, OnlineError> {
if c <= 0.0 {
return Err(OnlineError::InvalidHyperparameter(format!(
"aggressiveness C must be > 0, got {c}"
)));
}
self.aggressiveness = c;
Ok(self)
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
fn compute_tau(&self, loss: f64, x_norm_sq: f64) -> f64 {
match self.variant {
PAVariant::PA => {
if x_norm_sq == 0.0 {
0.0
} else {
loss / x_norm_sq
}
}
PAVariant::PAI => {
let tau_unconstrained = if x_norm_sq == 0.0 {
0.0
} else {
loss / x_norm_sq
};
tau_unconstrained.min(self.aggressiveness)
}
PAVariant::PAII => {
let denom = x_norm_sq + 1.0 / (2.0 * self.aggressiveness);
if denom == 0.0 {
0.0
} else {
loss / denom
}
}
}
}
}
impl OnlineLearner for PassiveAggressive {
fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
let score = dot(&self.weights, features) + self.bias;
let y = sign(label);
let margin = y * score;
let loss = (1.0 - margin).max(0.0);
let was_mistake = sign(score) != y;
let x_norm_sq = l2_norm_sq(features);
let tau = self.compute_tau(loss, x_norm_sq);
let mut delta_sq = 0.0_f64;
if tau > 0.0 {
let tau_y = tau * y;
for (w, x) in self.weights.iter_mut().zip(features.iter()) {
let delta = tau_y * x;
delta_sq += delta * delta;
*w += delta;
}
let bias_delta = tau * y;
delta_sq += bias_delta * bias_delta;
self.bias += bias_delta;
}
self.n_updates += 1;
let result = OnlineUpdateResult {
loss,
weight_delta_norm: delta_sq.sqrt(),
was_mistake,
};
self.stats.update(&result);
self.stats.last_weight_norm = l2_norm(&self.weights);
Ok(result)
}
fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
Ok(sign(dot(&self.weights, features) + self.bias))
}
fn n_updates(&self) -> usize {
self.n_updates
}
fn weights(&self) -> &[f64] {
&self.weights
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OGDLoss {
Squared,
Hinge,
Logistic,
}
#[derive(Debug, Clone)]
pub struct OnlineGradientDescent {
weights: Vec<f64>,
bias: f64,
n_updates: usize,
stats: OnlineStats,
initial_lr: f64,
lr_decay: f64,
l2_reg: f64,
loss: OGDLoss,
}
impl OnlineGradientDescent {
pub fn new(n_features: usize, loss: OGDLoss) -> Self {
Self {
weights: vec![0.0; n_features],
bias: 0.0,
n_updates: 0,
stats: OnlineStats::default(),
initial_lr: 0.1,
lr_decay: 0.0,
l2_reg: 0.0,
loss,
}
}
pub fn with_lr(mut self, lr: f64) -> Self {
self.initial_lr = lr;
self
}
pub fn with_l2(mut self, lambda: f64) -> Self {
self.l2_reg = lambda;
self
}
pub fn with_lr_decay(mut self, decay: f64) -> Self {
self.lr_decay = decay;
self
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
fn current_lr(&self) -> f64 {
if self.lr_decay > 0.0 {
self.initial_lr / ((self.n_updates as f64 + 1.0).sqrt())
} else {
self.initial_lr
}
}
fn compute_loss_and_grad(&self, features: &[f64], label: f64) -> (f64, f64, f64) {
let score = dot(&self.weights, features) + self.bias;
match self.loss {
OGDLoss::Squared => {
let diff = score - label;
let loss = 0.5 * diff * diff;
(loss, diff, diff)
}
OGDLoss::Hinge => {
let y = sign(label);
let margin = y * score;
if margin < 1.0 {
let loss = 1.0 - margin;
(loss, -y, -y)
} else {
(0.0, 0.0, 0.0)
}
}
OGDLoss::Logistic => {
let y = sign(label);
let ys = y * score;
let sigma_neg = 1.0 / (1.0 + ys.exp()); let loss = (1.0 + (-ys).exp()).ln();
let grad_coeff = -y * sigma_neg;
(loss, grad_coeff, grad_coeff)
}
}
}
}
impl OnlineLearner for OnlineGradientDescent {
fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
let (loss, grad_coeff, bias_grad) = self.compute_loss_and_grad(features, label);
let eta = self.current_lr();
let was_mistake = match self.loss {
OGDLoss::Squared => false, OGDLoss::Hinge | OGDLoss::Logistic => {
let score = dot(&self.weights, features) + self.bias;
sign(score) != sign(label)
}
};
let mut delta_sq = 0.0_f64;
for (w, x) in self.weights.iter_mut().zip(features.iter()) {
let grad = grad_coeff * x + self.l2_reg * (*w);
let delta = -eta * grad;
delta_sq += delta * delta;
*w += delta;
}
let bias_delta = -eta * bias_grad;
delta_sq += bias_delta * bias_delta;
self.bias += bias_delta;
self.n_updates += 1;
let result = OnlineUpdateResult {
loss,
weight_delta_norm: delta_sq.sqrt(),
was_mistake,
};
self.stats.update(&result);
self.stats.last_weight_norm = l2_norm(&self.weights);
Ok(result)
}
fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
let score = dot(&self.weights, features) + self.bias;
let prediction = match self.loss {
OGDLoss::Squared => score,
OGDLoss::Hinge | OGDLoss::Logistic => sign(score),
};
Ok(prediction)
}
fn n_updates(&self) -> usize {
self.n_updates
}
fn weights(&self) -> &[f64] {
&self.weights
}
}
#[derive(Debug, Clone)]
pub struct Ftrl {
weights: Vec<f64>,
z: Vec<f64>,
n_vec: Vec<f64>,
n_updates: usize,
stats: OnlineStats,
alpha: f64,
beta: f64,
l1: f64,
l2: f64,
}
impl Ftrl {
pub fn new(n_features: usize) -> Self {
Self {
weights: vec![0.0; n_features],
z: vec![0.0; n_features],
n_vec: vec![0.0; n_features],
n_updates: 0,
stats: OnlineStats::default(),
alpha: 0.1,
beta: 1.0,
l1: 0.0,
l2: 0.0,
}
}
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
pub fn with_l1_l2(mut self, l1: f64, l2: f64) -> Self {
self.l1 = l1;
self.l2 = l2;
self
}
pub fn stats(&self) -> &OnlineStats {
&self.stats
}
#[inline]
fn compute_weight(&self, i: usize) -> f64 {
let z_i = self.z[i];
let n_i = self.n_vec[i];
if z_i.abs() <= self.l1 {
0.0
} else {
let numerator = -(z_i - sign(z_i) * self.l1);
let denominator = (self.beta + n_i.sqrt()) / self.alpha + self.l2;
if denominator == 0.0 {
0.0
} else {
numerator / denominator
}
}
}
fn score(&self, features: &[f64]) -> f64 {
features
.iter()
.enumerate()
.map(|(i, x)| self.compute_weight(i) * x)
.sum::<f64>()
}
#[inline]
fn sigmoid(s: f64) -> f64 {
1.0 / (1.0 + (-s).exp())
}
}
impl OnlineLearner for Ftrl {
fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
for i in 0..n {
self.weights[i] = self.compute_weight(i);
}
let score = dot(&self.weights, features);
let p = Self::sigmoid(score);
let y_01 = if label > 0.0 { 1.0_f64 } else { 0.0_f64 };
let grad_scale = p - y_01;
let loss = if y_01 > 0.0 {
-p.ln().max(-1e15)
} else {
-(1.0 - p).ln().max(-1e15)
};
let was_mistake = sign(score) != sign(label - 0.5);
let old_weights = self.weights.clone();
for (i, &feat_i) in features.iter().enumerate().take(n) {
let g_i = grad_scale * feat_i;
let n_i_old = self.n_vec[i];
let n_i_new = n_i_old + g_i * g_i;
let sigma_i = (n_i_new.sqrt() - n_i_old.sqrt()) / self.alpha;
self.z[i] += g_i - sigma_i * self.weights[i];
self.n_vec[i] = n_i_new;
self.weights[i] = self.compute_weight(i);
}
let delta_norm = {
let sq: f64 = self
.weights
.iter()
.zip(old_weights.iter())
.map(|(w_new, w_old)| {
let d = w_new - w_old;
d * d
})
.sum();
sq.sqrt()
};
self.n_updates += 1;
let result = OnlineUpdateResult {
loss,
weight_delta_norm: delta_norm,
was_mistake,
};
self.stats.update(&result);
self.stats.last_weight_norm = l2_norm(&self.weights);
Ok(result)
}
fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
let n = self.weights.len();
if features.len() != n {
return Err(OnlineError::DimensionMismatch {
expected: n,
got: features.len(),
});
}
let score = self.score(features);
Ok(sign(score))
}
fn n_updates(&self) -> usize {
self.n_updates
}
fn weights(&self) -> &[f64] {
&self.weights
}
}
pub fn online_evaluate(
learner: &mut dyn OnlineLearner,
data: &[(Vec<f64>, f64)],
train: bool,
) -> Result<(Vec<f64>, OnlineStats), OnlineError> {
let mut predictions = Vec::with_capacity(data.len());
let mut stats = OnlineStats::default();
for (features, label) in data {
let pred = learner.predict(features)?;
predictions.push(pred);
if train {
let result = learner.update(features, *label)?;
stats.update(&result);
} else {
let was_mistake = sign(pred) != sign(*label);
let pseudo_result = OnlineUpdateResult {
loss: 0.0,
weight_delta_norm: 0.0,
was_mistake,
};
stats.update(&pseudo_result);
}
}
Ok((predictions, stats))
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_perceptron_zero_init() {
let p = Perceptron::new(4);
assert_eq!(p.weights(), &[0.0_f64; 4]);
assert_eq!(p.bias(), 0.0);
assert_eq!(p.n_updates(), 0);
}
#[test]
fn test_perceptron_update_on_mistake_positive() {
let mut p = Perceptron::new(2).with_learning_rate(1.0);
let x = vec![1.0, 0.5];
let result = p.update(&x, 1.0).expect("update failed");
assert!(result.was_mistake);
assert!(approx_eq(p.weights()[0], 1.0, 1e-10));
assert!(approx_eq(p.weights()[1], 0.5, 1e-10));
assert!(approx_eq(p.bias(), 1.0, 1e-10));
}
#[test]
fn test_perceptron_no_update_on_correct() {
let mut p = Perceptron::new(2);
let x = vec![1.0, 0.0];
p.update(&x, 1.0).expect("update");
let w_after_first = p.weights().to_vec();
p.update(&x, 1.0).expect("update");
assert_eq!(p.weights(), w_after_first.as_slice());
}
#[test]
fn test_perceptron_linearly_separable_2d() {
let data: Vec<(Vec<f64>, f64)> = vec![
(vec![1.0, 0.2], 1.0),
(vec![-1.0, 0.3], -1.0),
(vec![2.0, -0.5], 1.0),
(vec![-2.0, 0.1], -1.0),
(vec![0.5, 0.5], 1.0),
(vec![-0.5, -0.5], -1.0),
(vec![1.5, -0.1], 1.0),
(vec![-1.5, 0.4], -1.0),
(vec![0.8, 0.0], 1.0),
(vec![-0.8, 0.2], -1.0),
];
let mut p = Perceptron::new(2);
for _ in 0..20 {
for (x, y) in &data {
p.update(x, *y).expect("update");
}
}
for (x, y) in &data {
let pred = p.predict(x).expect("predict");
assert_eq!(pred, *y, "misclassified {:?} (label {})", x, y);
}
}
#[test]
fn test_perceptron_n_updates_increments() {
let mut p = Perceptron::new(2);
for i in 0..5 {
p.update(&[1.0, -1.0], 1.0).expect("update");
assert_eq!(p.n_updates(), i + 1);
}
}
#[test]
fn test_perceptron_dimension_mismatch() {
let mut p = Perceptron::new(3);
let err = p.update(&[1.0, 2.0], 1.0);
assert!(matches!(
err,
Err(OnlineError::DimensionMismatch {
expected: 3,
got: 2
})
));
}
#[test]
fn test_pa_tau_basic() {
let mut pa = PassiveAggressive::new(2, PAVariant::PA);
let result = pa.update(&[1.0, 0.0], 1.0).expect("update");
assert!(approx_eq(result.loss, 1.0, 1e-10));
assert!(approx_eq(pa.weights()[0], 1.0, 1e-10));
}
#[test]
fn test_pa1_tau_clamped() {
let mut pa = PassiveAggressive::new(2, PAVariant::PAI)
.with_aggressiveness(0.3)
.expect("valid C");
let _r = pa.update(&[1.0, 0.0], 1.0).expect("update");
assert!(approx_eq(pa.weights()[0], 0.3, 1e-10));
}
#[test]
fn test_pa2_tau_formula() {
let mut pa = PassiveAggressive::new(2, PAVariant::PAII)
.with_aggressiveness(1.0)
.expect("valid C");
let _r = pa.update(&[1.0, 0.0], 1.0).expect("update");
let expected_tau = 1.0 / 1.5;
assert!(
approx_eq(pa.weights()[0], expected_tau, 1e-10),
"expected {expected_tau}, got {}",
pa.weights()[0]
);
}
#[test]
fn test_pa_negative_c_returns_err() {
let res = PassiveAggressive::new(2, PAVariant::PA).with_aggressiveness(-1.0);
assert!(res.is_err());
}
#[test]
fn test_pa_dimension_mismatch() {
let mut pa = PassiveAggressive::new(3, PAVariant::PA);
let err = pa.update(&[1.0], 1.0);
assert!(matches!(
err,
Err(OnlineError::DimensionMismatch {
expected: 3,
got: 1
})
));
}
#[test]
fn test_ogd_squared_loss_gradient() {
let mut ogd = OnlineGradientDescent::new(2, OGDLoss::Squared).with_lr(0.1);
let result = ogd.update(&[2.0, 0.0], 3.0).expect("update");
assert!(approx_eq(result.loss, 4.5, 1e-10));
assert!(approx_eq(ogd.weights()[0], 0.6, 1e-10));
}
#[test]
fn test_ogd_hinge_no_update_when_margin_ok() {
let mut ogd = OnlineGradientDescent::new(2, OGDLoss::Hinge).with_lr(1.0);
for _ in 0..20 {
ogd.update(&[10.0, 0.0], 1.0).expect("update");
}
let w_before = ogd.weights().to_vec();
let result = ogd.update(&[10.0, 0.0], 1.0).expect("update");
assert_eq!(result.loss, 0.0, "expected zero hinge loss");
assert_eq!(result.weight_delta_norm, 0.0);
assert_eq!(ogd.weights(), w_before.as_slice());
}
#[test]
fn test_ogd_lr_decay_reduces_lr() {
let mut ogd_decay = OnlineGradientDescent::new(1, OGDLoss::Squared)
.with_lr(1.0)
.with_lr_decay(1.0);
let mut ogd_nodecay = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(1.0);
for _ in 0..5 {
ogd_decay.update(&[0.0], 1.0).expect("update");
ogd_nodecay.update(&[0.0], 1.0).expect("update");
}
assert!(
ogd_decay.bias.abs() <= ogd_nodecay.bias.abs() + 1e-9,
"decaying lr should not exceed constant lr convergence; decay_bias={}, nodecay_bias={}",
ogd_decay.bias,
ogd_nodecay.bias
);
let mut ogd = OnlineGradientDescent::new(1, OGDLoss::Squared)
.with_lr(1.0)
.with_lr_decay(1.0);
for _ in 0..9 {
ogd.update(&[0.0], 0.0).expect("update"); }
let lr_at_t9 = ogd.current_lr();
assert!(
lr_at_t9 < 0.5,
"lr at t=9 should be 1/√10 ≈ 0.316, got {lr_at_t9}"
);
assert!(
approx_eq(lr_at_t9, 1.0 / 10_f64.sqrt(), 1e-10),
"expected 1/√10, got {lr_at_t9}"
);
}
#[test]
fn test_ogd_l2_penalises_large_weights() {
let mut ogd_no_reg = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(0.5);
let mut ogd_l2 = OnlineGradientDescent::new(1, OGDLoss::Squared)
.with_lr(0.5)
.with_l2(0.5);
for _ in 0..30 {
ogd_no_reg.update(&[1.0], 1.0).expect("update");
ogd_l2.update(&[1.0], 1.0).expect("update");
}
assert!(
ogd_l2.weights()[0].abs() < ogd_no_reg.weights()[0].abs(),
"l2 reg should shrink weights; no_reg={}, l2={}",
ogd_no_reg.weights()[0],
ogd_l2.weights()[0]
);
}
#[test]
fn test_ogd_dimension_mismatch() {
let mut ogd = OnlineGradientDescent::new(3, OGDLoss::Squared);
let err = ogd.update(&[1.0, 2.0], 0.0);
assert!(matches!(
err,
Err(OnlineError::DimensionMismatch {
expected: 3,
got: 2
})
));
}
#[test]
fn test_ftrl_l1_sparsity() {
let mut ftrl = Ftrl::new(2).with_alpha(0.1).with_l1_l2(10.0, 0.0);
ftrl.update(&[1.0, 0.0], 1.0).expect("update");
assert_eq!(ftrl.weights()[0], 0.0, "weight should be zero due to L1");
}
#[test]
fn test_ftrl_adaptive_per_feature() {
let mut ftrl = Ftrl::new(2).with_alpha(0.1);
for _ in 0..50 {
ftrl.update(&[1.0, 0.0], 1.0).expect("update");
}
assert!(ftrl.n_vec[0] > ftrl.n_vec[1]);
}
#[test]
fn test_ftrl_l1_zero_l2_zero_adagrad_like() {
let mut ftrl = Ftrl::new(1).with_alpha(1.0).with_l1_l2(0.0, 0.0);
for _ in 0..10 {
ftrl.update(&[1.0], 1.0).expect("update");
}
assert!(
ftrl.weights()[0] > 0.0,
"weight should be positive; got {}",
ftrl.weights()[0]
);
}
#[test]
fn test_ftrl_dimension_mismatch() {
let mut ftrl = Ftrl::new(3);
let err = ftrl.update(&[1.0, 2.0], 1.0);
assert!(matches!(
err,
Err(OnlineError::DimensionMismatch {
expected: 3,
got: 2
})
));
}
#[test]
fn test_ftrl_predict_dimension_mismatch() {
let ftrl = Ftrl::new(3);
let err = ftrl.predict(&[1.0]);
assert!(matches!(
err,
Err(OnlineError::DimensionMismatch {
expected: 3,
got: 1
})
));
}
#[test]
fn test_online_stats_mistake_rate_zero_updates() {
let stats = OnlineStats::default();
assert_eq!(stats.mistake_rate(), 0.0);
}
#[test]
fn test_online_stats_mistake_rate_computation() {
let mut stats = OnlineStats::default();
let mistake = OnlineUpdateResult {
loss: 1.0,
weight_delta_norm: 0.5,
was_mistake: true,
};
let correct = OnlineUpdateResult {
loss: 0.0,
weight_delta_norm: 0.0,
was_mistake: false,
};
stats.update(&mistake);
stats.update(&correct);
stats.update(&mistake);
assert!(approx_eq(stats.mistake_rate(), 2.0 / 3.0, 1e-10));
}
#[test]
fn test_online_stats_cumulative_loss() {
let mut stats = OnlineStats::default();
for loss_val in [0.5, 1.0, 1.5] {
let r = OnlineUpdateResult {
loss: loss_val,
weight_delta_norm: 0.0,
was_mistake: false,
};
stats.update(&r);
}
assert!(approx_eq(stats.cumulative_loss, 3.0, 1e-10));
assert!(approx_eq(stats.mean_loss, 1.0, 1e-10));
}
#[test]
fn test_online_evaluate_train_true_updates_model() {
let mut p = Perceptron::new(2);
let data = vec![(vec![1.0, 0.0], 1.0), (vec![-1.0, 0.0], -1.0)];
let (preds, _stats) = online_evaluate(&mut p, &data, true).expect("evaluate");
assert_eq!(preds.len(), 2);
assert_eq!(p.n_updates(), 2);
}
#[test]
fn test_online_evaluate_train_false_no_update() {
let mut p = Perceptron::new(2);
let data = vec![(vec![1.0, 0.0], 1.0), (vec![-1.0, 0.0], -1.0)];
let (preds, _stats) = online_evaluate(&mut p, &data, false).expect("evaluate");
assert_eq!(preds.len(), 2);
assert_eq!(p.n_updates(), 0);
}
#[test]
fn test_perceptron_converges_linearly_separable_10_samples() {
let data: Vec<(Vec<f64>, f64)> = vec![
(vec![2.0, 1.0], 1.0),
(vec![1.5, 0.8], 1.0),
(vec![1.0, 0.5], 1.0),
(vec![0.5, 0.2], 1.0),
(vec![0.2, 0.1], 1.0),
(vec![-0.2, -0.1], -1.0),
(vec![-0.5, -0.3], -1.0),
(vec![-1.0, -0.5], -1.0),
(vec![-1.5, -0.7], -1.0),
(vec![-2.0, -1.0], -1.0),
];
let mut p = Perceptron::new(2);
for _ in 0..50 {
for (x, y) in &data {
p.update(x, *y).expect("update");
}
}
let mut correct = 0;
for (x, y) in &data {
let pred = p.predict(x).expect("predict");
if pred == *y {
correct += 1;
}
}
assert_eq!(
correct, 10,
"Perceptron should converge on linearly separable data"
);
}
#[test]
fn test_pa_converges_linearly_separable() {
let data: Vec<(Vec<f64>, f64)> = vec![
(vec![1.0, 0.5], 1.0),
(vec![-1.0, -0.5], -1.0),
(vec![2.0, 1.0], 1.0),
(vec![-2.0, -1.0], -1.0),
];
let mut pa = PassiveAggressive::new(2, PAVariant::PAI)
.with_aggressiveness(1.0)
.expect("valid C");
for _ in 0..30 {
for (x, y) in &data {
pa.update(x, *y).expect("update");
}
}
for (x, y) in &data {
let pred = pa.predict(x).expect("predict");
assert_eq!(pred, *y);
}
}
#[test]
fn test_ogd_squared_converges_to_constant() {
let mut ogd = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(0.3);
let x = vec![1.0];
for _ in 0..200 {
ogd.update(&x, 2.0).expect("update");
}
let pred = ogd.predict(&x).expect("predict");
assert!(
approx_eq(pred, 2.0, 0.1),
"OGD should converge near 2.0, got {pred}"
);
}
#[test]
fn test_ftrl_n_updates_increments() {
let mut ftrl = Ftrl::new(2);
for i in 0..7 {
ftrl.update(&[1.0, 0.5], 1.0).expect("update");
assert_eq!(ftrl.n_updates(), i + 1);
}
}
#[test]
fn test_online_error_display() {
let e = OnlineError::DimensionMismatch {
expected: 5,
got: 3,
};
let s = e.to_string();
assert!(s.contains("5") && s.contains("3"));
let e2 = OnlineError::InvalidHyperparameter("C must be positive".to_string());
assert!(e2.to_string().contains("C must be positive"));
let e3 = OnlineError::NotFitted;
assert!(e3.to_string().contains("fitted"));
}
}