Skip to main content

irithyll_core/tree/
leaf_model.rs

1//! Pluggable leaf prediction models for streaming decision trees.
2//!
3//! By default, leaves use a closed-form weight computed from accumulated
4//! gradient and hessian sums. Trainable leaf models replace this with learnable
5//! functions that capture more complex patterns within each leaf's partition.
6//!
7//! # Leaf model variants
8//!
9//! | Model | Prediction | Overhead | Best for |
10//! |-------|-----------|----------|----------|
11//! | [`ClosedFormLeaf`] | constant weight `-G/(H+lambda)` | zero | general use (default) |
12//! | [`LinearLeafModel`] | `w . x + b` (AdaGrad-optimized) | O(d) per update | low-depth trees (2--4) |
13//! | [`MLPLeafModel`] | single-hidden-layer neural net | O(d*h) per update | complex local patterns |
14//! | [`AdaptiveLeafModel`] | starts constant, auto-promotes | shadow model cost | automatic complexity allocation |
15//!
16//! # AdaGrad optimization
17//!
18//! [`LinearLeafModel`] uses per-weight AdaGrad accumulators for adaptive learning
19//! rates. Features at different scales converge at their natural rates without
20//! manual tuning. Combined with Newton scaling from the hessian, this gives
21//! second-order-informed, per-feature adaptive optimization.
22//!
23//! # Exponential forgetting
24//!
25//! [`LinearLeafModel`] and [`MLPLeafModel`] support an optional `decay` parameter
26//! that applies exponential weight decay before each update. This gives the model
27//! a finite memory horizon, adapting to concept drift in non-stationary streams.
28//! Typical values: 0.999 (slow drift) to 0.99 (fast drift).
29//!
30//! # Warm-starting on split
31//!
32//! When a leaf splits, child leaves can inherit the parent's learned function via
33//! [`LeafModel::clone_warm`]. Linear children start with the parent's weights
34//! (resetting optimizer state), converging faster than starting from scratch.
35//!
36//! # Adaptive promotion
37//!
38//! [`AdaptiveLeafModel`] runs a shadow model alongside the default closed-form
39//! model. Both are trained on every sample, and their per-sample losses are
40//! compared using the second-order Taylor approximation. When the Hoeffding bound
41//! (the tree's existing `delta` parameter) confirms the shadow model is
42//! statistically superior, the leaf promotes -- no arbitrary thresholds.
43
44use alloc::boxed::Box;
45use alloc::vec;
46use alloc::vec::Vec;
47
48use crate::math;
49
50/// A trainable prediction model that lives inside a decision tree leaf.
51///
52/// Implementations must be `Send + Sync` so trees can be shared across threads.
53pub trait LeafModel: Send + Sync {
54    /// Produce a prediction given input features.
55    fn predict(&self, features: &[f64]) -> f64;
56
57    /// Update model parameters given a gradient, hessian, and regularization lambda.
58    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64);
59
60    /// Create a fresh (zeroed / re-initialized) clone of this model's architecture.
61    fn clone_fresh(&self) -> Box<dyn LeafModel>;
62
63    /// Create a warm clone preserving learned weights but resetting optimizer state.
64    ///
65    /// Used when splitting a leaf: child leaves inherit the parent's learned
66    /// function as a starting point, converging faster than starting from scratch.
67    /// Defaults to [`clone_fresh`](LeafModel::clone_fresh) for models where
68    /// warm-starting is not meaningful (e.g. [`ClosedFormLeaf`]).
69    fn clone_warm(&self) -> Box<dyn LeafModel> {
70        self.clone_fresh()
71    }
72}
73
74// ---------------------------------------------------------------------------
75// ClosedFormLeaf
76// ---------------------------------------------------------------------------
77
78/// Leaf model that computes the optimal weight in closed form:
79/// `weight = -grad_sum / (hess_sum + lambda)`.
80///
81/// This is the standard leaf value used in gradient boosted trees.
82pub struct ClosedFormLeaf {
83    grad_sum: f64,
84    hess_sum: f64,
85    weight: f64,
86}
87
88impl Default for ClosedFormLeaf {
89    fn default() -> Self {
90        Self {
91            grad_sum: 0.0,
92            hess_sum: 0.0,
93            weight: 0.0,
94        }
95    }
96}
97
98impl ClosedFormLeaf {
99    /// Create a new zeroed closed-form leaf.
100    pub fn new() -> Self {
101        Self::default()
102    }
103}
104
105impl LeafModel for ClosedFormLeaf {
106    fn predict(&self, _features: &[f64]) -> f64 {
107        self.weight
108    }
109
110    fn update(&mut self, _features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
111        self.grad_sum += gradient;
112        self.hess_sum += hessian;
113        self.weight = -self.grad_sum / (self.hess_sum + lambda);
114    }
115
116    fn clone_fresh(&self) -> Box<dyn LeafModel> {
117        Box::new(ClosedFormLeaf::new())
118    }
119}
120
121// ---------------------------------------------------------------------------
122// LinearLeafModel
123// ---------------------------------------------------------------------------
124
125/// Online ridge regression leaf model with AdaGrad optimization.
126///
127/// Learns a linear function `w . x + b` using Newton-scaled gradient descent
128/// with per-weight AdaGrad accumulators for adaptive learning rates. Features
129/// at different scales converge at their natural rates without manual tuning.
130///
131/// Weights are lazily initialized on the first `update` call so the model
132/// adapts to whatever dimensionality arrives.
133///
134/// Optional exponential weight decay (`decay`) gives the model a finite memory
135/// horizon for non-stationary streams.
136pub struct LinearLeafModel {
137    weights: Vec<f64>,
138    bias: f64,
139    learning_rate: f64,
140    decay: Option<f64>,
141    use_adagrad: bool,
142    /// Per-weight squared gradient accumulator (AdaGrad).
143    sq_grad_accum: Vec<f64>,
144    /// Bias squared gradient accumulator (AdaGrad).
145    sq_bias_accum: f64,
146    initialized: bool,
147}
148
149impl LinearLeafModel {
150    /// Create a new linear leaf model with the given base learning rate,
151    /// optional exponential decay factor, and AdaGrad toggle.
152    ///
153    /// When `decay` is `Some(d)` with `d` in (0, 1), weights are multiplied
154    /// by `d` before each update, giving the model a memory half-life of
155    /// `ln(2) / ln(1/d)` samples.
156    ///
157    /// When `use_adagrad` is `true`, per-weight squared gradient accumulators
158    /// give each feature its own adaptive learning rate. When `false`, all
159    /// weights share a single Newton-scaled learning rate (plain SGD).
160    pub fn new(learning_rate: f64, decay: Option<f64>, use_adagrad: bool) -> Self {
161        Self {
162            weights: Vec::new(),
163            bias: 0.0,
164            learning_rate,
165            decay,
166            use_adagrad,
167            sq_grad_accum: Vec::new(),
168            sq_bias_accum: 0.0,
169            initialized: false,
170        }
171    }
172}
173
174/// AdaGrad epsilon to prevent division by zero.
175const ADAGRAD_EPS: f64 = 1e-8;
176
177impl LeafModel for LinearLeafModel {
178    fn predict(&self, features: &[f64]) -> f64 {
179        if !self.initialized {
180            return 0.0;
181        }
182        let mut dot = self.bias;
183        for (w, x) in self.weights.iter().zip(features.iter()) {
184            dot += w * x;
185        }
186        dot
187    }
188
189    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
190        if !self.initialized {
191            let d = features.len();
192            self.weights = vec![0.0; d];
193            self.sq_grad_accum = vec![0.0; d];
194            self.initialized = true;
195        }
196
197        // Exponential weight decay (forgetting old data).
198        if let Some(d) = self.decay {
199            for w in self.weights.iter_mut() {
200                *w *= d;
201            }
202            self.bias *= d;
203        }
204
205        // Newton-scaled base learning rate.
206        let base_lr = self.learning_rate / (math::abs(hessian) + lambda);
207
208        if self.use_adagrad {
209            // AdaGrad: per-weight adaptive learning rates.
210            for (i, (w, x)) in self.weights.iter_mut().zip(features.iter()).enumerate() {
211                let g = gradient * x;
212                self.sq_grad_accum[i] += g * g;
213                let adaptive_lr = base_lr / (math::sqrt(self.sq_grad_accum[i]) + ADAGRAD_EPS);
214                *w -= adaptive_lr * g;
215            }
216            self.sq_bias_accum += gradient * gradient;
217            let bias_lr = base_lr / (math::sqrt(self.sq_bias_accum) + ADAGRAD_EPS);
218            self.bias -= bias_lr * gradient;
219        } else {
220            // Plain Newton-scaled SGD.
221            for (w, x) in self.weights.iter_mut().zip(features.iter()) {
222                *w -= base_lr * gradient * x;
223            }
224            self.bias -= base_lr * gradient;
225        }
226    }
227
228    fn clone_fresh(&self) -> Box<dyn LeafModel> {
229        Box::new(LinearLeafModel::new(
230            self.learning_rate,
231            self.decay,
232            self.use_adagrad,
233        ))
234    }
235
236    fn clone_warm(&self) -> Box<dyn LeafModel> {
237        Box::new(LinearLeafModel {
238            weights: self.weights.clone(),
239            bias: self.bias,
240            learning_rate: self.learning_rate,
241            decay: self.decay,
242            use_adagrad: self.use_adagrad,
243            // Reset AdaGrad accumulators -- the child's gradient landscape
244            // differs from the parent's, so accumulated curvature estimates
245            // don't transfer. Fresh accumulators let the child's learning
246            // rates adapt to its own region.
247            sq_grad_accum: vec![0.0; self.weights.len()],
248            sq_bias_accum: 0.0,
249            initialized: self.initialized,
250        })
251    }
252}
253
254// ---------------------------------------------------------------------------
255// MLPLeafModel
256// ---------------------------------------------------------------------------
257
258/// Single hidden layer MLP leaf model with ReLU activation.
259///
260/// Learns a nonlinear function via backpropagation with Newton-scaled
261/// learning rate. Weights are lazily initialized on the first `update` call
262/// using a deterministic xorshift64 PRNG so results are reproducible.
263///
264/// Optional exponential weight decay (`decay`) gives the model a finite memory
265/// horizon for non-stationary streams.
266pub struct MLPLeafModel {
267    hidden_weights: Vec<Vec<f64>>, // [hidden_size][input_size]
268    hidden_bias: Vec<f64>,
269    output_weights: Vec<f64>,
270    output_bias: f64,
271    hidden_size: usize,
272    learning_rate: f64,
273    decay: Option<f64>,
274    seed: u64,
275    initialized: bool,
276    hidden_activations: Vec<f64>,
277    hidden_pre_activations: Vec<f64>,
278}
279
280impl MLPLeafModel {
281    /// Create a new MLP leaf model with the given hidden layer size, learning rate,
282    /// seed, and optional decay.
283    ///
284    /// The seed controls deterministic weight initialization. Different seeds
285    /// produce different initial weights, which is critical for ensemble diversity
286    /// when multiple MLP leaves share the same `hidden_size`.
287    pub fn new(hidden_size: usize, learning_rate: f64, seed: u64, decay: Option<f64>) -> Self {
288        Self {
289            hidden_weights: Vec::new(),
290            hidden_bias: Vec::new(),
291            output_weights: Vec::new(),
292            output_bias: 0.0,
293            hidden_size,
294            learning_rate,
295            decay,
296            seed,
297            initialized: false,
298            hidden_activations: Vec::new(),
299            hidden_pre_activations: Vec::new(),
300        }
301    }
302
303    /// Initialize weights using xorshift64, scaled to [-0.1, 0.1].
304    fn initialize(&mut self, input_size: usize) {
305        let mut state = self.seed ^ (self.hidden_size as u64);
306
307        self.hidden_weights = Vec::with_capacity(self.hidden_size);
308        for _ in 0..self.hidden_size {
309            let mut row = Vec::with_capacity(input_size);
310            for _ in 0..input_size {
311                let r = xorshift64(&mut state);
312                // Map u64 to [-0.1, 0.1]
313                let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
314                row.push(val);
315            }
316            self.hidden_weights.push(row);
317        }
318
319        self.hidden_bias = Vec::with_capacity(self.hidden_size);
320        for _ in 0..self.hidden_size {
321            let r = xorshift64(&mut state);
322            let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
323            self.hidden_bias.push(val);
324        }
325
326        self.output_weights = Vec::with_capacity(self.hidden_size);
327        for _ in 0..self.hidden_size {
328            let r = xorshift64(&mut state);
329            let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
330            self.output_weights.push(val);
331        }
332
333        {
334            let r = xorshift64(&mut state);
335            self.output_bias = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
336        }
337
338        self.hidden_activations = vec![0.0; self.hidden_size];
339        self.hidden_pre_activations = vec![0.0; self.hidden_size];
340        self.initialized = true;
341    }
342
343    /// Forward pass: compute hidden pre-activations, ReLU activations, and output.
344    fn forward(&mut self, features: &[f64]) -> f64 {
345        // Hidden layer
346        for h in 0..self.hidden_size {
347            let mut z = self.hidden_bias[h];
348            for (j, x) in features.iter().enumerate() {
349                if j < self.hidden_weights[h].len() {
350                    z += self.hidden_weights[h][j] * x;
351                }
352            }
353            self.hidden_pre_activations[h] = z;
354            // ReLU
355            self.hidden_activations[h] = if z > 0.0 { z } else { 0.0 };
356        }
357
358        // Output layer
359        let mut out = self.output_bias;
360        for (w, a) in self
361            .output_weights
362            .iter()
363            .zip(self.hidden_activations.iter())
364        {
365            out += w * a;
366        }
367        out
368    }
369}
370
371impl LeafModel for MLPLeafModel {
372    fn predict(&self, features: &[f64]) -> f64 {
373        if !self.initialized {
374            return 0.0;
375        }
376        // Non-mutating forward pass (can't store activations, recompute locally)
377        let hidden_acts: Vec<f64> = self
378            .hidden_weights
379            .iter()
380            .zip(self.hidden_bias.iter())
381            .map(|(hw, &hb)| {
382                let mut z = hb;
383                for (j, x) in features.iter().enumerate() {
384                    if j < hw.len() {
385                        z += hw[j] * x;
386                    }
387                }
388                if z > 0.0 {
389                    z
390                } else {
391                    0.0
392                }
393            })
394            .collect();
395        let mut out = self.output_bias;
396        for (w, a) in self.output_weights.iter().zip(hidden_acts.iter()) {
397            out += w * a;
398        }
399        out
400    }
401
402    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
403        if !self.initialized {
404            self.initialize(features.len());
405        }
406
407        // Exponential weight decay (forgetting old data).
408        if let Some(d) = self.decay {
409            for row in self.hidden_weights.iter_mut() {
410                for w in row.iter_mut() {
411                    *w *= d;
412                }
413            }
414            for b in self.hidden_bias.iter_mut() {
415                *b *= d;
416            }
417            for w in self.output_weights.iter_mut() {
418                *w *= d;
419            }
420            self.output_bias *= d;
421        }
422
423        // Forward pass (stores activations for backprop)
424        let _output = self.forward(features);
425
426        let effective_lr = self.learning_rate / (math::abs(hessian) + lambda);
427
428        // Backprop: output gradient is the incoming `gradient` (chain rule from loss)
429        let d_output = gradient;
430
431        // Gradient for output weights and bias
432        // d_loss/d_output_w[h] = d_output * hidden_activations[h]
433        // d_loss/d_output_bias = d_output
434        for h in 0..self.hidden_size {
435            self.output_weights[h] -= effective_lr * d_output * self.hidden_activations[h];
436        }
437        self.output_bias -= effective_lr * d_output;
438
439        // Gradient for hidden layer
440        for h in 0..self.hidden_size {
441            // d_loss/d_hidden_act[h] = d_output * output_weights[h]
442            let d_hidden_act = d_output * self.output_weights[h];
443
444            // ReLU derivative
445            let d_relu = if self.hidden_pre_activations[h] > 0.0 {
446                d_hidden_act
447            } else {
448                0.0
449            };
450
451            // Update hidden weights and bias
452            for (j, x) in features.iter().enumerate() {
453                if j < self.hidden_weights[h].len() {
454                    self.hidden_weights[h][j] -= effective_lr * d_relu * x;
455                }
456            }
457            self.hidden_bias[h] -= effective_lr * d_relu;
458        }
459    }
460
461    fn clone_fresh(&self) -> Box<dyn LeafModel> {
462        // Derive a new seed so each fresh clone gets distinct initial weights.
463        let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
464        Box::new(MLPLeafModel::new(
465            self.hidden_size,
466            self.learning_rate,
467            derived_seed,
468            self.decay,
469        ))
470    }
471
472    fn clone_warm(&self) -> Box<dyn LeafModel> {
473        Box::new(MLPLeafModel {
474            hidden_weights: self.hidden_weights.clone(),
475            hidden_bias: self.hidden_bias.clone(),
476            output_weights: self.output_weights.clone(),
477            output_bias: self.output_bias,
478            hidden_size: self.hidden_size,
479            learning_rate: self.learning_rate,
480            decay: self.decay,
481            // Derive a new seed so warm clones diverge if they re-initialize.
482            seed: self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(2),
483            initialized: self.initialized,
484            hidden_activations: vec![0.0; self.hidden_size],
485            hidden_pre_activations: vec![0.0; self.hidden_size],
486        })
487    }
488}
489
490// ---------------------------------------------------------------------------
491// AdaptiveLeafModel
492// ---------------------------------------------------------------------------
493
494/// Leaf model that starts as closed-form and promotes to a more complex model
495/// when the Hoeffding bound confirms it is statistically superior.
496///
497/// Runs a shadow model alongside the active (closed-form) model. On each
498/// update, both models are trained and their per-sample losses compared using
499/// the second-order Taylor approximation:
500///
501/// ```text
502/// loss_i = gradient * prediction + 0.5 * hessian * prediction^2
503/// advantage_i = loss_active_i - loss_shadow_i
504/// ```
505///
506/// When `mean(advantage) > epsilon` where epsilon is the Hoeffding bound
507/// (using the tree's `delta` parameter), the shadow model is promoted to
508/// active and the overhead drops to zero.
509///
510/// This uses the **same statistical guarantee** as the tree's split decisions --
511/// no arbitrary thresholds.
512pub struct AdaptiveLeafModel {
513    /// The currently active prediction model. Starts as ClosedForm.
514    active: Box<dyn LeafModel>,
515    /// Shadow model being evaluated against the active model.
516    shadow: Box<dyn LeafModel>,
517    /// Configuration of the shadow model type (for cloning).
518    promote_to: LeafModelType,
519    /// Cumulative loss advantage: sum(loss_active - loss_shadow).
520    /// Positive means the shadow is winning.
521    cumulative_advantage: f64,
522    /// Number of samples seen for the Hoeffding bound.
523    n: u64,
524    /// Running maximum |loss_diff| for range estimation (R in the bound).
525    max_loss_diff: f64,
526    /// Hoeffding confidence parameter (from tree config).
527    delta: f64,
528    /// Whether the shadow has been promoted.
529    promoted: bool,
530    /// Seed for reproducible cloning.
531    seed: u64,
532}
533
534impl AdaptiveLeafModel {
535    /// Create a new adaptive leaf model.
536    ///
537    /// The active model starts as `ClosedFormLeaf`. The shadow model is the
538    /// candidate that will be promoted if it proves statistically superior.
539    pub fn new(
540        shadow: Box<dyn LeafModel>,
541        promote_to: LeafModelType,
542        delta: f64,
543        seed: u64,
544    ) -> Self {
545        Self {
546            active: Box::new(ClosedFormLeaf::new()),
547            shadow,
548            promote_to,
549            cumulative_advantage: 0.0,
550            n: 0,
551            max_loss_diff: 0.0,
552            delta,
553            promoted: false,
554            seed,
555        }
556    }
557}
558
559impl LeafModel for AdaptiveLeafModel {
560    fn predict(&self, features: &[f64]) -> f64 {
561        self.active.predict(features)
562    }
563
564    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
565        if self.promoted {
566            // Post-promotion: only update the promoted model.
567            self.active.update(features, gradient, hessian, lambda);
568            return;
569        }
570
571        // Compute predictions BEFORE updating (evaluate current state).
572        let pred_active = self.active.predict(features);
573        let pred_shadow = self.shadow.predict(features);
574
575        // Second-order Taylor loss approximation for each model.
576        // L(pred) ~= gradient * pred + 0.5 * hessian * pred^2
577        // This is the same loss proxy that XGBoost gain uses.
578        let loss_active = gradient * pred_active + 0.5 * hessian * pred_active * pred_active;
579        let loss_shadow = gradient * pred_shadow + 0.5 * hessian * pred_shadow * pred_shadow;
580
581        // Positive advantage means the shadow model is better (lower loss).
582        let diff = loss_active - loss_shadow;
583        self.cumulative_advantage += diff;
584        self.n += 1;
585
586        // Track range for the Hoeffding bound.
587        let abs_diff = math::abs(diff);
588        if abs_diff > self.max_loss_diff {
589            self.max_loss_diff = abs_diff;
590        }
591
592        // Update both models.
593        self.active.update(features, gradient, hessian, lambda);
594        self.shadow.update(features, gradient, hessian, lambda);
595
596        // Hoeffding bound test: is the shadow statistically better?
597        // epsilon = sqrt(R^2 * ln(1/delta) / (2*n))
598        // Promote when mean_advantage > epsilon.
599        if self.n >= 10 && self.max_loss_diff > 0.0 {
600            let mean_advantage = self.cumulative_advantage / self.n as f64;
601            if mean_advantage > 0.0 {
602                let r_squared = self.max_loss_diff * self.max_loss_diff;
603                let ln_inv_delta = math::ln(1.0 / self.delta);
604                let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * self.n as f64));
605
606                if mean_advantage > epsilon {
607                    // Promote: swap shadow into active, drop the old active.
608                    self.promoted = true;
609                    core::mem::swap(&mut self.active, &mut self.shadow);
610                }
611            }
612        }
613    }
614
615    fn clone_fresh(&self) -> Box<dyn LeafModel> {
616        let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
617        Box::new(AdaptiveLeafModel::new(
618            self.promote_to.create(derived_seed, self.delta),
619            self.promote_to.clone(),
620            self.delta,
621            derived_seed,
622        ))
623    }
624    // clone_warm: default (= clone_fresh). Promotion must be re-earned
625    // in each leaf's region -- the parent's promotion does not transfer.
626}
627
628// We need Send + Sync for AdaptiveLeafModel because it contains Box<dyn LeafModel>
629// which already requires Send + Sync. The compiler should derive these automatically,
630// but let's verify:
631// SAFETY: All fields are Send + Sync (Box<dyn LeafModel> requires it, f64/u64/bool are Send+Sync).
632unsafe impl Send for AdaptiveLeafModel {}
633unsafe impl Sync for AdaptiveLeafModel {}
634
635// ---------------------------------------------------------------------------
636// LeafModelType
637// ---------------------------------------------------------------------------
638
639/// Describes which leaf model architecture to use.
640///
641/// Used by tree builders to construct fresh leaf models when creating new leaves.
642///
643/// # Variants
644///
645/// - **`ClosedForm`** -- Standard constant leaf weight. Zero overhead (default).
646/// - **`Linear`** -- Per-leaf online ridge regression with AdaGrad optimization.
647///   Each leaf learns a local linear surface `w . x + b`. Recommended for
648///   low-depth trees (depth 2--4). Optional `decay` for concept drift.
649/// - **`MLP`** -- Per-leaf single-hidden-layer neural network with ReLU.
650///   Optional `decay` for concept drift.
651/// - **`Adaptive`** -- Starts as closed-form, auto-promotes to `promote_to`
652///   when the Hoeffding bound confirms it is statistically superior. Uses the
653///   tree's existing `delta` parameter -- no arbitrary thresholds.
654#[derive(Debug, Clone, Default, PartialEq)]
655#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
656pub enum LeafModelType {
657    /// Standard closed-form leaf weight.
658    #[default]
659    ClosedForm,
660
661    /// Online ridge regression, optionally with AdaGrad optimization.
662    ///
663    /// `decay`: optional exponential weight decay for non-stationary streams.
664    /// Typical values: 0.999 (slow drift) to 0.99 (fast drift).
665    ///
666    /// `use_adagrad`: when `true`, per-weight squared gradient accumulators
667    /// give each feature its own adaptive learning rate. When `false`
668    /// (default), all weights share a single Newton-scaled learning rate.
669    Linear {
670        learning_rate: f64,
671        #[cfg_attr(feature = "serde", serde(default))]
672        decay: Option<f64>,
673        #[cfg_attr(feature = "serde", serde(default))]
674        use_adagrad: bool,
675    },
676
677    /// Single hidden layer MLP with the given hidden size and learning rate.
678    ///
679    /// `decay`: optional exponential weight decay for non-stationary streams.
680    MLP {
681        hidden_size: usize,
682        learning_rate: f64,
683        #[cfg_attr(feature = "serde", serde(default))]
684        decay: Option<f64>,
685    },
686
687    /// Adaptive leaf that starts as closed-form and auto-promotes when
688    /// the Hoeffding bound confirms the promoted model is better.
689    ///
690    /// The `promote_to` field specifies the shadow model type to evaluate
691    /// against the default closed-form baseline.
692    Adaptive { promote_to: Box<LeafModelType> },
693}
694
695impl LeafModelType {
696    /// Create a fresh boxed leaf model of this type.
697    ///
698    /// The `seed` parameter controls deterministic initialization (MLP weights,
699    /// adaptive model seeding). The `delta` parameter is the Hoeffding bound
700    /// confidence level, used by [`Adaptive`](LeafModelType::Adaptive) leaves
701    /// for promotion decisions. For non-adaptive types, `delta` is unused.
702    pub fn create(&self, seed: u64, delta: f64) -> Box<dyn LeafModel> {
703        match self {
704            Self::ClosedForm => Box::new(ClosedFormLeaf::new()),
705            Self::Linear {
706                learning_rate,
707                decay,
708                use_adagrad,
709            } => Box::new(LinearLeafModel::new(*learning_rate, *decay, *use_adagrad)),
710            Self::MLP {
711                hidden_size,
712                learning_rate,
713                decay,
714            } => Box::new(MLPLeafModel::new(
715                *hidden_size,
716                *learning_rate,
717                seed,
718                *decay,
719            )),
720            Self::Adaptive { promote_to } => Box::new(AdaptiveLeafModel::new(
721                promote_to.create(seed, delta),
722                *promote_to.clone(),
723                delta,
724                seed,
725            )),
726        }
727    }
728}
729
730// ---------------------------------------------------------------------------
731// Shared utility
732// ---------------------------------------------------------------------------
733
734/// Xorshift64 PRNG for deterministic weight initialization.
735fn xorshift64(state: &mut u64) -> u64 {
736    let mut s = *state;
737    s ^= s << 13;
738    s ^= s >> 7;
739    s ^= s << 17;
740    *state = s;
741    s
742}
743
744// ---------------------------------------------------------------------------
745// Tests
746// ---------------------------------------------------------------------------
747
748#[cfg(test)]
749mod tests {
750    use super::*;
751
752    /// Xorshift64 for deterministic test data generation.
753    fn xorshift64(state: &mut u64) -> u64 {
754        let mut s = *state;
755        s ^= s << 13;
756        s ^= s >> 7;
757        s ^= s << 17;
758        *state = s;
759        s
760    }
761
762    /// Convert xorshift output to f64 in [0, 1).
763    fn rand_f64(state: &mut u64) -> f64 {
764        xorshift64(state) as f64 / u64::MAX as f64
765    }
766
767    #[test]
768    fn closed_form_matches_formula() {
769        let mut leaf = ClosedFormLeaf::new();
770        let lambda = 1.0;
771
772        // Accumulate several gradient/hessian pairs
773        let updates = [(0.5, 1.0), (-0.3, 0.8), (1.2, 2.0), (-0.1, 0.5)];
774        let mut grad_sum = 0.0;
775        let mut hess_sum = 0.0;
776
777        for &(g, h) in &updates {
778            leaf.update(&[], g, h, lambda);
779            grad_sum += g;
780            hess_sum += h;
781        }
782
783        let expected = -grad_sum / (hess_sum + lambda);
784        let predicted = leaf.predict(&[]);
785
786        assert!(
787            (predicted - expected).abs() < 1e-12,
788            "closed form mismatch: got {predicted}, expected {expected}"
789        );
790    }
791
792    #[test]
793    fn closed_form_clone_fresh_resets() {
794        let mut leaf = ClosedFormLeaf::new();
795        leaf.update(&[], 5.0, 2.0, 1.0);
796        assert!(
797            leaf.predict(&[]).abs() > 0.0,
798            "leaf should have non-zero weight after update"
799        );
800
801        let fresh = leaf.clone_fresh();
802        assert!(
803            fresh.predict(&[]).abs() < 1e-15,
804            "fresh clone should predict 0, got {}",
805            fresh.predict(&[])
806        );
807    }
808
809    #[test]
810    fn linear_converges_on_linear_target() {
811        // Target: y = 2*x1 + 3*x2
812        let mut model = LinearLeafModel::new(0.01, None, false);
813        let lambda = 0.1;
814        let mut rng = 42u64;
815
816        for _ in 0..2000 {
817            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
818            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
819            let features = vec![x1, x2];
820            let target = 2.0 * x1 + 3.0 * x2;
821
822            let pred = model.predict(&features);
823            let gradient = 2.0 * (pred - target);
824            let hessian = 2.0;
825            model.update(&features, gradient, hessian, lambda);
826        }
827
828        let test_features = vec![0.5, -0.3];
829        let target = 2.0 * 0.5 + 3.0 * (-0.3);
830        let pred = model.predict(&test_features);
831
832        assert!(
833            (pred - target).abs() < 1.0,
834            "linear model should converge within 1.0 of target: pred={pred}, target={target}"
835        );
836    }
837
838    #[test]
839    fn linear_uninitialized_predicts_zero() {
840        let model = LinearLeafModel::new(0.01, None, false);
841        let pred = model.predict(&[1.0, 2.0, 3.0]);
842        assert!(
843            pred.abs() < 1e-15,
844            "uninitialized linear model should predict 0, got {pred}"
845        );
846    }
847
848    #[test]
849    fn linear_clone_warm_preserves_weights() {
850        let mut model = LinearLeafModel::new(0.01, None, false);
851        let features = vec![1.0, 2.0];
852
853        // Train it
854        for i in 0..100 {
855            let target = 3.0 * features[0] + 2.0 * features[1];
856            let pred = model.predict(&features);
857            let gradient = 2.0 * (pred - target);
858            model.update(&features, gradient, 2.0, 0.1);
859            // Avoid unused variable warning
860            let _ = i;
861        }
862
863        let trained_pred = model.predict(&features);
864        assert!(
865            trained_pred.abs() > 0.01,
866            "model should have learned something"
867        );
868
869        // Warm clone should predict similarly
870        let warm = model.clone_warm();
871        let warm_pred = warm.predict(&features);
872        assert!(
873            (warm_pred - trained_pred).abs() < 1e-12,
874            "warm clone should preserve weights: trained={trained_pred}, warm={warm_pred}"
875        );
876
877        // Fresh clone should predict 0
878        let fresh = model.clone_fresh();
879        let fresh_pred = fresh.predict(&features);
880        assert!(
881            fresh_pred.abs() < 1e-15,
882            "fresh clone should predict 0, got {fresh_pred}"
883        );
884    }
885
886    #[test]
887    fn linear_decay_forgets_old_data() {
888        // Decay should pull weights toward zero when training stops,
889        // demonstrating that old learned state is forgotten.
890        let mut model_decay = LinearLeafModel::new(0.05, Some(0.99), false);
891        let mut model_no_decay = LinearLeafModel::new(0.05, None, false);
892        let features = vec![1.0];
893        let lambda = 0.1;
894
895        // Train both on target = 5.0
896        for _ in 0..500 {
897            let pred_d = model_decay.predict(&features);
898            let pred_n = model_no_decay.predict(&features);
899            model_decay.update(&features, 2.0 * (pred_d - 5.0), 2.0, lambda);
900            model_no_decay.update(&features, 2.0 * (pred_n - 5.0), 2.0, lambda);
901        }
902
903        // Both should have learned roughly the same function.
904        let pred_d_trained = model_decay.predict(&features);
905        let pred_n_trained = model_no_decay.predict(&features);
906        assert!(
907            (pred_d_trained - 5.0).abs() < 2.0,
908            "decay model should approximate target"
909        );
910        assert!(
911            (pred_n_trained - 5.0).abs() < 2.0,
912            "no-decay model should approximate target"
913        );
914
915        // Now apply 200 rounds of zero-gradient updates (data stopped).
916        // The decay model's weights should shrink toward zero,
917        // while the no-decay model retains its learned state.
918        for _ in 0..200 {
919            model_decay.update(&features, 0.0, 1.0, lambda);
920            model_no_decay.update(&features, 0.0, 1.0, lambda);
921        }
922
923        let pred_d_after = model_decay.predict(&features);
924        let pred_n_after = model_no_decay.predict(&features);
925
926        // Decay model should have drifted toward zero.
927        assert!(
928            pred_d_after.abs() < pred_n_after.abs(),
929            "decay model should forget: decay pred={pred_d_after:.3}, no-decay pred={pred_n_after:.3}"
930        );
931    }
932
933    #[test]
934    fn mlp_produces_finite_predictions() {
935        let model_uninit = MLPLeafModel::new(4, 0.01, 42, None);
936        let features = vec![1.0, 2.0, 3.0];
937
938        let pred_before = model_uninit.predict(&features);
939        assert!(
940            pred_before.is_finite(),
941            "uninit prediction should be finite"
942        );
943        assert!(
944            pred_before.abs() < 1e-15,
945            "uninit prediction should be 0, got {pred_before}"
946        );
947
948        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
949        for _ in 0..10 {
950            model.update(&features, 0.5, 1.0, 0.1);
951        }
952        let pred_after = model.predict(&features);
953        assert!(
954            pred_after.is_finite(),
955            "prediction after training should be finite, got {pred_after}"
956        );
957    }
958
959    #[test]
960    fn mlp_loss_decreases() {
961        let mut model = MLPLeafModel::new(8, 0.05, 123, None);
962        let features = vec![1.0, -0.5, 0.3];
963        let target = 2.5;
964        let lambda = 0.1;
965
966        model.update(&features, 0.0, 1.0, lambda); // dummy update to initialize
967        let initial_pred = model.predict(&features);
968        let initial_error = (initial_pred - target).abs();
969
970        for _ in 0..200 {
971            let pred = model.predict(&features);
972            let gradient = 2.0 * (pred - target);
973            let hessian = 2.0;
974            model.update(&features, gradient, hessian, lambda);
975        }
976
977        let final_pred = model.predict(&features);
978        let final_error = (final_pred - target).abs();
979
980        assert!(
981            final_error < initial_error,
982            "MLP error should decrease: initial={initial_error}, final={final_error}"
983        );
984    }
985
986    #[test]
987    fn mlp_clone_fresh_resets() {
988        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
989        let features = vec![1.0, 2.0];
990
991        for _ in 0..20 {
992            model.update(&features, 0.5, 1.0, 0.1);
993        }
994
995        let trained_pred = model.predict(&features);
996        assert!(
997            trained_pred.abs() > 1e-10,
998            "trained model should have non-zero prediction"
999        );
1000
1001        let fresh = model.clone_fresh();
1002        let fresh_pred = fresh.predict(&features);
1003        assert!(
1004            fresh_pred.abs() < 1e-15,
1005            "fresh clone should predict 0, got {fresh_pred}"
1006        );
1007    }
1008
1009    #[test]
1010    fn mlp_clone_warm_preserves_weights() {
1011        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
1012        let features = vec![1.0, 2.0];
1013
1014        for _ in 0..50 {
1015            model.update(&features, 0.5, 1.0, 0.1);
1016        }
1017
1018        let trained_pred = model.predict(&features);
1019        let warm = model.clone_warm();
1020        let warm_pred = warm.predict(&features);
1021
1022        assert!(
1023            (warm_pred - trained_pred).abs() < 1e-10,
1024            "warm clone should preserve predictions: trained={trained_pred}, warm={warm_pred}"
1025        );
1026    }
1027
1028    #[test]
1029    fn leaf_model_type_default_is_closed_form() {
1030        let default_type = LeafModelType::default();
1031        assert!(
1032            matches!(default_type, LeafModelType::ClosedForm),
1033            "default LeafModelType should be ClosedForm, got {default_type:?}"
1034        );
1035    }
1036
1037    #[test]
1038    fn leaf_model_type_create_all_variants() {
1039        let features = vec![1.0, 2.0, 3.0];
1040        let delta = 1e-7;
1041
1042        // ClosedForm
1043        let mut closed = LeafModelType::ClosedForm.create(0, delta);
1044        closed.update(&features, 1.0, 1.0, 0.1);
1045        let p = closed.predict(&features);
1046        assert!(p.is_finite(), "ClosedForm prediction should be finite");
1047
1048        // Linear
1049        let mut linear = LeafModelType::Linear {
1050            learning_rate: 0.01,
1051            decay: None,
1052            use_adagrad: false,
1053        }
1054        .create(0, delta);
1055        linear.update(&features, 1.0, 1.0, 0.1);
1056        let p = linear.predict(&features);
1057        assert!(p.is_finite(), "Linear prediction should be finite");
1058
1059        // MLP
1060        let mut mlp = LeafModelType::MLP {
1061            hidden_size: 4,
1062            learning_rate: 0.01,
1063            decay: None,
1064        }
1065        .create(99, delta);
1066        mlp.update(&features, 1.0, 1.0, 0.1);
1067        let p = mlp.predict(&features);
1068        assert!(p.is_finite(), "MLP prediction should be finite");
1069
1070        // Adaptive (promoting to Linear)
1071        let mut adaptive = LeafModelType::Adaptive {
1072            promote_to: Box::new(LeafModelType::Linear {
1073                learning_rate: 0.01,
1074                decay: None,
1075                use_adagrad: false,
1076            }),
1077        }
1078        .create(42, delta);
1079        adaptive.update(&features, 1.0, 1.0, 0.1);
1080        let p = adaptive.predict(&features);
1081        assert!(p.is_finite(), "Adaptive prediction should be finite");
1082    }
1083
1084    #[test]
1085    fn adaptive_promotes_on_linear_target() {
1086        // The shadow (Linear) should eventually outperform ClosedForm
1087        // on a target that varies with features.
1088        let promote_to = LeafModelType::Linear {
1089            learning_rate: 0.01,
1090            decay: None,
1091            use_adagrad: false,
1092        };
1093        let shadow = promote_to.create(42, 1e-7);
1094        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1095
1096        let mut rng = 42u64;
1097        for _ in 0..5000 {
1098            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1099            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1100            let features = vec![x1, x2];
1101            let target = 3.0 * x1 + 2.0 * x2;
1102
1103            let pred = model.predict(&features);
1104            let gradient = 2.0 * (pred - target);
1105            let hessian = 2.0;
1106            model.update(&features, gradient, hessian, 0.1);
1107        }
1108
1109        // After enough samples on a linear target, the Linear shadow
1110        // should have been promoted.
1111        assert!(
1112            model.promoted,
1113            "adaptive model should have promoted on linear target after 5000 samples"
1114        );
1115    }
1116
1117    #[test]
1118    fn adaptive_does_not_promote_on_constant_target() {
1119        // On a constant target, ClosedForm is optimal -- no promotion expected.
1120        let promote_to = LeafModelType::Linear {
1121            learning_rate: 0.01,
1122            decay: None,
1123            use_adagrad: false,
1124        };
1125        let shadow = promote_to.create(42, 1e-7);
1126        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-7, 42);
1127
1128        for _ in 0..2000 {
1129            let features = vec![1.0, 2.0];
1130            let target = 5.0; // constant -- no feature dependence
1131            let pred = model.predict(&features);
1132            let gradient = 2.0 * (pred - target);
1133            let hessian = 2.0;
1134            model.update(&features, gradient, hessian, 0.1);
1135        }
1136
1137        // With strict delta (1e-7) and a constant target, the Linear model
1138        // shouldn't gain a statistically significant advantage.
1139        // (It may or may not promote -- the point is it shouldn't be obvious.)
1140        // This test mainly verifies the mechanism doesn't crash and is conservative.
1141        let pred = model.predict(&[1.0, 2.0]);
1142        assert!(pred.is_finite(), "prediction should be finite");
1143    }
1144
1145    #[test]
1146    fn adaptive_clone_fresh_resets_promotion() {
1147        let promote_to = LeafModelType::Linear {
1148            learning_rate: 0.01,
1149            decay: None,
1150            use_adagrad: false,
1151        };
1152        let shadow = promote_to.create(42, 1e-3);
1153        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1154
1155        // Force promotion via many linear-target samples
1156        let mut rng = 42u64;
1157        for _ in 0..5000 {
1158            let x = rand_f64(&mut rng) * 2.0 - 1.0;
1159            let features = vec![x];
1160            let pred = model.predict(&features);
1161            model.update(&features, 2.0 * (pred - 3.0 * x), 2.0, 0.1);
1162        }
1163
1164        let fresh = model.clone_fresh();
1165        // Fresh clone should predict 0 (reset state).
1166        let p = fresh.predict(&[0.5]);
1167        assert!(
1168            p.abs() < 1e-10,
1169            "fresh adaptive clone should predict ~0, got {p}"
1170        );
1171    }
1172}