use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use crate::math;
pub trait LeafModel: Send + Sync {
fn predict(&self, features: &[f64]) -> f64;
fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64);
fn clone_fresh(&self) -> Box<dyn LeafModel>;
fn clone_warm(&self) -> Box<dyn LeafModel> {
self.clone_fresh()
}
}
pub struct ClosedFormLeaf {
grad_sum: f64,
hess_sum: f64,
weight: f64,
}
impl Default for ClosedFormLeaf {
fn default() -> Self {
Self {
grad_sum: 0.0,
hess_sum: 0.0,
weight: 0.0,
}
}
}
impl ClosedFormLeaf {
pub fn new() -> Self {
Self::default()
}
}
impl LeafModel for ClosedFormLeaf {
fn predict(&self, _features: &[f64]) -> f64 {
self.weight
}
fn update(&mut self, _features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
self.grad_sum += gradient;
self.hess_sum += hessian;
self.weight = -self.grad_sum / (self.hess_sum + lambda);
}
fn clone_fresh(&self) -> Box<dyn LeafModel> {
Box::new(ClosedFormLeaf::new())
}
}
pub struct LinearLeafModel {
weights: Vec<f64>,
bias: f64,
learning_rate: f64,
decay: Option<f64>,
use_adagrad: bool,
sq_grad_accum: Vec<f64>,
sq_bias_accum: f64,
initialized: bool,
}
impl LinearLeafModel {
pub fn new(learning_rate: f64, decay: Option<f64>, use_adagrad: bool) -> Self {
Self {
weights: Vec::new(),
bias: 0.0,
learning_rate,
decay,
use_adagrad,
sq_grad_accum: Vec::new(),
sq_bias_accum: 0.0,
initialized: false,
}
}
}
const ADAGRAD_EPS: f64 = 1e-8;
impl LeafModel for LinearLeafModel {
fn predict(&self, features: &[f64]) -> f64 {
if !self.initialized {
return 0.0;
}
let mut dot = self.bias;
for (w, x) in self.weights.iter().zip(features.iter()) {
dot += w * x;
}
dot
}
fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
if !self.initialized {
let d = features.len();
self.weights = vec![0.0; d];
self.sq_grad_accum = vec![0.0; d];
self.initialized = true;
}
if let Some(d) = self.decay {
for w in self.weights.iter_mut() {
*w *= d;
}
self.bias *= d;
}
let base_lr = self.learning_rate / (math::abs(hessian) + lambda);
if self.use_adagrad {
for (i, (w, x)) in self.weights.iter_mut().zip(features.iter()).enumerate() {
let g = gradient * x;
self.sq_grad_accum[i] += g * g;
let adaptive_lr = base_lr / (math::sqrt(self.sq_grad_accum[i]) + ADAGRAD_EPS);
*w -= adaptive_lr * g;
}
self.sq_bias_accum += gradient * gradient;
let bias_lr = base_lr / (math::sqrt(self.sq_bias_accum) + ADAGRAD_EPS);
self.bias -= bias_lr * gradient;
} else {
for (w, x) in self.weights.iter_mut().zip(features.iter()) {
*w -= base_lr * gradient * x;
}
self.bias -= base_lr * gradient;
}
}
fn clone_fresh(&self) -> Box<dyn LeafModel> {
Box::new(LinearLeafModel::new(
self.learning_rate,
self.decay,
self.use_adagrad,
))
}
fn clone_warm(&self) -> Box<dyn LeafModel> {
Box::new(LinearLeafModel {
weights: self.weights.clone(),
bias: self.bias,
learning_rate: self.learning_rate,
decay: self.decay,
use_adagrad: self.use_adagrad,
sq_grad_accum: vec![0.0; self.weights.len()],
sq_bias_accum: 0.0,
initialized: self.initialized,
})
}
}
pub struct MLPLeafModel {
hidden_weights: Vec<Vec<f64>>, hidden_bias: Vec<f64>,
output_weights: Vec<f64>,
output_bias: f64,
hidden_size: usize,
learning_rate: f64,
decay: Option<f64>,
seed: u64,
initialized: bool,
hidden_activations: Vec<f64>,
hidden_pre_activations: Vec<f64>,
}
impl MLPLeafModel {
pub fn new(hidden_size: usize, learning_rate: f64, seed: u64, decay: Option<f64>) -> Self {
Self {
hidden_weights: Vec::new(),
hidden_bias: Vec::new(),
output_weights: Vec::new(),
output_bias: 0.0,
hidden_size,
learning_rate,
decay,
seed,
initialized: false,
hidden_activations: Vec::new(),
hidden_pre_activations: Vec::new(),
}
}
fn initialize(&mut self, input_size: usize) {
let mut state = self.seed ^ (self.hidden_size as u64);
self.hidden_weights = Vec::with_capacity(self.hidden_size);
for _ in 0..self.hidden_size {
let mut row = Vec::with_capacity(input_size);
for _ in 0..input_size {
let r = xorshift64(&mut state);
let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
row.push(val);
}
self.hidden_weights.push(row);
}
self.hidden_bias = Vec::with_capacity(self.hidden_size);
for _ in 0..self.hidden_size {
let r = xorshift64(&mut state);
let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
self.hidden_bias.push(val);
}
self.output_weights = Vec::with_capacity(self.hidden_size);
for _ in 0..self.hidden_size {
let r = xorshift64(&mut state);
let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
self.output_weights.push(val);
}
{
let r = xorshift64(&mut state);
self.output_bias = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
}
self.hidden_activations = vec![0.0; self.hidden_size];
self.hidden_pre_activations = vec![0.0; self.hidden_size];
self.initialized = true;
}
fn forward(&mut self, features: &[f64]) -> f64 {
for h in 0..self.hidden_size {
let mut z = self.hidden_bias[h];
for (j, x) in features.iter().enumerate() {
if j < self.hidden_weights[h].len() {
z += self.hidden_weights[h][j] * x;
}
}
self.hidden_pre_activations[h] = z;
self.hidden_activations[h] = if z > 0.0 { z } else { 0.0 };
}
let mut out = self.output_bias;
for (w, a) in self
.output_weights
.iter()
.zip(self.hidden_activations.iter())
{
out += w * a;
}
out
}
}
impl LeafModel for MLPLeafModel {
fn predict(&self, features: &[f64]) -> f64 {
if !self.initialized {
return 0.0;
}
let hidden_acts: Vec<f64> = self
.hidden_weights
.iter()
.zip(self.hidden_bias.iter())
.map(|(hw, &hb)| {
let mut z = hb;
for (j, x) in features.iter().enumerate() {
if j < hw.len() {
z += hw[j] * x;
}
}
if z > 0.0 {
z
} else {
0.0
}
})
.collect();
let mut out = self.output_bias;
for (w, a) in self.output_weights.iter().zip(hidden_acts.iter()) {
out += w * a;
}
out
}
fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
if !self.initialized {
self.initialize(features.len());
}
if let Some(d) = self.decay {
for row in self.hidden_weights.iter_mut() {
for w in row.iter_mut() {
*w *= d;
}
}
for b in self.hidden_bias.iter_mut() {
*b *= d;
}
for w in self.output_weights.iter_mut() {
*w *= d;
}
self.output_bias *= d;
}
let _output = self.forward(features);
let effective_lr = self.learning_rate / (math::abs(hessian) + lambda);
let d_output = gradient;
for h in 0..self.hidden_size {
self.output_weights[h] -= effective_lr * d_output * self.hidden_activations[h];
}
self.output_bias -= effective_lr * d_output;
for h in 0..self.hidden_size {
let d_hidden_act = d_output * self.output_weights[h];
let d_relu = if self.hidden_pre_activations[h] > 0.0 {
d_hidden_act
} else {
0.0
};
for (j, x) in features.iter().enumerate() {
if j < self.hidden_weights[h].len() {
self.hidden_weights[h][j] -= effective_lr * d_relu * x;
}
}
self.hidden_bias[h] -= effective_lr * d_relu;
}
}
fn clone_fresh(&self) -> Box<dyn LeafModel> {
let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
Box::new(MLPLeafModel::new(
self.hidden_size,
self.learning_rate,
derived_seed,
self.decay,
))
}
fn clone_warm(&self) -> Box<dyn LeafModel> {
Box::new(MLPLeafModel {
hidden_weights: self.hidden_weights.clone(),
hidden_bias: self.hidden_bias.clone(),
output_weights: self.output_weights.clone(),
output_bias: self.output_bias,
hidden_size: self.hidden_size,
learning_rate: self.learning_rate,
decay: self.decay,
seed: self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(2),
initialized: self.initialized,
hidden_activations: vec![0.0; self.hidden_size],
hidden_pre_activations: vec![0.0; self.hidden_size],
})
}
}
pub struct AdaptiveLeafModel {
active: Box<dyn LeafModel>,
shadow: Box<dyn LeafModel>,
promote_to: LeafModelType,
cumulative_advantage: f64,
n: u64,
max_loss_diff: f64,
delta: f64,
promoted: bool,
seed: u64,
}
impl AdaptiveLeafModel {
pub fn new(
shadow: Box<dyn LeafModel>,
promote_to: LeafModelType,
delta: f64,
seed: u64,
) -> Self {
Self {
active: Box::new(ClosedFormLeaf::new()),
shadow,
promote_to,
cumulative_advantage: 0.0,
n: 0,
max_loss_diff: 0.0,
delta,
promoted: false,
seed,
}
}
}
impl LeafModel for AdaptiveLeafModel {
fn predict(&self, features: &[f64]) -> f64 {
self.active.predict(features)
}
fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
if self.promoted {
self.active.update(features, gradient, hessian, lambda);
return;
}
let pred_active = self.active.predict(features);
let pred_shadow = self.shadow.predict(features);
let loss_active = gradient * pred_active + 0.5 * hessian * pred_active * pred_active;
let loss_shadow = gradient * pred_shadow + 0.5 * hessian * pred_shadow * pred_shadow;
let diff = loss_active - loss_shadow;
self.cumulative_advantage += diff;
self.n += 1;
let abs_diff = math::abs(diff);
if abs_diff > self.max_loss_diff {
self.max_loss_diff = abs_diff;
}
self.active.update(features, gradient, hessian, lambda);
self.shadow.update(features, gradient, hessian, lambda);
if self.n >= 10 && self.max_loss_diff > 0.0 {
let mean_advantage = self.cumulative_advantage / self.n as f64;
if mean_advantage > 0.0 {
let r_squared = self.max_loss_diff * self.max_loss_diff;
let ln_inv_delta = math::ln(1.0 / self.delta);
let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * self.n as f64));
if mean_advantage > epsilon {
self.promoted = true;
core::mem::swap(&mut self.active, &mut self.shadow);
}
}
}
}
fn clone_fresh(&self) -> Box<dyn LeafModel> {
let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
Box::new(AdaptiveLeafModel::new(
self.promote_to.create(derived_seed, self.delta),
self.promote_to.clone(),
self.delta,
derived_seed,
))
}
}
unsafe impl Send for AdaptiveLeafModel {}
unsafe impl Sync for AdaptiveLeafModel {}
#[derive(Debug, Clone, Default, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum LeafModelType {
#[default]
ClosedForm,
Linear {
learning_rate: f64,
#[cfg_attr(feature = "serde", serde(default))]
decay: Option<f64>,
#[cfg_attr(feature = "serde", serde(default))]
use_adagrad: bool,
},
MLP {
hidden_size: usize,
learning_rate: f64,
#[cfg_attr(feature = "serde", serde(default))]
decay: Option<f64>,
},
Adaptive {
promote_to: Box<LeafModelType>,
},
}
impl LeafModelType {
pub fn create(&self, seed: u64, delta: f64) -> Box<dyn LeafModel> {
match self {
Self::ClosedForm => Box::new(ClosedFormLeaf::new()),
Self::Linear {
learning_rate,
decay,
use_adagrad,
} => Box::new(LinearLeafModel::new(*learning_rate, *decay, *use_adagrad)),
Self::MLP {
hidden_size,
learning_rate,
decay,
} => Box::new(MLPLeafModel::new(
*hidden_size,
*learning_rate,
seed,
*decay,
)),
Self::Adaptive { promote_to } => Box::new(AdaptiveLeafModel::new(
promote_to.create(seed, delta),
*promote_to.clone(),
delta,
seed,
)),
}
}
}
use crate::rng::xorshift64;
#[cfg(test)]
mod tests {
use super::*;
fn xorshift64(state: &mut u64) -> u64 {
let mut s = *state;
s ^= s << 13;
s ^= s >> 7;
s ^= s << 17;
*state = s;
s
}
fn rand_f64(state: &mut u64) -> f64 {
xorshift64(state) as f64 / u64::MAX as f64
}
#[test]
fn closed_form_matches_formula() {
let mut leaf = ClosedFormLeaf::new();
let lambda = 1.0;
let updates = [(0.5, 1.0), (-0.3, 0.8), (1.2, 2.0), (-0.1, 0.5)];
let mut grad_sum = 0.0;
let mut hess_sum = 0.0;
for &(g, h) in &updates {
leaf.update(&[], g, h, lambda);
grad_sum += g;
hess_sum += h;
}
let expected = -grad_sum / (hess_sum + lambda);
let predicted = leaf.predict(&[]);
assert!(
(predicted - expected).abs() < 1e-12,
"closed form mismatch: got {predicted}, expected {expected}"
);
}
#[test]
fn closed_form_clone_fresh_resets() {
let mut leaf = ClosedFormLeaf::new();
leaf.update(&[], 5.0, 2.0, 1.0);
assert!(
leaf.predict(&[]).abs() > 0.0,
"leaf should have non-zero weight after update"
);
let fresh = leaf.clone_fresh();
assert!(
fresh.predict(&[]).abs() < 1e-15,
"fresh clone should predict 0, got {}",
fresh.predict(&[])
);
}
#[test]
fn linear_converges_on_linear_target() {
let mut model = LinearLeafModel::new(0.01, None, false);
let lambda = 0.1;
let mut rng = 42u64;
for _ in 0..2000 {
let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
let features = vec![x1, x2];
let target = 2.0 * x1 + 3.0 * x2;
let pred = model.predict(&features);
let gradient = 2.0 * (pred - target);
let hessian = 2.0;
model.update(&features, gradient, hessian, lambda);
}
let test_features = vec![0.5, -0.3];
let target = 2.0 * 0.5 + 3.0 * (-0.3);
let pred = model.predict(&test_features);
assert!(
(pred - target).abs() < 1.0,
"linear model should converge within 1.0 of target: pred={pred}, target={target}"
);
}
#[test]
fn linear_uninitialized_predicts_zero() {
let model = LinearLeafModel::new(0.01, None, false);
let pred = model.predict(&[1.0, 2.0, 3.0]);
assert!(
pred.abs() < 1e-15,
"uninitialized linear model should predict 0, got {pred}"
);
}
#[test]
fn linear_clone_warm_preserves_weights() {
let mut model = LinearLeafModel::new(0.01, None, false);
let features = vec![1.0, 2.0];
for i in 0..100 {
let target = 3.0 * features[0] + 2.0 * features[1];
let pred = model.predict(&features);
let gradient = 2.0 * (pred - target);
model.update(&features, gradient, 2.0, 0.1);
let _ = i;
}
let trained_pred = model.predict(&features);
assert!(
trained_pred.abs() > 0.01,
"model should have learned something"
);
let warm = model.clone_warm();
let warm_pred = warm.predict(&features);
assert!(
(warm_pred - trained_pred).abs() < 1e-12,
"warm clone should preserve weights: trained={trained_pred}, warm={warm_pred}"
);
let fresh = model.clone_fresh();
let fresh_pred = fresh.predict(&features);
assert!(
fresh_pred.abs() < 1e-15,
"fresh clone should predict 0, got {fresh_pred}"
);
}
#[test]
fn linear_decay_forgets_old_data() {
let mut model_decay = LinearLeafModel::new(0.05, Some(0.99), false);
let mut model_no_decay = LinearLeafModel::new(0.05, None, false);
let features = vec![1.0];
let lambda = 0.1;
for _ in 0..500 {
let pred_d = model_decay.predict(&features);
let pred_n = model_no_decay.predict(&features);
model_decay.update(&features, 2.0 * (pred_d - 5.0), 2.0, lambda);
model_no_decay.update(&features, 2.0 * (pred_n - 5.0), 2.0, lambda);
}
let pred_d_trained = model_decay.predict(&features);
let pred_n_trained = model_no_decay.predict(&features);
assert!(
(pred_d_trained - 5.0).abs() < 2.0,
"decay model should approximate target"
);
assert!(
(pred_n_trained - 5.0).abs() < 2.0,
"no-decay model should approximate target"
);
for _ in 0..200 {
model_decay.update(&features, 0.0, 1.0, lambda);
model_no_decay.update(&features, 0.0, 1.0, lambda);
}
let pred_d_after = model_decay.predict(&features);
let pred_n_after = model_no_decay.predict(&features);
assert!(
pred_d_after.abs() < pred_n_after.abs(),
"decay model should forget: decay pred={pred_d_after:.3}, no-decay pred={pred_n_after:.3}"
);
}
#[test]
fn mlp_produces_finite_predictions() {
let model_uninit = MLPLeafModel::new(4, 0.01, 42, None);
let features = vec![1.0, 2.0, 3.0];
let pred_before = model_uninit.predict(&features);
assert!(
pred_before.is_finite(),
"uninit prediction should be finite"
);
assert!(
pred_before.abs() < 1e-15,
"uninit prediction should be 0, got {pred_before}"
);
let mut model = MLPLeafModel::new(4, 0.01, 42, None);
for _ in 0..10 {
model.update(&features, 0.5, 1.0, 0.1);
}
let pred_after = model.predict(&features);
assert!(
pred_after.is_finite(),
"prediction after training should be finite, got {pred_after}"
);
}
#[test]
fn mlp_loss_decreases() {
let mut model = MLPLeafModel::new(8, 0.05, 123, None);
let features = vec![1.0, -0.5, 0.3];
let target = 2.5;
let lambda = 0.1;
model.update(&features, 0.0, 1.0, lambda); let initial_pred = model.predict(&features);
let initial_error = (initial_pred - target).abs();
for _ in 0..200 {
let pred = model.predict(&features);
let gradient = 2.0 * (pred - target);
let hessian = 2.0;
model.update(&features, gradient, hessian, lambda);
}
let final_pred = model.predict(&features);
let final_error = (final_pred - target).abs();
assert!(
final_error < initial_error,
"MLP error should decrease: initial={initial_error}, final={final_error}"
);
}
#[test]
fn mlp_clone_fresh_resets() {
let mut model = MLPLeafModel::new(4, 0.01, 42, None);
let features = vec![1.0, 2.0];
for _ in 0..20 {
model.update(&features, 0.5, 1.0, 0.1);
}
let trained_pred = model.predict(&features);
assert!(
trained_pred.abs() > 1e-10,
"trained model should have non-zero prediction"
);
let fresh = model.clone_fresh();
let fresh_pred = fresh.predict(&features);
assert!(
fresh_pred.abs() < 1e-15,
"fresh clone should predict 0, got {fresh_pred}"
);
}
#[test]
fn mlp_clone_warm_preserves_weights() {
let mut model = MLPLeafModel::new(4, 0.01, 42, None);
let features = vec![1.0, 2.0];
for _ in 0..50 {
model.update(&features, 0.5, 1.0, 0.1);
}
let trained_pred = model.predict(&features);
let warm = model.clone_warm();
let warm_pred = warm.predict(&features);
assert!(
(warm_pred - trained_pred).abs() < 1e-10,
"warm clone should preserve predictions: trained={trained_pred}, warm={warm_pred}"
);
}
#[test]
fn leaf_model_type_default_is_closed_form() {
let default_type = LeafModelType::default();
assert!(
matches!(default_type, LeafModelType::ClosedForm),
"default LeafModelType should be ClosedForm, got {default_type:?}"
);
}
#[test]
fn leaf_model_type_create_all_variants() {
let features = vec![1.0, 2.0, 3.0];
let delta = 1e-7;
let mut closed = LeafModelType::ClosedForm.create(0, delta);
closed.update(&features, 1.0, 1.0, 0.1);
let p = closed.predict(&features);
assert!(p.is_finite(), "ClosedForm prediction should be finite");
let mut linear = LeafModelType::Linear {
learning_rate: 0.01,
decay: None,
use_adagrad: false,
}
.create(0, delta);
linear.update(&features, 1.0, 1.0, 0.1);
let p = linear.predict(&features);
assert!(p.is_finite(), "Linear prediction should be finite");
let mut mlp = LeafModelType::MLP {
hidden_size: 4,
learning_rate: 0.01,
decay: None,
}
.create(99, delta);
mlp.update(&features, 1.0, 1.0, 0.1);
let p = mlp.predict(&features);
assert!(p.is_finite(), "MLP prediction should be finite");
let mut adaptive = LeafModelType::Adaptive {
promote_to: Box::new(LeafModelType::Linear {
learning_rate: 0.01,
decay: None,
use_adagrad: false,
}),
}
.create(42, delta);
adaptive.update(&features, 1.0, 1.0, 0.1);
let p = adaptive.predict(&features);
assert!(p.is_finite(), "Adaptive prediction should be finite");
}
#[test]
fn adaptive_promotes_on_linear_target() {
let promote_to = LeafModelType::Linear {
learning_rate: 0.01,
decay: None,
use_adagrad: false,
};
let shadow = promote_to.create(42, 1e-7);
let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
let mut rng = 42u64;
for _ in 0..5000 {
let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
let features = vec![x1, x2];
let target = 3.0 * x1 + 2.0 * x2;
let pred = model.predict(&features);
let gradient = 2.0 * (pred - target);
let hessian = 2.0;
model.update(&features, gradient, hessian, 0.1);
}
assert!(
model.promoted,
"adaptive model should have promoted on linear target after 5000 samples"
);
}
#[test]
fn adaptive_does_not_promote_on_constant_target() {
let promote_to = LeafModelType::Linear {
learning_rate: 0.01,
decay: None,
use_adagrad: false,
};
let shadow = promote_to.create(42, 1e-7);
let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-7, 42);
for _ in 0..2000 {
let features = vec![1.0, 2.0];
let target = 5.0; let pred = model.predict(&features);
let gradient = 2.0 * (pred - target);
let hessian = 2.0;
model.update(&features, gradient, hessian, 0.1);
}
let pred = model.predict(&[1.0, 2.0]);
assert!(pred.is_finite(), "prediction should be finite");
}
#[test]
fn adaptive_clone_fresh_resets_promotion() {
let promote_to = LeafModelType::Linear {
learning_rate: 0.01,
decay: None,
use_adagrad: false,
};
let shadow = promote_to.create(42, 1e-3);
let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
let mut rng = 42u64;
for _ in 0..5000 {
let x = rand_f64(&mut rng) * 2.0 - 1.0;
let features = vec![x];
let pred = model.predict(&features);
model.update(&features, 2.0 * (pred - 3.0 * x), 2.0, 0.1);
}
let fresh = model.clone_fresh();
let p = fresh.predict(&[0.5]);
assert!(
p.abs() < 1e-10,
"fresh adaptive clone should predict ~0, got {p}"
);
}
}