Skip to main content

irithyll_core/
learner.rs

1//! Unified streaming learner trait for polymorphic model composition.
2//!
3//! [`StreamingLearner`] is an **object-safe** trait that abstracts over any
4//! online/streaming machine learning model -- gradient boosted trees, linear
5//! models, Naive Bayes, Mondrian forests, or anything else that can ingest
6//! samples one at a time and produce predictions.
7//!
8//! # Motivation
9//!
10//! Stacking ensembles and meta-learners need to treat heterogeneous base
11//! models uniformly: train them on the same stream, collect their predictions
12//! as features for a combiner, and manage their lifecycle (reset, clone,
13//! serialization). `StreamingLearner` provides exactly this interface.
14//!
15//! # Object Safety
16//!
17//! The trait is deliberately object-safe: every method uses `&self` /
18//! `&mut self` with concrete return types (no generics on methods, no
19//! `Self`-by-value in non-`Sized` positions). This means you can store
20//! `Box<dyn StreamingLearner>` in a `Vec`, enabling runtime-polymorphic
21//! stacking without monomorphization.
22
23use alloc::vec::Vec;
24
25/// Object-safe trait for any streaming (online) machine learning model.
26///
27/// All methods use `&self` or `&mut self` with concrete return types,
28/// ensuring the trait can be used behind `Box<dyn StreamingLearner>` for
29/// runtime-polymorphic stacking ensembles.
30///
31/// The `Send + Sync` supertraits allow learners to be shared across threads
32/// (e.g., for parallel prediction in async pipelines).
33///
34/// # Required Methods
35///
36/// | Method | Purpose |
37/// |--------|---------|
38/// | [`train_one`](Self::train_one) | Ingest a single weighted observation |
39/// | [`predict`](Self::predict) | Produce a prediction for a feature vector |
40/// | [`n_samples_seen`](Self::n_samples_seen) | Total observations ingested so far |
41/// | [`reset`](Self::reset) | Clear all learned state, returning to a fresh model |
42///
43/// # Default Methods
44///
45/// | Method | Purpose |
46/// |--------|---------|
47/// | [`train`](Self::train) | Convenience wrapper calling `train_one` with unit weight |
48/// | [`predict_batch`](Self::predict_batch) | Map `predict` over a slice of feature vectors |
49/// | [`diagnostics_array`](Self::diagnostics_array) | Raw diagnostic signals for adaptive tuning (all zeros by default) |
50/// | [`adjust_config`](Self::adjust_config) | Apply smooth LR/lambda adjustments (no-op by default) |
51/// | [`apply_structural_change`](Self::apply_structural_change) | Apply depth/steps changes at replacement boundaries (no-op by default) |
52/// | [`replacement_count`](Self::replacement_count) | Total internal model replacements (0 by default) |
53/// | [`readout_weights`](Self::readout_weights) | RLS readout weights for supervised projection (`None` by default) |
54pub trait StreamingLearner: Send + Sync {
55    /// Train on a single observation with explicit sample weight.
56    ///
57    /// This is the fundamental training primitive. All streaming models must
58    /// support weighted incremental updates -- even if the weight is simply
59    /// used to scale gradient contributions.
60    ///
61    /// # Arguments
62    ///
63    /// * `features` -- feature vector for this observation
64    /// * `target` -- target value (regression) or class label (classification)
65    /// * `weight` -- sample weight (1.0 for uniform weighting)
66    fn train_one(&mut self, features: &[f64], target: f64, weight: f64);
67
68    /// Predict the target for the given feature vector.
69    ///
70    /// Returns the raw model output (no loss transform applied). For SGBT
71    /// this is the sum of tree predictions; for linear models this is the
72    /// dot product plus bias.
73    fn predict(&self, features: &[f64]) -> f64;
74
75    /// Total number of observations trained on since creation or last reset.
76    fn n_samples_seen(&self) -> u64;
77
78    /// Reset the model to its initial (untrained) state.
79    ///
80    /// After calling `reset()`, the model should behave identically to a
81    /// freshly constructed instance with the same configuration. In particular,
82    /// `n_samples_seen()` must return 0.
83    fn reset(&mut self);
84
85    /// Train on a single observation with unit weight.
86    ///
87    /// Convenience wrapper around [`train_one`](Self::train_one) that passes
88    /// `weight = 1.0`. This is the most common training call in practice.
89    fn train(&mut self, features: &[f64], target: f64) {
90        self.train_one(features, target, 1.0);
91    }
92
93    /// Predict for each row in a feature matrix.
94    ///
95    /// Returns a `Vec<f64>` with one prediction per input row. The default
96    /// implementation simply maps [`predict`](Self::predict) over the slices;
97    /// concrete implementations may override this for SIMD or batch-optimized
98    /// prediction paths.
99    ///
100    /// # Arguments
101    ///
102    /// * `feature_matrix` -- each element is a feature vector (one row)
103    fn predict_batch(&self, feature_matrix: &[&[f64]]) -> Vec<f64> {
104        feature_matrix.iter().map(|row| self.predict(row)).collect()
105    }
106
107    /// Raw diagnostic signals for adaptive tuning.
108    ///
109    /// Returns `[residual_alignment, reg_sensitivity, depth_sufficiency,
110    /// effective_dof, uncertainty]`. These five signals drive the
111    /// diagnostic adaptor in the auto-builder pipeline.
112    ///
113    /// Default: all zeros (model does not provide diagnostics). Models with
114    /// internal diagnostic caches (e.g. SGBT, DistributionalSGBT) override
115    /// this to return real computed values.
116    fn diagnostics_array(&self) -> [f64; 5] {
117        [0.0; 5]
118    }
119
120    /// Apply smooth learning rate and regularization adjustments.
121    ///
122    /// * `lr_multiplier` -- scales the current learning rate (1.0 = no change,
123    ///   0.99 = 1% decrease, 1.01 = 1% increase).
124    /// * `lambda_delta` -- added to the L2 regularization parameter
125    ///   (0.0 = no change, positive = increase, negative = decrease).
126    ///
127    /// Default: no-op. Override for models with adjustable hyperparameters
128    /// (e.g. SGBT, DistributionalSGBT).
129    fn adjust_config(&mut self, _lr_multiplier: f64, _lambda_delta: f64) {}
130
131    /// Apply structural changes at model replacement boundaries.
132    ///
133    /// * `depth_delta` -- adjust maximum tree depth (+1, -1, or 0).
134    /// * `steps_delta` -- adjust number of ensemble steps (+2, -2, or 0).
135    ///
136    /// Structural changes take effect on the *next* tree replacement, not
137    /// immediately. Default: no-op for models without structural config.
138    fn apply_structural_change(&mut self, _depth_delta: i32, _steps_delta: i32) {}
139
140    /// Total number of internal model replacements (e.g. tree replacements
141    /// triggered by drift detection or max-tree-samples).
142    ///
143    /// External callers (e.g. the auto-builder) use this to detect when a
144    /// structural boundary has occurred and apply queued structural changes.
145    /// Default: 0 for models without replacement semantics.
146    fn replacement_count(&self) -> u64 {
147        0
148    }
149
150    /// Manually trigger a proactive prune check.
151    ///
152    /// Returns `true` if an internal component was pruned/replaced.
153    /// Default: no-op (returns `false`).
154    fn check_proactive_prune(&mut self) -> bool {
155        false
156    }
157
158    /// Dynamically set the contribution accuracy EWMA half-life.
159    ///
160    /// Recomputes `prune_alpha` so each correction batch contributes equally
161    /// regardless of size. Default: no-op.
162    fn set_prune_half_life(&mut self, _hl: usize) {}
163
164    /// Return the readout weight vector for supervised projection, if available.
165    ///
166    /// Models with an RLS readout layer return `Some(&weights)`. Models
167    /// without (KAN, SpikeNet, SGBT, etc.) return `None`. Used by
168    /// `ProjectedLearner` for supervised projection updates.
169    fn readout_weights(&self) -> Option<&[f64]> {
170        None
171    }
172
173    /// Optional tree-level structure diagnostics.
174    ///
175    /// Returns per-tree: `(depth, n_leaves, leaf_weight_mean, leaf_weight_std, samples_seen)`.
176    /// Default: empty vec (model has no trees).
177    fn tree_structure(&self) -> Vec<(usize, usize, f64, f64, u64)> {
178        Vec::new()
179    }
180}