Skip to main content

irithyll_core/
learner.rs

1//! Unified streaming learner trait for polymorphic model composition.
2//!
3//! [`StreamingLearner`] is an **object-safe** trait that abstracts over any
4//! online/streaming machine learning model -- gradient boosted trees, linear
5//! models, Naive Bayes, Mondrian forests, or anything else that can ingest
6//! samples one at a time and produce predictions.
7//!
8//! # Motivation
9//!
10//! Stacking ensembles and meta-learners need to treat heterogeneous base
11//! models uniformly: train them on the same stream, collect their predictions
12//! as features for a combiner, and manage their lifecycle (reset, clone,
13//! serialization). `StreamingLearner` provides exactly this interface.
14//!
15//! # Object Safety
16//!
17//! The trait is deliberately object-safe: every method uses `&self` /
18//! `&mut self` with concrete return types (no generics on methods, no
19//! `Self`-by-value in non-`Sized` positions). This means you can store
20//! `Box<dyn StreamingLearner>` in a `Vec`, enabling runtime-polymorphic
21//! stacking without monomorphization.
22//!
23//! # Capability Traits
24//!
25//! Opt-in capability traits narrow the required bound to the actual capability
26//! a wrapper or algorithm needs, enabling cleaner type-level documentation:
27//!
28//! - [`HasReadout`] -- models with a linear RLS readout layer (neural models,
29//!   KRLS, RLS). Used by `ProjectedLearner` supervised-projection path.
30//! - [`Tunable`] -- models that expose diagnostics and accept LR / lambda
31//!   adjustments from AutoML components.
32//! - [`Structural`] -- models whose capacity can grow or shrink at runtime
33//!   (tree ensembles: SGBT, ARF, Mondrian).
34
35use alloc::vec::Vec;
36
37/// Object-safe trait for any streaming (online) machine learning model.
38///
39/// All methods use `&self` or `&mut self` with concrete return types,
40/// ensuring the trait can be used behind `Box<dyn StreamingLearner>` for
41/// runtime-polymorphic stacking ensembles.
42///
43/// The `Send + Sync` supertraits allow learners to be shared across threads
44/// (e.g., for parallel prediction in async pipelines).
45///
46/// # Required Methods
47///
48/// | Method | Purpose |
49/// |--------|---------|
50/// | [`train_one`](Self::train_one) | Ingest a single weighted observation |
51/// | [`predict`](Self::predict) | Produce a prediction for a feature vector |
52/// | [`n_samples_seen`](Self::n_samples_seen) | Total observations ingested so far |
53/// | [`reset`](Self::reset) | Clear all learned state, returning to a fresh model |
54///
55/// # Default Methods
56///
57/// | Method | Purpose |
58/// |--------|---------|
59/// | [`train`](Self::train) | Convenience wrapper calling `train_one` with unit weight |
60/// | [`predict_batch`](Self::predict_batch) | Map `predict` over a slice of feature vectors |
61/// | [`diagnostics_array`](Self::diagnostics_array) | Raw diagnostic signals for adaptive tuning (all zeros by default) |
62/// | [`adjust_config`](Self::adjust_config) | Apply smooth LR/lambda adjustments (no-op by default) |
63/// | [`apply_structural_change`](Self::apply_structural_change) | Apply depth/steps changes at replacement boundaries (no-op by default) |
64/// | [`replacement_count`](Self::replacement_count) | Total internal model replacements (0 by default) |
65/// | [`readout_weights`](Self::readout_weights) | RLS readout weights for supervised projection (`None` by default) |
66pub trait StreamingLearner: Send + Sync {
67    /// Train on a single observation with explicit sample weight.
68    ///
69    /// This is the fundamental training primitive. All streaming models must
70    /// support weighted incremental updates -- even if the weight is simply
71    /// used to scale gradient contributions.
72    ///
73    /// # Arguments
74    ///
75    /// * `features` -- feature vector for this observation
76    /// * `target` -- target value (regression) or class label (classification)
77    /// * `weight` -- sample weight (1.0 for uniform weighting)
78    fn train_one(&mut self, features: &[f64], target: f64, weight: f64);
79
80    /// Predict the target for the given feature vector.
81    ///
82    /// Returns the raw model output (no loss transform applied). For SGBT
83    /// this is the sum of tree predictions; for linear models this is the
84    /// dot product plus bias.
85    fn predict(&self, features: &[f64]) -> f64;
86
87    /// Total number of observations trained on since creation or last reset.
88    fn n_samples_seen(&self) -> u64;
89
90    /// Reset the model to its initial (untrained) state.
91    ///
92    /// After calling `reset()`, the model should behave identically to a
93    /// freshly constructed instance with the same configuration. In particular,
94    /// `n_samples_seen()` must return 0.
95    fn reset(&mut self);
96
97    /// Train on a single observation with unit weight.
98    ///
99    /// Convenience wrapper around [`train_one`](Self::train_one) that passes
100    /// `weight = 1.0`. This is the most common training call in practice.
101    fn train(&mut self, features: &[f64], target: f64) {
102        self.train_one(features, target, 1.0);
103    }
104
105    /// Train on a single observation with an explicit distillation weight.
106    ///
107    /// Used by the knowledge-distillation path (`distill` feature) to replay
108    /// pseudo-targets from dominated candidates into the winner's model with a
109    /// down-weighted loss contribution.
110    ///
111    /// The default implementation delegates to [`train_one`](Self::train_one),
112    /// forwarding `weight` directly. Models that support weighted training
113    /// (e.g. `DistributionalSGBT`) therefore use the weight correctly. Models
114    /// that internally ignore the weight field still compile without changes --
115    /// the default is correct and transparent for both cases.
116    ///
117    /// Non-distillation consumers are unaffected: this method is not called by
118    /// any non-distillation code path.
119    fn train_one_weighted(&mut self, features: &[f64], target: f64, weight: f64) {
120        self.train_one(features, target, weight);
121    }
122
123    /// Predict for each row in a feature matrix.
124    ///
125    /// Returns a `Vec<f64>` with one prediction per input row. The default
126    /// implementation simply maps [`predict`](Self::predict) over the slices;
127    /// concrete implementations may override this for SIMD or batch-optimized
128    /// prediction paths.
129    ///
130    /// # Arguments
131    ///
132    /// * `feature_matrix` -- each element is a feature vector (one row)
133    fn predict_batch(&self, feature_matrix: &[&[f64]]) -> Vec<f64> {
134        feature_matrix.iter().map(|row| self.predict(row)).collect()
135    }
136
137    /// Raw diagnostic signals for adaptive tuning.
138    ///
139    /// Returns `[residual_alignment, reg_sensitivity, depth_sufficiency,
140    /// effective_dof, uncertainty]`. These five signals drive the
141    /// diagnostic adaptor in the auto-builder pipeline.
142    ///
143    /// Default: all zeros (model does not provide diagnostics). Models with
144    /// internal diagnostic caches (e.g. SGBT, DistributionalSGBT) override
145    /// this to return real computed values.
146    ///
147    /// # Deprecation
148    ///
149    /// Prefer `<T as Tunable>::diagnostics_array(model)` when the concrete type
150    /// is known, or hold a `Box<dyn Tunable>` when dynamic dispatch over only
151    /// tunable models is required. This shim keeps trait-object callers
152    /// (`Box<dyn StreamingLearner>`) working until v11.0 removes it.
153    #[deprecated(
154        since = "10.0.0",
155        note = "use the `Tunable` capability trait instead: `<T as Tunable>::diagnostics_array(model)` or hold `Box<dyn Tunable>`"
156    )]
157    #[doc(hidden)]
158    fn diagnostics_array(&self) -> [f64; 5] {
159        [0.0; 5]
160    }
161
162    /// Apply smooth learning rate and regularization adjustments.
163    ///
164    /// * `lr_multiplier` -- scales the current learning rate (1.0 = no change,
165    ///   0.99 = 1% decrease, 1.01 = 1% increase).
166    /// * `lambda_delta` -- added to the L2 regularization parameter
167    ///   (0.0 = no change, positive = increase, negative = decrease).
168    ///
169    /// Default: no-op. Override for models with adjustable hyperparameters
170    /// (e.g. SGBT, DistributionalSGBT).
171    ///
172    /// # Deprecation
173    ///
174    /// Prefer `<T as Tunable>::adjust_config(model, lr, lambda)` when the
175    /// concrete type is known. This shim keeps existing trait-object callers
176    /// working until v11.0 removes it.
177    #[deprecated(
178        since = "10.0.0",
179        note = "use the `Tunable` capability trait instead: `<T as Tunable>::adjust_config(model, lr_mult, lambda_delta)` or hold `Box<dyn Tunable>`"
180    )]
181    #[doc(hidden)]
182    fn adjust_config(&mut self, _lr_multiplier: f64, _lambda_delta: f64) {}
183
184    /// Apply structural changes at model replacement boundaries.
185    ///
186    /// * `depth_delta` -- adjust maximum tree depth (+1, -1, or 0).
187    /// * `steps_delta` -- adjust number of ensemble steps (+2, -2, or 0).
188    ///
189    /// Structural changes take effect on the *next* tree replacement, not
190    /// immediately. Default: no-op for models without structural config.
191    ///
192    /// # Deprecation
193    ///
194    /// Prefer `<T as Structural>::apply_structural_change(model, ...)` when
195    /// the concrete type is known. This shim keeps existing trait-object
196    /// callers working until v11.0 removes it.
197    #[deprecated(
198        since = "10.0.0",
199        note = "use the `Structural` capability trait instead: `<T as Structural>::apply_structural_change(model, depth_delta, steps_delta)` or hold `Box<dyn Structural>`"
200    )]
201    #[doc(hidden)]
202    fn apply_structural_change(&mut self, _depth_delta: i32, _steps_delta: i32) {}
203
204    /// Total number of internal model replacements (e.g. tree replacements
205    /// triggered by drift detection or max-tree-samples).
206    ///
207    /// External callers (e.g. the auto-builder) use this to detect when a
208    /// structural boundary has occurred and apply queued structural changes.
209    /// Default: 0 for models without replacement semantics.
210    ///
211    /// # Deprecation
212    ///
213    /// Prefer `<T as Structural>::replacement_count(model)` when the concrete
214    /// type is known. This shim keeps existing trait-object callers working
215    /// until v11.0 removes it.
216    #[deprecated(
217        since = "10.0.0",
218        note = "use the `Structural` capability trait instead: `<T as Structural>::replacement_count(model)` or hold `Box<dyn Structural>`"
219    )]
220    #[doc(hidden)]
221    fn replacement_count(&self) -> u64 {
222        0
223    }
224
225    /// Manually trigger a proactive prune check.
226    ///
227    /// Returns `true` if an internal component was pruned/replaced.
228    /// Default: no-op (returns `false`).
229    ///
230    /// # Deprecation
231    ///
232    /// Prefer `<T as Structural>::check_proactive_prune(model)` when the
233    /// concrete type is known. This shim keeps existing trait-object callers
234    /// working until v11.0 removes it.
235    #[deprecated(
236        since = "10.0.0",
237        note = "use the `Structural` capability trait instead: `<T as Structural>::check_proactive_prune(model)` or hold `Box<dyn Structural>`"
238    )]
239    #[doc(hidden)]
240    fn check_proactive_prune(&mut self) -> bool {
241        false
242    }
243
244    /// Dynamically set the contribution accuracy EWMA half-life.
245    ///
246    /// Recomputes `prune_alpha` so each correction batch contributes equally
247    /// regardless of size. Default: no-op.
248    ///
249    /// # Deprecation
250    ///
251    /// Prefer `<T as Structural>::set_prune_half_life(model, hl)` when the
252    /// concrete type is known. This shim keeps existing trait-object callers
253    /// working until v11.0 removes it.
254    #[deprecated(
255        since = "10.0.0",
256        note = "use the `Structural` capability trait instead: `<T as Structural>::set_prune_half_life(model, hl)` or hold `Box<dyn Structural>`"
257    )]
258    #[doc(hidden)]
259    fn set_prune_half_life(&mut self, _hl: usize) {}
260
261    /// Return the readout weight vector for supervised projection, if available.
262    ///
263    /// Models with an RLS readout layer return `Some(&weights)`. Models
264    /// without (KAN, SpikeNet, SGBT, etc.) return `None`. Used by
265    /// `ProjectedLearner` for supervised projection updates.
266    ///
267    /// # Deprecation
268    ///
269    /// Prefer `<T as HasReadout>::readout_weights(model)` when the concrete
270    /// type is known, or hold a `Box<dyn HasReadout>`. This shim keeps
271    /// existing trait-object callers working until v11.0 removes it.
272    #[deprecated(
273        since = "10.0.0",
274        note = "use the `HasReadout` capability trait instead: `<T as HasReadout>::readout_weights(model)` or hold `Box<dyn HasReadout>`"
275    )]
276    #[doc(hidden)]
277    fn readout_weights(&self) -> Option<&[f64]> {
278        None
279    }
280
281    /// Optional tree-level structure diagnostics.
282    ///
283    /// Returns per-tree: `(depth, n_leaves, leaf_weight_mean, leaf_weight_std, samples_seen)`.
284    /// Default: empty vec (model has no trees).
285    ///
286    /// # Deprecation
287    ///
288    /// Prefer `<T as Structural>::tree_structure(model)` when the concrete type
289    /// is known. This shim keeps existing trait-object callers working until
290    /// v11.0 removes it.
291    #[deprecated(
292        since = "10.0.0",
293        note = "use the `Structural` capability trait instead: `<T as Structural>::tree_structure(model)` or hold `Box<dyn Structural>`"
294    )]
295    #[doc(hidden)]
296    fn tree_structure(&self) -> Vec<(usize, usize, f64, f64, u64)> {
297        Vec::new()
298    }
299}
300
301// ===========================================================================
302// Capability traits
303// ===========================================================================
304
305/// Models that expose a linear readout weight vector.
306///
307/// Implemented by models with an RLS readout layer: neural models (ESN, Mamba,
308/// KAN, TTT, sLSTM, HGRN2, mGRADE, attention variants), kernel models (KRLS),
309/// and linear models (RLS). Used by `ProjectedLearner` for supervised
310/// projection updates.
311///
312/// # Object Safety
313///
314/// This trait is object-safe. `Box<dyn HasReadout>` is a legal type.
315pub trait HasReadout: StreamingLearner {
316    /// The linear readout weight vector.
317    ///
318    /// For RLS-family models this is the full coefficient vector including
319    /// the bias term if one is used. Length matches the model's internal
320    /// feature dimensionality.
321    fn readout_weights(&self) -> &[f64];
322}
323
324/// Models that expose diagnostics and accept smooth hyperparameter adjustments.
325///
326/// Implemented by models touched by AutoML components: SGBT, DistributionalSGBT,
327/// RLS, KAN, TTT, ESN, mGRADE, and any model with tunable LR or regularization.
328///
329/// # Object Safety
330///
331/// This trait is object-safe. `Box<dyn Tunable>` is a legal type.
332pub trait Tunable: StreamingLearner {
333    /// Raw diagnostic signals for adaptive tuning.
334    ///
335    /// Returns `[residual_alignment, reg_sensitivity, depth_sufficiency,
336    /// effective_dof, uncertainty]`. These five signals drive the diagnostic
337    /// adaptor in the AutoML pipeline.
338    fn diagnostics_array(&self) -> [f64; 5];
339
340    /// Apply smooth learning rate and regularization adjustments.
341    ///
342    /// * `lr_multiplier` -- scales the current learning rate (1.0 = no
343    ///   change, 0.99 = 1% decrease, 1.01 = 1% increase).
344    /// * `lambda_delta` -- additive delta applied to the L2 regularization
345    ///   parameter (0.0 = no change, positive = increase regularization).
346    fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64);
347}
348
349/// Models whose internal capacity can grow or shrink at runtime.
350///
351/// Implemented by tree ensemble models: SGBT, AdaptiveRandomForest,
352/// DistributionalSGBT, BaggedSGBT, stacked ensembles that delegate to trees.
353///
354/// # Object Safety
355///
356/// This trait is object-safe. `Box<dyn Structural>` is a legal type.
357pub trait Structural: StreamingLearner {
358    /// Apply depth/step changes that take effect at the next tree replacement.
359    ///
360    /// * `depth_delta` -- signed adjustment to maximum tree depth (+1, -1, 0).
361    /// * `steps_delta` -- signed adjustment to ensemble step count (+2, -2, 0).
362    fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32);
363
364    /// Total number of internal model replacements since creation or last reset.
365    ///
366    /// External callers use this counter to detect replacement boundaries and
367    /// apply queued structural changes.
368    fn replacement_count(&self) -> u64;
369
370    /// Manually trigger a proactive prune check.
371    ///
372    /// Returns `true` if an internal component was pruned or replaced.
373    /// Defaults to `false` (no-op) for models without proactive pruning.
374    fn check_proactive_prune(&mut self) -> bool {
375        false
376    }
377
378    /// Dynamically set the contribution-accuracy EWMA half-life.
379    ///
380    /// Recomputes `prune_alpha` so each correction batch contributes equally
381    /// regardless of batch size. Default: no-op for models without an EWMA
382    /// prune accumulator.
383    fn set_prune_half_life(&mut self, _hl: usize) {}
384
385    /// Per-tree structure diagnostics.
386    ///
387    /// Returns one tuple per tree:
388    /// `(max_depth, n_leaves, leaf_weight_mean, leaf_weight_std, samples_seen)`.
389    /// Defaults to an empty vec for models without tree diagnostics.
390    fn tree_structure(&self) -> Vec<(usize, usize, f64, f64, u64)> {
391        Vec::new()
392    }
393}
394
395// ===========================================================================
396// Object-safety smoke tests
397// ===========================================================================
398//
399// These functions never run -- they only compile-check object safety.
400// If any of the traits above inadvertently gains a non-object-safe method,
401// the test below will fail to compile.
402
403#[cfg(test)]
404mod _object_safety_tests {
405    use super::{HasReadout, StreamingLearner, Structural, Tunable};
406    use alloc::boxed::Box;
407
408    fn _object_safe_streaming_learner(_: Box<dyn StreamingLearner>) {}
409    fn _object_safe_has_readout(_: Box<dyn HasReadout>) {}
410    fn _object_safe_tunable(_: Box<dyn Tunable>) {}
411    fn _object_safe_structural(_: Box<dyn Structural>) {}
412}