gbrt_rs/lib.rs
1#![allow(unused)]
2
3//! Gradient Boosted Regression Trees (GBRT) in Rust
4//!
5//! A high-performance, feature-rich implementation of gradient boosting
6//! for machine learning tasks including regression and binary classification.
7//!
8//! This library provides a complete gradient boosting framework with:
9//! - **Multiple objective functions** (MSE, MAE, Huber, Log Loss)
10//! - **Categorical feature encoding** with auto-detection
11//! - **Early stopping** with validation support
12//! - **Cross-validation** with proper categorical handling
13//! - **Feature importance** analysis
14//! - **Model serialization** with metadata
15//! - **Comprehensive error handling** with unified error types
16//!
17//! # Architecture
18//!
19//! The library is organized into modular components:
20//!
21//! - [`core`]: Core configurations and loss functions
22//! - [`boosting`]: Gradient boosting algorithms and training loop
23//! - [`tree`]: Decision tree construction with pluggable split criteria
24//! - [`data`]: Dataset management and feature preprocessing
25//! - [`objective`]: Objective functions for optimization
26//! - [`utils`]: Mathematical utilities, validation, and serialization
27//! - [`io`]: Data loading and model persistence
28//!
29//! # Examples
30//!
31//! Basic regression example:
32//!
33//! ```
34//! use gbrt_rs::{
35//! GradientBooster, BoosterFactory, Dataset, FeatureMatrix,
36//! DataLoader, ObjectiveType
37//! };
38//! use ndarray::Array2;
39//!
40//! // Load data from CSV
41//! let data_loader = DataLoader::new()?;
42//! let dataset = data_loader.load_csv("data/training.csv")?;
43//!
44//! // Train regression model
45//! let mut booster = BoosterFactory::create_regression_booster()?;
46//! booster.fit(&dataset, None)?;
47//!
48//! // Make predictions
49//! let predictions = booster.predict(dataset.features())?;
50//! # Ok::<(), Box<dyn std::error::Error>>(())
51//! ```
52//!
53//! Training with validation and early stopping:
54//!
55//! ```
56//! use gbrt_rs::{
57//! GBRTModel, Dataset, DataLoader, GBRTConfig,
58//! LossFunction, ObjectiveType
59//! };
60//!
61//! // Configure model
62//! let config = GBRTConfig {
63//! n_estimators: 100,
64//! learning_rate: 0.1,
65//! max_depth: 6,
66//! early_stopping_rounds: Some(10),
67//! ..Default::default()
68//! };
69//!
70//! let mut model = GBRTModel::with_config(config)?;
71//!
72//! // Split data for validation
73//! let data_loader = DataLoader::new()?;
74//! let dataset = data_loader.load_csv("data.csv")?;
75//! let (train_data, val_data) = data_loader.split_data(&dataset, 0.2, true, None)?;
76//!
77//! // Train with early stopping
78//! model.fit_with_validation(&train_data, &val_data)?;
79//!
80//! // Evaluate
81//! let predictions = model.predict(val_data.features())?;
82//! # Ok::<(), Box<dyn std::error::Error>>(())
83//! ```
84
85// Core modules
86pub mod core;
87/// Gradient boosting algorithms and ensemble training
88pub mod boosting;
89/// Decision tree construction and split criteria
90pub mod tree;
91/// Dataset handling, feature matrices, and preprocessing
92pub mod data;
93/// Loss functions and objective functions for training
94pub mod objective;
95/// Mathematical utilities and validation helpers
96pub mod utils;
97/// Data loading and model serialization
98pub mod io;
99
100/// Data preprocessing utilities
101pub mod preprocessing {
102 pub use crate::data::preprocessing::*;
103}
104
105use std::collections::HashMap;
106
107// Re-export main types for easy access
108pub use boosting::{
109 GradientBooster, BoosterFactory, create_booster,
110 BoostingError, BoostingResult, BoostingType,
111 TrainingState, IterationState
112};
113pub use data::{
114 Dataset, FeatureMatrix, DataError, DataResult,
115 Scaler, StandardScaler, MinMaxScaler
116};
117pub use objective::{
118 Objective, RegressionObjective, BinaryClassificationObjective,
119 ObjectiveFactory, create_objective, ObjectiveType,
120 ObjectiveError, ObjectiveResult, ObjectiveConfig,
121 MSEObjective, MAEObjective, HuberObjective, LogLossObjective
122};
123pub use tree::{
124 DecisionTree, TreeBuilder, TreeError, TreeResult,
125 Splitter, BestSplitter, SplitCriterion, MSECriterion
126};
127pub use utils::{
128 MathError, MathResult, Statistics, VectorMath,
129 ValidationError, ValidationResult, DataValidator,
130 SerializationError, SerializationResult, ModelSerializer
131};
132pub use io::{
133 DataLoader, DataLoaderError, DataLoaderResult, DataFormat,
134 ModelIO, ModelIOError, ModelIOResult, ModelFormat,
135 LoadOptions as DataLoadOptions, SaveOptions as ModelSaveOptions
136};
137
138/// Unified error type for the entire crate.
139///
140/// This enum aggregates all possible errors that can occur during
141/// model training, prediction, or data processing.
142///
143/// # Variants
144///
145/// * `BoostingError` - Errors during gradient boosting
146/// * `DataError` - Dataset loading or preprocessing errors
147/// * `TreeError` - Decision tree construction errors
148/// * `ObjectiveError` - Loss function computation errors
149/// * `MathError` - Mathematical operation errors
150/// * `ValidationError` - Data validation failures
151/// * `SerializationError` - Model serialization/deserialization errors
152/// * `DataLoaderError` - Data loading errors
153/// * `ModelIOError` - Model I/O errors
154/// * `IoError` - General I/O errors
155/// * `CsvError` - CSV parsing errors
156/// * `JsonError` - JSON serialization errors
157/// * `ConfigError` - Invalid configuration
158/// * `TrainingError` - Training process errors
159/// * `PredictionError` - Prediction errors
160/// * `NotTrained` - Model not yet trained
161#[derive(thiserror::Error, Debug)]
162pub enum GBRTError {
163 #[error("Boosting error: {0}")]
164 BoostingError(#[from] BoostingError),
165
166 #[error("Data error: {0}")]
167 DataError(#[from] DataError),
168
169 #[error("Tree error: {0}")]
170 TreeError(#[from] TreeError),
171
172 #[error("Objective error: {0}")]
173 ObjectiveError(#[from] ObjectiveError),
174
175 #[error("Math error: {0}")]
176 MathError(#[from] MathError),
177
178 #[error("Validation error: {0}")]
179 ValidationError(#[from] ValidationError),
180
181 #[error("Serialization error: {0}")]
182 SerializationError(#[from] SerializationError),
183
184 #[error("Data loader error: {0}")]
185 DataLoaderError(#[from] DataLoaderError),
186
187 #[error("Model IO error: {0}")]
188 ModelIOError(#[from] ModelIOError),
189
190 #[error("IO error: {0}")]
191 IoError(#[from] std::io::Error),
192
193 #[error("CSV error: {0}")]
194 CsvError(#[from] csv::Error),
195
196 #[error("JSON error: {0}")]
197 JsonError(#[from] serde_json::Error),
198
199 #[error("Configuration error: {0}")]
200 ConfigError(String),
201
202 #[error("Training error: {0}")]
203 TrainingError(String),
204
205 #[error("Prediction error: {0}")]
206 PredictionError(String),
207
208 #[error("Model not trained")]
209 NotTrained,
210}
211
212/// Result type alias for the crate using `GBRTError`
213pub type Result<T> = std::result::Result<T, GBRTError>;
214
215/// Main GBRT model struct that provides a high-level API for training and prediction.
216///
217/// This struct encapsulates the gradient boosting algorithm and provides
218/// a simple interface for regression and classification tasks.
219///
220/// # Examples
221///
222/// ```
223/// # use gbrt_rs::{GBRTModel, Dataset, FeatureMatrix};
224/// # use ndarray::Array2;
225/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
226/// let features = FeatureMatrix::new(Array2::ones((10, 3)))?;
227/// let targets = vec![1.0; 10];
228/// let dataset = Dataset::new(features, targets)?;
229///
230/// let mut model = GBRTModel::new()?;
231/// model.fit(&dataset)?;
232/// # Ok(())
233/// # }
234/// ```
235pub struct GBRTModel {
236 // The underlying gradient booster implementation
237 booster: GradientBooster,
238 // Optional feature names for better interpretability
239 feature_names: Option<Vec<String>>,
240 // Optional target name for better interpretability
241 target_name: Option<String>,
242 // Minimum samples required to split a node (default: 2)
243 pub min_samples_split: usize,
244 // Minimum samples required in a leaf node (default: 1)
245 pub min_samples_leaf: usize,
246 // L2 regularization parameter (default: 0.0)
247 pub lambda_l2: f64,
248}
249
250impl GBRTModel {
251 /// Creates a new GBRT model with default configuration for regression.
252 ///
253 /// # Returns
254 ///
255 /// A `Result` containing the configured model or an error if creation fails.
256 ///
257 /// # Errors
258 ///
259 /// Returns an error if the underlying booster cannot be created.
260 pub fn new() -> Result<Self> {
261 let booster = BoosterFactory::create_regression_booster()?;
262 Ok(Self {
263 booster,
264 feature_names: None,
265 target_name: None,
266 min_samples_split: 2, // ✅ Default: split nodes with ≥2 samples
267 min_samples_leaf: 1, // ✅ Default: allow leaves with 1 sample
268 lambda_l2: 0.0, // ✅ Default: no L2 regularization
269 })
270 }
271
272 /// Creates a new GBRT model for binary classification.
273 ///
274 /// # Returns
275 ///
276 /// A `Result` containing the configured model or an error if creation fails.
277 ///
278 /// # Errors
279 ///
280 /// Returns an error if the underlying booster cannot be created.
281 pub fn new_classifier() -> Result<Self> {
282 let booster = BoosterFactory::create_classification_booster()?;
283 Ok(Self {
284 booster,
285 feature_names: None,
286 target_name: None,
287 min_samples_split: 2,
288 min_samples_leaf: 1,
289 lambda_l2: 0.0,
290 })
291 }
292
293 /// Creates a new GBRT model with custom configuration.
294 ///
295 /// # Parameters
296 ///
297 /// * `config` - Custom gradient boosting configuration
298 ///
299 /// # Returns
300 ///
301 /// A `Result` containing the configured model or an error if creation fails.
302 ///
303 /// # Errors
304 ///
305 /// Returns an error if the booster cannot be created with the given config.
306 pub fn with_config(config: boosting::GBRTConfig) -> Result<Self> {
307 let booster = GradientBooster::new(config.clone())?;
308 Ok(Self {
309 booster,
310 feature_names: None,
311 target_name: None,
312 min_samples_split: config.tree_config.min_samples_split,
313 min_samples_leaf: config.tree_config.min_samples_leaf,
314 lambda_l2: config.tree_config.lambda,
315 })
316 }
317
318 /// Trains the model on a dataset.
319 ///
320 /// # Parameters
321 ///
322 /// * `dataset` - The training dataset containing features and targets
323 ///
324 /// # Returns
325 ///
326 /// A `Result` indicating success or failure.
327 ///
328 /// # Errors
329 ///
330 /// Returns an error if training fails or the dataset is invalid.
331 pub fn fit(&mut self, dataset: &Dataset) -> Result<()> {
332 self.booster.fit(dataset, None)?;
333 Ok(())
334 }
335
336 /// Trains the model with validation data for early stopping.
337 ///
338 /// # Parameters
339 ///
340 /// * `train_data` - The training dataset
341 /// * `val_data` - The validation dataset for monitoring overfitting
342 ///
343 /// # Returns
344 ///
345 /// A `Result` indicating success or failure.
346 ///
347 /// # Errors
348 ///
349 /// Returns an error if training fails or either dataset is invalid.
350 pub fn fit_with_validation(&mut self, train_data: &Dataset, val_data: &Dataset) -> Result<()> {
351 self.booster.fit(train_data, Some(val_data))?;
352 Ok(())
353 }
354
355 /// Makes predictions for a feature matrix.
356 ///
357 /// # Parameters
358 ///
359 /// * `features` - The feature matrix to predict on
360 ///
361 /// # Returns
362 ///
363 /// A `Result` containing a vector of predictions.
364 ///
365 /// # Errors
366 ///
367 /// Returns an error if the model is not trained or prediction fails.
368 pub fn predict(&self, features: &FeatureMatrix) -> Result<Vec<f64>> {
369 self.booster.predict(features).map_err(GBRTError::from)
370 }
371
372 /// Makes prediction for a single sample.
373 ///
374 /// # Parameters
375 ///
376 /// * `features` - A slice of feature values for one sample
377 ///
378 /// # Returns
379 ///
380 /// A `Result` containing the predicted value.
381 ///
382 /// # Errors
383 ///
384 /// Returns an error if the model is not trained or the feature length is invalid.
385 pub fn predict_single(&self, features: &[f64]) -> Result<f64> {
386 self.booster.predict_single(features).map_err(GBRTError::from)
387 }
388
389 /// Gets feature importance scores.
390 ///
391 /// # Returns
392 ///
393 /// A vector of importance scores, one per feature.
394 pub fn feature_importance(&self) -> Vec<f64> {
395 self.booster.feature_importance().to_vec()
396 }
397
398 /// Gets training history and state.
399 ///
400 /// # Returns
401 ///
402 /// An optional reference to the training state.
403 pub fn training_history(&self) -> Option<&TrainingState> {
404 self.booster.training_state()
405 }
406
407 /// Checks if the model has been trained.
408 ///
409 /// # Returns
410 ///
411 /// `true` if the model has been trained, `false` otherwise.
412 pub fn is_trained(&self) -> bool {
413 self.booster.is_trained()
414 }
415
416 /// Gets the number of trees in the ensemble.
417 ///
418 /// # Returns
419 ///
420 /// The number of trees, or 0 if untrained.
421 pub fn n_trees(&self) -> usize {
422 self.booster.n_trees()
423 }
424
425 /// Sets feature names for better interpretability.
426 ///
427 /// # Parameters
428 ///
429 /// * `names` - Vector of feature names in order
430 pub fn set_feature_names(&mut self, names: Vec<String>) {
431 self.feature_names = Some(names);
432 }
433
434 /// Gets the feature names if they have been set.
435 ///
436 /// # Returns
437 ///
438 /// An optional reference to the feature names vector.
439 pub fn feature_names(&self) -> Option<&Vec<String>> {
440 self.feature_names.as_ref()
441 }
442
443 /// Sets the target name for better interpretability.
444 ///
445 /// # Parameters
446 ///
447 /// * `name` - The target variable name
448 pub fn set_target_name(&mut self, name: String) {
449 self.target_name = Some(name);
450 }
451
452 /// Gets a reference to the underlying booster.
453 ///
454 /// # Returns
455 ///
456 /// Reference to the gradient booster.
457 pub fn booster(&self) -> &GradientBooster {
458 &self.booster
459 }
460}
461
462impl Default for GBRTModel {
463 /// Creates a default GBRT model for regression.
464 fn default() -> Self {
465 Self::new().unwrap()
466 }
467}
468
469/// Convenience functions for common tasks
470
471/// Trains a GBRT model on data from a CSV file.
472///
473/// # Parameters
474///
475/// * `train_path` - Path to the CSV training data file
476/// * `target_column` - Optional name of the target column (defaults to "target")
477/// * `model_type` - Type of model to train (regression or binary classification)
478///
479/// # Returns
480///
481/// A `Result` containing the trained model.
482///
483/// # Errors
484///
485/// Returns an error if data loading or training fails.
486pub fn train_from_csv(
487 train_path: &str,
488 target_column: Option<&str>,
489 model_type: ObjectiveType,
490) -> Result<GBRTModel> {
491 let data_loader = DataLoader::new()?;
492 let options = DataLoadOptions::default()
493 .with_target_column(target_column.unwrap_or("target"));
494
495 let dataset = data_loader.load_csv(train_path)?;
496
497 let mut model = match model_type {
498 ObjectiveType::Regression => GBRTModel::new()?,
499 ObjectiveType::BinaryClassification => GBRTModel::new_classifier()?,
500 };
501
502 model.fit(&dataset)?;
503 Ok(model)
504}
505
506/// Trains a GBRT model with train/validation split.
507///
508/// # Parameters
509///
510/// * `data_path` - Path to the CSV data file
511/// * `target_column` - Optional name of the target column (defaults to "target")
512/// * `test_size` - Fraction of data to use for validation (e.g., 0.2 for 20%)
513/// * `model_type` - Type of model to train
514///
515/// # Returns
516///
517/// A `Result` containing the trained model with validation.
518///
519/// # Errors
520///
521/// Returns an error if data loading, splitting, or training fails.
522pub fn train_with_validation_split(
523 data_path: &str,
524 target_column: Option<&str>,
525 test_size: f64,
526 model_type: ObjectiveType,
527) -> Result<GBRTModel> {
528 let data_loader = DataLoader::new()?;
529 let options = DataLoadOptions::default()
530 .with_target_column(target_column.unwrap_or("target"));
531
532 let dataset = data_loader.load_csv(data_path)?;
533 let (train_data, val_data) = data_loader.split_data(&dataset, test_size, true, None)?;
534
535 let mut model = match model_type {
536 ObjectiveType::Regression => GBRTModel::new()?,
537 ObjectiveType::BinaryClassification => GBRTModel::new_classifier()?,
538 };
539
540 model.fit_with_validation(&train_data, &val_data)?;
541 Ok(model)
542}
543
544/// Performs k-fold cross-validation for model evaluation.
545///
546/// # Parameters
547///
548/// * `data_path` - Path to the CSV data file
549/// * `target_column` - Optional name of the target column (defaults to "target")
550/// * `n_splits` - Number of cross-validation folds
551/// * `model_type` - Type of model to evaluate
552/// * `categorical_columns` - Optional comma-separated list of categorical column names
553/// * `categorical_threshold` - Threshold for auto-detecting categorical features
554///
555/// # Returns
556///
557/// A `Result` containing a vector of R² scores for each fold.
558///
559/// # Errors
560///
561/// Returns an error if data loading, model training, or evaluation fails.
562pub fn cross_validate(
563 data_path: &str,
564 target_column: Option<&str>,
565 n_splits: usize,
566 model_type: ObjectiveType,
567 categorical_columns: Option<&str>, // ✅ NEW: Explicit categorical columns
568 categorical_threshold: f64, // ✅ NEW: Auto-detect threshold
569) -> Result<Vec<f64>> {
570 let data_loader = DataLoader::new()?;
571
572 // ✅ Parse CSV headers to get column name to index mapping (same logic as train_model)
573 let mut temp_reader = csv::Reader::from_path(data_path)?;
574 let headers: Vec<String> = temp_reader
575 .headers()?
576 .iter()
577 .map(|s| s.to_string())
578 .collect();
579
580 let header_map: HashMap<_, _> = headers.iter()
581 .enumerate()
582 .map(|(i, name)| (name.as_str(), i))
583 .collect();
584
585 // ✅ Parse categorical column names to indices (same logic as train_model)
586 let categorical_indices = if let Some(col_names) = categorical_columns {
587 col_names.split(',')
588 .map(|s| s.trim())
589 .map(|name| {
590 header_map.get(name)
591 .copied()
592 .ok_or_else(|| {
593 GBRTError::ConfigError(format!(
594 "Unknown categorical column '{}'. Available columns: {:?}",
595 name, headers
596 ))
597 })
598 })
599 .collect::<Result<Vec<_>>>()?
600 } else {
601 vec![] // Auto-detect will be handled by load_csv_with_categorical
602 };
603
604 // ✅ Load data WITH proper categorical encoding (critical fix)
605 let dataset = data_loader.load_csv_with_categorical(
606 data_path,
607 target_column.unwrap_or("target"),
608 if categorical_columns.is_some() {
609 Some(&categorical_indices)
610 } else {
611 None
612 },
613 categorical_threshold,
614 )?;
615
616 // Create cross-validation splits
617 let splits = data_loader.cross_validation_splits(&dataset, n_splits, true, None)?;
618
619 let mut scores = Vec::new();
620
621 for (fold_idx, (train_data, test_data)) in splits.into_iter().enumerate() {
622 let mut model = match model_type {
623 ObjectiveType::Regression => GBRTModel::new()?,
624 ObjectiveType::BinaryClassification => GBRTModel::new_classifier()?,
625 };
626
627 model.fit(&train_data)?;
628 let predictions = model.predict(test_data.features())?;
629
630 // ✅ Calculate R² (coefficient of determination) instead of RMSE
631 let targets = test_data.targets().as_slice().unwrap();
632 let score = calculate_r2(&predictions, targets)?;
633
634 scores.push(score);
635 }
636
637 Ok(scores)
638}
639
640/// Computes the R² (coefficient of determination) metric.
641///
642/// # Parameters
643///
644/// * `predictions` - Slice of predicted values
645/// * `targets` - Slice of true target values
646///
647/// # Returns
648///
649/// A `Result` containing the R² score.
650///
651/// # Errors
652///
653/// Returns an error if the input slices have different lengths.
654///
655/// # Formula
656///
657/// R² = 1 - (Σ(y_true - y_pred)² / Σ(y_true - y_mean)²)
658fn calculate_r2(predictions: &[f64], targets: &[f64]) -> Result<f64> {
659 if predictions.len() != targets.len() {
660 return Err(GBRTError::ValidationError(
661 ValidationError::ValidationFailed("Predictions and targets length mismatch".to_string())
662 ));
663 }
664
665 let y_mean = targets.iter().sum::<f64>() / targets.len() as f64;
666 let total_ss: f64 = targets.iter()
667 .map(|&y| (y - y_mean).powi(2))
668 .sum();
669 let residual_ss: f64 = targets.iter()
670 .zip(predictions.iter())
671 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
672 .sum();
673
674 let r2 = if total_ss == 0.0 {
675 1.0 // Perfect prediction for constant target
676 } else {
677 1.0 - (residual_ss / total_ss)
678 };
679
680 Ok(r2)
681}
682
683/// Container for common model evaluation metrics.
684pub struct ModelMetrics {
685 /// Root Mean Squared Error
686 pub rmse: f64,
687 /// Mean Absolute Error
688 pub mae: f64,
689 /// R² (coefficient of determination)
690 pub r2: f64,
691}
692
693impl ModelMetrics {
694 /// Calculates regression metrics from predictions and true values.
695 ///
696 /// # Parameters
697 ///
698 /// * `y_true` - Slice of true target values
699 /// * `y_pred` - Slice of predicted values
700 ///
701 /// # Returns
702 ///
703 /// A `Result` containing the computed metrics.
704 ///
705 /// # Errors
706 ///
707 /// Returns an error if the input slices have different lengths.
708 pub fn regression_metrics(y_true: &[f64], y_pred: &[f64]) -> Result<Self> {
709 if y_true.len() != y_pred.len() {
710 return Err(GBRTError::ValidationError(
711 ValidationError::ValidationFailed("True and predicted values must have same length".to_string())
712 ));
713 }
714
715 let rmse = utils::rmse(y_pred, y_true)?;
716 let mae = utils::compute_mean(
717 &y_true.iter()
718 .zip(y_pred.iter())
719 .map(|(&true_val, &pred_val)| (true_val - pred_val).abs())
720 .collect::<Vec<f64>>()
721 )?;
722
723 // Calculate R²
724 let y_mean = utils::compute_mean(y_true)?;
725 let total_ss: f64 = y_true.iter()
726 .map(|&y| (y - y_mean).powi(2))
727 .sum();
728 let residual_ss: f64 = y_true.iter()
729 .zip(y_pred.iter())
730 .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
731 .sum();
732
733 let r2 = if total_ss == 0.0 {
734 1.0
735 } else {
736 1.0 - (residual_ss / total_ss)
737 };
738
739 Ok(Self { rmse, mae, r2 })
740 }
741}
742
743/// Returns the crate version string.
744///
745/// # Returns
746///
747/// The version from Cargo.toml
748pub fn version() -> &'static str {
749 env!("CARGO_PKG_VERSION")
750}
751
752/// Feature importance analysis with named features.
753pub struct FeatureAnalysis {
754 /// Sorted vector of (feature_name, importance_score) pairs
755 pub importance: Vec<(String, f64)>,
756 /// Summary statistics of feature importances
757 pub summary: FeatureSummary,
758}
759
760/// Summary statistics for feature importance analysis.
761pub struct FeatureSummary {
762 /// Total number of features
763 pub n_features: usize,
764 /// Top 10 most important features
765 pub top_features: Vec<(String, f64)>,
766 /// Sum of all importance scores
767 pub total_importance: f64,
768}
769
770impl FeatureAnalysis {
771 /// Creates a feature analysis from a trained model.
772 ///
773 /// # Parameters
774 ///
775 /// * `model` - Reference to a trained gradient booster
776 /// * `feature_names` - Optional vector of feature names
777 ///
778 /// # Returns
779 ///
780 /// A new `FeatureAnalysis` instance.
781 pub fn from_model(model: &GradientBooster, feature_names: Option<Vec<String>>) -> Self {
782 let importance_scores = model.feature_importance();
783 let names = feature_names.unwrap_or_else(|| {
784 (0..importance_scores.len())
785 .map(|i| format!("feature_{}", i))
786 .collect()
787 });
788
789 let mut importance: Vec<(String, f64)> = names
790 .into_iter()
791 .zip(importance_scores.iter().copied()) // Use .copied() to get f64 values from &[f64]
792 .collect();
793
794 // Sort by importance (descending)
795 importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
796
797 let total_importance: f64 = importance.iter().map(|(_, imp)| imp).sum();
798 let top_features = importance.iter().take(10).cloned().collect();
799
800 let summary = FeatureSummary {
801 n_features: importance.len(),
802 top_features,
803 total_importance,
804 };
805
806 Self { importance, summary }
807 }
808}
809