gbrt_rs/io/
model_io.rs

1#![allow(dead_code)]
2
3//! Model Persistence and Management for Gradient Boosted Models
4//! 
5//! This module provides comprehensive functionality for saving, loading, and managing
6//! trained [`GradientBooster`] models. It supports multiple serialization formats,
7//! model versioning, integrity verification, and a model registry for organizing
8//! multiple models.
9//! 
10//! # Features
11//! 
12//! - **Multiple Formats**: Save models as binary (compact) or JSON (human-readable)
13//! - **Metadata Tracking**: Store model configuration, training parameters, and custom metadata
14//! - **Integrity Verification**: Validate model integrity during loading
15//! - **Version Management**: Track model versions and detect compatibility issues
16//! - **Model Registry**: Centralized management of multiple models with name-based lookup
17//! - **Feature Importance**: Optionally export feature importance scores
18//! - **Compression**: Optional gzip compression for reduced disk usage
19
20use crate::boosting::{GradientBooster, GBRTConfig};  // Changed from core to boosting
21use crate::utils::{ModelSerializer, ModelMetadata, SerializationFormat, SerializationError};
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::fs::{File, create_dir_all, read_dir};
25use std::io::{BufReader, BufWriter};
26use std::path::{Path, PathBuf};
27use thiserror::Error;
28
29
30/// Errors that can occur during model I/O operations.
31///
32/// This error type covers all failure modes: file system errors, serialization
33/// failures, version mismatches, integrity violations, and missing models.
34#[derive(Error, Debug)]
35pub enum ModelIOError {
36    /// File system I/O error.
37    #[error("IO error: {0}")]
38    IoError(#[from] std::io::Error),
39
40    /// Model serialization/deserialization error.
41    #[error("Serialization error: {0}")]
42    SerializationError(#[from] crate::utils::SerializationError),
43
44    /// Parquet file reading error.
45    #[error("Parquet error: {0}")]
46    ParquetError(#[from] parquet::errors::ParquetError),
47
48    /// General model error (e.g., untrained model).
49    #[error("Model error: {0}")]
50    ModelError(String),
51
52    /// Model file is invalid or corrupted.
53    #[error("Invalid model file: {0}")]
54    InvalidModelFile(String),
55
56    /// Version incompatibility between saved model and current library.
57    #[error("Version mismatch: expected {expected}, got {actual}")]
58    VersionMismatch { expected: String, actual: String },
59
60    /// Unsupported file format.
61    #[error("Unsupported format: {format}")]
62    UnsupportedFormat { format: String },
63
64    /// Model file not found at specified path.
65    #[error("Model not found: {0}")]
66    ModelNotFound(String),
67}
68
69pub type ModelIOResult<T> = std::result::Result<T, ModelIOError>;
70
71/// Supported model serialization formats.
72/// 
73/// - **Binary**: Compact, fast serialization using bincode (default)
74/// - **JSON**: Human-readable text format for inspection and debugging
75#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
76pub enum ModelFormat {
77    /// Binary format for compact storage and fast loading.
78    Binary,
79    /// JSON format for human-readability and interoperability.
80    Json,
81}
82
83impl std::str::FromStr for ModelFormat {
84    type Err = ModelIOError;
85
86    /// Parses ModelFormat from string.
87    /// 
88    /// # Supported Values
89    /// - "binary", "bin" → `ModelFormat::Binary`
90    /// - "json" → `ModelFormat::Json`
91    /// 
92    /// # Errors
93    /// - `ModelIOError::UnsupportedFormat` if string doesn't match any known format
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        match s.to_lowercase().as_str() {
96            "binary" | "bin" => Ok(ModelFormat::Binary),
97            "json" => Ok(ModelFormat::Json),
98            _ => Err(ModelIOError::UnsupportedFormat {
99                format: s.to_string(),
100            }),
101        }
102    }
103}
104
105impl std::fmt::Display for ModelFormat {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        match self {
108            ModelFormat::Binary => write!(f, "binary"),
109            ModelFormat::Json => write!(f, "json"),
110        }
111    }
112}
113
114/// Configuration options for saving models.
115/// 
116/// Controls serialization format, compression, and what additional data to include.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct SaveOptions {
119    /// Serialization format to use (Binary or JSON).
120    pub format: ModelFormat,
121    /// Whether to compress the model file.
122    pub compression: bool,
123    /// Whether to include training data in the save (not yet implemented).
124    pub include_training_data: bool,
125    /// Whether to include feature importance scores.
126    pub include_feature_importance: bool,
127    /// Additional metadata to store with the model.
128    pub metadata: HashMap<String, String>,
129}
130
131impl Default for SaveOptions {
132    fn default() -> Self {
133        Self {
134            format: ModelFormat::Json,
135            compression: false,
136            include_training_data: false,
137            include_feature_importance: true,
138            metadata: HashMap::new(),
139        }
140    }
141}
142
143impl SaveOptions {
144    /// Creates default SaveOptions.
145    pub fn new() -> Self {
146        Self::default()
147    }
148
149    /// Sets the serialization format.
150    pub fn with_format(mut self, format: ModelFormat) -> Self {
151        self.format = format;
152        self
153    }
154
155    /// Enables or disables compression.
156    pub fn with_compression(mut self, compression: bool) -> Self {
157        self.compression = compression;
158        self
159    }
160
161    /// Sets whether to include training data (future feature).
162    pub fn with_training_data(mut self, include: bool) -> Self {
163        self.include_training_data = include;
164        self
165    }
166
167    /// Sets whether to include feature importance scores.
168    pub fn with_feature_importance(mut self, include: bool) -> Self {
169        self.include_feature_importance = include;
170        self
171    }
172
173    /// Replaces all metadata with provided map.
174    pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
175        self.metadata = metadata;
176        self
177    }
178
179    /// Adds a single metadata key-value pair.
180    pub fn add_metadata(mut self, key: &str, value: &str) -> Self {
181        self.metadata.insert(key.to_string(), value.to_string());
182        self
183    }
184}
185
186/// Configuration options for loading models.
187/// 
188/// Controls validation and integrity checks during model loading.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LoadOptions {
191    /// Whether to verify model integrity (default: true).
192    pub verify_integrity: bool,
193    /// Whether to enforce strict version checking (default: false).
194    pub strict_version_check: bool,
195}
196
197impl Default for LoadOptions {
198    fn default() -> Self {
199        Self {
200            verify_integrity: true,
201            strict_version_check: false,
202        }
203    }
204}
205
206impl LoadOptions {
207    /// Creates default LoadOptions.
208    pub fn new() -> Self {
209        Self::default()
210    }
211
212    /// Enables or disables integrity verification.
213    pub fn with_verify_integrity(mut self, verify: bool) -> Self {
214        self.verify_integrity = verify;
215        self
216    }
217
218    /// Enables or disables strict version checking.
219    pub fn with_strict_version_check(mut self, strict: bool) -> Self {
220        self.strict_version_check = strict;
221        self
222    }
223}
224
225/// Main model I/O handler for saving and loading gradient boosted models.
226/// 
227/// `ModelIO` provides a high-level interface for model persistence operations.
228/// It can be configured with custom save/load options and handles all aspects
229/// of model serialization, metadata management, and validation.
230pub struct ModelIO {
231    save_options: SaveOptions,
232    load_options: LoadOptions,
233}
234
235impl ModelIO {
236    /// Creates a new ModelIO with default options.
237    /// 
238    /// # Returns
239    /// `Ok(ModelIO)` with default save and load options.
240    pub fn new() -> ModelIOResult<Self> {
241        Ok(Self {
242            save_options: SaveOptions::default(),
243            load_options: LoadOptions::default(),
244        })
245    }
246
247    /// Creates a ModelIO with custom save options.
248    pub fn with_save_options(save_options: SaveOptions) -> Self {
249        Self {
250            save_options,
251            load_options: LoadOptions::default(),
252        }
253    }
254
255    /// Creates a ModelIO with custom load options.
256    pub fn with_load_options(load_options: LoadOptions) -> Self {
257        Self {
258            save_options: SaveOptions::default(),
259            load_options,
260        }
261    }
262
263    /// Saves a gradient boosting model to file.
264    /// 
265    /// The model is saved as a directory containing:
266    /// - `{model_name}.json`: The main model file
267    /// - `{model_name}_metadata.json`: Model metadata and hyperparameters
268    /// - `{model_name}_feature_importance.json`: Feature importance scores (if enabled)
269    /// - Additional files for training data (future feature)
270    /// 
271    /// # Parameters
272    /// - `model`: Trained gradient boosting model
273    /// - `path`: Directory path where model will be saved
274    /// - `model_name`: Base name for model files
275    /// 
276    /// # Errors
277    /// - `ModelIOError::ModelError` if model is untrained
278    /// - `ModelIOError::IoError` if directory creation or file writing fails
279    /// - `ModelIOError::SerializationError` if serialization fails
280    pub fn save_model<P: AsRef<Path>>(
281        &self,
282        model: &GradientBooster,  
283        path: P,
284        model_name: &str,
285    ) -> ModelIOResult<()> {
286        save_model(model, path, model_name, &self.save_options)
287    }
288
289    /// Loads a gradient boosting model from file or directory.
290    /// 
291    /// # Path Resolution
292    /// 
293    /// - If `path` is a file: Loads that file directly
294    /// - If `path` is a directory: Looks for `gbrt_model.json` inside the directory
295    /// 
296    /// # Parameters
297    /// - `path`: Path to model file or directory
298    /// 
299    /// # Errors
300    /// - `ModelIOError::ModelNotFound` if model file doesn't exist
301    /// - `ModelIOError::SerializationError` if deserialization fails
302    /// - `ModelIOError::InvalidModelFile` if integrity check fails
303    /// - `ModelIOError::VersionMismatch` if version check fails and strict mode enabled
304    pub fn load_model<P: AsRef<Path>>(&self, path: P) -> ModelIOResult<GradientBooster> {
305        let path = path.as_ref();
306
307        // ✅ If path is a directory, look for gbrt_model.json inside
308        let model_file = if path.is_dir() {
309            path.join("gbrt_model.json")
310        } else {
311            path.to_path_buf()
312        };
313
314        if !model_file.exists() {
315            return Err(ModelIOError::ModelNotFound(
316                format!("Model file not found: {}", model_file.display())
317            ));
318        }
319
320        let serializer = ModelSerializer::new();
321        let (model, metadata) = serializer.load_model(&model_file)?;
322
323        // ✅ Verify integrity if requested
324        if self.load_options.verify_integrity {
325            verify_model_integrity(&model, &metadata)?;
326        }
327
328        // ✅ Version check
329        if self.load_options.strict_version_check {
330            metadata.validate_version()
331                .map_err(|e| {
332                    match e {
333                        SerializationError::VersionMismatch { expected, actual } => {
334                            ModelIOError::VersionMismatch {
335                                expected: expected.clone(),
336                                actual: actual.clone(),
337                            }
338                        }
339                        other => ModelIOError::SerializationError(other),
340                    }
341                })?;
342        }
343
344        Ok(model)
345    }
346
347    /// Exports a model to a different format.
348    /// 
349    /// This allows converting between Binary and JSON formats without retraining.
350    /// 
351    /// # Parameters
352    /// - `model`: Model to export
353    /// - `path`: Output file path
354    /// - `format`: Target format (Binary or JSON)
355    /// 
356    /// # Errors
357    /// - `ModelIOError::SerializationError` if conversion fails
358    /// - `ModelIOError::IoError` if file writing fails
359    pub fn export_model<P: AsRef<Path>>(
360        &self,
361        model: &GradientBooster,  // Changed from GBRT to GradientBooster
362        path: P,
363        format: ModelFormat,
364    ) -> ModelIOResult<()> {
365        export_model(model, path, format)
366    }
367
368    /// Imports a model from a specific format.
369    /// 
370    /// # Parameters
371    /// - `path`: Path to model file
372    /// - `format`: Format of the model file
373    /// 
374    /// # Returns
375    /// Loaded model
376    /// 
377    /// # Errors
378    /// - `ModelIOError::SerializationError` if deserialization fails
379    /// - `ModelIOError::IoError` if file reading fails
380    pub fn import_model<P: AsRef<Path>>(&self, path: P, format: ModelFormat) -> ModelIOResult<GradientBooster> {  // Changed return type
381        import_model(path, format)
382    }
383
384    /// Lists all saved model files in a directory.
385    /// 
386    /// Searches for files with extensions: `.bin`, `.json`, `.gz`
387    /// 
388    /// # Parameters
389    /// - `directory`: Directory to search
390    /// 
391    /// # Returns
392    /// Sorted vector of model file paths
393    /// 
394    /// # Errors
395    /// - `ModelIOError::IoError` if directory cannot be read
396    pub fn list_saved_models<P: AsRef<Path>>(&self, directory: P) -> ModelIOResult<Vec<PathBuf>> {
397        list_saved_models(directory)
398    }
399
400    /// Retrieves metadata for a saved model.
401    /// 
402    /// # Parameters
403    /// - `path`: Path to model file or directory
404    /// 
405    /// # Returns
406    /// Model metadata including configuration, version, and parameters
407    /// 
408    /// # Errors
409    /// - `ModelIOError::InvalidModelFile` if metadata file is missing or corrupted
410    /// - `ModelIOError::SerializationError` if metadata cannot be parsed
411    pub fn get_model_info<P: AsRef<Path>>(&self, path: P) -> ModelIOResult<ModelMetadata> {
412        get_model_info(path)
413    }
414
415    /// Validates a model file for integrity and compatibility.
416    /// 
417    /// Performs basic checks:
418    /// - Model is trained
419    /// - Number of trees matches metadata
420    /// - Feature dimensions are consistent
421    /// - Model type is correct
422    /// 
423    /// # Parameters
424    /// - `path`: Path to model file
425    /// 
426    /// # Returns
427    /// `Ok(true)` if model is valid, `Err` with details otherwise
428    /// 
429    /// # Errors
430    /// - `ModelIOError::InvalidModelFile` if any validation check fails
431    pub fn validate_model_file<P: AsRef<Path>>(&self, path: P) -> ModelIOResult<bool> {
432        validate_model_file(path)
433    }
434}
435
436impl Default for ModelIO {
437    fn default() -> Self {
438        Self::new().unwrap()
439    }
440}
441
442// ============================================================================
443// Standalone Model I/O Functions
444// ============================================================================
445
446/// Saves a gradient boosting model to a directory.
447/// 
448/// Creates a directory structure with:
449/// - `{model_name}.json`: Main model file with tree ensembles
450/// - `{model_name}_metadata.json`: Hyperparameters and configuration
451/// - `{model_name}_feature_importance.json`: Per-feature importance scores (if enabled)
452/// 
453/// # Parameters
454/// - `model`: Trained gradient boosting model
455/// - `path`: Directory path for saving
456/// - `model_name`: Base name for output files
457/// - `options`: Save configuration (format, compression, metadata, etc.)
458/// 
459    /// # Errors
460/// - `ModelIOError::ModelError` if model is untrained
461/// - `ModelIOError::IoError` if directory or file creation fails
462/// - `ModelIOError::SerializationError` if serialization fails
463pub fn save_model<P: AsRef<Path>>(
464    model: &GradientBooster,  // Changed from GBRT to GradientBooster
465    path: P,
466    model_name: &str,
467    options: &SaveOptions,
468) -> ModelIOResult<()> {
469    if !model.is_trained() {
470        return Err(ModelIOError::ModelError("Cannot save untrained model".to_string()));
471    }
472
473    let path = path.as_ref();
474    create_dir_all(path.parent().unwrap_or(Path::new(".")))?;
475
476    // Create metadata
477    let mut metadata = ModelMetadata::new(
478        "GBRT",
479        model.feature_importance().len(),
480        model.n_trees(),
481    );
482
483    // Add configuration to metadata
484    let config = model.config();
485    metadata = metadata
486        .with_parameter("n_estimators", &config.n_estimators.to_string())
487        .with_parameter("learning_rate", &config.learning_rate.to_string())
488        .with_parameter("loss", &config.loss.to_string())
489        .with_parameter("subsample", &config.subsample.to_string());
490
491    // Add custom metadata
492    for (key, value) in &options.metadata {
493        metadata = metadata.with_parameter(key, value);
494    }
495
496    // Choose serialization format
497    let serialization_format = match options.format {
498        ModelFormat::Binary => SerializationFormat::Bincode,
499        ModelFormat::Json => SerializationFormat::Json,
500    };
501
502    let serializer = ModelSerializer::new()
503        .with_format(serialization_format)
504        .with_compression(options.compression);
505
506    // Create file path
507    let file_path = path.join(format!("{}.json", model_name));
508
509    serializer.save_model(model, &metadata, &file_path)?;
510
511    // Save additional information if requested
512    if options.include_feature_importance {
513        save_feature_importance(model, path, model_name)?;
514    }
515
516
517    // Save metadata
518    save_metadata(&metadata, path, model_name)?;  // Add this line
519
520    Ok(())
521}
522
523/// Loads a gradient boosting model from file.
524/// 
525/// # Parameters
526/// - `path`: Path to model file
527/// - `options`: Load configuration (verification, version checking)
528/// 
529/// # Returns
530/// Loaded model with validated metadata
531/// 
532/// # Errors
533/// - `ModelIOError::ModelNotFound` if file doesn't exist
534/// - `ModelIOError::SerializationError` if deserialization fails
535/// - `ModelIOError::InvalidModelFile` if integrity check fails
536/// - `ModelIOError::VersionMismatch` if strict version check fails
537pub fn load_model<P: AsRef<Path>>(
538    path: P,
539    options: &LoadOptions,
540) -> ModelIOResult<GradientBooster> {  // Changed return type
541    let path = path.as_ref();
542
543    if !path.exists() {
544        return Err(ModelIOError::ModelNotFound(path.to_string_lossy().to_string()));
545    }
546
547    let serializer = ModelSerializer::new();
548    let (model, metadata) = serializer.load_model(path)?;
549
550    // Verify integrity if requested
551    if options.verify_integrity {
552        verify_model_integrity(&model, &metadata)?;
553    }
554
555    // Version check
556    if options.strict_version_check {
557        metadata.validate_version()
558            .map_err(|e| {
559                match e {
560                    SerializationError::VersionMismatch { expected, actual } => {
561                        ModelIOError::VersionMismatch {
562                            expected: expected.clone(),
563                            actual: actual.clone(),
564                        }
565                    }
566                    other => ModelIOError::SerializationError(other),
567                }
568            })?;
569    }
570
571    Ok(model)
572}
573
574
575/// Exports a model to a different format.
576/// 
577/// Allows format conversion without retraining. Useful for sharing models
578/// or switching between compact binary and readable JSON formats.
579/// 
580/// # Parameters
581/// - `model`: Model to export
582/// - `path`: Output file path
583/// - `format`: Target format
584/// 
585/// # Errors
586/// - `ModelIOError::SerializationError` if conversion fails
587/// - `ModelIOError::IoError` if file cannot be written
588pub fn export_model<P: AsRef<Path>>(
589    model: &GradientBooster,  // Changed from GBRT to GradientBooster
590    path: P,
591    format: ModelFormat,
592) -> ModelIOResult<()> {
593    match format {
594        ModelFormat::Json => {
595            // Export as human-readable JSON
596            let json = serde_json::to_string_pretty(model)
597                .map_err(|e| ModelIOError::SerializationError(e.into()))?;
598            std::fs::write(path, json)?;
599            Ok(())
600        }
601        ModelFormat::Binary => {
602            // Use standard save with binary format
603            let options = SaveOptions::new().with_format(ModelFormat::Binary);
604            save_model(model, path, "export", &options)
605        }
606    }
607}
608
609/// Imports a model from a specific format.
610/// 
611/// # Parameters
612/// - `path`: Path to model file
613/// - `format`: Expected format of the file
614/// 
615/// # Returns
616/// Loaded model
617/// 
618/// # Errors
619/// - `ModelIOError::SerializationError` if deserialization fails
620/// - `ModelIOError::IoError` if file cannot be read
621/// - `ModelIOError::InvalidModelFile` if format doesn't match content
622pub fn import_model<P: AsRef<Path>>(
623    path: P,
624    format: ModelFormat,
625) -> ModelIOResult<GradientBooster> {  // Changed return type
626    match format {
627        ModelFormat::Json => {
628            let file = File::open(path)?;
629            let reader = BufReader::new(file);
630            let model: GradientBooster = serde_json::from_reader(reader)  // Changed type
631                .map_err(|e| ModelIOError::SerializationError(e.into()))?;
632            Ok(model)
633        }
634        ModelFormat::Binary => {
635            load_model(path, &LoadOptions::default())
636        }
637    }
638}
639
640/// Lists all model files in a directory.
641/// 
642/// Searches for files with extensions: `.bin`, `.json`, `.gz`
643/// 
644/// # Parameters
645/// - `directory`: Directory to search
646/// 
647/// # Returns
648/// Sorted vector of absolute paths to model files
649/// 
650/// # Errors
651/// - `ModelIOError::IoError` if directory cannot be read
652pub fn list_saved_models<P: AsRef<Path>>(directory: P) -> ModelIOResult<Vec<PathBuf>> {
653    let directory = directory.as_ref();
654    let mut model_files = Vec::new();
655
656    if !directory.exists() {
657        return Ok(model_files);
658    }
659
660    for entry in read_dir(directory)? {
661        let entry = entry?;
662        let path = entry.path();
663
664        if path.is_file() {
665            if let Some(ext) = path.extension() {
666                if ext == "bin" || ext == "json" || ext == "gz" {
667                    model_files.push(path);
668                }
669            }
670        }
671    }
672
673    model_files.sort();
674    Ok(model_files)
675}
676
677/// Retrieves metadata for a saved model.
678/// 
679/// Metadata includes model type, version, hyperparameters, and custom parameters.
680/// 
681/// # Parameters
682/// - `path`: Path to model file or directory
683/// 
684/// # Returns
685/// Deserialized ModelMetadata
686/// 
687/// # Errors
688/// - `ModelIOError::InvalidModelFile` if metadata file is missing
689/// - `ModelIOError::SerializationError` if JSON parsing fails
690/// 
691/// # Path Resolution
692/// - If `path` is a file like `model.json`, looks for `model_metadata.json`
693/// - If `path` is a directory, looks for `gbrt_model_metadata.json`
694pub fn get_model_info<P: AsRef<Path>>(path: P) -> ModelIOResult<ModelMetadata> {
695    let path = path.as_ref();
696
697    // Determine metadata file path
698    let metadata_path = if path.is_file() {
699        // If given a file like "gbrt_model.json", look for "gbrt_model_metadata.json"
700        if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
701            path.with_file_name(format!("{}_metadata.json", stem))
702        } else {
703            return Err(ModelIOError::InvalidModelFile("Invalid file name".to_string()));
704        }
705    } else {
706        // If given a directory, look for "gbrt_model_metadata.json"
707        path.join("gbrt_model_metadata.json")
708    };
709
710    if metadata_path.exists() {
711        let file = File::open(metadata_path)?;
712        let reader = BufReader::new(file);
713        let metadata: ModelMetadata = serde_json::from_reader(reader)
714            .map_err(|e| ModelIOError::SerializationError(e.into()))?;
715        Ok(metadata)
716    } else {
717        Err(ModelIOError::InvalidModelFile(
718            format!("Metadata file not found: {}", metadata_path.display())
719        ))
720    }
721}
722
723/// Validates a model file by checking integrity and consistency.
724/// 
725/// Performs comprehensive validation:
726/// - Model is trained
727/// - Tree count matches metadata
728/// - Feature dimensions match metadata
729/// - Model type is correct
730/// 
731/// # Parameters
732/// - `path`: Path to model file
733/// 
734/// # Returns
735/// `Ok(true)` if all checks pass
736/// 
737/// # Errors
738/// `Err` with specific violation details if any check fails
739pub fn validate_model_file<P: AsRef<Path>>(path: P) -> ModelIOResult<bool> {
740    let metadata = get_model_info(path)?;
741    
742    // Basic validation checks
743    if metadata.num_features == 0 {
744        return Err(ModelIOError::InvalidModelFile("Invalid number of features".to_string()));
745    }
746
747    if metadata.model_type != "GBRT" {
748        return Err(ModelIOError::InvalidModelFile(format!("Invalid model type: {}", metadata.model_type)));
749    }
750
751    Ok(true)
752}
753
754// ============================================================================
755// Helper Functions
756// ============================================================================
757
758/// Saves feature importance scores to a JSON file.
759/// 
760/// Creates a sorted list of features by importance score.
761fn save_feature_importance<P: AsRef<Path>>(
762    model: &GradientBooster,  // Changed from GBRT to GradientBooster
763    path: P,
764    model_name: &str,
765) -> ModelIOResult<()> {
766    let importance = model.feature_importance();
767    let feature_names: Vec<String> = (0..importance.len())
768        .map(|i| format!("feature_{}", i))
769        .collect();
770
771    let mut importance_data = Vec::new();
772    for (i, (&importance, name)) in importance.iter().zip(feature_names.iter()).enumerate() {
773        importance_data.push(FeatureImportance {
774            feature_index: i,
775            feature_name: name.clone(),
776            importance,
777        });
778    }
779
780    // Sort by importance (descending)
781    importance_data.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap());
782
783    let file_path = path.as_ref().join(format!("{}_feature_importance.json", model_name));
784    let file = File::create(file_path)?;
785    let writer = BufWriter::new(file);
786
787    serde_json::to_writer_pretty(writer, &importance_data)
788        .map_err(|e| ModelIOError::SerializationError(e.into()))?;
789
790    Ok(())
791}
792
793/// Saves model metadata to a JSON file.
794fn save_metadata<P: AsRef<Path>>(
795    metadata: &ModelMetadata,
796    path: P,
797    model_name: &str,
798) -> ModelIOResult<()> {
799    let file_path = path.as_ref().join(format!("{}_metadata.json", model_name));
800    let file = File::create(file_path)?;
801    let writer = BufWriter::new(file);
802
803    serde_json::to_writer_pretty(writer, metadata)
804        .map_err(|e| ModelIOError::SerializationError(e.into()))?;
805
806    Ok(())
807}
808
809/// Verifies model integrity by comparing against metadata.
810///
811/// Checks:
812/// - Model is trained
813/// - Tree count matches metadata
814/// - Feature dimension matches metadata
815fn verify_model_integrity(model: &GradientBooster, metadata: &ModelMetadata) -> ModelIOResult<()> {  // Changed parameter type
816    // Check if model is trained
817    if !model.is_trained() {
818        return Err(ModelIOError::InvalidModelFile("Model appears to be untrained".to_string()));
819    }
820
821    // Check number of trees
822    if model.n_trees() != metadata.num_trees {
823        return Err(ModelIOError::InvalidModelFile(
824            format!("Tree count mismatch: expected {}, got {}", metadata.num_trees, model.n_trees())
825        ));
826    }
827
828    // Check feature dimensions
829    if model.feature_importance().len() != metadata.num_features {
830        return Err(ModelIOError::InvalidModelFile(
831            format!("Feature dimension mismatch: expected {}, got {}", metadata.num_features, model.feature_importance().len())
832        ));
833    }
834
835    Ok(())
836}
837
838/// Internal representation of feature importance for JSON export.
839#[derive(Serialize, Deserialize)]
840struct FeatureImportance {
841    feature_index: usize,
842    feature_name: String,
843    importance: f64,
844}
845
846
847// ============================================================================
848// Model Registry for Managing Multiple Models
849// ============================================================================
850
851/// Centralized registry for managing multiple models by name.
852/// 
853/// The registry stores a JSON map of model names to file paths, enabling
854/// easy lookup and organization of models.
855/// 
856/// # Registry Structure
857/// 
858/// The registry is stored as `models.json` in the specified directory:
859/// ```json
860/// {
861///   "iris_classifier": "models/iris.json",
862///   "housing_predictor": "models/housing.json"
863/// }
864/// ```
865pub struct ModelRegistry {
866    registry_path: PathBuf,
867}
868
869impl ModelRegistry {
870    /// Creates a new model registry in the specified directory.
871    ///
872    /// # Parameters
873    /// - `path`: Directory where registry file (`models.json`) will be stored
874    pub fn new<P: AsRef<Path>>(path: P) -> Self {
875        Self {
876            registry_path: path.as_ref().to_path_buf(),
877        }
878    }
879
880    /// Registers a model in the registry.
881    /// 
882    /// Creates or updates the `models.json` file with the model name and path.
883    /// 
884    /// # Parameters
885    /// - `model_name`: Unique identifier for the model
886    /// - `model_path`: Path to the saved model file
887    /// 
888    /// # Errors
889    /// - `ModelIOError::SerializationError` if registry JSON cannot be written
890    /// - `ModelIOError::IoError` if registry directory cannot be created
891    pub fn register_model(&self, model_name: &str, model_path: &Path) -> ModelIOResult<()> {
892        let registry_file = self.registry_path.join("models.json");
893        let mut registry: HashMap<String, String> = if registry_file.exists() {
894            let file = File::open(&registry_file)?;
895            let reader = BufReader::new(file);
896            serde_json::from_reader(reader).unwrap_or_default()
897        } else {
898            HashMap::new()
899        };
900
901        registry.insert(model_name.to_string(), model_path.to_string_lossy().to_string());
902
903        let file = File::create(registry_file)?;
904        let writer = BufWriter::new(file);
905        serde_json::to_writer_pretty(writer, &registry)
906            .map_err(|e| ModelIOError::SerializationError(e.into()))?;
907
908        Ok(())
909    }
910
911   /// Retrieves the file path for a registered model.
912    /// 
913    /// # Parameters
914    /// - `model_name`: Name of the registered model
915    /// 
916    /// # Returns
917    /// Path to the model file
918    /// 
919    /// # Errors
920    /// - `ModelIOError::ModelNotFound` if model name is not in registry
921    /// - `ModelIOError::SerializationError` if registry JSON cannot be read
922    pub fn get_model_path(&self, model_name: &str) -> ModelIOResult<PathBuf> {
923        let registry_file = self.registry_path.join("models.json");
924        if !registry_file.exists() {
925            return Err(ModelIOError::ModelNotFound(format!("Model registry not found: {}", model_name)));
926        }
927
928        let file = File::open(registry_file)?;
929        let reader = BufReader::new(file);
930        let registry: HashMap<String, String> = serde_json::from_reader(reader)
931            .map_err(|e| ModelIOError::SerializationError(e.into()))?;
932
933        registry
934            .get(model_name)
935            .map(|path| PathBuf::from(path))
936            .ok_or_else(|| ModelIOError::ModelNotFound(model_name.to_string()))
937    }
938
939    /// Lists all registered model names.
940    /// 
941    /// # Returns
942    /// Vector of model names sorted alphabetically
943    /// 
944    /// # Errors
945    /// - `ModelIOError::SerializationError` if registry file cannot be read
946    pub fn list_registered_models(&self) -> ModelIOResult<Vec<String>> {
947        let registry_file = self.registry_path.join("models.json");
948        if !registry_file.exists() {
949            return Ok(Vec::new());
950        }
951
952        let file = File::open(registry_file)?;
953        let reader = BufReader::new(file);
954        let registry: HashMap<String, String> = serde_json::from_reader(reader)
955            .map_err(|e| ModelIOError::SerializationError(e.into()))?;
956
957        Ok(registry.keys().cloned().collect())
958    }
959}
960