use crate::defaults::linear as linear_defaults;
use crate::learner::WeakLearner;
use crate::{Result, TreeBoostError};
use rkyv::{Archive, Deserialize, Serialize};
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct LinearConfig {
pub lambda: f32,
pub l1_ratio: f32,
pub shrinkage_factor: f32,
pub max_iter: usize,
pub tol: f32,
pub max_weight: f32,
pub extrapolation_damping: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LinearPreset {
Ridge,
Lasso,
ElasticNet,
Aggressive,
Conservative,
SafeRidge,
}
impl Default for LinearConfig {
fn default() -> Self {
Self {
lambda: linear_defaults::DEFAULT_LAMBDA, l1_ratio: linear_defaults::DEFAULT_L1_RATIO, shrinkage_factor: linear_defaults::DEFAULT_SHRINKAGE_FACTOR, max_iter: linear_defaults::DEFAULT_MAX_ITER, tol: linear_defaults::DEFAULT_TOL, max_weight: linear_defaults::DEFAULT_MAX_WEIGHT, extrapolation_damping: linear_defaults::DEFAULT_EXTRAPOLATION_DAMPING, }
}
}
impl LinearConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_preset(mut self, preset: LinearPreset) -> Self {
match preset {
LinearPreset::Ridge => {
self.lambda = linear_defaults::DEFAULT_LAMBDA;
self.l1_ratio = linear_defaults::DEFAULT_L1_RATIO;
}
LinearPreset::Lasso => {
self.lambda = linear_defaults::DEFAULT_LAMBDA;
self.l1_ratio = linear_defaults::LASSO_L1_RATIO;
}
LinearPreset::ElasticNet => {
self.lambda = linear_defaults::DEFAULT_LAMBDA;
self.l1_ratio = linear_defaults::ELASTIC_NET_L1_RATIO;
}
LinearPreset::Aggressive => {
self.shrinkage_factor = linear_defaults::AGGRESSIVE_SHRINKAGE;
}
LinearPreset::Conservative => {
self.shrinkage_factor = linear_defaults::CONSERVATIVE_SHRINKAGE;
}
LinearPreset::SafeRidge => {
self.lambda = linear_defaults::DEFAULT_LAMBDA;
self.l1_ratio = linear_defaults::DEFAULT_L1_RATIO;
self.extrapolation_damping = linear_defaults::SAFE_EXTRAPOLATION_DAMPING;
}
}
self
}
pub fn with_lambda(mut self, lambda: f32) -> Self {
self.lambda = lambda.max(1e-6);
self
}
pub fn with_l1_ratio(mut self, l1_ratio: f32) -> Self {
self.l1_ratio = l1_ratio.clamp(0.0, 1.0);
self
}
pub fn with_shrinkage_factor(mut self, shrinkage: f32) -> Self {
self.shrinkage_factor = shrinkage.clamp(1e-6, 1.0);
self
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter.max(1);
self
}
pub fn with_tol(mut self, tol: f32) -> Self {
self.tol = tol.max(1e-10);
self
}
pub fn with_max_weight(mut self, max_weight: f32) -> Self {
self.max_weight = max_weight.max(1.0);
self
}
pub fn with_extrapolation_damping(mut self, damping: f32) -> Self {
self.extrapolation_damping = damping.clamp(0.0, 1.0);
self
}
#[inline]
pub fn l2_penalty(&self) -> f32 {
self.lambda * (1.0 - self.l1_ratio)
}
#[inline]
pub fn l1_penalty(&self) -> f32 {
self.lambda * self.l1_ratio
}
}
#[derive(Debug, Clone, Archive, Serialize, Deserialize, serde::Serialize, serde::Deserialize)]
pub struct LinearBooster {
weights: Vec<f32>,
bias: f32,
means: Vec<f32>,
stds: Vec<f32>,
config: LinearConfig,
num_features: usize,
scaler_fitted: bool,
target_mean: f32,
#[serde(default)]
iterations_completed: usize,
}
impl LinearBooster {
pub fn new(num_features: usize, config: LinearConfig) -> Self {
Self {
weights: vec![0.0; num_features],
bias: 0.0,
means: vec![0.0; num_features],
stds: vec![1.0; num_features],
config,
num_features,
scaler_fitted: false,
target_mean: 0.0,
iterations_completed: 0,
}
}
pub fn weights(&self) -> &[f32] {
&self.weights
}
pub fn bias(&self) -> f32 {
self.bias
}
pub fn config(&self) -> &LinearConfig {
&self.config
}
fn fit_scaler(&mut self, features: &[f32], num_features: usize) {
let num_rows = features.len() / num_features;
if num_rows == 0 {
return;
}
for j in 0..num_features {
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
let mut count = 0usize;
for i in 0..num_rows {
let val = features[i * num_features + j] as f64;
if val.is_finite() {
sum += val;
sum_sq += val * val;
count += 1;
}
}
if count > 0 {
let mean = sum / count as f64;
let variance = (sum_sq / count as f64) - mean * mean;
let std = variance.max(0.0).sqrt();
self.means[j] = mean as f32;
self.stds[j] = if std > 1e-10 { std as f32 } else { 1.0 };
}
}
self.scaler_fitted = true;
}
#[inline]
fn standardize(&self, value: f32, feature_idx: usize) -> f32 {
(value - self.means[feature_idx]) / self.stds[feature_idx]
}
#[inline]
fn soft_threshold(x: f32, threshold: f32) -> f32 {
if x > threshold {
x - threshold
} else if x < -threshold {
x + threshold
} else {
0.0
}
}
fn coordinate_descent(
&mut self,
features: &[f32],
num_features: usize,
gradients: &[f32],
hessians: &[f32],
update_bias: bool,
) -> usize {
let num_rows = gradients.len();
if num_rows == 0 {
return 0;
}
let l2_penalty = self.config.l2_penalty();
let l1_penalty = self.config.l1_penalty();
let mut residuals = vec![0.0f32; num_rows];
for i in 0..num_rows {
residuals[i] = -gradients[i] / hessians[i].max(1e-10);
}
let mut x_sq_sums = vec![0.0f32; num_features];
for j in 0..num_features {
for i in 0..num_rows {
let x_ij = self.standardize(features[i * num_features + j], j);
x_sq_sums[j] += hessians[i] * x_ij * x_ij;
}
}
if update_bias {
let sum_residuals: f32 = residuals.iter().sum();
let sum_hessians: f32 = hessians.iter().sum();
self.bias = (sum_residuals / sum_hessians.max(1e-10))
.clamp(-self.config.max_weight, self.config.max_weight);
self.target_mean = self.bias;
}
for r in residuals.iter_mut() {
*r -= self.bias;
}
let mut actual_iterations = 0;
for _iter in 0..self.config.max_iter {
actual_iterations += 1;
let mut max_change = 0.0f32;
for j in 0..num_features {
let mut rho = 0.0f32;
for i in 0..num_rows {
let x_ij = self.standardize(features[i * num_features + j], j);
rho += hessians[i] * residuals[i] * x_ij;
}
rho += x_sq_sums[j] * self.weights[j];
let denominator = (x_sq_sums[j] + l2_penalty).max(1e-10);
let raw_weight = rho / denominator;
let l1_threshold = l1_penalty / denominator;
let new_weight = Self::soft_threshold(raw_weight, l1_threshold);
let new_weight = new_weight.clamp(-self.config.max_weight, self.config.max_weight);
let old_weight = self.weights[j];
let weight_change = new_weight - old_weight;
self.weights[j] = new_weight;
for i in 0..num_rows {
let x_ij = self.standardize(features[i * num_features + j], j);
residuals[i] -= weight_change * x_ij;
}
max_change = max_change.max(weight_change.abs());
}
if max_change < self.config.tol {
break;
}
}
actual_iterations
}
pub fn num_nonzero_weights(&self) -> usize {
self.weights.iter().filter(|&&w| w.abs() > 1e-10).count()
}
pub fn selected_features(&self) -> Vec<usize> {
self.weights
.iter()
.enumerate()
.filter(|(_, &w)| w.abs() > 1e-10)
.map(|(i, _)| i)
.collect()
}
pub fn fit_direct(
&mut self,
features: &[f32],
num_features: usize,
targets: &[f32],
) -> Result<Vec<f32>> {
let num_rows = targets.len();
if features.len() != num_rows * num_features {
return Err(TreeBoostError::Data(format!(
"Feature matrix size mismatch: expected {}, got {}",
num_rows * num_features,
features.len()
)));
}
if !self.scaler_fitted {
self.fit_scaler(features, num_features);
}
let lambda = self.config.lambda as f64;
let mut xtx = vec![0.0f64; num_features * num_features];
let mut xty = vec![0.0f64; num_features];
let y_mean: f64 = targets.iter().map(|&y| y as f64).sum::<f64>() / num_rows as f64;
let mut x_means = vec![0.0f64; num_features];
for i in 0..num_rows {
let y = targets[i] as f64;
for j in 0..num_features {
let xj = self.standardize(features[i * num_features + j], j) as f64;
x_means[j] += xj;
xty[j] += xj * y;
for k in 0..num_features {
let xk = self.standardize(features[i * num_features + k], k) as f64;
xtx[j * num_features + k] += xj * xk;
}
}
}
for x_mean in x_means.iter_mut() {
*x_mean /= num_rows as f64;
}
for j in 0..num_features {
xtx[j * num_features + j] += lambda;
}
let mut aug = vec![0.0f64; num_features * (num_features + 1)];
for i in 0..num_features {
for j in 0..num_features {
aug[i * (num_features + 1) + j] = xtx[i * num_features + j];
}
aug[i * (num_features + 1) + num_features] = xty[i];
}
for col in 0..num_features {
let mut max_row = col;
for row in (col + 1)..num_features {
if aug[row * (num_features + 1) + col].abs()
> aug[max_row * (num_features + 1) + col].abs()
{
max_row = row;
}
}
for k in 0..=num_features {
aug.swap(
col * (num_features + 1) + k,
max_row * (num_features + 1) + k,
);
}
let pivot = aug[col * (num_features + 1) + col];
if pivot.abs() < 1e-12 {
continue;
}
for row in 0..num_features {
if row != col {
let factor = aug[row * (num_features + 1) + col] / pivot;
for k in 0..=num_features {
aug[row * (num_features + 1) + k] -=
factor * aug[col * (num_features + 1) + k];
}
}
}
}
for i in 0..num_features {
let diag = aug[i * (num_features + 1) + i];
if diag.abs() > 1e-12 {
self.weights[i] = (aug[i * (num_features + 1) + num_features] / diag) as f32;
} else {
self.weights[i] = 0.0;
}
}
let weights_dot_xmean: f64 = self
.weights
.iter()
.zip(x_means.iter())
.map(|(&w, &xm)| w as f64 * xm)
.sum();
self.bias = (y_mean - weights_dot_xmean) as f32;
self.target_mean = y_mean as f32;
Ok(self.predict_batch(features, num_features))
}
}
impl WeakLearner for LinearBooster {
fn fit_on_gradients(
&mut self,
features: &[f32],
num_features: usize,
gradients: &[f32],
hessians: &[f32],
) -> Result<()> {
if num_features != self.num_features {
return Err(TreeBoostError::Config(format!(
"Feature count mismatch: expected {}, got {}",
self.num_features, num_features
)));
}
let num_rows = gradients.len();
if features.len() != num_rows * num_features {
return Err(TreeBoostError::Data(format!(
"Feature matrix size mismatch: expected {}, got {}",
num_rows * num_features,
features.len()
)));
}
if hessians.len() != num_rows {
return Err(TreeBoostError::Data(format!(
"Hessian size mismatch: expected {}, got {}",
num_rows,
hessians.len()
)));
}
if !self.scaler_fitted {
self.fit_scaler(features, num_features);
}
let iters = self.coordinate_descent(features, num_features, gradients, hessians, true);
self.iterations_completed += iters;
Ok(())
}
fn predict_batch(&self, features: &[f32], num_features: usize) -> Vec<f32> {
let num_rows = features.len() / num_features;
let mut predictions = vec![self.bias; num_rows];
for i in 0..num_rows {
for j in 0..num_features {
let x_ij = self.standardize(features[i * num_features + j], j);
predictions[i] += self.weights[j] * x_ij;
}
}
let damping = self.config.extrapolation_damping;
if damping > 0.0 {
let scale = 1.0 - damping;
let offset = damping * self.target_mean;
for pred in predictions.iter_mut() {
*pred = scale * *pred + offset;
}
}
predictions
}
fn predict_row(&self, features: &[f32], num_features: usize, row_idx: usize) -> f32 {
let mut pred = self.bias;
let start = row_idx * num_features;
for j in 0..num_features {
let x_ij = self.standardize(features[start + j], j);
pred += self.weights[j] * x_ij;
}
let damping = self.config.extrapolation_damping;
if damping > 0.0 {
pred = (1.0 - damping) * pred + damping * self.target_mean;
}
pred
}
fn num_params(&self) -> usize {
self.num_features + 1 }
fn reset(&mut self) {
self.weights.fill(0.0);
self.bias = 0.0;
self.target_mean = 0.0;
self.iterations_completed = 0;
}
}
impl crate::learner::incremental::IncrementalLearner for LinearBooster {
fn warm_fit(
&mut self,
features: &[f32],
num_features: usize,
gradients: &[f32],
hessians: &[f32],
) -> Result<()> {
if num_features != self.num_features {
return Err(TreeBoostError::Config(format!(
"Feature count mismatch: expected {}, got {}",
self.num_features, num_features
)));
}
let num_rows = gradients.len();
if features.len() != num_rows * num_features {
return Err(TreeBoostError::Data(format!(
"Feature matrix size mismatch: expected {}, got {}",
num_rows * num_features,
features.len()
)));
}
if hessians.len() != num_rows {
return Err(TreeBoostError::Data(format!(
"Hessian size mismatch: expected {}, got {}",
num_rows,
hessians.len()
)));
}
if !self.scaler_fitted {
return Err(TreeBoostError::Config(
"Cannot warm_fit on unfitted LinearBooster. \
Call fit_on_gradients first to initialize the scaler."
.to_string(),
));
}
let iters = self.coordinate_descent(features, num_features, gradients, hessians, false);
self.iterations_completed += iters;
Ok(())
}
fn iterations_completed(&self) -> usize {
self.iterations_completed
}
fn reset_iterations(&mut self) {
self.iterations_completed = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_config_lambda_minimum() {
let config = LinearConfig::new().with_lambda(0.0);
assert!(config.lambda >= 1e-6, "Lambda should never be 0");
let config = LinearConfig::new().with_lambda(-1.0);
assert!(config.lambda >= 1e-6, "Lambda should never be negative");
}
#[test]
fn test_linear_booster_creation() {
let config = LinearConfig::default();
let booster = LinearBooster::new(5, config);
assert_eq!(booster.weights().len(), 5);
assert_eq!(booster.bias(), 0.0);
assert_eq!(booster.num_params(), 6);
}
#[test]
fn test_linear_booster_simple_fit() {
let features = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let targets = vec![3.0, 5.0, 7.0, 9.0, 11.0];
let gradients: Vec<f32> = targets.iter().map(|&t| -t).collect();
let hessians = vec![1.0; 5];
let config = LinearConfig::default()
.with_lambda(0.01)
.with_shrinkage_factor(0.5)
.with_max_iter(100);
let mut booster = LinearBooster::new(1, config);
booster
.fit_on_gradients(&features, 1, &gradients, &hessians)
.unwrap();
let predictions = booster.predict_batch(&features, 1);
for (pred, &target) in predictions.iter().zip(targets.iter()) {
let error = (pred - target).abs();
assert!(
error < 2.0,
"Prediction {} too far from target {}",
pred,
target
);
}
}
#[test]
fn test_linear_booster_multivariate() {
let features = vec![
1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, ];
let targets = vec![3.0, 4.0, 5.0, 6.0];
let gradients: Vec<f32> = targets.iter().map(|&t| -t).collect();
let hessians = vec![1.0; 4];
let config = LinearConfig::default()
.with_lambda(0.001)
.with_shrinkage_factor(0.5)
.with_max_iter(200);
let mut booster = LinearBooster::new(2, config);
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
let predictions = booster.predict_batch(&features, 2);
for (i, (pred, &target)) in predictions.iter().zip(targets.iter()).enumerate() {
let error = (pred - target).abs();
assert!(
error < 1.5,
"Row {}: pred {} too far from target {}",
i,
pred,
target
);
}
}
#[test]
fn test_linear_booster_no_nan() {
let features = vec![
1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0,
];
let gradients = vec![-1.0, -2.0, -3.0, -4.0];
let hessians = vec![1.0; 4];
let config = LinearConfig::default();
let mut booster = LinearBooster::new(2, config);
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
let predictions = booster.predict_batch(&features, 2);
for pred in &predictions {
assert!(
pred.is_finite(),
"Prediction should be finite, got {}",
pred
);
}
}
#[test]
fn test_linear_booster_constant_feature() {
let features = vec![
1.0, 5.0, 2.0, 5.0, 3.0, 5.0,
];
let gradients = vec![-1.0, -2.0, -3.0];
let hessians = vec![1.0; 3];
let config = LinearConfig::default();
let mut booster = LinearBooster::new(2, config);
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
let predictions = booster.predict_batch(&features, 2);
for pred in &predictions {
assert!(
pred.is_finite(),
"Prediction should be finite, got {}",
pred
);
}
}
#[test]
fn test_linear_booster_reset() {
let config = LinearConfig::default();
let mut booster = LinearBooster::new(3, config);
let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let gradients = vec![-1.0, -2.0];
let hessians = vec![1.0, 1.0];
booster
.fit_on_gradients(&features, 3, &gradients, &hessians)
.unwrap();
let has_nonzero = booster.weights().iter().any(|&w| w.abs() > 1e-10);
assert!(has_nonzero, "Weights should be non-zero after fitting");
booster.reset();
for &w in booster.weights() {
assert!((w.abs()) < 1e-10, "Weights should be zero after reset");
}
assert!(
(booster.bias().abs()) < 1e-10,
"Bias should be zero after reset"
);
}
#[test]
fn test_linear_booster_single_row_prediction() {
let config = LinearConfig::default();
let mut booster = LinearBooster::new(2, config);
let features = vec![1.0, 2.0, 3.0, 4.0];
let gradients = vec![-5.0, -10.0];
let hessians = vec![1.0, 1.0];
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
let batch_preds = booster.predict_batch(&features, 2);
let single_pred_0 = booster.predict_row(&features, 2, 0);
let single_pred_1 = booster.predict_row(&features, 2, 1);
assert!((batch_preds[0] - single_pred_0).abs() < 1e-6);
assert!((batch_preds[1] - single_pred_1).abs() < 1e-6);
}
#[test]
fn test_soft_threshold() {
assert!((LinearBooster::soft_threshold(5.0, 2.0) - 3.0).abs() < 1e-6);
assert!((LinearBooster::soft_threshold(-5.0, 2.0) - (-3.0)).abs() < 1e-6);
assert!((LinearBooster::soft_threshold(1.5, 2.0) - 0.0).abs() < 1e-6);
assert!((LinearBooster::soft_threshold(-1.5, 2.0) - 0.0).abs() < 1e-6);
assert!((LinearBooster::soft_threshold(2.0, 2.0) - 0.0).abs() < 1e-6);
}
#[test]
fn test_lasso_sparsity() {
let n_samples = 100;
let n_features = 4;
let mut features = Vec::with_capacity(n_samples * n_features);
let mut targets = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let x0 = (i as f32) / 10.0;
features.push(x0); features.push(0.5); features.push(0.3); features.push(0.1); targets.push(3.0 * x0); }
let gradients: Vec<f32> = targets.iter().map(|&t| -t).collect();
let hessians = vec![1.0; n_samples];
let config = LinearConfig::default()
.with_preset(LinearPreset::Lasso)
.with_lambda(2.0)
.with_shrinkage_factor(0.5)
.with_max_iter(200);
let mut booster = LinearBooster::new(n_features, config);
booster
.fit_on_gradients(&features, n_features, &gradients, &hessians)
.unwrap();
assert!(
booster.weights()[0].abs() > 0.1,
"Feature 0 should be selected"
);
let selected = booster.selected_features();
println!("Selected features: {:?}", selected);
println!("Weights: {:?}", booster.weights());
println!("Num nonzero: {}", booster.num_nonzero_weights());
assert!(selected.contains(&0), "Feature 0 must be selected");
}
#[test]
fn test_elastic_net_config() {
let config = LinearConfig::default()
.with_preset(LinearPreset::ElasticNet)
.with_lambda(1.0)
.with_l1_ratio(0.5);
assert!((config.lambda - 1.0).abs() < 1e-6);
assert!((config.l1_ratio - 0.5).abs() < 1e-6);
assert!((config.l1_penalty() - 0.5).abs() < 1e-6);
assert!((config.l2_penalty() - 0.5).abs() < 1e-6);
}
#[test]
fn test_ridge_vs_lasso_sparsity() {
let n_samples = 50;
let n_features = 10;
let mut features = Vec::with_capacity(n_samples * n_features);
let mut targets = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let x = (i as f32) / 10.0;
for _ in 0..n_features {
features.push(x);
}
targets.push(x); }
let gradients: Vec<f32> = targets.iter().map(|&t| -t).collect();
let hessians = vec![1.0; n_samples];
let ridge_config = LinearConfig::default()
.with_preset(LinearPreset::Ridge)
.with_lambda(0.1)
.with_shrinkage_factor(0.5)
.with_max_iter(100);
let mut ridge_booster = LinearBooster::new(n_features, ridge_config);
ridge_booster
.fit_on_gradients(&features, n_features, &gradients, &hessians)
.unwrap();
let lasso_config = LinearConfig::default()
.with_preset(LinearPreset::Lasso)
.with_lambda(0.5)
.with_shrinkage_factor(0.5)
.with_max_iter(100);
let mut lasso_booster = LinearBooster::new(n_features, lasso_config);
lasso_booster
.fit_on_gradients(&features, n_features, &gradients, &hessians)
.unwrap();
println!("Ridge nonzero: {}", ridge_booster.num_nonzero_weights());
println!("LASSO nonzero: {}", lasso_booster.num_nonzero_weights());
let ridge_preds = ridge_booster.predict_batch(&features, n_features);
let lasso_preds = lasso_booster.predict_batch(&features, n_features);
for pred in ridge_preds.iter().chain(lasso_preds.iter()) {
assert!(pred.is_finite(), "Predictions must be finite");
}
}
#[test]
fn test_elastic_net_stability() {
let features = vec![
1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0,
];
let gradients = vec![-1.0, -2.0, -3.0, -4.0];
let hessians = vec![1.0; 4];
let config = LinearConfig::default()
.with_preset(LinearPreset::ElasticNet)
.with_lambda(0.5)
.with_l1_ratio(0.5) .with_shrinkage_factor(0.5)
.with_max_iter(100);
let mut booster = LinearBooster::new(2, config);
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
let predictions = booster.predict_batch(&features, 2);
for pred in &predictions {
assert!(
pred.is_finite(),
"Elastic Net prediction should be finite, got {}",
pred
);
}
}
#[test]
fn test_shrinkage_factor_clamping() {
let config = LinearConfig::new().with_shrinkage_factor(-1.0);
assert!(
config.shrinkage_factor >= 1e-6,
"shrinkage_factor should be clamped to minimum 1e-6, got {}",
config.shrinkage_factor
);
let config = LinearConfig::new().with_shrinkage_factor(2.0);
assert!(
config.shrinkage_factor <= 1.0,
"shrinkage_factor should be clamped to maximum 1.0, got {}",
config.shrinkage_factor
);
let config = LinearConfig::new().with_shrinkage_factor(0.5);
assert_eq!(config.shrinkage_factor, 0.5);
}
#[test]
fn test_shrinkage_factor_near_zero_contribution() {
let features = vec![1.0, 2.0, 3.0, 4.0]; let gradients = vec![-1.0, -2.0];
let hessians = vec![1.0, 1.0];
let config = LinearConfig::default().with_shrinkage_factor(0.0);
let mut booster = LinearBooster::new(2, config);
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
assert_eq!(booster.config().shrinkage_factor, 1e-6);
}
#[test]
fn test_shrinkage_factor_full_contribution() {
let features = vec![1.0, 2.0, 3.0, 4.0]; let gradients = vec![-1.0, -2.0];
let hessians = vec![1.0, 1.0];
let config = LinearConfig::default().with_shrinkage_factor(1.0);
let mut booster = LinearBooster::new(2, config);
booster
.fit_on_gradients(&features, 2, &gradients, &hessians)
.unwrap();
assert_eq!(booster.config().shrinkage_factor, 1.0);
}
#[test]
fn test_shrinkage_factor_vs_extrapolation_damping() {
let config = LinearConfig::default()
.with_shrinkage_factor(0.3)
.with_extrapolation_damping(0.1);
assert_eq!(config.shrinkage_factor, 0.3);
assert_eq!(config.extrapolation_damping, 0.1);
let config2 = LinearConfig::default()
.with_shrinkage_factor(0.5)
.with_extrapolation_damping(0.0);
assert_eq!(config2.shrinkage_factor, 0.5);
assert_eq!(config2.extrapolation_damping, 0.0);
}
#[test]
fn test_linear_warm_start() {
use crate::learner::incremental::IncrementalLearner;
let features_a = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let targets_a: Vec<f32> = features_a.iter().map(|&x| 2.0 * x).collect();
let gradients_a: Vec<f32> = targets_a.iter().map(|&t| -t).collect();
let hessians_a = vec![1.0; 5];
let config = LinearConfig::default().with_lambda(0.01).with_max_iter(100);
let mut booster = LinearBooster::new(1, config);
booster
.fit_on_gradients(&features_a, 1, &gradients_a, &hessians_a)
.unwrap();
let initial_weight = booster.weights()[0];
let initial_iters = booster.iterations_completed();
assert!(
initial_weight > 1.0,
"Initial weight {} should be > 1.0",
initial_weight
);
let features_b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let targets_b: Vec<f32> = features_b.iter().map(|&x| 3.0 * x).collect();
let gradients_b: Vec<f32> = targets_b.iter().map(|&t| -t).collect();
let hessians_b = vec![1.0; 5];
booster
.warm_fit(&features_b, 1, &gradients_b, &hessians_b)
.unwrap();
let new_weight = booster.weights()[0];
let total_iters = booster.iterations_completed();
assert!(
new_weight > initial_weight,
"Warm start weight {} should be > initial {}",
new_weight,
initial_weight
);
assert!(
total_iters > initial_iters,
"Total iterations {} should be > initial {}",
total_iters,
initial_iters
);
}
#[test]
fn test_linear_scaler_preserved_on_warm_fit() {
use crate::learner::incremental::IncrementalLearner;
let features_a = vec![1.0, 2.0, 3.0, 4.0];
let gradients_a = vec![-1.0, -2.0, -3.0, -4.0];
let hessians_a = vec![1.0; 4];
let config = LinearConfig::default();
let mut booster = LinearBooster::new(1, config);
booster
.fit_on_gradients(&features_a, 1, &gradients_a, &hessians_a)
.unwrap();
let mean_after_first = booster.means[0];
let std_after_first = booster.stds[0];
let features_b = vec![100.0, 200.0, 300.0, 400.0]; let gradients_b = vec![-1.0, -2.0, -3.0, -4.0];
let hessians_b = vec![1.0; 4];
booster
.warm_fit(&features_b, 1, &gradients_b, &hessians_b)
.unwrap();
assert_eq!(
booster.means[0], mean_after_first,
"Mean should be preserved after warm_fit"
);
assert_eq!(
booster.stds[0], std_after_first,
"Std should be preserved after warm_fit"
);
}
#[test]
fn test_warm_fit_requires_prior_fit() {
use crate::learner::incremental::IncrementalLearner;
let features = vec![1.0, 2.0, 3.0, 4.0];
let gradients = vec![-1.0, -2.0, -3.0, -4.0];
let hessians = vec![1.0; 4];
let config = LinearConfig::default();
let mut booster = LinearBooster::new(1, config);
let result = booster.warm_fit(&features, 1, &gradients, &hessians);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("unfitted"));
}
#[test]
fn test_iterations_tracking() {
use crate::learner::incremental::IncrementalLearner;
let features = vec![1.0, 2.0, 3.0, 4.0];
let gradients = vec![-1.0, -2.0, -3.0, -4.0];
let hessians = vec![1.0; 4];
let config = LinearConfig::default().with_max_iter(10);
let mut booster = LinearBooster::new(1, config);
assert_eq!(booster.iterations_completed(), 0);
booster
.fit_on_gradients(&features, 1, &gradients, &hessians)
.unwrap();
let iters_after_first = booster.iterations_completed();
assert!(iters_after_first > 0);
booster
.fit_on_gradients(&features, 1, &gradients, &hessians)
.unwrap();
let iters_after_second = booster.iterations_completed();
assert!(iters_after_second > iters_after_first);
booster.reset_iterations();
assert_eq!(booster.iterations_completed(), 0);
}
#[test]
fn test_reset_clears_iterations() {
use crate::learner::incremental::IncrementalLearner;
let features = vec![1.0, 2.0, 3.0, 4.0];
let gradients = vec![-1.0, -2.0, -3.0, -4.0];
let hessians = vec![1.0; 4];
let config = LinearConfig::default();
let mut booster = LinearBooster::new(1, config);
booster
.fit_on_gradients(&features, 1, &gradients, &hessians)
.unwrap();
assert!(booster.iterations_completed() > 0);
booster.reset();
assert_eq!(booster.iterations_completed(), 0);
}
}