Skip to main content

irithyll/tree/
leaf_model.rs

1//! Pluggable leaf prediction models for streaming decision trees.
2//!
3//! By default, leaves use a closed-form weight computed from accumulated
4//! gradient and hessian sums. Trainable leaf models replace this with learnable
5//! functions that capture more complex patterns within each leaf's partition.
6//!
7//! # Leaf model variants
8//!
9//! | Model | Prediction | Overhead | Best for |
10//! |-------|-----------|----------|----------|
11//! | [`ClosedFormLeaf`] | constant weight `-G/(H+lambda)` | zero | general use (default) |
12//! | [`LinearLeafModel`] | `w . x + b` (AdaGrad-optimized) | O(d) per update | low-depth trees (2--4) |
13//! | [`MLPLeafModel`] | single-hidden-layer neural net | O(d*h) per update | complex local patterns |
14//! | [`AdaptiveLeafModel`] | starts constant, auto-promotes | shadow model cost | automatic complexity allocation |
15//!
16//! # AdaGrad optimization
17//!
18//! [`LinearLeafModel`] uses per-weight AdaGrad accumulators for adaptive learning
19//! rates. Features at different scales converge at their natural rates without
20//! manual tuning. Combined with Newton scaling from the hessian, this gives
21//! second-order-informed, per-feature adaptive optimization.
22//!
23//! # Exponential forgetting
24//!
25//! [`LinearLeafModel`] and [`MLPLeafModel`] support an optional `decay` parameter
26//! that applies exponential weight decay before each update. This gives the model
27//! a finite memory horizon, adapting to concept drift in non-stationary streams.
28//! Typical values: 0.999 (slow drift) to 0.99 (fast drift).
29//!
30//! # Warm-starting on split
31//!
32//! When a leaf splits, child leaves can inherit the parent's learned function via
33//! [`LeafModel::clone_warm`]. Linear children start with the parent's weights
34//! (resetting optimizer state), converging faster than starting from scratch.
35//!
36//! # Adaptive promotion
37//!
38//! [`AdaptiveLeafModel`] runs a shadow model alongside the default closed-form
39//! model. Both are trained on every sample, and their per-sample losses are
40//! compared using the second-order Taylor approximation. When the Hoeffding bound
41//! (the tree's existing `delta` parameter) confirms the shadow model is
42//! statistically superior, the leaf promotes -- no arbitrary thresholds.
43
44/// A trainable prediction model that lives inside a decision tree leaf.
45///
46/// Implementations must be `Send + Sync` so trees can be shared across threads.
47pub trait LeafModel: Send + Sync {
48    /// Produce a prediction given input features.
49    fn predict(&self, features: &[f64]) -> f64;
50
51    /// Update model parameters given a gradient, hessian, and regularization lambda.
52    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64);
53
54    /// Create a fresh (zeroed / re-initialized) clone of this model's architecture.
55    fn clone_fresh(&self) -> Box<dyn LeafModel>;
56
57    /// Create a warm clone preserving learned weights but resetting optimizer state.
58    ///
59    /// Used when splitting a leaf: child leaves inherit the parent's learned
60    /// function as a starting point, converging faster than starting from scratch.
61    /// Defaults to [`clone_fresh`](LeafModel::clone_fresh) for models where
62    /// warm-starting is not meaningful (e.g. [`ClosedFormLeaf`]).
63    fn clone_warm(&self) -> Box<dyn LeafModel> {
64        self.clone_fresh()
65    }
66}
67
68// ---------------------------------------------------------------------------
69// ClosedFormLeaf
70// ---------------------------------------------------------------------------
71
72/// Leaf model that computes the optimal weight in closed form:
73/// `weight = -grad_sum / (hess_sum + lambda)`.
74///
75/// This is the standard leaf value used in gradient boosted trees.
76pub struct ClosedFormLeaf {
77    grad_sum: f64,
78    hess_sum: f64,
79    weight: f64,
80}
81
82impl Default for ClosedFormLeaf {
83    fn default() -> Self {
84        Self {
85            grad_sum: 0.0,
86            hess_sum: 0.0,
87            weight: 0.0,
88        }
89    }
90}
91
92impl ClosedFormLeaf {
93    /// Create a new zeroed closed-form leaf.
94    pub fn new() -> Self {
95        Self::default()
96    }
97}
98
99impl LeafModel for ClosedFormLeaf {
100    fn predict(&self, _features: &[f64]) -> f64 {
101        self.weight
102    }
103
104    fn update(&mut self, _features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
105        self.grad_sum += gradient;
106        self.hess_sum += hessian;
107        self.weight = -self.grad_sum / (self.hess_sum + lambda);
108    }
109
110    fn clone_fresh(&self) -> Box<dyn LeafModel> {
111        Box::new(ClosedFormLeaf::new())
112    }
113}
114
115// ---------------------------------------------------------------------------
116// LinearLeafModel
117// ---------------------------------------------------------------------------
118
119/// Online ridge regression leaf model with AdaGrad optimization.
120///
121/// Learns a linear function `w . x + b` using Newton-scaled gradient descent
122/// with per-weight AdaGrad accumulators for adaptive learning rates. Features
123/// at different scales converge at their natural rates without manual tuning.
124///
125/// Weights are lazily initialized on the first `update` call so the model
126/// adapts to whatever dimensionality arrives.
127///
128/// Optional exponential weight decay (`decay`) gives the model a finite memory
129/// horizon for non-stationary streams.
130pub struct LinearLeafModel {
131    weights: Vec<f64>,
132    bias: f64,
133    learning_rate: f64,
134    decay: Option<f64>,
135    use_adagrad: bool,
136    /// Per-weight squared gradient accumulator (AdaGrad).
137    sq_grad_accum: Vec<f64>,
138    /// Bias squared gradient accumulator (AdaGrad).
139    sq_bias_accum: f64,
140    initialized: bool,
141}
142
143impl LinearLeafModel {
144    /// Create a new linear leaf model with the given base learning rate,
145    /// optional exponential decay factor, and AdaGrad toggle.
146    ///
147    /// When `decay` is `Some(d)` with `d` in (0, 1), weights are multiplied
148    /// by `d` before each update, giving the model a memory half-life of
149    /// `ln(2) / ln(1/d)` samples.
150    ///
151    /// When `use_adagrad` is `true`, per-weight squared gradient accumulators
152    /// give each feature its own adaptive learning rate. When `false`, all
153    /// weights share a single Newton-scaled learning rate (plain SGD).
154    pub fn new(learning_rate: f64, decay: Option<f64>, use_adagrad: bool) -> Self {
155        Self {
156            weights: Vec::new(),
157            bias: 0.0,
158            learning_rate,
159            decay,
160            use_adagrad,
161            sq_grad_accum: Vec::new(),
162            sq_bias_accum: 0.0,
163            initialized: false,
164        }
165    }
166}
167
168/// AdaGrad epsilon to prevent division by zero.
169const ADAGRAD_EPS: f64 = 1e-8;
170
171impl LeafModel for LinearLeafModel {
172    fn predict(&self, features: &[f64]) -> f64 {
173        if !self.initialized {
174            return 0.0;
175        }
176        let mut dot = self.bias;
177        for (w, x) in self.weights.iter().zip(features.iter()) {
178            dot += w * x;
179        }
180        dot
181    }
182
183    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
184        if !self.initialized {
185            let d = features.len();
186            self.weights = vec![0.0; d];
187            self.sq_grad_accum = vec![0.0; d];
188            self.initialized = true;
189        }
190
191        // Exponential weight decay (forgetting old data).
192        if let Some(d) = self.decay {
193            for w in self.weights.iter_mut() {
194                *w *= d;
195            }
196            self.bias *= d;
197        }
198
199        // Newton-scaled base learning rate.
200        let base_lr = self.learning_rate / (hessian.abs() + lambda);
201
202        if self.use_adagrad {
203            // AdaGrad: per-weight adaptive learning rates.
204            for (i, (w, x)) in self.weights.iter_mut().zip(features.iter()).enumerate() {
205                let g = gradient * x;
206                self.sq_grad_accum[i] += g * g;
207                let adaptive_lr = base_lr / (self.sq_grad_accum[i].sqrt() + ADAGRAD_EPS);
208                *w -= adaptive_lr * g;
209            }
210            self.sq_bias_accum += gradient * gradient;
211            let bias_lr = base_lr / (self.sq_bias_accum.sqrt() + ADAGRAD_EPS);
212            self.bias -= bias_lr * gradient;
213        } else {
214            // Plain Newton-scaled SGD.
215            for (w, x) in self.weights.iter_mut().zip(features.iter()) {
216                *w -= base_lr * gradient * x;
217            }
218            self.bias -= base_lr * gradient;
219        }
220    }
221
222    fn clone_fresh(&self) -> Box<dyn LeafModel> {
223        Box::new(LinearLeafModel::new(
224            self.learning_rate,
225            self.decay,
226            self.use_adagrad,
227        ))
228    }
229
230    fn clone_warm(&self) -> Box<dyn LeafModel> {
231        Box::new(LinearLeafModel {
232            weights: self.weights.clone(),
233            bias: self.bias,
234            learning_rate: self.learning_rate,
235            decay: self.decay,
236            use_adagrad: self.use_adagrad,
237            // Reset AdaGrad accumulators -- the child's gradient landscape
238            // differs from the parent's, so accumulated curvature estimates
239            // don't transfer. Fresh accumulators let the child's learning
240            // rates adapt to its own region.
241            sq_grad_accum: vec![0.0; self.weights.len()],
242            sq_bias_accum: 0.0,
243            initialized: self.initialized,
244        })
245    }
246}
247
248// ---------------------------------------------------------------------------
249// MLPLeafModel
250// ---------------------------------------------------------------------------
251
252/// Single hidden layer MLP leaf model with ReLU activation.
253///
254/// Learns a nonlinear function via backpropagation with Newton-scaled
255/// learning rate. Weights are lazily initialized on the first `update` call
256/// using a deterministic xorshift64 PRNG so results are reproducible.
257///
258/// Optional exponential weight decay (`decay`) gives the model a finite memory
259/// horizon for non-stationary streams.
260pub struct MLPLeafModel {
261    hidden_weights: Vec<Vec<f64>>, // [hidden_size][input_size]
262    hidden_bias: Vec<f64>,
263    output_weights: Vec<f64>,
264    output_bias: f64,
265    hidden_size: usize,
266    learning_rate: f64,
267    decay: Option<f64>,
268    seed: u64,
269    initialized: bool,
270    hidden_activations: Vec<f64>,
271    hidden_pre_activations: Vec<f64>,
272}
273
274impl MLPLeafModel {
275    /// Create a new MLP leaf model with the given hidden layer size, learning rate,
276    /// seed, and optional decay.
277    ///
278    /// The seed controls deterministic weight initialization. Different seeds
279    /// produce different initial weights, which is critical for ensemble diversity
280    /// when multiple MLP leaves share the same `hidden_size`.
281    pub fn new(hidden_size: usize, learning_rate: f64, seed: u64, decay: Option<f64>) -> Self {
282        Self {
283            hidden_weights: Vec::new(),
284            hidden_bias: Vec::new(),
285            output_weights: Vec::new(),
286            output_bias: 0.0,
287            hidden_size,
288            learning_rate,
289            decay,
290            seed,
291            initialized: false,
292            hidden_activations: Vec::new(),
293            hidden_pre_activations: Vec::new(),
294        }
295    }
296
297    /// Initialize weights using xorshift64, scaled to [-0.1, 0.1].
298    fn initialize(&mut self, input_size: usize) {
299        let mut state = self.seed ^ (self.hidden_size as u64);
300
301        self.hidden_weights = Vec::with_capacity(self.hidden_size);
302        for _ in 0..self.hidden_size {
303            let mut row = Vec::with_capacity(input_size);
304            for _ in 0..input_size {
305                let r = xorshift64(&mut state);
306                // Map u64 to [-0.1, 0.1]
307                let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
308                row.push(val);
309            }
310            self.hidden_weights.push(row);
311        }
312
313        self.hidden_bias = Vec::with_capacity(self.hidden_size);
314        for _ in 0..self.hidden_size {
315            let r = xorshift64(&mut state);
316            let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
317            self.hidden_bias.push(val);
318        }
319
320        self.output_weights = Vec::with_capacity(self.hidden_size);
321        for _ in 0..self.hidden_size {
322            let r = xorshift64(&mut state);
323            let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
324            self.output_weights.push(val);
325        }
326
327        {
328            let r = xorshift64(&mut state);
329            self.output_bias = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
330        }
331
332        self.hidden_activations = vec![0.0; self.hidden_size];
333        self.hidden_pre_activations = vec![0.0; self.hidden_size];
334        self.initialized = true;
335    }
336
337    /// Forward pass: compute hidden pre-activations, ReLU activations, and output.
338    fn forward(&mut self, features: &[f64]) -> f64 {
339        // Hidden layer
340        for h in 0..self.hidden_size {
341            let mut z = self.hidden_bias[h];
342            for (j, x) in features.iter().enumerate() {
343                if j < self.hidden_weights[h].len() {
344                    z += self.hidden_weights[h][j] * x;
345                }
346            }
347            self.hidden_pre_activations[h] = z;
348            // ReLU
349            self.hidden_activations[h] = if z > 0.0 { z } else { 0.0 };
350        }
351
352        // Output layer
353        let mut out = self.output_bias;
354        for (w, a) in self
355            .output_weights
356            .iter()
357            .zip(self.hidden_activations.iter())
358        {
359            out += w * a;
360        }
361        out
362    }
363}
364
365impl LeafModel for MLPLeafModel {
366    fn predict(&self, features: &[f64]) -> f64 {
367        if !self.initialized {
368            return 0.0;
369        }
370        // Non-mutating forward pass (can't store activations, recompute locally)
371        let hidden_acts: Vec<f64> = self
372            .hidden_weights
373            .iter()
374            .zip(self.hidden_bias.iter())
375            .map(|(hw, &hb)| {
376                let mut z = hb;
377                for (j, x) in features.iter().enumerate() {
378                    if j < hw.len() {
379                        z += hw[j] * x;
380                    }
381                }
382                if z > 0.0 {
383                    z
384                } else {
385                    0.0
386                }
387            })
388            .collect();
389        let mut out = self.output_bias;
390        for (w, a) in self.output_weights.iter().zip(hidden_acts.iter()) {
391            out += w * a;
392        }
393        out
394    }
395
396    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
397        if !self.initialized {
398            self.initialize(features.len());
399        }
400
401        // Exponential weight decay (forgetting old data).
402        if let Some(d) = self.decay {
403            for row in self.hidden_weights.iter_mut() {
404                for w in row.iter_mut() {
405                    *w *= d;
406                }
407            }
408            for b in self.hidden_bias.iter_mut() {
409                *b *= d;
410            }
411            for w in self.output_weights.iter_mut() {
412                *w *= d;
413            }
414            self.output_bias *= d;
415        }
416
417        // Forward pass (stores activations for backprop)
418        let _output = self.forward(features);
419
420        let effective_lr = self.learning_rate / (hessian.abs() + lambda);
421
422        // Backprop: output gradient is the incoming `gradient` (chain rule from loss)
423        let d_output = gradient;
424
425        // Gradient for output weights and bias
426        // d_loss/d_output_w[h] = d_output * hidden_activations[h]
427        // d_loss/d_output_bias = d_output
428        for h in 0..self.hidden_size {
429            self.output_weights[h] -= effective_lr * d_output * self.hidden_activations[h];
430        }
431        self.output_bias -= effective_lr * d_output;
432
433        // Gradient for hidden layer
434        for h in 0..self.hidden_size {
435            // d_loss/d_hidden_act[h] = d_output * output_weights[h]
436            let d_hidden_act = d_output * self.output_weights[h];
437
438            // ReLU derivative
439            let d_relu = if self.hidden_pre_activations[h] > 0.0 {
440                d_hidden_act
441            } else {
442                0.0
443            };
444
445            // Update hidden weights and bias
446            for (j, x) in features.iter().enumerate() {
447                if j < self.hidden_weights[h].len() {
448                    self.hidden_weights[h][j] -= effective_lr * d_relu * x;
449                }
450            }
451            self.hidden_bias[h] -= effective_lr * d_relu;
452        }
453    }
454
455    fn clone_fresh(&self) -> Box<dyn LeafModel> {
456        // Derive a new seed so each fresh clone gets distinct initial weights.
457        let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
458        Box::new(MLPLeafModel::new(
459            self.hidden_size,
460            self.learning_rate,
461            derived_seed,
462            self.decay,
463        ))
464    }
465
466    fn clone_warm(&self) -> Box<dyn LeafModel> {
467        Box::new(MLPLeafModel {
468            hidden_weights: self.hidden_weights.clone(),
469            hidden_bias: self.hidden_bias.clone(),
470            output_weights: self.output_weights.clone(),
471            output_bias: self.output_bias,
472            hidden_size: self.hidden_size,
473            learning_rate: self.learning_rate,
474            decay: self.decay,
475            // Derive a new seed so warm clones diverge if they re-initialize.
476            seed: self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(2),
477            initialized: self.initialized,
478            hidden_activations: vec![0.0; self.hidden_size],
479            hidden_pre_activations: vec![0.0; self.hidden_size],
480        })
481    }
482}
483
484// ---------------------------------------------------------------------------
485// AdaptiveLeafModel
486// ---------------------------------------------------------------------------
487
488/// Leaf model that starts as closed-form and promotes to a more complex model
489/// when the Hoeffding bound confirms it is statistically superior.
490///
491/// Runs a shadow model alongside the active (closed-form) model. On each
492/// update, both models are trained and their per-sample losses compared using
493/// the second-order Taylor approximation:
494///
495/// ```text
496/// loss_i = gradient * prediction + 0.5 * hessian * prediction^2
497/// advantage_i = loss_active_i - loss_shadow_i
498/// ```
499///
500/// When `mean(advantage) > epsilon` where epsilon is the Hoeffding bound
501/// (using the tree's `delta` parameter), the shadow model is promoted to
502/// active and the overhead drops to zero.
503///
504/// This uses the **same statistical guarantee** as the tree's split decisions --
505/// no arbitrary thresholds.
506pub struct AdaptiveLeafModel {
507    /// The currently active prediction model. Starts as ClosedForm.
508    active: Box<dyn LeafModel>,
509    /// Shadow model being evaluated against the active model.
510    shadow: Box<dyn LeafModel>,
511    /// Configuration of the shadow model type (for cloning).
512    promote_to: LeafModelType,
513    /// Cumulative loss advantage: sum(loss_active - loss_shadow).
514    /// Positive means the shadow is winning.
515    cumulative_advantage: f64,
516    /// Number of samples seen for the Hoeffding bound.
517    n: u64,
518    /// Running maximum |loss_diff| for range estimation (R in the bound).
519    max_loss_diff: f64,
520    /// Hoeffding confidence parameter (from tree config).
521    delta: f64,
522    /// Whether the shadow has been promoted.
523    promoted: bool,
524    /// Seed for reproducible cloning.
525    seed: u64,
526}
527
528impl AdaptiveLeafModel {
529    /// Create a new adaptive leaf model.
530    ///
531    /// The active model starts as `ClosedFormLeaf`. The shadow model is the
532    /// candidate that will be promoted if it proves statistically superior.
533    pub fn new(
534        shadow: Box<dyn LeafModel>,
535        promote_to: LeafModelType,
536        delta: f64,
537        seed: u64,
538    ) -> Self {
539        Self {
540            active: Box::new(ClosedFormLeaf::new()),
541            shadow,
542            promote_to,
543            cumulative_advantage: 0.0,
544            n: 0,
545            max_loss_diff: 0.0,
546            delta,
547            promoted: false,
548            seed,
549        }
550    }
551}
552
553impl LeafModel for AdaptiveLeafModel {
554    fn predict(&self, features: &[f64]) -> f64 {
555        self.active.predict(features)
556    }
557
558    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
559        if self.promoted {
560            // Post-promotion: only update the promoted model.
561            self.active.update(features, gradient, hessian, lambda);
562            return;
563        }
564
565        // Compute predictions BEFORE updating (evaluate current state).
566        let pred_active = self.active.predict(features);
567        let pred_shadow = self.shadow.predict(features);
568
569        // Second-order Taylor loss approximation for each model.
570        // L(pred) ~= gradient * pred + 0.5 * hessian * pred^2
571        // This is the same loss proxy that XGBoost gain uses.
572        let loss_active = gradient * pred_active + 0.5 * hessian * pred_active * pred_active;
573        let loss_shadow = gradient * pred_shadow + 0.5 * hessian * pred_shadow * pred_shadow;
574
575        // Positive advantage means the shadow model is better (lower loss).
576        let diff = loss_active - loss_shadow;
577        self.cumulative_advantage += diff;
578        self.n += 1;
579
580        // Track range for the Hoeffding bound.
581        let abs_diff = diff.abs();
582        if abs_diff > self.max_loss_diff {
583            self.max_loss_diff = abs_diff;
584        }
585
586        // Update both models.
587        self.active.update(features, gradient, hessian, lambda);
588        self.shadow.update(features, gradient, hessian, lambda);
589
590        // Hoeffding bound test: is the shadow statistically better?
591        // epsilon = sqrt(R^2 * ln(1/delta) / (2*n))
592        // Promote when mean_advantage > epsilon.
593        if self.n >= 10 && self.max_loss_diff > 0.0 {
594            let mean_advantage = self.cumulative_advantage / self.n as f64;
595            if mean_advantage > 0.0 {
596                let r_squared = self.max_loss_diff * self.max_loss_diff;
597                let ln_inv_delta = (1.0 / self.delta).ln();
598                let epsilon = (r_squared * ln_inv_delta / (2.0 * self.n as f64)).sqrt();
599
600                if mean_advantage > epsilon {
601                    // Promote: swap shadow into active, drop the old active.
602                    self.promoted = true;
603                    std::mem::swap(&mut self.active, &mut self.shadow);
604                }
605            }
606        }
607    }
608
609    fn clone_fresh(&self) -> Box<dyn LeafModel> {
610        let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
611        Box::new(AdaptiveLeafModel::new(
612            self.promote_to.create(derived_seed, self.delta),
613            self.promote_to.clone(),
614            self.delta,
615            derived_seed,
616        ))
617    }
618    // clone_warm: default (= clone_fresh). Promotion must be re-earned
619    // in each leaf's region -- the parent's promotion does not transfer.
620}
621
622// We need Send + Sync for AdaptiveLeafModel because it contains Box<dyn LeafModel>
623// which already requires Send + Sync. The compiler should derive these automatically,
624// but let's verify:
625// SAFETY: All fields are Send + Sync (Box<dyn LeafModel> requires it, f64/u64/bool are Send+Sync).
626unsafe impl Send for AdaptiveLeafModel {}
627unsafe impl Sync for AdaptiveLeafModel {}
628
629// ---------------------------------------------------------------------------
630// LeafModelType
631// ---------------------------------------------------------------------------
632
633/// Describes which leaf model architecture to use.
634///
635/// Used by tree builders to construct fresh leaf models when creating new leaves.
636///
637/// # Variants
638///
639/// - **`ClosedForm`** -- Standard constant leaf weight. Zero overhead (default).
640/// - **`Linear`** -- Per-leaf online ridge regression with AdaGrad optimization.
641///   Each leaf learns a local linear surface `w . x + b`. Recommended for
642///   low-depth trees (depth 2--4). Optional `decay` for concept drift.
643/// - **`MLP`** -- Per-leaf single-hidden-layer neural network with ReLU.
644///   Optional `decay` for concept drift.
645/// - **`Adaptive`** -- Starts as closed-form, auto-promotes to `promote_to`
646///   when the Hoeffding bound confirms it is statistically superior. Uses the
647///   tree's existing `delta` parameter -- no arbitrary thresholds.
648#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
649pub enum LeafModelType {
650    /// Standard closed-form leaf weight.
651    #[default]
652    ClosedForm,
653
654    /// Online ridge regression, optionally with AdaGrad optimization.
655    ///
656    /// `decay`: optional exponential weight decay for non-stationary streams.
657    /// Typical values: 0.999 (slow drift) to 0.99 (fast drift).
658    ///
659    /// `use_adagrad`: when `true`, per-weight squared gradient accumulators
660    /// give each feature its own adaptive learning rate. When `false`
661    /// (default), all weights share a single Newton-scaled learning rate.
662    Linear {
663        learning_rate: f64,
664        #[serde(default)]
665        decay: Option<f64>,
666        #[serde(default)]
667        use_adagrad: bool,
668    },
669
670    /// Single hidden layer MLP with the given hidden size and learning rate.
671    ///
672    /// `decay`: optional exponential weight decay for non-stationary streams.
673    MLP {
674        hidden_size: usize,
675        learning_rate: f64,
676        #[serde(default)]
677        decay: Option<f64>,
678    },
679
680    /// Adaptive leaf that starts as closed-form and auto-promotes when
681    /// the Hoeffding bound confirms the promoted model is better.
682    ///
683    /// The `promote_to` field specifies the shadow model type to evaluate
684    /// against the default closed-form baseline.
685    Adaptive { promote_to: Box<LeafModelType> },
686}
687
688impl LeafModelType {
689    /// Create a fresh boxed leaf model of this type.
690    ///
691    /// The `seed` parameter controls deterministic initialization (MLP weights,
692    /// adaptive model seeding). The `delta` parameter is the Hoeffding bound
693    /// confidence level, used by [`Adaptive`](LeafModelType::Adaptive) leaves
694    /// for promotion decisions. For non-adaptive types, `delta` is unused.
695    pub fn create(&self, seed: u64, delta: f64) -> Box<dyn LeafModel> {
696        match self {
697            Self::ClosedForm => Box::new(ClosedFormLeaf::new()),
698            Self::Linear {
699                learning_rate,
700                decay,
701                use_adagrad,
702            } => Box::new(LinearLeafModel::new(*learning_rate, *decay, *use_adagrad)),
703            Self::MLP {
704                hidden_size,
705                learning_rate,
706                decay,
707            } => Box::new(MLPLeafModel::new(
708                *hidden_size,
709                *learning_rate,
710                seed,
711                *decay,
712            )),
713            Self::Adaptive { promote_to } => Box::new(AdaptiveLeafModel::new(
714                promote_to.create(seed, delta),
715                *promote_to.clone(),
716                delta,
717                seed,
718            )),
719        }
720    }
721}
722
723// ---------------------------------------------------------------------------
724// Shared utility
725// ---------------------------------------------------------------------------
726
727/// Xorshift64 PRNG for deterministic weight initialization.
728fn xorshift64(state: &mut u64) -> u64 {
729    let mut s = *state;
730    s ^= s << 13;
731    s ^= s >> 7;
732    s ^= s << 17;
733    *state = s;
734    s
735}
736
737// ---------------------------------------------------------------------------
738// Tests
739// ---------------------------------------------------------------------------
740
741#[cfg(test)]
742mod tests {
743    use super::*;
744
745    /// Xorshift64 for deterministic test data generation.
746    fn xorshift64(state: &mut u64) -> u64 {
747        let mut s = *state;
748        s ^= s << 13;
749        s ^= s >> 7;
750        s ^= s << 17;
751        *state = s;
752        s
753    }
754
755    /// Convert xorshift output to f64 in [0, 1).
756    fn rand_f64(state: &mut u64) -> f64 {
757        xorshift64(state) as f64 / u64::MAX as f64
758    }
759
760    #[test]
761    fn closed_form_matches_formula() {
762        let mut leaf = ClosedFormLeaf::new();
763        let lambda = 1.0;
764
765        // Accumulate several gradient/hessian pairs
766        let updates = [(0.5, 1.0), (-0.3, 0.8), (1.2, 2.0), (-0.1, 0.5)];
767        let mut grad_sum = 0.0;
768        let mut hess_sum = 0.0;
769
770        for &(g, h) in &updates {
771            leaf.update(&[], g, h, lambda);
772            grad_sum += g;
773            hess_sum += h;
774        }
775
776        let expected = -grad_sum / (hess_sum + lambda);
777        let predicted = leaf.predict(&[]);
778
779        assert!(
780            (predicted - expected).abs() < 1e-12,
781            "closed form mismatch: got {predicted}, expected {expected}"
782        );
783    }
784
785    #[test]
786    fn closed_form_clone_fresh_resets() {
787        let mut leaf = ClosedFormLeaf::new();
788        leaf.update(&[], 5.0, 2.0, 1.0);
789        assert!(
790            leaf.predict(&[]).abs() > 0.0,
791            "leaf should have non-zero weight after update"
792        );
793
794        let fresh = leaf.clone_fresh();
795        assert!(
796            fresh.predict(&[]).abs() < 1e-15,
797            "fresh clone should predict 0, got {}",
798            fresh.predict(&[])
799        );
800    }
801
802    #[test]
803    fn linear_converges_on_linear_target() {
804        // Target: y = 2*x1 + 3*x2
805        let mut model = LinearLeafModel::new(0.01, None, false);
806        let lambda = 0.1;
807        let mut rng = 42u64;
808
809        for _ in 0..2000 {
810            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
811            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
812            let features = vec![x1, x2];
813            let target = 2.0 * x1 + 3.0 * x2;
814
815            let pred = model.predict(&features);
816            let gradient = 2.0 * (pred - target);
817            let hessian = 2.0;
818            model.update(&features, gradient, hessian, lambda);
819        }
820
821        let test_features = vec![0.5, -0.3];
822        let target = 2.0 * 0.5 + 3.0 * (-0.3);
823        let pred = model.predict(&test_features);
824
825        assert!(
826            (pred - target).abs() < 1.0,
827            "linear model should converge within 1.0 of target: pred={pred}, target={target}"
828        );
829    }
830
831    #[test]
832    fn linear_uninitialized_predicts_zero() {
833        let model = LinearLeafModel::new(0.01, None, false);
834        let pred = model.predict(&[1.0, 2.0, 3.0]);
835        assert!(
836            pred.abs() < 1e-15,
837            "uninitialized linear model should predict 0, got {pred}"
838        );
839    }
840
841    #[test]
842    fn linear_clone_warm_preserves_weights() {
843        let mut model = LinearLeafModel::new(0.01, None, false);
844        let features = vec![1.0, 2.0];
845
846        // Train it
847        for i in 0..100 {
848            let target = 3.0 * features[0] + 2.0 * features[1];
849            let pred = model.predict(&features);
850            let gradient = 2.0 * (pred - target);
851            model.update(&features, gradient, 2.0, 0.1);
852            // Avoid unused variable warning
853            let _ = i;
854        }
855
856        let trained_pred = model.predict(&features);
857        assert!(
858            trained_pred.abs() > 0.01,
859            "model should have learned something"
860        );
861
862        // Warm clone should predict similarly
863        let warm = model.clone_warm();
864        let warm_pred = warm.predict(&features);
865        assert!(
866            (warm_pred - trained_pred).abs() < 1e-12,
867            "warm clone should preserve weights: trained={trained_pred}, warm={warm_pred}"
868        );
869
870        // Fresh clone should predict 0
871        let fresh = model.clone_fresh();
872        let fresh_pred = fresh.predict(&features);
873        assert!(
874            fresh_pred.abs() < 1e-15,
875            "fresh clone should predict 0, got {fresh_pred}"
876        );
877    }
878
879    #[test]
880    fn linear_decay_forgets_old_data() {
881        // Decay should pull weights toward zero when training stops,
882        // demonstrating that old learned state is forgotten.
883        let mut model_decay = LinearLeafModel::new(0.05, Some(0.99), false);
884        let mut model_no_decay = LinearLeafModel::new(0.05, None, false);
885        let features = vec![1.0];
886        let lambda = 0.1;
887
888        // Train both on target = 5.0
889        for _ in 0..500 {
890            let pred_d = model_decay.predict(&features);
891            let pred_n = model_no_decay.predict(&features);
892            model_decay.update(&features, 2.0 * (pred_d - 5.0), 2.0, lambda);
893            model_no_decay.update(&features, 2.0 * (pred_n - 5.0), 2.0, lambda);
894        }
895
896        // Both should have learned roughly the same function.
897        let pred_d_trained = model_decay.predict(&features);
898        let pred_n_trained = model_no_decay.predict(&features);
899        assert!(
900            (pred_d_trained - 5.0).abs() < 2.0,
901            "decay model should approximate target"
902        );
903        assert!(
904            (pred_n_trained - 5.0).abs() < 2.0,
905            "no-decay model should approximate target"
906        );
907
908        // Now apply 200 rounds of zero-gradient updates (data stopped).
909        // The decay model's weights should shrink toward zero,
910        // while the no-decay model retains its learned state.
911        for _ in 0..200 {
912            model_decay.update(&features, 0.0, 1.0, lambda);
913            model_no_decay.update(&features, 0.0, 1.0, lambda);
914        }
915
916        let pred_d_after = model_decay.predict(&features);
917        let pred_n_after = model_no_decay.predict(&features);
918
919        // Decay model should have drifted toward zero.
920        assert!(
921            pred_d_after.abs() < pred_n_after.abs(),
922            "decay model should forget: decay pred={pred_d_after:.3}, no-decay pred={pred_n_after:.3}"
923        );
924    }
925
926    #[test]
927    fn mlp_produces_finite_predictions() {
928        let model_uninit = MLPLeafModel::new(4, 0.01, 42, None);
929        let features = vec![1.0, 2.0, 3.0];
930
931        let pred_before = model_uninit.predict(&features);
932        assert!(
933            pred_before.is_finite(),
934            "uninit prediction should be finite"
935        );
936        assert!(
937            pred_before.abs() < 1e-15,
938            "uninit prediction should be 0, got {pred_before}"
939        );
940
941        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
942        for _ in 0..10 {
943            model.update(&features, 0.5, 1.0, 0.1);
944        }
945        let pred_after = model.predict(&features);
946        assert!(
947            pred_after.is_finite(),
948            "prediction after training should be finite, got {pred_after}"
949        );
950    }
951
952    #[test]
953    fn mlp_loss_decreases() {
954        let mut model = MLPLeafModel::new(8, 0.05, 123, None);
955        let features = vec![1.0, -0.5, 0.3];
956        let target = 2.5;
957        let lambda = 0.1;
958
959        model.update(&features, 0.0, 1.0, lambda); // dummy update to initialize
960        let initial_pred = model.predict(&features);
961        let initial_error = (initial_pred - target).abs();
962
963        for _ in 0..200 {
964            let pred = model.predict(&features);
965            let gradient = 2.0 * (pred - target);
966            let hessian = 2.0;
967            model.update(&features, gradient, hessian, lambda);
968        }
969
970        let final_pred = model.predict(&features);
971        let final_error = (final_pred - target).abs();
972
973        assert!(
974            final_error < initial_error,
975            "MLP error should decrease: initial={initial_error}, final={final_error}"
976        );
977    }
978
979    #[test]
980    fn mlp_clone_fresh_resets() {
981        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
982        let features = vec![1.0, 2.0];
983
984        for _ in 0..20 {
985            model.update(&features, 0.5, 1.0, 0.1);
986        }
987
988        let trained_pred = model.predict(&features);
989        assert!(
990            trained_pred.abs() > 1e-10,
991            "trained model should have non-zero prediction"
992        );
993
994        let fresh = model.clone_fresh();
995        let fresh_pred = fresh.predict(&features);
996        assert!(
997            fresh_pred.abs() < 1e-15,
998            "fresh clone should predict 0, got {fresh_pred}"
999        );
1000    }
1001
1002    #[test]
1003    fn mlp_clone_warm_preserves_weights() {
1004        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
1005        let features = vec![1.0, 2.0];
1006
1007        for _ in 0..50 {
1008            model.update(&features, 0.5, 1.0, 0.1);
1009        }
1010
1011        let trained_pred = model.predict(&features);
1012        let warm = model.clone_warm();
1013        let warm_pred = warm.predict(&features);
1014
1015        assert!(
1016            (warm_pred - trained_pred).abs() < 1e-10,
1017            "warm clone should preserve predictions: trained={trained_pred}, warm={warm_pred}"
1018        );
1019    }
1020
1021    #[test]
1022    fn leaf_model_type_default_is_closed_form() {
1023        let default_type = LeafModelType::default();
1024        assert!(
1025            matches!(default_type, LeafModelType::ClosedForm),
1026            "default LeafModelType should be ClosedForm, got {default_type:?}"
1027        );
1028    }
1029
1030    #[test]
1031    fn leaf_model_type_create_all_variants() {
1032        let features = vec![1.0, 2.0, 3.0];
1033        let delta = 1e-7;
1034
1035        // ClosedForm
1036        let mut closed = LeafModelType::ClosedForm.create(0, delta);
1037        closed.update(&features, 1.0, 1.0, 0.1);
1038        let p = closed.predict(&features);
1039        assert!(p.is_finite(), "ClosedForm prediction should be finite");
1040
1041        // Linear
1042        let mut linear = LeafModelType::Linear {
1043            learning_rate: 0.01,
1044            decay: None,
1045            use_adagrad: false,
1046        }
1047        .create(0, delta);
1048        linear.update(&features, 1.0, 1.0, 0.1);
1049        let p = linear.predict(&features);
1050        assert!(p.is_finite(), "Linear prediction should be finite");
1051
1052        // MLP
1053        let mut mlp = LeafModelType::MLP {
1054            hidden_size: 4,
1055            learning_rate: 0.01,
1056            decay: None,
1057        }
1058        .create(99, delta);
1059        mlp.update(&features, 1.0, 1.0, 0.1);
1060        let p = mlp.predict(&features);
1061        assert!(p.is_finite(), "MLP prediction should be finite");
1062
1063        // Adaptive (promoting to Linear)
1064        let mut adaptive = LeafModelType::Adaptive {
1065            promote_to: Box::new(LeafModelType::Linear {
1066                learning_rate: 0.01,
1067                decay: None,
1068                use_adagrad: false,
1069            }),
1070        }
1071        .create(42, delta);
1072        adaptive.update(&features, 1.0, 1.0, 0.1);
1073        let p = adaptive.predict(&features);
1074        assert!(p.is_finite(), "Adaptive prediction should be finite");
1075    }
1076
1077    #[test]
1078    fn adaptive_promotes_on_linear_target() {
1079        // The shadow (Linear) should eventually outperform ClosedForm
1080        // on a target that varies with features.
1081        let promote_to = LeafModelType::Linear {
1082            learning_rate: 0.01,
1083            decay: None,
1084            use_adagrad: false,
1085        };
1086        let shadow = promote_to.create(42, 1e-7);
1087        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1088
1089        let mut rng = 42u64;
1090        for _ in 0..5000 {
1091            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1092            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1093            let features = vec![x1, x2];
1094            let target = 3.0 * x1 + 2.0 * x2;
1095
1096            let pred = model.predict(&features);
1097            let gradient = 2.0 * (pred - target);
1098            let hessian = 2.0;
1099            model.update(&features, gradient, hessian, 0.1);
1100        }
1101
1102        // After enough samples on a linear target, the Linear shadow
1103        // should have been promoted.
1104        assert!(
1105            model.promoted,
1106            "adaptive model should have promoted on linear target after 5000 samples"
1107        );
1108    }
1109
1110    #[test]
1111    fn adaptive_does_not_promote_on_constant_target() {
1112        // On a constant target, ClosedForm is optimal -- no promotion expected.
1113        let promote_to = LeafModelType::Linear {
1114            learning_rate: 0.01,
1115            decay: None,
1116            use_adagrad: false,
1117        };
1118        let shadow = promote_to.create(42, 1e-7);
1119        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-7, 42);
1120
1121        for _ in 0..2000 {
1122            let features = vec![1.0, 2.0];
1123            let target = 5.0; // constant -- no feature dependence
1124            let pred = model.predict(&features);
1125            let gradient = 2.0 * (pred - target);
1126            let hessian = 2.0;
1127            model.update(&features, gradient, hessian, 0.1);
1128        }
1129
1130        // With strict delta (1e-7) and a constant target, the Linear model
1131        // shouldn't gain a statistically significant advantage.
1132        // (It may or may not promote -- the point is it shouldn't be obvious.)
1133        // This test mainly verifies the mechanism doesn't crash and is conservative.
1134        let pred = model.predict(&[1.0, 2.0]);
1135        assert!(pred.is_finite(), "prediction should be finite");
1136    }
1137
1138    #[test]
1139    fn adaptive_clone_fresh_resets_promotion() {
1140        let promote_to = LeafModelType::Linear {
1141            learning_rate: 0.01,
1142            decay: None,
1143            use_adagrad: false,
1144        };
1145        let shadow = promote_to.create(42, 1e-3);
1146        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1147
1148        // Force promotion via many linear-target samples
1149        let mut rng = 42u64;
1150        for _ in 0..5000 {
1151            let x = rand_f64(&mut rng) * 2.0 - 1.0;
1152            let features = vec![x];
1153            let pred = model.predict(&features);
1154            model.update(&features, 2.0 * (pred - 3.0 * x), 2.0, 0.1);
1155        }
1156
1157        let fresh = model.clone_fresh();
1158        // Fresh clone should predict 0 (reset state).
1159        let p = fresh.predict(&[0.5]);
1160        assert!(
1161            p.abs() < 1e-10,
1162            "fresh adaptive clone should predict ~0, got {p}"
1163        );
1164    }
1165}