Skip to main content

irithyll_core/
learner.rs

1//! Unified streaming learner trait for polymorphic model composition.
2//!
3//! [`StreamingLearner`] is an **object-safe** trait that abstracts over any
4//! online/streaming machine learning model -- gradient boosted trees, linear
5//! models, Naive Bayes, Mondrian forests, or anything else that can ingest
6//! samples one at a time and produce predictions.
7//!
8//! # Motivation
9//!
10//! Stacking ensembles and meta-learners need to treat heterogeneous base
11//! models uniformly: train them on the same stream, collect their predictions
12//! as features for a combiner, and manage their lifecycle (reset, clone,
13//! serialization). `StreamingLearner` provides exactly this interface.
14//!
15//! # Object Safety
16//!
17//! The trait is deliberately object-safe: every method uses `&self` /
18//! `&mut self` with concrete return types (no generics on methods, no
19//! `Self`-by-value in non-`Sized` positions). This means you can store
20//! `Box<dyn StreamingLearner>` in a `Vec`, enabling runtime-polymorphic
21//! stacking without monomorphization.
22
23use alloc::vec::Vec;
24
25/// Object-safe trait for any streaming (online) machine learning model.
26///
27/// All methods use `&self` or `&mut self` with concrete return types,
28/// ensuring the trait can be used behind `Box<dyn StreamingLearner>` for
29/// runtime-polymorphic stacking ensembles.
30///
31/// The `Send + Sync` supertraits allow learners to be shared across threads
32/// (e.g., for parallel prediction in async pipelines).
33///
34/// # Required Methods
35///
36/// | Method | Purpose |
37/// |--------|---------|
38/// | [`train_one`](Self::train_one) | Ingest a single weighted observation |
39/// | [`predict`](Self::predict) | Produce a prediction for a feature vector |
40/// | [`n_samples_seen`](Self::n_samples_seen) | Total observations ingested so far |
41/// | [`reset`](Self::reset) | Clear all learned state, returning to a fresh model |
42///
43/// # Default Methods
44///
45/// | Method | Purpose |
46/// |--------|---------|
47/// | [`train`](Self::train) | Convenience wrapper calling `train_one` with unit weight |
48/// | [`predict_batch`](Self::predict_batch) | Map `predict` over a slice of feature vectors |
49pub trait StreamingLearner: Send + Sync {
50    /// Train on a single observation with explicit sample weight.
51    ///
52    /// This is the fundamental training primitive. All streaming models must
53    /// support weighted incremental updates -- even if the weight is simply
54    /// used to scale gradient contributions.
55    ///
56    /// # Arguments
57    ///
58    /// * `features` -- feature vector for this observation
59    /// * `target` -- target value (regression) or class label (classification)
60    /// * `weight` -- sample weight (1.0 for uniform weighting)
61    fn train_one(&mut self, features: &[f64], target: f64, weight: f64);
62
63    /// Predict the target for the given feature vector.
64    ///
65    /// Returns the raw model output (no loss transform applied). For SGBT
66    /// this is the sum of tree predictions; for linear models this is the
67    /// dot product plus bias.
68    fn predict(&self, features: &[f64]) -> f64;
69
70    /// Total number of observations trained on since creation or last reset.
71    fn n_samples_seen(&self) -> u64;
72
73    /// Reset the model to its initial (untrained) state.
74    ///
75    /// After calling `reset()`, the model should behave identically to a
76    /// freshly constructed instance with the same configuration. In particular,
77    /// `n_samples_seen()` must return 0.
78    fn reset(&mut self);
79
80    /// Train on a single observation with unit weight.
81    ///
82    /// Convenience wrapper around [`train_one`](Self::train_one) that passes
83    /// `weight = 1.0`. This is the most common training call in practice.
84    fn train(&mut self, features: &[f64], target: f64) {
85        self.train_one(features, target, 1.0);
86    }
87
88    /// Predict for each row in a feature matrix.
89    ///
90    /// Returns a `Vec<f64>` with one prediction per input row. The default
91    /// implementation simply maps [`predict`](Self::predict) over the slices;
92    /// concrete implementations may override this for SIMD or batch-optimized
93    /// prediction paths.
94    ///
95    /// # Arguments
96    ///
97    /// * `feature_matrix` -- each element is a feature vector (one row)
98    fn predict_batch(&self, feature_matrix: &[&[f64]]) -> Vec<f64> {
99        feature_matrix.iter().map(|row| self.predict(row)).collect()
100    }
101}