Skip to main content

irithyll_core/ensemble/
multi_target.rs

1//! Multi-target regression via parallel SGBT ensembles.
2//!
3//! For T targets, maintains T independent `SGBT<L>` models, each trained on
4//! a single target dimension. Predictions return a `Vec<f64>` of length T.
5
6use alloc::vec::Vec;
7
8use crate::ensemble::config::SGBTConfig;
9use crate::ensemble::SGBT;
10use crate::error::{ConfigError, IrithyllError};
11use crate::loss::squared::SquaredLoss;
12use crate::loss::Loss;
13use crate::sample::SampleRef;
14
15/// Multi-target regression SGBT.
16///
17/// Wraps `T` independent `SGBT<L>` models, one per target dimension.
18/// Each model is trained and predicts independently, sharing the same
19/// configuration and loss function.
20///
21/// # Examples
22///
23/// ```text
24/// use irithyll::ensemble::multi_target::MultiTargetSGBT;
25/// use irithyll::SGBTConfig;
26///
27/// let config = SGBTConfig::builder()
28///     .n_steps(10)
29///     .learning_rate(0.1)
30///     .grace_period(10)
31///     .build()
32///     .unwrap();
33///
34/// let mut model = MultiTargetSGBT::new(config, 3).unwrap();
35/// model.train_one(&[1.0, 2.0], &[0.5, 1.0, 1.5]);
36/// let preds = model.predict(&[1.0, 2.0]);
37/// assert_eq!(preds.len(), 3);
38/// ```
39#[derive(Debug)]
40pub struct MultiTargetSGBT<L: Loss = SquaredLoss> {
41    /// One SGBT per target dimension.
42    models: Vec<SGBT<L>>,
43    /// Number of target dimensions.
44    n_targets: usize,
45    /// Total samples seen.
46    samples_seen: u64,
47}
48
49impl<L: Loss + Clone> Clone for MultiTargetSGBT<L> {
50    fn clone(&self) -> Self {
51        Self {
52            models: self.models.clone(),
53            n_targets: self.n_targets,
54            samples_seen: self.samples_seen,
55        }
56    }
57}
58
59impl MultiTargetSGBT<SquaredLoss> {
60    /// Create a new multi-target SGBT with squared loss (default).
61    ///
62    /// # Errors
63    ///
64    /// Returns [`IrithyllError::InvalidConfig`] if `n_targets < 1`.
65    pub fn new(config: SGBTConfig, n_targets: usize) -> crate::error::Result<Self> {
66        Self::with_loss(config, SquaredLoss, n_targets)
67    }
68}
69
70impl<L: Loss + Clone> MultiTargetSGBT<L> {
71    /// Create a new multi-target SGBT with a custom loss function.
72    ///
73    /// The loss is cloned for each target model.
74    ///
75    /// # Errors
76    ///
77    /// Returns [`IrithyllError::InvalidConfig`] if `n_targets < 1`.
78    pub fn with_loss(config: SGBTConfig, loss: L, n_targets: usize) -> crate::error::Result<Self> {
79        if n_targets < 1 {
80            return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
81                "n_targets",
82                "must be >= 1",
83                n_targets,
84            )));
85        }
86
87        let models = (0..n_targets)
88            .map(|_| SGBT::with_loss(config.clone(), loss.clone()))
89            .collect();
90
91        Ok(Self {
92            models,
93            n_targets,
94            samples_seen: 0,
95        })
96    }
97
98    /// Train on a single sample with multiple target values.
99    ///
100    /// # Panics
101    ///
102    /// Panics if `targets.len() != n_targets`.
103    pub fn train_one(&mut self, features: &[f64], targets: &[f64]) {
104        assert_eq!(
105            targets.len(),
106            self.n_targets,
107            "expected {} targets, got {}",
108            self.n_targets,
109            targets.len()
110        );
111        self.samples_seen += 1;
112        for (model, &target) in self.models.iter_mut().zip(targets.iter()) {
113            let sample = SampleRef::new(features, target);
114            model.train_one(&sample);
115        }
116    }
117
118    /// Train on a batch of multi-target samples.
119    pub fn train_batch(&mut self, feature_matrix: &[Vec<f64>], target_matrix: &[Vec<f64>]) {
120        for (features, targets) in feature_matrix.iter().zip(target_matrix.iter()) {
121            self.train_one(features, targets);
122        }
123    }
124
125    /// Predict all target values for a feature vector.
126    pub fn predict(&self, features: &[f64]) -> Vec<f64> {
127        self.models.iter().map(|m| m.predict(features)).collect()
128    }
129
130    /// Number of target dimensions.
131    pub fn n_targets(&self) -> usize {
132        self.n_targets
133    }
134
135    /// Total samples trained.
136    pub fn n_samples_seen(&self) -> u64 {
137        self.samples_seen
138    }
139
140    /// Access the model for a specific target dimension.
141    ///
142    /// # Panics
143    ///
144    /// Panics if `idx >= n_targets`.
145    pub fn model(&self, idx: usize) -> &SGBT<L> {
146        &self.models[idx]
147    }
148
149    /// Access all target models.
150    pub fn models(&self) -> &[SGBT<L>] {
151        &self.models
152    }
153
154    /// Reset all target models.
155    pub fn reset(&mut self) {
156        for model in &mut self.models {
157            model.reset();
158        }
159        self.samples_seen = 0;
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::sample::Sample;
167    use alloc::string::ToString;
168    use alloc::vec;
169
170    fn test_config() -> SGBTConfig {
171        SGBTConfig::builder()
172            .n_steps(5)
173            .learning_rate(0.1)
174            .grace_period(10)
175            .max_depth(3)
176            .n_bins(8)
177            .build()
178            .unwrap()
179    }
180
181    #[test]
182    fn new_multi_target_creates_models() {
183        let model = MultiTargetSGBT::new(test_config(), 3).unwrap();
184        assert_eq!(model.n_targets(), 3);
185        assert_eq!(model.n_samples_seen(), 0);
186    }
187
188    #[test]
189    fn rejects_zero_targets() {
190        let err = MultiTargetSGBT::new(test_config(), 0).unwrap_err();
191        assert!(
192            err.to_string().contains("n_targets"),
193            "error should mention n_targets: {}",
194            err
195        );
196    }
197
198    #[test]
199    fn single_target_works() {
200        let mut model = MultiTargetSGBT::new(test_config(), 1).unwrap();
201        model.train_one(&[1.0, 2.0], &[5.0]);
202        let preds = model.predict(&[1.0, 2.0]);
203        assert_eq!(preds.len(), 1);
204    }
205
206    #[test]
207    fn train_and_predict() {
208        let mut model = MultiTargetSGBT::new(test_config(), 2).unwrap();
209
210        for i in 0..100 {
211            let x = i as f64 * 0.1;
212            model.train_one(&[x, x * 2.0], &[x * 3.0, -x]);
213        }
214
215        assert_eq!(model.n_samples_seen(), 100);
216        let preds = model.predict(&[1.0, 2.0]);
217        assert_eq!(preds.len(), 2);
218        assert!(preds[0].is_finite());
219        assert!(preds[1].is_finite());
220    }
221
222    #[test]
223    fn targets_are_independent() {
224        let config = test_config();
225        let mut multi = MultiTargetSGBT::new(config.clone(), 2).unwrap();
226        let mut single = SGBT::new(config);
227
228        let mut rng: u64 = 42;
229        for _ in 0..200 {
230            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
231            let x = (rng >> 33) as f64 / (u32::MAX as f64);
232            let t0 = 2.0 * x;
233            let t1 = -3.0 * x;
234
235            multi.train_one(&[x], &[t0, t1]);
236            single.train_one(&Sample::new(vec![x], t0));
237        }
238
239        // The first target model should match the single model.
240        let pred_multi = multi.predict(&[0.5]);
241        let pred_single = single.predict(&[0.5]);
242        assert!(
243            (pred_multi[0] - pred_single).abs() < 1e-10,
244            "target 0 should match independent model: multi={}, single={}",
245            pred_multi[0],
246            pred_single
247        );
248    }
249
250    #[test]
251    fn reset_clears_state() {
252        let mut model = MultiTargetSGBT::new(test_config(), 2).unwrap();
253        for i in 0..50 {
254            let x = i as f64 * 0.1;
255            model.train_one(&[x], &[x, x * 2.0]);
256        }
257        model.reset();
258        assert_eq!(model.n_samples_seen(), 0);
259        let preds = model.predict(&[1.0]);
260        for &p in &preds {
261            assert!(p.abs() < 1e-12, "after reset, prediction should be ~0.0");
262        }
263    }
264
265    #[test]
266    fn model_accessor_works() {
267        let model = MultiTargetSGBT::new(test_config(), 3).unwrap();
268        assert_eq!(model.model(0).n_steps(), 5);
269        assert_eq!(model.model(2).n_steps(), 5);
270        assert_eq!(model.models().len(), 3);
271    }
272
273    #[test]
274    #[should_panic(expected = "expected 2 targets")]
275    fn wrong_target_count_panics() {
276        let mut model = MultiTargetSGBT::new(test_config(), 2).unwrap();
277        model.train_one(&[1.0], &[1.0, 2.0, 3.0]);
278    }
279
280    #[test]
281    fn convergence_on_linear_signal() {
282        let config = SGBTConfig::builder()
283            .n_steps(10)
284            .learning_rate(0.1)
285            .grace_period(10)
286            .max_depth(3)
287            .n_bins(16)
288            .build()
289            .unwrap();
290        let mut model = MultiTargetSGBT::new(config, 2).unwrap();
291
292        let mut rng: u64 = 99;
293        let mut early_mse = [0.0f64; 2];
294        let mut late_mse = [0.0f64; 2];
295        let mut early_n = 0;
296        let mut late_n = 0;
297
298        for i in 0..500 {
299            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
300            let x0 = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0 - 5.0;
301            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
302            let x1 = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0 - 5.0;
303
304            let t0 = 2.0 * x0 + 3.0 * x1;
305            let t1 = -x0 + 0.5 * x1;
306
307            let preds = model.predict(&[x0, x1]);
308
309            if (50..150).contains(&i) {
310                early_mse[0] += (preds[0] - t0).powi(2);
311                early_mse[1] += (preds[1] - t1).powi(2);
312                early_n += 1;
313            }
314            if i >= 400 {
315                late_mse[0] += (preds[0] - t0).powi(2);
316                late_mse[1] += (preds[1] - t1).powi(2);
317                late_n += 1;
318            }
319
320            model.train_one(&[x0, x1], &[t0, t1]);
321        }
322
323        let early_rmse_0 = (early_mse[0] / early_n as f64).sqrt();
324        let late_rmse_0 = (late_mse[0] / late_n as f64).sqrt();
325        assert!(
326            late_rmse_0 < early_rmse_0,
327            "target 0 RMSE should improve: early={:.4}, late={:.4}",
328            early_rmse_0,
329            late_rmse_0
330        );
331
332        let early_rmse_1 = (early_mse[1] / early_n as f64).sqrt();
333        let late_rmse_1 = (late_mse[1] / late_n as f64).sqrt();
334        assert!(
335            late_rmse_1 < early_rmse_1,
336            "target 1 RMSE should improve: early={:.4}, late={:.4}",
337            early_rmse_1,
338            late_rmse_1
339        );
340    }
341}