gbrt_rs/
lib.rs

1#![allow(unused)]
2
3//! Gradient Boosted Regression Trees (GBRT) in Rust
4//!
5//! A high-performance, feature-rich implementation of gradient boosting
6//! for machine learning tasks including regression and binary classification.
7//!
8//! This library provides a complete gradient boosting framework with:
9//! - **Multiple objective functions** (MSE, MAE, Huber, Log Loss)
10//! - **Categorical feature encoding** with auto-detection
11//! - **Early stopping** with validation support
12//! - **Cross-validation** with proper categorical handling
13//! - **Feature importance** analysis
14//! - **Model serialization** with metadata
15//! - **Comprehensive error handling** with unified error types
16//!
17//! # Architecture
18//!
19//! The library is organized into modular components:
20//!
21//! - [`core`]: Core configurations and loss functions
22//! - [`boosting`]: Gradient boosting algorithms and training loop
23//! - [`tree`]: Decision tree construction with pluggable split criteria
24//! - [`data`]: Dataset management and feature preprocessing
25//! - [`objective`]: Objective functions for optimization
26//! - [`utils`]: Mathematical utilities, validation, and serialization
27//! - [`io`]: Data loading and model persistence
28//!
29//! # Examples
30//!
31//! Basic regression example:
32//!
33//! ```
34//! use gbrt_rs::{
35//!     GradientBooster, BoosterFactory, Dataset, FeatureMatrix,
36//!     DataLoader, ObjectiveType
37//! };
38//! use ndarray::Array2;
39//!
40//! // Load data from CSV
41//! let data_loader = DataLoader::new()?;
42//! let dataset = data_loader.load_csv("data/training.csv")?;
43//!
44//! // Train regression model
45//! let mut booster = BoosterFactory::create_regression_booster()?;
46//! booster.fit(&dataset, None)?;
47//!
48//! // Make predictions
49//! let predictions = booster.predict(dataset.features())?;
50//! # Ok::<(), Box<dyn std::error::Error>>(())
51//! ```
52//!
53//! Training with validation and early stopping:
54//!
55//! ```
56//! use gbrt_rs::{
57//!     GBRTModel, Dataset, DataLoader, GBRTConfig,
58//!     LossFunction, ObjectiveType
59//! };
60//!
61//! // Configure model
62//! let config = GBRTConfig {
63//!     n_estimators: 100,
64//!     learning_rate: 0.1,
65//!     max_depth: 6,
66//!     early_stopping_rounds: Some(10),
67//!     ..Default::default()
68//! };
69//!
70//! let mut model = GBRTModel::with_config(config)?;
71//!
72//! // Split data for validation
73//! let data_loader = DataLoader::new()?;
74//! let dataset = data_loader.load_csv("data.csv")?;
75//! let (train_data, val_data) = data_loader.split_data(&dataset, 0.2, true, None)?;
76//!
77//! // Train with early stopping
78//! model.fit_with_validation(&train_data, &val_data)?;
79//!
80//! // Evaluate
81//! let predictions = model.predict(val_data.features())?;
82//! # Ok::<(), Box<dyn std::error::Error>>(())
83//! ```
84
85// Core modules
86pub mod core;
87/// Gradient boosting algorithms and ensemble training
88pub mod boosting;
89/// Decision tree construction and split criteria
90pub mod tree;
91/// Dataset handling, feature matrices, and preprocessing
92pub mod data;
93/// Loss functions and objective functions for training
94pub mod objective;
95/// Mathematical utilities and validation helpers
96pub mod utils;
97/// Data loading and model serialization
98pub mod io;
99
100/// Data preprocessing utilities
101pub mod preprocessing {
102    pub use crate::data::preprocessing::*;
103}
104
105use std::collections::HashMap;
106
107// Re-export main types for easy access
108pub use boosting::{
109    GradientBooster, BoosterFactory, create_booster,
110    BoostingError, BoostingResult, BoostingType,
111    TrainingState, IterationState
112};
113pub use data::{
114    Dataset, FeatureMatrix, DataError, DataResult,
115    Scaler, StandardScaler, MinMaxScaler
116};
117pub use objective::{
118    Objective, RegressionObjective, BinaryClassificationObjective,
119    ObjectiveFactory, create_objective, ObjectiveType,
120    ObjectiveError, ObjectiveResult, ObjectiveConfig,
121    MSEObjective, MAEObjective, HuberObjective, LogLossObjective
122};
123pub use tree::{
124    DecisionTree, TreeBuilder, TreeError, TreeResult,
125    Splitter, BestSplitter, SplitCriterion, MSECriterion
126};
127pub use utils::{
128    MathError, MathResult, Statistics, VectorMath,
129    ValidationError, ValidationResult, DataValidator,
130    SerializationError, SerializationResult, ModelSerializer
131};
132pub use io::{
133    DataLoader, DataLoaderError, DataLoaderResult, DataFormat,
134    ModelIO, ModelIOError, ModelIOResult, ModelFormat,
135    LoadOptions as DataLoadOptions, SaveOptions as ModelSaveOptions
136};
137
138/// Unified error type for the entire crate.
139///
140/// This enum aggregates all possible errors that can occur during
141/// model training, prediction, or data processing.
142///
143/// # Variants
144///
145/// * `BoostingError` - Errors during gradient boosting
146/// * `DataError` - Dataset loading or preprocessing errors
147/// * `TreeError` - Decision tree construction errors
148/// * `ObjectiveError` - Loss function computation errors
149/// * `MathError` - Mathematical operation errors
150/// * `ValidationError` - Data validation failures
151/// * `SerializationError` - Model serialization/deserialization errors
152/// * `DataLoaderError` - Data loading errors
153/// * `ModelIOError` - Model I/O errors
154/// * `IoError` - General I/O errors
155/// * `CsvError` - CSV parsing errors
156/// * `JsonError` - JSON serialization errors
157/// * `ConfigError` - Invalid configuration
158/// * `TrainingError` - Training process errors
159/// * `PredictionError` - Prediction errors
160/// * `NotTrained` - Model not yet trained
161#[derive(thiserror::Error, Debug)]
162pub enum GBRTError {
163    #[error("Boosting error: {0}")]
164    BoostingError(#[from] BoostingError),
165    
166    #[error("Data error: {0}")]
167    DataError(#[from] DataError),
168    
169    #[error("Tree error: {0}")]
170    TreeError(#[from] TreeError),
171    
172    #[error("Objective error: {0}")]
173    ObjectiveError(#[from] ObjectiveError),
174    
175    #[error("Math error: {0}")]
176    MathError(#[from] MathError),
177    
178    #[error("Validation error: {0}")]
179    ValidationError(#[from] ValidationError),
180    
181    #[error("Serialization error: {0}")]
182    SerializationError(#[from] SerializationError),
183    
184    #[error("Data loader error: {0}")]
185    DataLoaderError(#[from] DataLoaderError),
186    
187    #[error("Model IO error: {0}")]
188    ModelIOError(#[from] ModelIOError),
189    
190    #[error("IO error: {0}")]
191    IoError(#[from] std::io::Error),
192        
193    #[error("CSV error: {0}")]
194    CsvError(#[from] csv::Error),
195
196    #[error("JSON error: {0}")]
197    JsonError(#[from] serde_json::Error),
198
199    #[error("Configuration error: {0}")]
200    ConfigError(String),
201    
202    #[error("Training error: {0}")]
203    TrainingError(String),
204    
205    #[error("Prediction error: {0}")]
206    PredictionError(String),
207    
208    #[error("Model not trained")]
209    NotTrained,
210}
211
212/// Result type alias for the crate using `GBRTError`
213pub type Result<T> = std::result::Result<T, GBRTError>;
214
215/// Main GBRT model struct that provides a high-level API for training and prediction.
216///
217/// This struct encapsulates the gradient boosting algorithm and provides
218/// a simple interface for regression and classification tasks.
219///
220/// # Examples
221///
222/// ```
223/// # use gbrt_rs::{GBRTModel, Dataset, FeatureMatrix};
224/// # use ndarray::Array2;
225/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
226/// let features = FeatureMatrix::new(Array2::ones((10, 3)))?;
227/// let targets = vec![1.0; 10];
228/// let dataset = Dataset::new(features, targets)?;
229///
230/// let mut model = GBRTModel::new()?;
231/// model.fit(&dataset)?;
232/// # Ok(())
233/// # }
234/// ```
235pub struct GBRTModel {
236    // The underlying gradient booster implementation
237    booster: GradientBooster,
238    // Optional feature names for better interpretability
239    feature_names: Option<Vec<String>>,
240    // Optional target name for better interpretability
241    target_name: Option<String>,
242    // Minimum samples required to split a node (default: 2)
243    pub min_samples_split: usize,
244    // Minimum samples required in a leaf node (default: 1)
245    pub min_samples_leaf: usize,
246    // L2 regularization parameter (default: 0.0)
247    pub lambda_l2: f64,          
248}
249
250impl GBRTModel {
251    /// Creates a new GBRT model with default configuration for regression.
252    ///
253    /// # Returns
254    ///
255    /// A `Result` containing the configured model or an error if creation fails.
256    ///
257    /// # Errors
258    ///
259    /// Returns an error if the underlying booster cannot be created.
260    pub fn new() -> Result<Self> {
261        let booster = BoosterFactory::create_regression_booster()?;
262        Ok(Self {
263            booster,
264            feature_names: None,
265            target_name: None,
266            min_samples_split: 2,  // ✅ Default: split nodes with ≥2 samples
267            min_samples_leaf: 1,   // ✅ Default: allow leaves with 1 sample
268            lambda_l2: 0.0,        // ✅ Default: no L2 regularization
269        })
270    }
271    
272    /// Creates a new GBRT model for binary classification.
273    ///
274    /// # Returns
275    ///
276    /// A `Result` containing the configured model or an error if creation fails.
277    ///
278    /// # Errors
279    ///
280    /// Returns an error if the underlying booster cannot be created.
281    pub fn new_classifier() -> Result<Self> {
282        let booster = BoosterFactory::create_classification_booster()?;
283        Ok(Self {
284            booster,
285            feature_names: None,
286            target_name: None,
287            min_samples_split: 2,
288            min_samples_leaf: 1,
289            lambda_l2: 0.0,
290        })
291    }
292    
293    /// Creates a new GBRT model with custom configuration.
294    ///
295    /// # Parameters
296    ///
297    /// * `config` - Custom gradient boosting configuration
298    ///
299    /// # Returns
300    ///
301    /// A `Result` containing the configured model or an error if creation fails.
302    ///
303    /// # Errors
304    ///
305    /// Returns an error if the booster cannot be created with the given config.
306    pub fn with_config(config: boosting::GBRTConfig) -> Result<Self> {
307        let booster = GradientBooster::new(config.clone())?;
308        Ok(Self {
309            booster,
310            feature_names: None,
311            target_name: None,
312            min_samples_split: config.tree_config.min_samples_split,
313            min_samples_leaf: config.tree_config.min_samples_leaf,
314            lambda_l2: config.tree_config.lambda,
315        })
316    }
317    
318    /// Trains the model on a dataset.
319    ///
320    /// # Parameters
321    ///
322    /// * `dataset` - The training dataset containing features and targets
323    ///
324    /// # Returns
325    ///
326    /// A `Result` indicating success or failure.
327    ///
328    /// # Errors
329    ///
330    /// Returns an error if training fails or the dataset is invalid.
331    pub fn fit(&mut self, dataset: &Dataset) -> Result<()> {
332        self.booster.fit(dataset, None)?;
333        Ok(())
334    }
335
336    /// Trains the model with validation data for early stopping.
337    ///
338    /// # Parameters
339    ///
340    /// * `train_data` - The training dataset
341    /// * `val_data` - The validation dataset for monitoring overfitting
342    ///
343    /// # Returns
344    ///
345    /// A `Result` indicating success or failure.
346    ///
347    /// # Errors
348    ///
349    /// Returns an error if training fails or either dataset is invalid.
350    pub fn fit_with_validation(&mut self, train_data: &Dataset, val_data: &Dataset) -> Result<()> {
351        self.booster.fit(train_data, Some(val_data))?;
352        Ok(())
353    }
354    
355    /// Makes predictions for a feature matrix.
356    ///
357    /// # Parameters
358    ///
359    /// * `features` - The feature matrix to predict on
360    ///
361    /// # Returns
362    ///
363    /// A `Result` containing a vector of predictions.
364    ///
365    /// # Errors
366    ///
367    /// Returns an error if the model is not trained or prediction fails.
368    pub fn predict(&self, features: &FeatureMatrix) -> Result<Vec<f64>> {
369        self.booster.predict(features).map_err(GBRTError::from)
370    }
371    
372    /// Makes prediction for a single sample.
373    ///
374    /// # Parameters
375    ///
376    /// * `features` - A slice of feature values for one sample
377    ///
378    /// # Returns
379    ///
380    /// A `Result` containing the predicted value.
381    ///
382    /// # Errors
383    ///
384    /// Returns an error if the model is not trained or the feature length is invalid.
385    pub fn predict_single(&self, features: &[f64]) -> Result<f64> {
386        self.booster.predict_single(features).map_err(GBRTError::from)
387    }
388    
389    /// Gets feature importance scores.
390    ///
391    /// # Returns
392    ///
393    /// A vector of importance scores, one per feature.
394    pub fn feature_importance(&self) -> Vec<f64> {
395        self.booster.feature_importance().to_vec()
396    }
397
398    /// Gets training history and state.
399    ///
400    /// # Returns
401    ///
402    /// An optional reference to the training state.
403    pub fn training_history(&self) -> Option<&TrainingState> {
404        self.booster.training_state()
405    }
406    
407    /// Checks if the model has been trained.
408    ///
409    /// # Returns
410    ///
411    /// `true` if the model has been trained, `false` otherwise. 
412    pub fn is_trained(&self) -> bool {
413        self.booster.is_trained()
414    }
415    
416    /// Gets the number of trees in the ensemble.
417    ///
418    /// # Returns
419    ///
420    /// The number of trees, or 0 if untrained.
421    pub fn n_trees(&self) -> usize {
422        self.booster.n_trees()
423    }
424    
425    /// Sets feature names for better interpretability.
426    ///
427    /// # Parameters
428    ///
429    /// * `names` - Vector of feature names in order
430    pub fn set_feature_names(&mut self, names: Vec<String>) {
431        self.feature_names = Some(names);
432    }
433
434    /// Gets the feature names if they have been set.
435    ///
436    /// # Returns
437    ///
438    /// An optional reference to the feature names vector.
439    pub fn feature_names(&self) -> Option<&Vec<String>> {
440        self.feature_names.as_ref()
441    }
442
443    /// Sets the target name for better interpretability.
444    ///
445    /// # Parameters
446    ///
447    /// * `name` - The target variable name
448    pub fn set_target_name(&mut self, name: String) {
449        self.target_name = Some(name);
450    }
451
452    /// Gets a reference to the underlying booster.
453    ///
454    /// # Returns
455    ///
456    /// Reference to the gradient booster.
457    pub fn booster(&self) -> &GradientBooster {
458        &self.booster
459    }
460}
461
462impl Default for GBRTModel {
463    /// Creates a default GBRT model for regression.
464    fn default() -> Self {
465        Self::new().unwrap()
466    }
467}
468
469/// Convenience functions for common tasks
470
471/// Trains a GBRT model on data from a CSV file.
472///
473/// # Parameters
474///
475/// * `train_path` - Path to the CSV training data file
476/// * `target_column` - Optional name of the target column (defaults to "target")
477/// * `model_type` - Type of model to train (regression or binary classification)
478///
479/// # Returns
480///
481/// A `Result` containing the trained model.
482///
483/// # Errors
484///
485/// Returns an error if data loading or training fails.
486pub fn train_from_csv(
487    train_path: &str,
488    target_column: Option<&str>,
489    model_type: ObjectiveType,
490) -> Result<GBRTModel> {
491    let data_loader = DataLoader::new()?;
492    let options = DataLoadOptions::default()
493        .with_target_column(target_column.unwrap_or("target"));
494    
495    let dataset = data_loader.load_csv(train_path)?;
496    
497    let mut model = match model_type {
498        ObjectiveType::Regression => GBRTModel::new()?,
499        ObjectiveType::BinaryClassification => GBRTModel::new_classifier()?,
500    };
501    
502    model.fit(&dataset)?;
503    Ok(model)
504}
505
506/// Trains a GBRT model with train/validation split.
507///
508/// # Parameters
509///
510/// * `data_path` - Path to the CSV data file
511/// * `target_column` - Optional name of the target column (defaults to "target")
512/// * `test_size` - Fraction of data to use for validation (e.g., 0.2 for 20%)
513/// * `model_type` - Type of model to train
514///
515/// # Returns
516///
517/// A `Result` containing the trained model with validation.
518///
519/// # Errors
520///
521/// Returns an error if data loading, splitting, or training fails.
522pub fn train_with_validation_split(
523    data_path: &str,
524    target_column: Option<&str>,
525    test_size: f64,
526    model_type: ObjectiveType,
527) -> Result<GBRTModel> {
528    let data_loader = DataLoader::new()?;
529    let options = DataLoadOptions::default()
530        .with_target_column(target_column.unwrap_or("target"));
531    
532    let dataset = data_loader.load_csv(data_path)?;
533    let (train_data, val_data) = data_loader.split_data(&dataset, test_size, true, None)?;
534    
535    let mut model = match model_type {
536        ObjectiveType::Regression => GBRTModel::new()?,
537        ObjectiveType::BinaryClassification => GBRTModel::new_classifier()?,
538    };
539    
540    model.fit_with_validation(&train_data, &val_data)?;
541    Ok(model)
542}
543
544/// Performs k-fold cross-validation for model evaluation.
545///
546/// # Parameters
547///
548/// * `data_path` - Path to the CSV data file
549/// * `target_column` - Optional name of the target column (defaults to "target")
550/// * `n_splits` - Number of cross-validation folds
551/// * `model_type` - Type of model to evaluate
552/// * `categorical_columns` - Optional comma-separated list of categorical column names
553/// * `categorical_threshold` - Threshold for auto-detecting categorical features
554///
555/// # Returns
556///
557/// A `Result` containing a vector of R² scores for each fold.
558///
559/// # Errors
560///
561/// Returns an error if data loading, model training, or evaluation fails.
562pub fn cross_validate(
563    data_path: &str,
564    target_column: Option<&str>,
565    n_splits: usize,
566    model_type: ObjectiveType,
567    categorical_columns: Option<&str>, // ✅ NEW: Explicit categorical columns
568    categorical_threshold: f64,        // ✅ NEW: Auto-detect threshold
569) -> Result<Vec<f64>> {
570    let data_loader = DataLoader::new()?;
571
572    // ✅ Parse CSV headers to get column name to index mapping (same logic as train_model)
573    let mut temp_reader = csv::Reader::from_path(data_path)?;
574    let headers: Vec<String> = temp_reader
575        .headers()?
576        .iter()
577        .map(|s| s.to_string())
578        .collect();
579
580    let header_map: HashMap<_, _> = headers.iter()
581        .enumerate()
582        .map(|(i, name)| (name.as_str(), i))
583        .collect();
584
585    // ✅ Parse categorical column names to indices (same logic as train_model)
586    let categorical_indices = if let Some(col_names) = categorical_columns {
587        col_names.split(',')
588            .map(|s| s.trim())
589            .map(|name| {
590                header_map.get(name)
591                    .copied()
592                    .ok_or_else(|| {
593                        GBRTError::ConfigError(format!(
594                            "Unknown categorical column '{}'. Available columns: {:?}",
595                            name, headers
596                        ))
597                    })
598            })
599            .collect::<Result<Vec<_>>>()?
600    } else {
601        vec![] // Auto-detect will be handled by load_csv_with_categorical
602    };
603
604    // ✅ Load data WITH proper categorical encoding (critical fix)
605    let dataset = data_loader.load_csv_with_categorical(
606        data_path,
607        target_column.unwrap_or("target"),
608        if categorical_columns.is_some() {
609            Some(&categorical_indices)
610        } else {
611            None
612        },
613        categorical_threshold,
614    )?;
615
616    // Create cross-validation splits
617    let splits = data_loader.cross_validation_splits(&dataset, n_splits, true, None)?;
618
619    let mut scores = Vec::new();
620
621    for (fold_idx, (train_data, test_data)) in splits.into_iter().enumerate() {
622        let mut model = match model_type {
623            ObjectiveType::Regression => GBRTModel::new()?,
624            ObjectiveType::BinaryClassification => GBRTModel::new_classifier()?,
625        };
626
627        model.fit(&train_data)?;
628        let predictions = model.predict(test_data.features())?;
629
630        // ✅ Calculate R² (coefficient of determination) instead of RMSE
631        let targets = test_data.targets().as_slice().unwrap();
632        let score = calculate_r2(&predictions, targets)?;
633
634        scores.push(score);
635    }
636
637    Ok(scores)
638}
639
640/// Computes the R² (coefficient of determination) metric.
641///
642/// # Parameters
643///
644/// * `predictions` - Slice of predicted values
645/// * `targets` - Slice of true target values
646///
647/// # Returns
648///
649/// A `Result` containing the R² score.
650///
651/// # Errors
652///
653/// Returns an error if the input slices have different lengths.
654///
655/// # Formula
656///
657/// R² = 1 - (Σ(y_true - y_pred)² / Σ(y_true - y_mean)²)
658fn calculate_r2(predictions: &[f64], targets: &[f64]) -> Result<f64> {
659    if predictions.len() != targets.len() {
660        return Err(GBRTError::ValidationError(
661            ValidationError::ValidationFailed("Predictions and targets length mismatch".to_string())
662        ));
663    }
664
665    let y_mean = targets.iter().sum::<f64>() / targets.len() as f64;
666    let total_ss: f64 = targets.iter()
667        .map(|&y| (y - y_mean).powi(2))
668        .sum();
669    let residual_ss: f64 = targets.iter()
670        .zip(predictions.iter())
671        .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
672        .sum();
673
674    let r2 = if total_ss == 0.0 {
675        1.0 // Perfect prediction for constant target
676    } else {
677        1.0 - (residual_ss / total_ss)
678    };
679
680    Ok(r2)
681}
682
683/// Container for common model evaluation metrics.
684pub struct ModelMetrics {
685    /// Root Mean Squared Error
686    pub rmse: f64,
687    /// Mean Absolute Error
688    pub mae: f64,
689    /// R² (coefficient of determination)
690    pub r2: f64,
691}
692
693impl ModelMetrics {
694    /// Calculates regression metrics from predictions and true values.
695    ///
696    /// # Parameters
697    ///
698    /// * `y_true` - Slice of true target values
699    /// * `y_pred` - Slice of predicted values
700    ///
701    /// # Returns
702    ///
703    /// A `Result` containing the computed metrics.
704    ///
705    /// # Errors
706    ///
707    /// Returns an error if the input slices have different lengths.
708    pub fn regression_metrics(y_true: &[f64], y_pred: &[f64]) -> Result<Self> {
709        if y_true.len() != y_pred.len() {
710            return Err(GBRTError::ValidationError(
711                ValidationError::ValidationFailed("True and predicted values must have same length".to_string())
712            ));
713        }
714        
715        let rmse = utils::rmse(y_pred, y_true)?;
716        let mae = utils::compute_mean(
717            &y_true.iter()
718                .zip(y_pred.iter())
719                .map(|(&true_val, &pred_val)| (true_val - pred_val).abs())
720                .collect::<Vec<f64>>()
721        )?;
722        
723        // Calculate R²
724        let y_mean = utils::compute_mean(y_true)?;
725        let total_ss: f64 = y_true.iter()
726            .map(|&y| (y - y_mean).powi(2))
727            .sum();
728        let residual_ss: f64 = y_true.iter()
729            .zip(y_pred.iter())
730            .map(|(&true_val, &pred_val)| (true_val - pred_val).powi(2))
731            .sum();
732        
733        let r2 = if total_ss == 0.0 {
734            1.0
735        } else {
736            1.0 - (residual_ss / total_ss)
737        };
738        
739        Ok(Self { rmse, mae, r2 })
740    }
741}
742
743/// Returns the crate version string.
744///
745/// # Returns
746///
747/// The version from Cargo.toml
748pub fn version() -> &'static str {
749    env!("CARGO_PKG_VERSION")
750}
751
752/// Feature importance analysis with named features.
753pub struct FeatureAnalysis {
754    /// Sorted vector of (feature_name, importance_score) pairs
755    pub importance: Vec<(String, f64)>,
756    /// Summary statistics of feature importances
757    pub summary: FeatureSummary,
758}
759
760/// Summary statistics for feature importance analysis.
761pub struct FeatureSummary {
762    /// Total number of features
763    pub n_features: usize,
764    /// Top 10 most important features
765    pub top_features: Vec<(String, f64)>,
766    /// Sum of all importance scores
767    pub total_importance: f64,
768}
769
770impl FeatureAnalysis {
771    /// Creates a feature analysis from a trained model.
772    ///
773    /// # Parameters
774    ///
775    /// * `model` - Reference to a trained gradient booster
776    /// * `feature_names` - Optional vector of feature names
777    ///
778    /// # Returns
779    ///
780    /// A new `FeatureAnalysis` instance.
781     pub fn from_model(model: &GradientBooster, feature_names: Option<Vec<String>>) -> Self {
782        let importance_scores = model.feature_importance();
783        let names = feature_names.unwrap_or_else(|| {
784            (0..importance_scores.len())
785                .map(|i| format!("feature_{}", i))
786                .collect()
787        });
788
789        let mut importance: Vec<(String, f64)> = names
790            .into_iter()
791            .zip(importance_scores.iter().copied()) // Use .copied() to get f64 values from &[f64]
792            .collect();
793
794        // Sort by importance (descending)
795        importance.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
796
797        let total_importance: f64 = importance.iter().map(|(_, imp)| imp).sum();
798        let top_features = importance.iter().take(10).cloned().collect();
799
800        let summary = FeatureSummary {
801            n_features: importance.len(),
802            top_features,
803            total_importance,
804        };
805
806        Self { importance, summary }
807    }
808}
809