1use crate::core::{GBRTConfig, LossFunction};
23use crate::tree::Tree;
24use crate::data::{Dataset, FeatureMatrix};
25use crate::core::{GradientLoss, create_loss};
26use crate::tree::DecisionTree;
27use serde::{Deserialize, Serialize};
28use thiserror::Error;
29use rand::seq::SliceRandom;
30
31#[derive(Error, Debug)]
36pub enum BoostingError {
37 #[error("Invalid input data: {0}")]
38 InvalidInput(String),
39
40 #[error("Training error: {0}")]
41 TrainingError(String),
42
43 #[error("Prediction error: {0}")]
44 PredictionError(String),
45
46 #[error("Configuration error: {0}")]
47 ConfigError(String),
48
49 #[error("Tree building error: {0}")]
50 TreeError(String),
51
52 #[error("Loss function error: {0}")]
53 LossError(String),
54
55 #[error("Serialization error: {0}")]
56 SerializationError(String),
57}
58
59pub type BoostingResult<T> = std::result::Result<T, BoostingError>;
61
62
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct IterationState {
70 pub iteration: usize,
72 pub train_loss: f64,
74 pub validation_loss: Option<f64>,
76 pub n_trees: usize,
78 pub n_leaves: usize,
80}
81
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct TrainingState {
89 pub iterations: Vec<IterationState>,
91 pub best_iteration: Option<usize>,
93 pub best_validation_loss: Option<f64>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct GradientBooster {
117 config: GBRTConfig,
119 trees: Vec<Tree>,
121 initial_prediction: f64,
123 training_state: Option<TrainingState>,
125 feature_importance: Vec<f64>,
127 is_trained: bool,
129}
130
131impl GradientBooster {
132 pub fn new(config: GBRTConfig) -> BoostingResult<Self> {
143 config.validate()
144 .map_err(|e| BoostingError::ConfigError(e))?;
145
146 Ok(Self {
147 config,
148 trees: Vec::new(),
149 initial_prediction: 0.0,
150 training_state: None,
151 feature_importance: Vec::new(),
152 is_trained: false,
153 })
154 }
155
156 pub fn fit(
182 &mut self,
183 train_data: &Dataset,
184 validation_data: Option<&Dataset>,
185 ) -> BoostingResult<()> {
186 self.validate_training_data(train_data, validation_data)?;
187
188 self.trees.clear();
190 self.initial_prediction = 0.0;
191 self.is_trained = false;
192 self.training_state = Some(TrainingState {
193 iterations: Vec::new(),
194 best_iteration: None,
195 best_validation_loss: None,
196 });
197
198 self.initial_prediction = self.compute_initial_prediction(train_data.targets().as_slice().unwrap())?;
200 let mut predictions = vec![self.initial_prediction; train_data.n_samples()];
201
202 let loss_fn = create_loss(&self.config.loss);
204
205 let (val_features, val_targets, mut val_predictions) =
207 self.prepare_validation_data(validation_data, train_data.n_samples())?;
208
209 let mut best_val_loss = f64::INFINITY;
211 let mut no_improvement_count = 0;
212 let mut best_iteration = 0;
213
214 for iteration in 0..self.config.n_estimators {
216 let (gradients, hessians) = loss_fn.gradient_hessian(
218 train_data.targets().as_slice().unwrap(),
219 &predictions
220 );
221
222 let gradients_slice = gradients.as_slice().unwrap();
224 let hessians_slice = hessians.as_slice().unwrap();
225
226 let (sampled_features, sampled_gradients, sampled_hessians, sample_indices) =
228 self.sample_data(train_data.features(), gradients_slice, hessians_slice)?;
229
230 let tree = self.fit_tree(&sampled_features, &sampled_gradients, &sampled_hessians)?;
232
233 self.update_predictions(train_data.features(), &tree, &mut predictions, sample_indices.as_ref())?;
235
236 let should_stop = self.update_training_state(
238 iteration,
239 train_data,
240 validation_data,
241 &predictions,
242 &mut val_predictions,
243 &tree,
244 &mut best_val_loss,
245 &mut no_improvement_count,
246 &mut best_iteration,
247 &loss_fn,
248 )?;
249
250 self.trees.push(tree);
251
252 if should_stop {
253 println!("Early stopping at iteration {}. Best iteration: {}", iteration, best_iteration);
254 break;
255 }
256 }
257
258 self.compute_feature_importance(train_data.n_features());
260 self.is_trained = true;
261
262 Ok(())
263 }
264
265 pub fn predict(&self, features: &FeatureMatrix) -> BoostingResult<Vec<f64>> {
281 if !self.is_trained {
282 return Err(BoostingError::PredictionError("Model not trained".to_string()));
283 }
284
285 if features.n_features() != self.feature_importance.len() {
286 return Err(BoostingError::PredictionError(
287 format!("Expected {} features, got {}", self.feature_importance.len(), features.n_features())
288 ));
289 }
290
291 let mut predictions = vec![self.initial_prediction; features.n_samples()];
292
293 for tree in &self.trees {
294 for (i, pred) in predictions.iter_mut().enumerate() {
295 let sample = features.get_sample(i)
296 .map_err(|e| BoostingError::PredictionError(e.to_string()))?;
297 *pred += self.config.learning_rate * tree.predict(&sample.to_vec());
298 }
299 }
300
301 let transformed_predictions = self.apply_prediction_transform(&predictions);
303
304 Ok(transformed_predictions)
305 }
306
307 pub fn predict_single(&self, features: &[f64]) -> BoostingResult<f64> {
321 if !self.is_trained {
322 return Err(BoostingError::PredictionError("Model not trained".to_string()));
323 }
324
325 if features.len() != self.feature_importance.len() {
326 return Err(BoostingError::PredictionError(
327 format!("Expected {} features, got {}", self.feature_importance.len(), features.len())
328 ));
329 }
330
331 let mut prediction = self.initial_prediction;
332
333 for tree in &self.trees {
334 prediction += self.config.learning_rate * tree.predict(features);
335 }
336
337 let transformed_prediction = self.apply_single_prediction_transform(prediction);
338
339 Ok(transformed_prediction)
340 }
341
342 pub fn feature_importance(&self) -> &[f64] {
354 &self.feature_importance
355 }
356
357 pub fn training_state(&self) -> Option<&TrainingState> {
366 self.training_state.as_ref()
367 }
368
369 pub fn n_trees(&self) -> usize {
376 self.trees.len()
377 }
378
379 pub fn config(&self) -> &GBRTConfig {
381 &self.config
382 }
383
384 pub fn is_trained(&self) -> bool {
386 self.is_trained
387 }
388
389 fn validate_training_data(
393 &self,
394 train_data: &Dataset,
395 validation_data: Option<&Dataset>,
396 ) -> BoostingResult<()> {
397 if train_data.n_samples() == 0 {
398 return Err(BoostingError::InvalidInput("Training dataset is empty".to_string()));
399 }
400
401 if let Some(val_data) = validation_data {
402 if val_data.n_features() != train_data.n_features() {
403 return Err(BoostingError::InvalidInput(
404 "Validation data has different number of features".to_string()
405 ));
406 }
407 }
408
409 Ok(())
410 }
411
412 fn compute_initial_prediction(&self, targets: &[f64]) -> BoostingResult<f64> {
416 match self.config.loss {
417 LossFunction::MSE | LossFunction::Huber(_) => {
418 Ok(targets.iter().sum::<f64>() / targets.len() as f64)
419 }
420 LossFunction::MAE => {
421 let mut sorted = targets.to_vec();
422 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
423 Ok(sorted[sorted.len() / 2]) }
425 LossFunction::LogLoss => {
426 let mean = targets.iter().sum::<f64>() / targets.len() as f64;
428 let mean_clamped = mean.max(1e-15).min(1.0 - 1e-15);
429 Ok((mean_clamped / (1.0 - mean_clamped)).ln())
430 }
431 }
432 }
433
434 fn prepare_validation_data(
436 &self,
437 validation_data: Option<&Dataset>,
438 n_train_samples: usize,
439 ) -> BoostingResult<(Option<FeatureMatrix>, Option<Vec<f64>>, Vec<f64>)> {
440 match validation_data {
441 Some(val_data) => {
442 let val_predictions = vec![self.initial_prediction; val_data.n_samples()];
443 Ok((
444 Some(val_data.features().clone()),
445 Some(val_data.targets().as_slice().unwrap().to_vec()),
446 val_predictions,
447 ))
448 }
449 None => {
450 Ok((None, None, Vec::new()))
451 }
452 }
453 }
454
455 fn sample_data(
459 &self,
460 features: &FeatureMatrix,
461 gradients: &[f64],
462 hessians: &[f64],
463 ) -> BoostingResult<(FeatureMatrix, Vec<f64>, Vec<f64>, Option<Vec<usize>>)> {
464 if self.config.subsample >= 1.0 {
465 return Ok((
467 features.clone(),
468 gradients.to_vec(),
469 hessians.to_vec(),
470 None,
471 ));
472 }
473
474 let n_samples = (features.n_samples() as f64 * self.config.subsample) as usize;
476 let n_samples = n_samples.max(1).min(features.n_samples());
477
478 let mut rng = rand::thread_rng();
479 let mut indices: Vec<usize> = (0..features.n_samples()).collect();
480 indices.shuffle(&mut rng);
481 indices.truncate(n_samples);
482
483 let sampled_features = features.select_samples(&indices)
484 .map_err(|e| BoostingError::TrainingError(e.to_string()))?;
485 let sampled_gradients: Vec<f64> = indices.iter().map(|&i| gradients[i]).collect();
486 let sampled_hessians: Vec<f64> = indices.iter().map(|&i| hessians[i]).collect();
487
488 Ok((sampled_features, sampled_gradients, sampled_hessians, Some(indices)))
489 }
490
491 fn fit_tree(
493 &self,
494 features: &FeatureMatrix,
495 gradients: &[f64],
496 hessians: &[f64],
497 ) -> BoostingResult<Tree> {
498 let tree_config = &self.config.tree_config;
499
500 let tree_builder = DecisionTree::new(
501 tree_config.max_depth,
502 tree_config.min_samples_split,
503 tree_config.min_samples_leaf,
504 tree_config.min_impurity_decrease,
505 tree_config.max_features,
506 tree_config.lambda,
507 );
508
509 tree_builder.fit(features, gradients, hessians)
510 .map_err(|e| BoostingError::TreeError(format!("Tree fitting failed: {}", e)))
511 }
512
513 fn update_predictions(
515 &self,
516 features: &FeatureMatrix,
517 tree: &Tree,
518 predictions: &mut Vec<f64>,
519 sample_indices: Option<&Vec<usize>>,
520 ) -> BoostingResult<()> {
521 match sample_indices {
522 Some(indices) => {
523 for &idx in indices {
525 let sample = features.get_sample(idx)
526 .map_err(|e| BoostingError::TrainingError(e.to_string()))?;
527 predictions[idx] += self.config.learning_rate * tree.predict(&sample.to_vec());
528 }
529 }
530 None => {
531 for (i, pred) in predictions.iter_mut().enumerate() {
533 let sample = features.get_sample(i)
534 .map_err(|e| BoostingError::TrainingError(e.to_string()))?;
535 *pred += self.config.learning_rate * tree.predict(&sample.to_vec());
536 }
537 }
538 }
539
540 Ok(())
541 }
542
543
544 fn update_training_state(
548 &mut self,
549 iteration: usize,
550 train_data: &Dataset,
551 validation_data: Option<&Dataset>,
552 predictions: &[f64],
553 val_predictions: &mut Vec<f64>,
554 tree: &Tree,
555 best_val_loss: &mut f64,
556 no_improvement_count: &mut usize,
557 best_iteration: &mut usize,
558 loss_fn: &Box<dyn GradientLoss>,
559 ) -> BoostingResult<bool> {
560 let mut iteration_state = IterationState {
561 iteration,
562 train_loss: 0.0,
563 validation_loss: None,
564 n_trees: self.trees.len() + 1,
565 n_leaves: tree.n_leaves(),
566 };
567
568 iteration_state.train_loss = loss_fn.loss(
570 train_data.targets().as_slice().unwrap(),
571 predictions
572 );
573
574 let should_stop = if let (Some(val_features), Some(val_targets)) = (
575 validation_data.map(|d| d.features()),
576 validation_data.map(|d| d.targets())
577 ) {
578 for (i, pred) in val_predictions.iter_mut().enumerate() {
580 let sample = val_features.get_sample(i)
581 .map_err(|e| BoostingError::TrainingError(e.to_string()))?;
582 *pred += self.config.learning_rate * tree.predict(&sample.to_vec());
583 }
584
585 let current_val_loss = loss_fn.loss(
586 val_targets.as_slice().unwrap(),
587 val_predictions
588 );
589
590 if !current_val_loss.is_finite() {
592 eprintln!("Warning: Validation loss became NaN/Inf at iteration {}", iteration);
593 return Ok(true); }
595
596 iteration_state.validation_loss = Some(current_val_loss);
597
598 let improvement_threshold = *best_val_loss * (1.0 + self.config.early_stopping_tolerance);
600
601 if current_val_loss < improvement_threshold {
602 *best_val_loss = current_val_loss;
604 *no_improvement_count = 0;
605 *best_iteration = iteration;
606 } else {
607 *no_improvement_count += 1;
609 }
610
611 if let Some(patience) = self.config.early_stopping_rounds {
613 if *no_improvement_count >= patience {
614 println!("Early stopping triggered at iteration {}. Best validation loss: {:.6} at iteration {}",
615 iteration, *best_val_loss, *best_iteration);
616 if let Some(state) = &mut self.training_state {
617 state.best_iteration = Some(*best_iteration);
618 state.best_validation_loss = Some(*best_val_loss);
619 }
620 return Ok(true); }
622 }
623
624 false
625 } else {
626 false
627 };
628
629 if let Some(state) = &mut self.training_state {
631 state.iterations.push(iteration_state);
632 }
633
634 Ok(should_stop)
635 }
636
637 fn compute_feature_importance(&mut self, n_features: usize) {
639 if !self.config.compute_feature_importance {
640 self.feature_importance = vec![0.0; n_features];
641 return;
642 }
643
644 let mut importance = vec![0.0; n_features];
645 let total_gain: f64 = self.trees.iter()
646 .map(|tree| {
647 tree.feature_importance()
648 .iter()
649 .map(|&(feature, gain)| {
650 importance[feature] += gain;
651 gain
652 })
653 .sum::<f64>()
654 })
655 .sum();
656
657 if total_gain > 0.0 {
659 for imp in &mut importance {
660 *imp /= total_gain;
661 }
662 }
663
664 self.feature_importance = importance;
665 }
666
667 fn apply_prediction_transform(&self, predictions: &[f64]) -> Vec<f64> {
669 if matches!(self.config.loss, LossFunction::LogLoss) {
670 let loss_fn = create_loss(&self.config.loss);
671 loss_fn.transform(predictions)
672 } else {
673 predictions.to_vec()
674 }
675 }
676
677 fn apply_single_prediction_transform(&self, prediction: f64) -> f64 {
679 if matches!(self.config.loss, LossFunction::LogLoss) {
680 let loss_fn = create_loss(&self.config.loss);
681 loss_fn.transform(&[prediction])[0]
682 } else {
683 prediction
684 }
685 }
686}
687
688impl std::fmt::Display for GradientBooster {
689 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
690 writeln!(f, "GradientBooster")?;
691 writeln!(f, " Trained: {}", self.is_trained)?;
692 writeln!(f, " Trees: {}", self.trees.len())?;
693 writeln!(f, " Loss: {}", self.config.loss)?;
694 writeln!(f, " Learning Rate: {:.4}", self.config.learning_rate)?;
695 writeln!(f, " Subsampling: {:.2}", self.config.subsample)?;
696
697 if let Some(state) = &self.training_state {
698 if let Some(iter_state) = state.iterations.last() {
699 writeln!(f, " Final Training Loss: {:.6}", iter_state.train_loss)?;
700 if let Some(val_loss) = iter_state.validation_loss {
701 writeln!(f, " Final Validation Loss: {:.6}", val_loss)?;
702 }
703 }
704 }
705
706 Ok(())
707 }
708}
709