1use alloc::vec::Vec;
7
8use crate::ensemble::config::SGBTConfig;
9use crate::ensemble::SGBT;
10use crate::error::{ConfigError, IrithyllError};
11use crate::loss::squared::SquaredLoss;
12use crate::loss::Loss;
13use crate::sample::SampleRef;
14
15#[derive(Debug)]
40pub struct MultiTargetSGBT<L: Loss = SquaredLoss> {
41 models: Vec<SGBT<L>>,
43 n_targets: usize,
45 samples_seen: u64,
47}
48
49impl<L: Loss + Clone> Clone for MultiTargetSGBT<L> {
50 fn clone(&self) -> Self {
51 Self {
52 models: self.models.clone(),
53 n_targets: self.n_targets,
54 samples_seen: self.samples_seen,
55 }
56 }
57}
58
59impl MultiTargetSGBT<SquaredLoss> {
60 pub fn new(config: SGBTConfig, n_targets: usize) -> crate::error::Result<Self> {
66 Self::with_loss(config, SquaredLoss, n_targets)
67 }
68}
69
70impl<L: Loss + Clone> MultiTargetSGBT<L> {
71 pub fn with_loss(config: SGBTConfig, loss: L, n_targets: usize) -> crate::error::Result<Self> {
79 if n_targets < 1 {
80 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
81 "n_targets",
82 "must be >= 1",
83 n_targets,
84 )));
85 }
86
87 let models = (0..n_targets)
88 .map(|_| SGBT::with_loss(config.clone(), loss.clone()))
89 .collect();
90
91 Ok(Self {
92 models,
93 n_targets,
94 samples_seen: 0,
95 })
96 }
97
98 pub fn train_one(&mut self, features: &[f64], targets: &[f64]) {
104 assert_eq!(
105 targets.len(),
106 self.n_targets,
107 "expected {} targets, got {}",
108 self.n_targets,
109 targets.len()
110 );
111 self.samples_seen += 1;
112 for (model, &target) in self.models.iter_mut().zip(targets.iter()) {
113 let sample = SampleRef::new(features, target);
114 model.train_one(&sample);
115 }
116 }
117
118 pub fn train_batch(&mut self, feature_matrix: &[Vec<f64>], target_matrix: &[Vec<f64>]) {
120 for (features, targets) in feature_matrix.iter().zip(target_matrix.iter()) {
121 self.train_one(features, targets);
122 }
123 }
124
125 pub fn predict(&self, features: &[f64]) -> Vec<f64> {
127 self.models.iter().map(|m| m.predict(features)).collect()
128 }
129
130 pub fn n_targets(&self) -> usize {
132 self.n_targets
133 }
134
135 pub fn n_samples_seen(&self) -> u64 {
137 self.samples_seen
138 }
139
140 pub fn model(&self, idx: usize) -> &SGBT<L> {
146 &self.models[idx]
147 }
148
149 pub fn models(&self) -> &[SGBT<L>] {
151 &self.models
152 }
153
154 pub fn reset(&mut self) {
156 for model in &mut self.models {
157 model.reset();
158 }
159 self.samples_seen = 0;
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::sample::Sample;
167 use alloc::string::ToString;
168 use alloc::vec;
169
170 fn test_config() -> SGBTConfig {
171 SGBTConfig::builder()
172 .n_steps(5)
173 .learning_rate(0.1)
174 .grace_period(10)
175 .max_depth(3)
176 .n_bins(8)
177 .build()
178 .unwrap()
179 }
180
181 #[test]
182 fn new_multi_target_creates_models() {
183 let model = MultiTargetSGBT::new(test_config(), 3).unwrap();
184 assert_eq!(model.n_targets(), 3);
185 assert_eq!(model.n_samples_seen(), 0);
186 }
187
188 #[test]
189 fn rejects_zero_targets() {
190 let err = MultiTargetSGBT::new(test_config(), 0).unwrap_err();
191 assert!(
192 err.to_string().contains("n_targets"),
193 "error should mention n_targets: {}",
194 err
195 );
196 }
197
198 #[test]
199 fn single_target_works() {
200 let mut model = MultiTargetSGBT::new(test_config(), 1).unwrap();
201 model.train_one(&[1.0, 2.0], &[5.0]);
202 let preds = model.predict(&[1.0, 2.0]);
203 assert_eq!(preds.len(), 1);
204 }
205
206 #[test]
207 fn train_and_predict() {
208 let mut model = MultiTargetSGBT::new(test_config(), 2).unwrap();
209
210 for i in 0..100 {
211 let x = i as f64 * 0.1;
212 model.train_one(&[x, x * 2.0], &[x * 3.0, -x]);
213 }
214
215 assert_eq!(model.n_samples_seen(), 100);
216 let preds = model.predict(&[1.0, 2.0]);
217 assert_eq!(preds.len(), 2);
218 assert!(preds[0].is_finite());
219 assert!(preds[1].is_finite());
220 }
221
222 #[test]
223 fn targets_are_independent() {
224 let config = test_config();
225 let mut multi = MultiTargetSGBT::new(config.clone(), 2).unwrap();
226 let mut single = SGBT::new(config);
227
228 let mut rng: u64 = 42;
229 for _ in 0..200 {
230 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
231 let x = (rng >> 33) as f64 / (u32::MAX as f64);
232 let t0 = 2.0 * x;
233 let t1 = -3.0 * x;
234
235 multi.train_one(&[x], &[t0, t1]);
236 single.train_one(&Sample::new(vec![x], t0));
237 }
238
239 let pred_multi = multi.predict(&[0.5]);
241 let pred_single = single.predict(&[0.5]);
242 assert!(
243 (pred_multi[0] - pred_single).abs() < 1e-10,
244 "target 0 should match independent model: multi={}, single={}",
245 pred_multi[0],
246 pred_single
247 );
248 }
249
250 #[test]
251 fn reset_clears_state() {
252 let mut model = MultiTargetSGBT::new(test_config(), 2).unwrap();
253 for i in 0..50 {
254 let x = i as f64 * 0.1;
255 model.train_one(&[x], &[x, x * 2.0]);
256 }
257 model.reset();
258 assert_eq!(model.n_samples_seen(), 0);
259 let preds = model.predict(&[1.0]);
260 for &p in &preds {
261 assert!(p.abs() < 1e-12, "after reset, prediction should be ~0.0");
262 }
263 }
264
265 #[test]
266 fn model_accessor_works() {
267 let model = MultiTargetSGBT::new(test_config(), 3).unwrap();
268 assert_eq!(model.model(0).n_steps(), 5);
269 assert_eq!(model.model(2).n_steps(), 5);
270 assert_eq!(model.models().len(), 3);
271 }
272
273 #[test]
274 #[should_panic(expected = "expected 2 targets")]
275 fn wrong_target_count_panics() {
276 let mut model = MultiTargetSGBT::new(test_config(), 2).unwrap();
277 model.train_one(&[1.0], &[1.0, 2.0, 3.0]);
278 }
279
280 #[test]
281 fn convergence_on_linear_signal() {
282 let config = SGBTConfig::builder()
283 .n_steps(10)
284 .learning_rate(0.1)
285 .grace_period(10)
286 .max_depth(3)
287 .n_bins(16)
288 .build()
289 .unwrap();
290 let mut model = MultiTargetSGBT::new(config, 2).unwrap();
291
292 let mut rng: u64 = 99;
293 let mut early_mse = [0.0f64; 2];
294 let mut late_mse = [0.0f64; 2];
295 let mut early_n = 0;
296 let mut late_n = 0;
297
298 for i in 0..500 {
299 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
300 let x0 = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0 - 5.0;
301 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
302 let x1 = (rng >> 33) as f64 / (u32::MAX as f64) * 10.0 - 5.0;
303
304 let t0 = 2.0 * x0 + 3.0 * x1;
305 let t1 = -x0 + 0.5 * x1;
306
307 let preds = model.predict(&[x0, x1]);
308
309 if (50..150).contains(&i) {
310 early_mse[0] += (preds[0] - t0).powi(2);
311 early_mse[1] += (preds[1] - t1).powi(2);
312 early_n += 1;
313 }
314 if i >= 400 {
315 late_mse[0] += (preds[0] - t0).powi(2);
316 late_mse[1] += (preds[1] - t1).powi(2);
317 late_n += 1;
318 }
319
320 model.train_one(&[x0, x1], &[t0, t1]);
321 }
322
323 let early_rmse_0 = (early_mse[0] / early_n as f64).sqrt();
324 let late_rmse_0 = (late_mse[0] / late_n as f64).sqrt();
325 assert!(
326 late_rmse_0 < early_rmse_0,
327 "target 0 RMSE should improve: early={:.4}, late={:.4}",
328 early_rmse_0,
329 late_rmse_0
330 );
331
332 let early_rmse_1 = (early_mse[1] / early_n as f64).sqrt();
333 let late_rmse_1 = (late_mse[1] / late_n as f64).sqrt();
334 assert!(
335 late_rmse_1 < early_rmse_1,
336 "target 1 RMSE should improve: early={:.4}, late={:.4}",
337 early_rmse_1,
338 late_rmse_1
339 );
340 }
341}