Skip to main content

irithyll_core/tree/
leaf_model.rs

1//! Pluggable leaf prediction models for streaming decision trees.
2//!
3//! By default, leaves use a closed-form weight computed from accumulated
4//! gradient and hessian sums. Trainable leaf models replace this with learnable
5//! functions that capture more complex patterns within each leaf's partition.
6//!
7//! # Leaf model variants
8//!
9//! | Model | Prediction | Overhead | Best for |
10//! |-------|-----------|----------|----------|
11//! | [`ClosedFormLeaf`] | constant weight `-G/(H+lambda)` | zero | general use (default) |
12//! | [`LinearLeafModel`] | `w . x + b` (AdaGrad-optimized) | O(d) per update | low-depth trees (2--4) |
13//! | [`MLPLeafModel`] | single-hidden-layer neural net | O(d*h) per update | complex local patterns |
14//! | [`AdaptiveLeafModel`] | starts constant, auto-promotes | shadow model cost | automatic complexity allocation |
15//!
16//! # AdaGrad optimization
17//!
18//! [`LinearLeafModel`] uses per-weight AdaGrad accumulators for adaptive learning
19//! rates. Features at different scales converge at their natural rates without
20//! manual tuning. Combined with Newton scaling from the hessian, this gives
21//! second-order-informed, per-feature adaptive optimization.
22//!
23//! # Exponential forgetting
24//!
25//! [`LinearLeafModel`] and [`MLPLeafModel`] support an optional `decay` parameter
26//! that applies exponential weight decay before each update. This gives the model
27//! a finite memory horizon, adapting to concept drift in non-stationary streams.
28//! Typical values: 0.999 (slow drift) to 0.99 (fast drift).
29//!
30//! # Warm-starting on split
31//!
32//! When a leaf splits, child leaves can inherit the parent's learned function via
33//! [`LeafModel::clone_warm`]. Linear children start with the parent's weights
34//! (resetting optimizer state), converging faster than starting from scratch.
35//!
36//! # Adaptive promotion
37//!
38//! [`AdaptiveLeafModel`] runs a shadow model alongside the default closed-form
39//! model. Both are trained on every sample, and their per-sample losses are
40//! compared using the second-order Taylor approximation. When the Hoeffding bound
41//! (the tree's existing `delta` parameter) confirms the shadow model is
42//! statistically superior, the leaf promotes -- no arbitrary thresholds.
43
44use alloc::boxed::Box;
45use alloc::vec;
46use alloc::vec::Vec;
47
48use crate::math;
49
50/// A trainable prediction model that lives inside a decision tree leaf.
51///
52/// Implementations must be `Send + Sync` so trees can be shared across threads.
53pub trait LeafModel: Send + Sync {
54    /// Produce a prediction given input features.
55    fn predict(&self, features: &[f64]) -> f64;
56
57    /// Update model parameters given a gradient, hessian, and regularization lambda.
58    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64);
59
60    /// Create a fresh (zeroed / re-initialized) clone of this model's architecture.
61    fn clone_fresh(&self) -> Box<dyn LeafModel>;
62
63    /// Create a warm clone preserving learned weights but resetting optimizer state.
64    ///
65    /// Used when splitting a leaf: child leaves inherit the parent's learned
66    /// function as a starting point, converging faster than starting from scratch.
67    /// Defaults to [`clone_fresh`](LeafModel::clone_fresh) for models where
68    /// warm-starting is not meaningful (e.g. [`ClosedFormLeaf`]).
69    fn clone_warm(&self) -> Box<dyn LeafModel> {
70        self.clone_fresh()
71    }
72}
73
74// ---------------------------------------------------------------------------
75// ClosedFormLeaf
76// ---------------------------------------------------------------------------
77
78/// Leaf model that computes the optimal weight in closed form:
79/// `weight = -grad_sum / (hess_sum + lambda)`.
80///
81/// This is the standard leaf value used in gradient boosted trees.
82pub struct ClosedFormLeaf {
83    grad_sum: f64,
84    hess_sum: f64,
85    weight: f64,
86}
87
88impl Default for ClosedFormLeaf {
89    fn default() -> Self {
90        Self {
91            grad_sum: 0.0,
92            hess_sum: 0.0,
93            weight: 0.0,
94        }
95    }
96}
97
98impl ClosedFormLeaf {
99    /// Create a new zeroed closed-form leaf.
100    pub fn new() -> Self {
101        Self::default()
102    }
103}
104
105impl LeafModel for ClosedFormLeaf {
106    fn predict(&self, _features: &[f64]) -> f64 {
107        self.weight
108    }
109
110    fn update(&mut self, _features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
111        self.grad_sum += gradient;
112        self.hess_sum += hessian;
113        self.weight = -self.grad_sum / (self.hess_sum + lambda);
114    }
115
116    fn clone_fresh(&self) -> Box<dyn LeafModel> {
117        Box::new(ClosedFormLeaf::new())
118    }
119}
120
121// ---------------------------------------------------------------------------
122// LinearLeafModel
123// ---------------------------------------------------------------------------
124
125/// Online ridge regression leaf model with AdaGrad optimization.
126///
127/// Learns a linear function `w . x + b` using Newton-scaled gradient descent
128/// with per-weight AdaGrad accumulators for adaptive learning rates. Features
129/// at different scales converge at their natural rates without manual tuning.
130///
131/// Weights are lazily initialized on the first `update` call so the model
132/// adapts to whatever dimensionality arrives.
133///
134/// Optional exponential weight decay (`decay`) gives the model a finite memory
135/// horizon for non-stationary streams.
136pub struct LinearLeafModel {
137    weights: Vec<f64>,
138    bias: f64,
139    learning_rate: f64,
140    decay: Option<f64>,
141    use_adagrad: bool,
142    /// Per-weight squared gradient accumulator (AdaGrad).
143    sq_grad_accum: Vec<f64>,
144    /// Bias squared gradient accumulator (AdaGrad).
145    sq_bias_accum: f64,
146    initialized: bool,
147}
148
149impl LinearLeafModel {
150    /// Create a new linear leaf model with the given base learning rate,
151    /// optional exponential decay factor, and AdaGrad toggle.
152    ///
153    /// When `decay` is `Some(d)` with `d` in (0, 1), weights are multiplied
154    /// by `d` before each update, giving the model a memory half-life of
155    /// `ln(2) / ln(1/d)` samples.
156    ///
157    /// When `use_adagrad` is `true`, per-weight squared gradient accumulators
158    /// give each feature its own adaptive learning rate. When `false`, all
159    /// weights share a single Newton-scaled learning rate (plain SGD).
160    pub fn new(learning_rate: f64, decay: Option<f64>, use_adagrad: bool) -> Self {
161        Self {
162            weights: Vec::new(),
163            bias: 0.0,
164            learning_rate,
165            decay,
166            use_adagrad,
167            sq_grad_accum: Vec::new(),
168            sq_bias_accum: 0.0,
169            initialized: false,
170        }
171    }
172}
173
174/// AdaGrad epsilon to prevent division by zero.
175const ADAGRAD_EPS: f64 = 1e-8;
176
177impl LeafModel for LinearLeafModel {
178    fn predict(&self, features: &[f64]) -> f64 {
179        if !self.initialized {
180            return 0.0;
181        }
182        let mut dot = self.bias;
183        for (w, x) in self.weights.iter().zip(features.iter()) {
184            dot += w * x;
185        }
186        dot
187    }
188
189    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
190        if !self.initialized {
191            let d = features.len();
192            self.weights = vec![0.0; d];
193            self.sq_grad_accum = vec![0.0; d];
194            self.initialized = true;
195        }
196
197        // Exponential weight decay (forgetting old data).
198        if let Some(d) = self.decay {
199            for w in self.weights.iter_mut() {
200                *w *= d;
201            }
202            self.bias *= d;
203        }
204
205        // Newton-scaled base learning rate.
206        let base_lr = self.learning_rate / (math::abs(hessian) + lambda);
207
208        if self.use_adagrad {
209            // AdaGrad: per-weight adaptive learning rates.
210            for (i, (w, x)) in self.weights.iter_mut().zip(features.iter()).enumerate() {
211                let g = gradient * x;
212                self.sq_grad_accum[i] += g * g;
213                let adaptive_lr = base_lr / (math::sqrt(self.sq_grad_accum[i]) + ADAGRAD_EPS);
214                *w -= adaptive_lr * g;
215            }
216            self.sq_bias_accum += gradient * gradient;
217            let bias_lr = base_lr / (math::sqrt(self.sq_bias_accum) + ADAGRAD_EPS);
218            self.bias -= bias_lr * gradient;
219        } else {
220            // Plain Newton-scaled SGD.
221            for (w, x) in self.weights.iter_mut().zip(features.iter()) {
222                *w -= base_lr * gradient * x;
223            }
224            self.bias -= base_lr * gradient;
225        }
226    }
227
228    fn clone_fresh(&self) -> Box<dyn LeafModel> {
229        Box::new(LinearLeafModel::new(
230            self.learning_rate,
231            self.decay,
232            self.use_adagrad,
233        ))
234    }
235
236    fn clone_warm(&self) -> Box<dyn LeafModel> {
237        Box::new(LinearLeafModel {
238            weights: self.weights.clone(),
239            bias: self.bias,
240            learning_rate: self.learning_rate,
241            decay: self.decay,
242            use_adagrad: self.use_adagrad,
243            // Reset AdaGrad accumulators -- the child's gradient landscape
244            // differs from the parent's, so accumulated curvature estimates
245            // don't transfer. Fresh accumulators let the child's learning
246            // rates adapt to its own region.
247            sq_grad_accum: vec![0.0; self.weights.len()],
248            sq_bias_accum: 0.0,
249            initialized: self.initialized,
250        })
251    }
252}
253
254// ---------------------------------------------------------------------------
255// MLPLeafModel
256// ---------------------------------------------------------------------------
257
258/// Single hidden layer MLP leaf model with ReLU activation.
259///
260/// Learns a nonlinear function via backpropagation with Newton-scaled
261/// learning rate. Weights are lazily initialized on the first `update` call
262/// using a deterministic xorshift64 PRNG so results are reproducible.
263///
264/// Optional exponential weight decay (`decay`) gives the model a finite memory
265/// horizon for non-stationary streams.
266pub struct MLPLeafModel {
267    hidden_weights: Vec<Vec<f64>>, // [hidden_size][input_size]
268    hidden_bias: Vec<f64>,
269    output_weights: Vec<f64>,
270    output_bias: f64,
271    hidden_size: usize,
272    learning_rate: f64,
273    decay: Option<f64>,
274    seed: u64,
275    initialized: bool,
276    hidden_activations: Vec<f64>,
277    hidden_pre_activations: Vec<f64>,
278}
279
280impl MLPLeafModel {
281    /// Create a new MLP leaf model with the given hidden layer size, learning rate,
282    /// seed, and optional decay.
283    ///
284    /// The seed controls deterministic weight initialization. Different seeds
285    /// produce different initial weights, which is critical for ensemble diversity
286    /// when multiple MLP leaves share the same `hidden_size`.
287    pub fn new(hidden_size: usize, learning_rate: f64, seed: u64, decay: Option<f64>) -> Self {
288        Self {
289            hidden_weights: Vec::new(),
290            hidden_bias: Vec::new(),
291            output_weights: Vec::new(),
292            output_bias: 0.0,
293            hidden_size,
294            learning_rate,
295            decay,
296            seed,
297            initialized: false,
298            hidden_activations: Vec::new(),
299            hidden_pre_activations: Vec::new(),
300        }
301    }
302
303    /// Initialize weights using xorshift64, scaled to [-0.1, 0.1].
304    fn initialize(&mut self, input_size: usize) {
305        let mut state = self.seed ^ (self.hidden_size as u64);
306
307        self.hidden_weights = Vec::with_capacity(self.hidden_size);
308        for _ in 0..self.hidden_size {
309            let mut row = Vec::with_capacity(input_size);
310            for _ in 0..input_size {
311                let r = xorshift64(&mut state);
312                // Map u64 to [-0.1, 0.1]
313                let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
314                row.push(val);
315            }
316            self.hidden_weights.push(row);
317        }
318
319        self.hidden_bias = Vec::with_capacity(self.hidden_size);
320        for _ in 0..self.hidden_size {
321            let r = xorshift64(&mut state);
322            let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
323            self.hidden_bias.push(val);
324        }
325
326        self.output_weights = Vec::with_capacity(self.hidden_size);
327        for _ in 0..self.hidden_size {
328            let r = xorshift64(&mut state);
329            let val = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
330            self.output_weights.push(val);
331        }
332
333        {
334            let r = xorshift64(&mut state);
335            self.output_bias = (r as f64 / u64::MAX as f64) * 0.2 - 0.1;
336        }
337
338        self.hidden_activations = vec![0.0; self.hidden_size];
339        self.hidden_pre_activations = vec![0.0; self.hidden_size];
340        self.initialized = true;
341    }
342
343    /// Forward pass: compute hidden pre-activations, ReLU activations, and output.
344    fn forward(&mut self, features: &[f64]) -> f64 {
345        // Hidden layer
346        for h in 0..self.hidden_size {
347            let mut z = self.hidden_bias[h];
348            for (j, x) in features.iter().enumerate() {
349                if j < self.hidden_weights[h].len() {
350                    z += self.hidden_weights[h][j] * x;
351                }
352            }
353            self.hidden_pre_activations[h] = z;
354            // ReLU
355            self.hidden_activations[h] = if z > 0.0 { z } else { 0.0 };
356        }
357
358        // Output layer
359        let mut out = self.output_bias;
360        for (w, a) in self
361            .output_weights
362            .iter()
363            .zip(self.hidden_activations.iter())
364        {
365            out += w * a;
366        }
367        out
368    }
369}
370
371impl LeafModel for MLPLeafModel {
372    fn predict(&self, features: &[f64]) -> f64 {
373        if !self.initialized {
374            return 0.0;
375        }
376        // Non-mutating forward pass (can't store activations, recompute locally)
377        let hidden_acts: Vec<f64> = self
378            .hidden_weights
379            .iter()
380            .zip(self.hidden_bias.iter())
381            .map(|(hw, &hb)| {
382                let mut z = hb;
383                for (j, x) in features.iter().enumerate() {
384                    if j < hw.len() {
385                        z += hw[j] * x;
386                    }
387                }
388                if z > 0.0 {
389                    z
390                } else {
391                    0.0
392                }
393            })
394            .collect();
395        let mut out = self.output_bias;
396        for (w, a) in self.output_weights.iter().zip(hidden_acts.iter()) {
397            out += w * a;
398        }
399        out
400    }
401
402    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
403        if !self.initialized {
404            self.initialize(features.len());
405        }
406
407        // Exponential weight decay (forgetting old data).
408        if let Some(d) = self.decay {
409            for row in self.hidden_weights.iter_mut() {
410                for w in row.iter_mut() {
411                    *w *= d;
412                }
413            }
414            for b in self.hidden_bias.iter_mut() {
415                *b *= d;
416            }
417            for w in self.output_weights.iter_mut() {
418                *w *= d;
419            }
420            self.output_bias *= d;
421        }
422
423        // Forward pass (stores activations for backprop)
424        let _output = self.forward(features);
425
426        let effective_lr = self.learning_rate / (math::abs(hessian) + lambda);
427
428        // Backprop: output gradient is the incoming `gradient` (chain rule from loss)
429        let d_output = gradient;
430
431        // Gradient for output weights and bias
432        // d_loss/d_output_w[h] = d_output * hidden_activations[h]
433        // d_loss/d_output_bias = d_output
434        for h in 0..self.hidden_size {
435            self.output_weights[h] -= effective_lr * d_output * self.hidden_activations[h];
436        }
437        self.output_bias -= effective_lr * d_output;
438
439        // Gradient for hidden layer
440        for h in 0..self.hidden_size {
441            // d_loss/d_hidden_act[h] = d_output * output_weights[h]
442            let d_hidden_act = d_output * self.output_weights[h];
443
444            // ReLU derivative
445            let d_relu = if self.hidden_pre_activations[h] > 0.0 {
446                d_hidden_act
447            } else {
448                0.0
449            };
450
451            // Update hidden weights and bias
452            for (j, x) in features.iter().enumerate() {
453                if j < self.hidden_weights[h].len() {
454                    self.hidden_weights[h][j] -= effective_lr * d_relu * x;
455                }
456            }
457            self.hidden_bias[h] -= effective_lr * d_relu;
458        }
459    }
460
461    fn clone_fresh(&self) -> Box<dyn LeafModel> {
462        // Derive a new seed so each fresh clone gets distinct initial weights.
463        let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
464        Box::new(MLPLeafModel::new(
465            self.hidden_size,
466            self.learning_rate,
467            derived_seed,
468            self.decay,
469        ))
470    }
471
472    fn clone_warm(&self) -> Box<dyn LeafModel> {
473        Box::new(MLPLeafModel {
474            hidden_weights: self.hidden_weights.clone(),
475            hidden_bias: self.hidden_bias.clone(),
476            output_weights: self.output_weights.clone(),
477            output_bias: self.output_bias,
478            hidden_size: self.hidden_size,
479            learning_rate: self.learning_rate,
480            decay: self.decay,
481            // Derive a new seed so warm clones diverge if they re-initialize.
482            seed: self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(2),
483            initialized: self.initialized,
484            hidden_activations: vec![0.0; self.hidden_size],
485            hidden_pre_activations: vec![0.0; self.hidden_size],
486        })
487    }
488}
489
490// ---------------------------------------------------------------------------
491// AdaptiveLeafModel
492// ---------------------------------------------------------------------------
493
494/// Leaf model that starts as closed-form and promotes to a more complex model
495/// when the Hoeffding bound confirms it is statistically superior.
496///
497/// Runs a shadow model alongside the active (closed-form) model. On each
498/// update, both models are trained and their per-sample losses compared using
499/// the second-order Taylor approximation:
500///
501/// ```text
502/// loss_i = gradient * prediction + 0.5 * hessian * prediction^2
503/// advantage_i = loss_active_i - loss_shadow_i
504/// ```
505///
506/// When `mean(advantage) > epsilon` where epsilon is the Hoeffding bound
507/// (using the tree's `delta` parameter), the shadow model is promoted to
508/// active and the overhead drops to zero.
509///
510/// This uses the **same statistical guarantee** as the tree's split decisions --
511/// no arbitrary thresholds.
512pub struct AdaptiveLeafModel {
513    /// The currently active prediction model. Starts as ClosedForm.
514    active: Box<dyn LeafModel>,
515    /// Shadow model being evaluated against the active model.
516    shadow: Box<dyn LeafModel>,
517    /// Configuration of the shadow model type (for cloning).
518    promote_to: LeafModelType,
519    /// Cumulative loss advantage: sum(loss_active - loss_shadow).
520    /// Positive means the shadow is winning.
521    cumulative_advantage: f64,
522    /// Number of samples seen for the Hoeffding bound.
523    n: u64,
524    /// Running maximum |loss_diff| for range estimation (R in the bound).
525    max_loss_diff: f64,
526    /// Hoeffding confidence parameter (from tree config).
527    delta: f64,
528    /// Whether the shadow has been promoted.
529    promoted: bool,
530    /// Seed for reproducible cloning.
531    seed: u64,
532}
533
534impl AdaptiveLeafModel {
535    /// Create a new adaptive leaf model.
536    ///
537    /// The active model starts as `ClosedFormLeaf`. The shadow model is the
538    /// candidate that will be promoted if it proves statistically superior.
539    pub fn new(
540        shadow: Box<dyn LeafModel>,
541        promote_to: LeafModelType,
542        delta: f64,
543        seed: u64,
544    ) -> Self {
545        Self {
546            active: Box::new(ClosedFormLeaf::new()),
547            shadow,
548            promote_to,
549            cumulative_advantage: 0.0,
550            n: 0,
551            max_loss_diff: 0.0,
552            delta,
553            promoted: false,
554            seed,
555        }
556    }
557}
558
559impl LeafModel for AdaptiveLeafModel {
560    fn predict(&self, features: &[f64]) -> f64 {
561        self.active.predict(features)
562    }
563
564    fn update(&mut self, features: &[f64], gradient: f64, hessian: f64, lambda: f64) {
565        if self.promoted {
566            // Post-promotion: only update the promoted model.
567            self.active.update(features, gradient, hessian, lambda);
568            return;
569        }
570
571        // Compute predictions BEFORE updating (evaluate current state).
572        let pred_active = self.active.predict(features);
573        let pred_shadow = self.shadow.predict(features);
574
575        // Second-order Taylor loss approximation for each model.
576        // L(pred) ~= gradient * pred + 0.5 * hessian * pred^2
577        // This is the same loss proxy that XGBoost gain uses.
578        let loss_active = gradient * pred_active + 0.5 * hessian * pred_active * pred_active;
579        let loss_shadow = gradient * pred_shadow + 0.5 * hessian * pred_shadow * pred_shadow;
580
581        // Positive advantage means the shadow model is better (lower loss).
582        let diff = loss_active - loss_shadow;
583        self.cumulative_advantage += diff;
584        self.n += 1;
585
586        // Track range for the Hoeffding bound.
587        let abs_diff = math::abs(diff);
588        if abs_diff > self.max_loss_diff {
589            self.max_loss_diff = abs_diff;
590        }
591
592        // Update both models.
593        self.active.update(features, gradient, hessian, lambda);
594        self.shadow.update(features, gradient, hessian, lambda);
595
596        // Hoeffding bound test: is the shadow statistically better?
597        // epsilon = sqrt(R^2 * ln(1/delta) / (2*n))
598        // Promote when mean_advantage > epsilon.
599        if self.n >= 10 && self.max_loss_diff > 0.0 {
600            let mean_advantage = self.cumulative_advantage / self.n as f64;
601            if mean_advantage > 0.0 {
602                let r_squared = self.max_loss_diff * self.max_loss_diff;
603                let ln_inv_delta = math::ln(1.0 / self.delta);
604                let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * self.n as f64));
605
606                if mean_advantage > epsilon {
607                    // Promote: swap shadow into active, drop the old active.
608                    self.promoted = true;
609                    core::mem::swap(&mut self.active, &mut self.shadow);
610                }
611            }
612        }
613    }
614
615    fn clone_fresh(&self) -> Box<dyn LeafModel> {
616        let derived_seed = self.seed.wrapping_mul(0x9E3779B97F4A7C15).wrapping_add(1);
617        Box::new(AdaptiveLeafModel::new(
618            self.promote_to.create(derived_seed, self.delta),
619            self.promote_to.clone(),
620            self.delta,
621            derived_seed,
622        ))
623    }
624    // clone_warm: default (= clone_fresh). Promotion must be re-earned
625    // in each leaf's region -- the parent's promotion does not transfer.
626}
627
628// We need Send + Sync for AdaptiveLeafModel because it contains Box<dyn LeafModel>
629// which already requires Send + Sync. The compiler should derive these automatically,
630// but let's verify:
631// SAFETY: All fields are Send + Sync (Box<dyn LeafModel> requires it, f64/u64/bool are Send+Sync).
632unsafe impl Send for AdaptiveLeafModel {}
633unsafe impl Sync for AdaptiveLeafModel {}
634
635// ---------------------------------------------------------------------------
636// LeafModelType
637// ---------------------------------------------------------------------------
638
639/// Describes which leaf model architecture to use.
640///
641/// Used by tree builders to construct fresh leaf models when creating new leaves.
642///
643/// # Variants
644///
645/// - **`ClosedForm`** -- Standard constant leaf weight. Zero overhead (default).
646/// - **`Linear`** -- Per-leaf online ridge regression with AdaGrad optimization.
647///   Each leaf learns a local linear surface `w . x + b`. Recommended for
648///   low-depth trees (depth 2--4). Optional `decay` for concept drift.
649/// - **`MLP`** -- Per-leaf single-hidden-layer neural network with ReLU.
650///   Optional `decay` for concept drift.
651/// - **`Adaptive`** -- Starts as closed-form, auto-promotes to `promote_to`
652///   when the Hoeffding bound confirms it is statistically superior. Uses the
653///   tree's existing `delta` parameter -- no arbitrary thresholds.
654#[derive(Debug, Clone, Default, PartialEq)]
655#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
656#[non_exhaustive]
657pub enum LeafModelType {
658    /// Standard closed-form leaf weight.
659    #[default]
660    ClosedForm,
661
662    /// Online ridge regression, optionally with AdaGrad optimization.
663    ///
664    /// `decay`: optional exponential weight decay for non-stationary streams.
665    /// Typical values: 0.999 (slow drift) to 0.99 (fast drift).
666    ///
667    /// `use_adagrad`: when `true`, per-weight squared gradient accumulators
668    /// give each feature its own adaptive learning rate. When `false`
669    /// (default), all weights share a single Newton-scaled learning rate.
670    Linear {
671        /// SGD learning rate for the linear leaf weights.
672        learning_rate: f64,
673        /// Optional exponential weight decay for non-stationary streams.
674        #[cfg_attr(feature = "serde", serde(default))]
675        decay: Option<f64>,
676        /// When true, per-weight AdaGrad accumulators give adaptive rates.
677        #[cfg_attr(feature = "serde", serde(default))]
678        use_adagrad: bool,
679    },
680
681    /// Single hidden layer MLP with the given hidden size and learning rate.
682    ///
683    /// `decay`: optional exponential weight decay for non-stationary streams.
684    MLP {
685        /// Number of hidden units in the single hidden layer.
686        hidden_size: usize,
687        /// SGD learning rate for MLP weights.
688        learning_rate: f64,
689        /// Optional exponential weight decay.
690        #[cfg_attr(feature = "serde", serde(default))]
691        decay: Option<f64>,
692    },
693
694    /// Adaptive leaf that starts as closed-form and auto-promotes when
695    /// the Hoeffding bound confirms the promoted model is better.
696    ///
697    /// The `promote_to` field specifies the shadow model type to evaluate
698    /// against the default closed-form baseline.
699    Adaptive {
700        /// Shadow model type to evaluate for potential promotion.
701        promote_to: Box<LeafModelType>,
702    },
703}
704
705impl LeafModelType {
706    /// Create a fresh boxed leaf model of this type.
707    ///
708    /// The `seed` parameter controls deterministic initialization (MLP weights,
709    /// adaptive model seeding). The `delta` parameter is the Hoeffding bound
710    /// confidence level, used by [`Adaptive`](LeafModelType::Adaptive) leaves
711    /// for promotion decisions. For non-adaptive types, `delta` is unused.
712    pub fn create(&self, seed: u64, delta: f64) -> Box<dyn LeafModel> {
713        match self {
714            Self::ClosedForm => Box::new(ClosedFormLeaf::new()),
715            Self::Linear {
716                learning_rate,
717                decay,
718                use_adagrad,
719            } => Box::new(LinearLeafModel::new(*learning_rate, *decay, *use_adagrad)),
720            Self::MLP {
721                hidden_size,
722                learning_rate,
723                decay,
724            } => Box::new(MLPLeafModel::new(
725                *hidden_size,
726                *learning_rate,
727                seed,
728                *decay,
729            )),
730            Self::Adaptive { promote_to } => Box::new(AdaptiveLeafModel::new(
731                promote_to.create(seed, delta),
732                *promote_to.clone(),
733                delta,
734                seed,
735            )),
736        }
737    }
738}
739
740use crate::rng::xorshift64;
741
742// ---------------------------------------------------------------------------
743// Tests
744// ---------------------------------------------------------------------------
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749
750    /// Xorshift64 for deterministic test data generation.
751    fn xorshift64(state: &mut u64) -> u64 {
752        let mut s = *state;
753        s ^= s << 13;
754        s ^= s >> 7;
755        s ^= s << 17;
756        *state = s;
757        s
758    }
759
760    /// Convert xorshift output to f64 in [0, 1).
761    fn rand_f64(state: &mut u64) -> f64 {
762        xorshift64(state) as f64 / u64::MAX as f64
763    }
764
765    #[test]
766    fn closed_form_matches_formula() {
767        let mut leaf = ClosedFormLeaf::new();
768        let lambda = 1.0;
769
770        // Accumulate several gradient/hessian pairs
771        let updates = [(0.5, 1.0), (-0.3, 0.8), (1.2, 2.0), (-0.1, 0.5)];
772        let mut grad_sum = 0.0;
773        let mut hess_sum = 0.0;
774
775        for &(g, h) in &updates {
776            leaf.update(&[], g, h, lambda);
777            grad_sum += g;
778            hess_sum += h;
779        }
780
781        let expected = -grad_sum / (hess_sum + lambda);
782        let predicted = leaf.predict(&[]);
783
784        assert!(
785            (predicted - expected).abs() < 1e-12,
786            "closed form mismatch: got {predicted}, expected {expected}"
787        );
788    }
789
790    #[test]
791    fn closed_form_clone_fresh_resets() {
792        let mut leaf = ClosedFormLeaf::new();
793        leaf.update(&[], 5.0, 2.0, 1.0);
794        assert!(
795            leaf.predict(&[]).abs() > 0.0,
796            "leaf should have non-zero weight after update"
797        );
798
799        let fresh = leaf.clone_fresh();
800        assert!(
801            fresh.predict(&[]).abs() < 1e-15,
802            "fresh clone should predict 0, got {}",
803            fresh.predict(&[])
804        );
805    }
806
807    #[test]
808    fn linear_converges_on_linear_target() {
809        // Target: y = 2*x1 + 3*x2
810        let mut model = LinearLeafModel::new(0.01, None, false);
811        let lambda = 0.1;
812        let mut rng = 42u64;
813
814        for _ in 0..2000 {
815            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
816            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
817            let features = vec![x1, x2];
818            let target = 2.0 * x1 + 3.0 * x2;
819
820            let pred = model.predict(&features);
821            let gradient = 2.0 * (pred - target);
822            let hessian = 2.0;
823            model.update(&features, gradient, hessian, lambda);
824        }
825
826        let test_features = vec![0.5, -0.3];
827        let target = 2.0 * 0.5 + 3.0 * (-0.3);
828        let pred = model.predict(&test_features);
829
830        assert!(
831            (pred - target).abs() < 1.0,
832            "linear model should converge within 1.0 of target: pred={pred}, target={target}"
833        );
834    }
835
836    #[test]
837    fn linear_uninitialized_predicts_zero() {
838        let model = LinearLeafModel::new(0.01, None, false);
839        let pred = model.predict(&[1.0, 2.0, 3.0]);
840        assert!(
841            pred.abs() < 1e-15,
842            "uninitialized linear model should predict 0, got {pred}"
843        );
844    }
845
846    #[test]
847    fn linear_clone_warm_preserves_weights() {
848        let mut model = LinearLeafModel::new(0.01, None, false);
849        let features = vec![1.0, 2.0];
850
851        // Train it
852        for i in 0..100 {
853            let target = 3.0 * features[0] + 2.0 * features[1];
854            let pred = model.predict(&features);
855            let gradient = 2.0 * (pred - target);
856            model.update(&features, gradient, 2.0, 0.1);
857            // Avoid unused variable warning
858            let _ = i;
859        }
860
861        let trained_pred = model.predict(&features);
862        assert!(
863            trained_pred.abs() > 0.01,
864            "model should have learned something"
865        );
866
867        // Warm clone should predict similarly
868        let warm = model.clone_warm();
869        let warm_pred = warm.predict(&features);
870        assert!(
871            (warm_pred - trained_pred).abs() < 1e-12,
872            "warm clone should preserve weights: trained={trained_pred}, warm={warm_pred}"
873        );
874
875        // Fresh clone should predict 0
876        let fresh = model.clone_fresh();
877        let fresh_pred = fresh.predict(&features);
878        assert!(
879            fresh_pred.abs() < 1e-15,
880            "fresh clone should predict 0, got {fresh_pred}"
881        );
882    }
883
884    #[test]
885    fn linear_decay_forgets_old_data() {
886        // Decay should pull weights toward zero when training stops,
887        // demonstrating that old learned state is forgotten.
888        let mut model_decay = LinearLeafModel::new(0.05, Some(0.99), false);
889        let mut model_no_decay = LinearLeafModel::new(0.05, None, false);
890        let features = vec![1.0];
891        let lambda = 0.1;
892
893        // Train both on target = 5.0
894        for _ in 0..500 {
895            let pred_d = model_decay.predict(&features);
896            let pred_n = model_no_decay.predict(&features);
897            model_decay.update(&features, 2.0 * (pred_d - 5.0), 2.0, lambda);
898            model_no_decay.update(&features, 2.0 * (pred_n - 5.0), 2.0, lambda);
899        }
900
901        // Both should have learned roughly the same function.
902        let pred_d_trained = model_decay.predict(&features);
903        let pred_n_trained = model_no_decay.predict(&features);
904        assert!(
905            (pred_d_trained - 5.0).abs() < 2.0,
906            "decay model should approximate target"
907        );
908        assert!(
909            (pred_n_trained - 5.0).abs() < 2.0,
910            "no-decay model should approximate target"
911        );
912
913        // Now apply 200 rounds of zero-gradient updates (data stopped).
914        // The decay model's weights should shrink toward zero,
915        // while the no-decay model retains its learned state.
916        for _ in 0..200 {
917            model_decay.update(&features, 0.0, 1.0, lambda);
918            model_no_decay.update(&features, 0.0, 1.0, lambda);
919        }
920
921        let pred_d_after = model_decay.predict(&features);
922        let pred_n_after = model_no_decay.predict(&features);
923
924        // Decay model should have drifted toward zero.
925        assert!(
926            pred_d_after.abs() < pred_n_after.abs(),
927            "decay model should forget: decay pred={pred_d_after:.3}, no-decay pred={pred_n_after:.3}"
928        );
929    }
930
931    #[test]
932    fn mlp_produces_finite_predictions() {
933        let model_uninit = MLPLeafModel::new(4, 0.01, 42, None);
934        let features = vec![1.0, 2.0, 3.0];
935
936        let pred_before = model_uninit.predict(&features);
937        assert!(
938            pred_before.is_finite(),
939            "uninit prediction should be finite"
940        );
941        assert!(
942            pred_before.abs() < 1e-15,
943            "uninit prediction should be 0, got {pred_before}"
944        );
945
946        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
947        for _ in 0..10 {
948            model.update(&features, 0.5, 1.0, 0.1);
949        }
950        let pred_after = model.predict(&features);
951        assert!(
952            pred_after.is_finite(),
953            "prediction after training should be finite, got {pred_after}"
954        );
955    }
956
957    #[test]
958    fn mlp_loss_decreases() {
959        let mut model = MLPLeafModel::new(8, 0.05, 123, None);
960        let features = vec![1.0, -0.5, 0.3];
961        let target = 2.5;
962        let lambda = 0.1;
963
964        model.update(&features, 0.0, 1.0, lambda); // dummy update to initialize
965        let initial_pred = model.predict(&features);
966        let initial_error = (initial_pred - target).abs();
967
968        for _ in 0..200 {
969            let pred = model.predict(&features);
970            let gradient = 2.0 * (pred - target);
971            let hessian = 2.0;
972            model.update(&features, gradient, hessian, lambda);
973        }
974
975        let final_pred = model.predict(&features);
976        let final_error = (final_pred - target).abs();
977
978        assert!(
979            final_error < initial_error,
980            "MLP error should decrease: initial={initial_error}, final={final_error}"
981        );
982    }
983
984    #[test]
985    fn mlp_clone_fresh_resets() {
986        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
987        let features = vec![1.0, 2.0];
988
989        for _ in 0..20 {
990            model.update(&features, 0.5, 1.0, 0.1);
991        }
992
993        let trained_pred = model.predict(&features);
994        assert!(
995            trained_pred.abs() > 1e-10,
996            "trained model should have non-zero prediction"
997        );
998
999        let fresh = model.clone_fresh();
1000        let fresh_pred = fresh.predict(&features);
1001        assert!(
1002            fresh_pred.abs() < 1e-15,
1003            "fresh clone should predict 0, got {fresh_pred}"
1004        );
1005    }
1006
1007    #[test]
1008    fn mlp_clone_warm_preserves_weights() {
1009        let mut model = MLPLeafModel::new(4, 0.01, 42, None);
1010        let features = vec![1.0, 2.0];
1011
1012        for _ in 0..50 {
1013            model.update(&features, 0.5, 1.0, 0.1);
1014        }
1015
1016        let trained_pred = model.predict(&features);
1017        let warm = model.clone_warm();
1018        let warm_pred = warm.predict(&features);
1019
1020        assert!(
1021            (warm_pred - trained_pred).abs() < 1e-10,
1022            "warm clone should preserve predictions: trained={trained_pred}, warm={warm_pred}"
1023        );
1024    }
1025
1026    #[test]
1027    fn leaf_model_type_default_is_closed_form() {
1028        let default_type = LeafModelType::default();
1029        assert!(
1030            matches!(default_type, LeafModelType::ClosedForm),
1031            "default LeafModelType should be ClosedForm, got {default_type:?}"
1032        );
1033    }
1034
1035    #[test]
1036    fn leaf_model_type_create_all_variants() {
1037        let features = vec![1.0, 2.0, 3.0];
1038        let delta = 1e-7;
1039
1040        // ClosedForm
1041        let mut closed = LeafModelType::ClosedForm.create(0, delta);
1042        closed.update(&features, 1.0, 1.0, 0.1);
1043        let p = closed.predict(&features);
1044        assert!(p.is_finite(), "ClosedForm prediction should be finite");
1045
1046        // Linear
1047        let mut linear = LeafModelType::Linear {
1048            learning_rate: 0.01,
1049            decay: None,
1050            use_adagrad: false,
1051        }
1052        .create(0, delta);
1053        linear.update(&features, 1.0, 1.0, 0.1);
1054        let p = linear.predict(&features);
1055        assert!(p.is_finite(), "Linear prediction should be finite");
1056
1057        // MLP
1058        let mut mlp = LeafModelType::MLP {
1059            hidden_size: 4,
1060            learning_rate: 0.01,
1061            decay: None,
1062        }
1063        .create(99, delta);
1064        mlp.update(&features, 1.0, 1.0, 0.1);
1065        let p = mlp.predict(&features);
1066        assert!(p.is_finite(), "MLP prediction should be finite");
1067
1068        // Adaptive (promoting to Linear)
1069        let mut adaptive = LeafModelType::Adaptive {
1070            promote_to: Box::new(LeafModelType::Linear {
1071                learning_rate: 0.01,
1072                decay: None,
1073                use_adagrad: false,
1074            }),
1075        }
1076        .create(42, delta);
1077        adaptive.update(&features, 1.0, 1.0, 0.1);
1078        let p = adaptive.predict(&features);
1079        assert!(p.is_finite(), "Adaptive prediction should be finite");
1080    }
1081
1082    #[test]
1083    fn adaptive_promotes_on_linear_target() {
1084        // The shadow (Linear) should eventually outperform ClosedForm
1085        // on a target that varies with features.
1086        let promote_to = LeafModelType::Linear {
1087            learning_rate: 0.01,
1088            decay: None,
1089            use_adagrad: false,
1090        };
1091        let shadow = promote_to.create(42, 1e-7);
1092        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1093
1094        let mut rng = 42u64;
1095        for _ in 0..5000 {
1096            let x1 = rand_f64(&mut rng) * 2.0 - 1.0;
1097            let x2 = rand_f64(&mut rng) * 2.0 - 1.0;
1098            let features = vec![x1, x2];
1099            let target = 3.0 * x1 + 2.0 * x2;
1100
1101            let pred = model.predict(&features);
1102            let gradient = 2.0 * (pred - target);
1103            let hessian = 2.0;
1104            model.update(&features, gradient, hessian, 0.1);
1105        }
1106
1107        // After enough samples on a linear target, the Linear shadow
1108        // should have been promoted.
1109        assert!(
1110            model.promoted,
1111            "adaptive model should have promoted on linear target after 5000 samples"
1112        );
1113    }
1114
1115    #[test]
1116    fn adaptive_does_not_promote_on_constant_target() {
1117        // On a constant target, ClosedForm is optimal -- no promotion expected.
1118        let promote_to = LeafModelType::Linear {
1119            learning_rate: 0.01,
1120            decay: None,
1121            use_adagrad: false,
1122        };
1123        let shadow = promote_to.create(42, 1e-7);
1124        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-7, 42);
1125
1126        for _ in 0..2000 {
1127            let features = vec![1.0, 2.0];
1128            let target = 5.0; // constant -- no feature dependence
1129            let pred = model.predict(&features);
1130            let gradient = 2.0 * (pred - target);
1131            let hessian = 2.0;
1132            model.update(&features, gradient, hessian, 0.1);
1133        }
1134
1135        // With strict delta (1e-7) and a constant target, the Linear model
1136        // shouldn't gain a statistically significant advantage.
1137        // (It may or may not promote -- the point is it shouldn't be obvious.)
1138        // This test mainly verifies the mechanism doesn't crash and is conservative.
1139        let pred = model.predict(&[1.0, 2.0]);
1140        assert!(pred.is_finite(), "prediction should be finite");
1141    }
1142
1143    #[test]
1144    fn adaptive_clone_fresh_resets_promotion() {
1145        let promote_to = LeafModelType::Linear {
1146            learning_rate: 0.01,
1147            decay: None,
1148            use_adagrad: false,
1149        };
1150        let shadow = promote_to.create(42, 1e-3);
1151        let mut model = AdaptiveLeafModel::new(shadow, promote_to, 1e-3, 42);
1152
1153        // Force promotion via many linear-target samples
1154        let mut rng = 42u64;
1155        for _ in 0..5000 {
1156            let x = rand_f64(&mut rng) * 2.0 - 1.0;
1157            let features = vec![x];
1158            let pred = model.predict(&features);
1159            model.update(&features, 2.0 * (pred - 3.0 * x), 2.0, 0.1);
1160        }
1161
1162        let fresh = model.clone_fresh();
1163        // Fresh clone should predict 0 (reset state).
1164        let p = fresh.predict(&[0.5]);
1165        assert!(
1166            p.abs() < 1e-10,
1167            "fresh adaptive clone should predict ~0, got {p}"
1168        );
1169    }
1170}