gbrt_rs/io/model_io.rs
1#![allow(dead_code)]
2
3//! Model Persistence and Management for Gradient Boosted Models
4//!
5//! This module provides comprehensive functionality for saving, loading, and managing
6//! trained [`GradientBooster`] models. It supports multiple serialization formats,
7//! model versioning, integrity verification, and a model registry for organizing
8//! multiple models.
9//!
10//! # Features
11//!
12//! - **Multiple Formats**: Save models as binary (compact) or JSON (human-readable)
13//! - **Metadata Tracking**: Store model configuration, training parameters, and custom metadata
14//! - **Integrity Verification**: Validate model integrity during loading
15//! - **Version Management**: Track model versions and detect compatibility issues
16//! - **Model Registry**: Centralized management of multiple models with name-based lookup
17//! - **Feature Importance**: Optionally export feature importance scores
18//! - **Compression**: Optional gzip compression for reduced disk usage
19
20use crate::boosting::{GradientBooster, GBRTConfig}; // Changed from core to boosting
21use crate::utils::{ModelSerializer, ModelMetadata, SerializationFormat, SerializationError};
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::fs::{File, create_dir_all, read_dir};
25use std::io::{BufReader, BufWriter};
26use std::path::{Path, PathBuf};
27use thiserror::Error;
28
29
30/// Errors that can occur during model I/O operations.
31///
32/// This error type covers all failure modes: file system errors, serialization
33/// failures, version mismatches, integrity violations, and missing models.
34#[derive(Error, Debug)]
35pub enum ModelIOError {
36 /// File system I/O error.
37 #[error("IO error: {0}")]
38 IoError(#[from] std::io::Error),
39
40 /// Model serialization/deserialization error.
41 #[error("Serialization error: {0}")]
42 SerializationError(#[from] crate::utils::SerializationError),
43
44 /// Parquet file reading error.
45 #[error("Parquet error: {0}")]
46 ParquetError(#[from] parquet::errors::ParquetError),
47
48 /// General model error (e.g., untrained model).
49 #[error("Model error: {0}")]
50 ModelError(String),
51
52 /// Model file is invalid or corrupted.
53 #[error("Invalid model file: {0}")]
54 InvalidModelFile(String),
55
56 /// Version incompatibility between saved model and current library.
57 #[error("Version mismatch: expected {expected}, got {actual}")]
58 VersionMismatch { expected: String, actual: String },
59
60 /// Unsupported file format.
61 #[error("Unsupported format: {format}")]
62 UnsupportedFormat { format: String },
63
64 /// Model file not found at specified path.
65 #[error("Model not found: {0}")]
66 ModelNotFound(String),
67}
68
69pub type ModelIOResult<T> = std::result::Result<T, ModelIOError>;
70
71/// Supported model serialization formats.
72///
73/// - **Binary**: Compact, fast serialization using bincode (default)
74/// - **JSON**: Human-readable text format for inspection and debugging
75#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
76pub enum ModelFormat {
77 /// Binary format for compact storage and fast loading.
78 Binary,
79 /// JSON format for human-readability and interoperability.
80 Json,
81}
82
83impl std::str::FromStr for ModelFormat {
84 type Err = ModelIOError;
85
86 /// Parses ModelFormat from string.
87 ///
88 /// # Supported Values
89 /// - "binary", "bin" → `ModelFormat::Binary`
90 /// - "json" → `ModelFormat::Json`
91 ///
92 /// # Errors
93 /// - `ModelIOError::UnsupportedFormat` if string doesn't match any known format
94 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 match s.to_lowercase().as_str() {
96 "binary" | "bin" => Ok(ModelFormat::Binary),
97 "json" => Ok(ModelFormat::Json),
98 _ => Err(ModelIOError::UnsupportedFormat {
99 format: s.to_string(),
100 }),
101 }
102 }
103}
104
105impl std::fmt::Display for ModelFormat {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 match self {
108 ModelFormat::Binary => write!(f, "binary"),
109 ModelFormat::Json => write!(f, "json"),
110 }
111 }
112}
113
114/// Configuration options for saving models.
115///
116/// Controls serialization format, compression, and what additional data to include.
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct SaveOptions {
119 /// Serialization format to use (Binary or JSON).
120 pub format: ModelFormat,
121 /// Whether to compress the model file.
122 pub compression: bool,
123 /// Whether to include training data in the save (not yet implemented).
124 pub include_training_data: bool,
125 /// Whether to include feature importance scores.
126 pub include_feature_importance: bool,
127 /// Additional metadata to store with the model.
128 pub metadata: HashMap<String, String>,
129}
130
131impl Default for SaveOptions {
132 fn default() -> Self {
133 Self {
134 format: ModelFormat::Json,
135 compression: false,
136 include_training_data: false,
137 include_feature_importance: true,
138 metadata: HashMap::new(),
139 }
140 }
141}
142
143impl SaveOptions {
144 /// Creates default SaveOptions.
145 pub fn new() -> Self {
146 Self::default()
147 }
148
149 /// Sets the serialization format.
150 pub fn with_format(mut self, format: ModelFormat) -> Self {
151 self.format = format;
152 self
153 }
154
155 /// Enables or disables compression.
156 pub fn with_compression(mut self, compression: bool) -> Self {
157 self.compression = compression;
158 self
159 }
160
161 /// Sets whether to include training data (future feature).
162 pub fn with_training_data(mut self, include: bool) -> Self {
163 self.include_training_data = include;
164 self
165 }
166
167 /// Sets whether to include feature importance scores.
168 pub fn with_feature_importance(mut self, include: bool) -> Self {
169 self.include_feature_importance = include;
170 self
171 }
172
173 /// Replaces all metadata with provided map.
174 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
175 self.metadata = metadata;
176 self
177 }
178
179 /// Adds a single metadata key-value pair.
180 pub fn add_metadata(mut self, key: &str, value: &str) -> Self {
181 self.metadata.insert(key.to_string(), value.to_string());
182 self
183 }
184}
185
186/// Configuration options for loading models.
187///
188/// Controls validation and integrity checks during model loading.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct LoadOptions {
191 /// Whether to verify model integrity (default: true).
192 pub verify_integrity: bool,
193 /// Whether to enforce strict version checking (default: false).
194 pub strict_version_check: bool,
195}
196
197impl Default for LoadOptions {
198 fn default() -> Self {
199 Self {
200 verify_integrity: true,
201 strict_version_check: false,
202 }
203 }
204}
205
206impl LoadOptions {
207 /// Creates default LoadOptions.
208 pub fn new() -> Self {
209 Self::default()
210 }
211
212 /// Enables or disables integrity verification.
213 pub fn with_verify_integrity(mut self, verify: bool) -> Self {
214 self.verify_integrity = verify;
215 self
216 }
217
218 /// Enables or disables strict version checking.
219 pub fn with_strict_version_check(mut self, strict: bool) -> Self {
220 self.strict_version_check = strict;
221 self
222 }
223}
224
225/// Main model I/O handler for saving and loading gradient boosted models.
226///
227/// `ModelIO` provides a high-level interface for model persistence operations.
228/// It can be configured with custom save/load options and handles all aspects
229/// of model serialization, metadata management, and validation.
230pub struct ModelIO {
231 save_options: SaveOptions,
232 load_options: LoadOptions,
233}
234
235impl ModelIO {
236 /// Creates a new ModelIO with default options.
237 ///
238 /// # Returns
239 /// `Ok(ModelIO)` with default save and load options.
240 pub fn new() -> ModelIOResult<Self> {
241 Ok(Self {
242 save_options: SaveOptions::default(),
243 load_options: LoadOptions::default(),
244 })
245 }
246
247 /// Creates a ModelIO with custom save options.
248 pub fn with_save_options(save_options: SaveOptions) -> Self {
249 Self {
250 save_options,
251 load_options: LoadOptions::default(),
252 }
253 }
254
255 /// Creates a ModelIO with custom load options.
256 pub fn with_load_options(load_options: LoadOptions) -> Self {
257 Self {
258 save_options: SaveOptions::default(),
259 load_options,
260 }
261 }
262
263 /// Saves a gradient boosting model to file.
264 ///
265 /// The model is saved as a directory containing:
266 /// - `{model_name}.json`: The main model file
267 /// - `{model_name}_metadata.json`: Model metadata and hyperparameters
268 /// - `{model_name}_feature_importance.json`: Feature importance scores (if enabled)
269 /// - Additional files for training data (future feature)
270 ///
271 /// # Parameters
272 /// - `model`: Trained gradient boosting model
273 /// - `path`: Directory path where model will be saved
274 /// - `model_name`: Base name for model files
275 ///
276 /// # Errors
277 /// - `ModelIOError::ModelError` if model is untrained
278 /// - `ModelIOError::IoError` if directory creation or file writing fails
279 /// - `ModelIOError::SerializationError` if serialization fails
280 pub fn save_model<P: AsRef<Path>>(
281 &self,
282 model: &GradientBooster,
283 path: P,
284 model_name: &str,
285 ) -> ModelIOResult<()> {
286 save_model(model, path, model_name, &self.save_options)
287 }
288
289 /// Loads a gradient boosting model from file or directory.
290 ///
291 /// # Path Resolution
292 ///
293 /// - If `path` is a file: Loads that file directly
294 /// - If `path` is a directory: Looks for `gbrt_model.json` inside the directory
295 ///
296 /// # Parameters
297 /// - `path`: Path to model file or directory
298 ///
299 /// # Errors
300 /// - `ModelIOError::ModelNotFound` if model file doesn't exist
301 /// - `ModelIOError::SerializationError` if deserialization fails
302 /// - `ModelIOError::InvalidModelFile` if integrity check fails
303 /// - `ModelIOError::VersionMismatch` if version check fails and strict mode enabled
304 pub fn load_model<P: AsRef<Path>>(&self, path: P) -> ModelIOResult<GradientBooster> {
305 let path = path.as_ref();
306
307 // ✅ If path is a directory, look for gbrt_model.json inside
308 let model_file = if path.is_dir() {
309 path.join("gbrt_model.json")
310 } else {
311 path.to_path_buf()
312 };
313
314 if !model_file.exists() {
315 return Err(ModelIOError::ModelNotFound(
316 format!("Model file not found: {}", model_file.display())
317 ));
318 }
319
320 let serializer = ModelSerializer::new();
321 let (model, metadata) = serializer.load_model(&model_file)?;
322
323 // ✅ Verify integrity if requested
324 if self.load_options.verify_integrity {
325 verify_model_integrity(&model, &metadata)?;
326 }
327
328 // ✅ Version check
329 if self.load_options.strict_version_check {
330 metadata.validate_version()
331 .map_err(|e| {
332 match e {
333 SerializationError::VersionMismatch { expected, actual } => {
334 ModelIOError::VersionMismatch {
335 expected: expected.clone(),
336 actual: actual.clone(),
337 }
338 }
339 other => ModelIOError::SerializationError(other),
340 }
341 })?;
342 }
343
344 Ok(model)
345 }
346
347 /// Exports a model to a different format.
348 ///
349 /// This allows converting between Binary and JSON formats without retraining.
350 ///
351 /// # Parameters
352 /// - `model`: Model to export
353 /// - `path`: Output file path
354 /// - `format`: Target format (Binary or JSON)
355 ///
356 /// # Errors
357 /// - `ModelIOError::SerializationError` if conversion fails
358 /// - `ModelIOError::IoError` if file writing fails
359 pub fn export_model<P: AsRef<Path>>(
360 &self,
361 model: &GradientBooster, // Changed from GBRT to GradientBooster
362 path: P,
363 format: ModelFormat,
364 ) -> ModelIOResult<()> {
365 export_model(model, path, format)
366 }
367
368 /// Imports a model from a specific format.
369 ///
370 /// # Parameters
371 /// - `path`: Path to model file
372 /// - `format`: Format of the model file
373 ///
374 /// # Returns
375 /// Loaded model
376 ///
377 /// # Errors
378 /// - `ModelIOError::SerializationError` if deserialization fails
379 /// - `ModelIOError::IoError` if file reading fails
380 pub fn import_model<P: AsRef<Path>>(&self, path: P, format: ModelFormat) -> ModelIOResult<GradientBooster> { // Changed return type
381 import_model(path, format)
382 }
383
384 /// Lists all saved model files in a directory.
385 ///
386 /// Searches for files with extensions: `.bin`, `.json`, `.gz`
387 ///
388 /// # Parameters
389 /// - `directory`: Directory to search
390 ///
391 /// # Returns
392 /// Sorted vector of model file paths
393 ///
394 /// # Errors
395 /// - `ModelIOError::IoError` if directory cannot be read
396 pub fn list_saved_models<P: AsRef<Path>>(&self, directory: P) -> ModelIOResult<Vec<PathBuf>> {
397 list_saved_models(directory)
398 }
399
400 /// Retrieves metadata for a saved model.
401 ///
402 /// # Parameters
403 /// - `path`: Path to model file or directory
404 ///
405 /// # Returns
406 /// Model metadata including configuration, version, and parameters
407 ///
408 /// # Errors
409 /// - `ModelIOError::InvalidModelFile` if metadata file is missing or corrupted
410 /// - `ModelIOError::SerializationError` if metadata cannot be parsed
411 pub fn get_model_info<P: AsRef<Path>>(&self, path: P) -> ModelIOResult<ModelMetadata> {
412 get_model_info(path)
413 }
414
415 /// Validates a model file for integrity and compatibility.
416 ///
417 /// Performs basic checks:
418 /// - Model is trained
419 /// - Number of trees matches metadata
420 /// - Feature dimensions are consistent
421 /// - Model type is correct
422 ///
423 /// # Parameters
424 /// - `path`: Path to model file
425 ///
426 /// # Returns
427 /// `Ok(true)` if model is valid, `Err` with details otherwise
428 ///
429 /// # Errors
430 /// - `ModelIOError::InvalidModelFile` if any validation check fails
431 pub fn validate_model_file<P: AsRef<Path>>(&self, path: P) -> ModelIOResult<bool> {
432 validate_model_file(path)
433 }
434}
435
436impl Default for ModelIO {
437 fn default() -> Self {
438 Self::new().unwrap()
439 }
440}
441
442// ============================================================================
443// Standalone Model I/O Functions
444// ============================================================================
445
446/// Saves a gradient boosting model to a directory.
447///
448/// Creates a directory structure with:
449/// - `{model_name}.json`: Main model file with tree ensembles
450/// - `{model_name}_metadata.json`: Hyperparameters and configuration
451/// - `{model_name}_feature_importance.json`: Per-feature importance scores (if enabled)
452///
453/// # Parameters
454/// - `model`: Trained gradient boosting model
455/// - `path`: Directory path for saving
456/// - `model_name`: Base name for output files
457/// - `options`: Save configuration (format, compression, metadata, etc.)
458///
459 /// # Errors
460/// - `ModelIOError::ModelError` if model is untrained
461/// - `ModelIOError::IoError` if directory or file creation fails
462/// - `ModelIOError::SerializationError` if serialization fails
463pub fn save_model<P: AsRef<Path>>(
464 model: &GradientBooster, // Changed from GBRT to GradientBooster
465 path: P,
466 model_name: &str,
467 options: &SaveOptions,
468) -> ModelIOResult<()> {
469 if !model.is_trained() {
470 return Err(ModelIOError::ModelError("Cannot save untrained model".to_string()));
471 }
472
473 let path = path.as_ref();
474 create_dir_all(path.parent().unwrap_or(Path::new(".")))?;
475
476 // Create metadata
477 let mut metadata = ModelMetadata::new(
478 "GBRT",
479 model.feature_importance().len(),
480 model.n_trees(),
481 );
482
483 // Add configuration to metadata
484 let config = model.config();
485 metadata = metadata
486 .with_parameter("n_estimators", &config.n_estimators.to_string())
487 .with_parameter("learning_rate", &config.learning_rate.to_string())
488 .with_parameter("loss", &config.loss.to_string())
489 .with_parameter("subsample", &config.subsample.to_string());
490
491 // Add custom metadata
492 for (key, value) in &options.metadata {
493 metadata = metadata.with_parameter(key, value);
494 }
495
496 // Choose serialization format
497 let serialization_format = match options.format {
498 ModelFormat::Binary => SerializationFormat::Bincode,
499 ModelFormat::Json => SerializationFormat::Json,
500 };
501
502 let serializer = ModelSerializer::new()
503 .with_format(serialization_format)
504 .with_compression(options.compression);
505
506 // Create file path
507 let file_path = path.join(format!("{}.json", model_name));
508
509 serializer.save_model(model, &metadata, &file_path)?;
510
511 // Save additional information if requested
512 if options.include_feature_importance {
513 save_feature_importance(model, path, model_name)?;
514 }
515
516
517 // Save metadata
518 save_metadata(&metadata, path, model_name)?; // Add this line
519
520 Ok(())
521}
522
523/// Loads a gradient boosting model from file.
524///
525/// # Parameters
526/// - `path`: Path to model file
527/// - `options`: Load configuration (verification, version checking)
528///
529/// # Returns
530/// Loaded model with validated metadata
531///
532/// # Errors
533/// - `ModelIOError::ModelNotFound` if file doesn't exist
534/// - `ModelIOError::SerializationError` if deserialization fails
535/// - `ModelIOError::InvalidModelFile` if integrity check fails
536/// - `ModelIOError::VersionMismatch` if strict version check fails
537pub fn load_model<P: AsRef<Path>>(
538 path: P,
539 options: &LoadOptions,
540) -> ModelIOResult<GradientBooster> { // Changed return type
541 let path = path.as_ref();
542
543 if !path.exists() {
544 return Err(ModelIOError::ModelNotFound(path.to_string_lossy().to_string()));
545 }
546
547 let serializer = ModelSerializer::new();
548 let (model, metadata) = serializer.load_model(path)?;
549
550 // Verify integrity if requested
551 if options.verify_integrity {
552 verify_model_integrity(&model, &metadata)?;
553 }
554
555 // Version check
556 if options.strict_version_check {
557 metadata.validate_version()
558 .map_err(|e| {
559 match e {
560 SerializationError::VersionMismatch { expected, actual } => {
561 ModelIOError::VersionMismatch {
562 expected: expected.clone(),
563 actual: actual.clone(),
564 }
565 }
566 other => ModelIOError::SerializationError(other),
567 }
568 })?;
569 }
570
571 Ok(model)
572}
573
574
575/// Exports a model to a different format.
576///
577/// Allows format conversion without retraining. Useful for sharing models
578/// or switching between compact binary and readable JSON formats.
579///
580/// # Parameters
581/// - `model`: Model to export
582/// - `path`: Output file path
583/// - `format`: Target format
584///
585/// # Errors
586/// - `ModelIOError::SerializationError` if conversion fails
587/// - `ModelIOError::IoError` if file cannot be written
588pub fn export_model<P: AsRef<Path>>(
589 model: &GradientBooster, // Changed from GBRT to GradientBooster
590 path: P,
591 format: ModelFormat,
592) -> ModelIOResult<()> {
593 match format {
594 ModelFormat::Json => {
595 // Export as human-readable JSON
596 let json = serde_json::to_string_pretty(model)
597 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
598 std::fs::write(path, json)?;
599 Ok(())
600 }
601 ModelFormat::Binary => {
602 // Use standard save with binary format
603 let options = SaveOptions::new().with_format(ModelFormat::Binary);
604 save_model(model, path, "export", &options)
605 }
606 }
607}
608
609/// Imports a model from a specific format.
610///
611/// # Parameters
612/// - `path`: Path to model file
613/// - `format`: Expected format of the file
614///
615/// # Returns
616/// Loaded model
617///
618/// # Errors
619/// - `ModelIOError::SerializationError` if deserialization fails
620/// - `ModelIOError::IoError` if file cannot be read
621/// - `ModelIOError::InvalidModelFile` if format doesn't match content
622pub fn import_model<P: AsRef<Path>>(
623 path: P,
624 format: ModelFormat,
625) -> ModelIOResult<GradientBooster> { // Changed return type
626 match format {
627 ModelFormat::Json => {
628 let file = File::open(path)?;
629 let reader = BufReader::new(file);
630 let model: GradientBooster = serde_json::from_reader(reader) // Changed type
631 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
632 Ok(model)
633 }
634 ModelFormat::Binary => {
635 load_model(path, &LoadOptions::default())
636 }
637 }
638}
639
640/// Lists all model files in a directory.
641///
642/// Searches for files with extensions: `.bin`, `.json`, `.gz`
643///
644/// # Parameters
645/// - `directory`: Directory to search
646///
647/// # Returns
648/// Sorted vector of absolute paths to model files
649///
650/// # Errors
651/// - `ModelIOError::IoError` if directory cannot be read
652pub fn list_saved_models<P: AsRef<Path>>(directory: P) -> ModelIOResult<Vec<PathBuf>> {
653 let directory = directory.as_ref();
654 let mut model_files = Vec::new();
655
656 if !directory.exists() {
657 return Ok(model_files);
658 }
659
660 for entry in read_dir(directory)? {
661 let entry = entry?;
662 let path = entry.path();
663
664 if path.is_file() {
665 if let Some(ext) = path.extension() {
666 if ext == "bin" || ext == "json" || ext == "gz" {
667 model_files.push(path);
668 }
669 }
670 }
671 }
672
673 model_files.sort();
674 Ok(model_files)
675}
676
677/// Retrieves metadata for a saved model.
678///
679/// Metadata includes model type, version, hyperparameters, and custom parameters.
680///
681/// # Parameters
682/// - `path`: Path to model file or directory
683///
684/// # Returns
685/// Deserialized ModelMetadata
686///
687/// # Errors
688/// - `ModelIOError::InvalidModelFile` if metadata file is missing
689/// - `ModelIOError::SerializationError` if JSON parsing fails
690///
691/// # Path Resolution
692/// - If `path` is a file like `model.json`, looks for `model_metadata.json`
693/// - If `path` is a directory, looks for `gbrt_model_metadata.json`
694pub fn get_model_info<P: AsRef<Path>>(path: P) -> ModelIOResult<ModelMetadata> {
695 let path = path.as_ref();
696
697 // Determine metadata file path
698 let metadata_path = if path.is_file() {
699 // If given a file like "gbrt_model.json", look for "gbrt_model_metadata.json"
700 if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
701 path.with_file_name(format!("{}_metadata.json", stem))
702 } else {
703 return Err(ModelIOError::InvalidModelFile("Invalid file name".to_string()));
704 }
705 } else {
706 // If given a directory, look for "gbrt_model_metadata.json"
707 path.join("gbrt_model_metadata.json")
708 };
709
710 if metadata_path.exists() {
711 let file = File::open(metadata_path)?;
712 let reader = BufReader::new(file);
713 let metadata: ModelMetadata = serde_json::from_reader(reader)
714 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
715 Ok(metadata)
716 } else {
717 Err(ModelIOError::InvalidModelFile(
718 format!("Metadata file not found: {}", metadata_path.display())
719 ))
720 }
721}
722
723/// Validates a model file by checking integrity and consistency.
724///
725/// Performs comprehensive validation:
726/// - Model is trained
727/// - Tree count matches metadata
728/// - Feature dimensions match metadata
729/// - Model type is correct
730///
731/// # Parameters
732/// - `path`: Path to model file
733///
734/// # Returns
735/// `Ok(true)` if all checks pass
736///
737/// # Errors
738/// `Err` with specific violation details if any check fails
739pub fn validate_model_file<P: AsRef<Path>>(path: P) -> ModelIOResult<bool> {
740 let metadata = get_model_info(path)?;
741
742 // Basic validation checks
743 if metadata.num_features == 0 {
744 return Err(ModelIOError::InvalidModelFile("Invalid number of features".to_string()));
745 }
746
747 if metadata.model_type != "GBRT" {
748 return Err(ModelIOError::InvalidModelFile(format!("Invalid model type: {}", metadata.model_type)));
749 }
750
751 Ok(true)
752}
753
754// ============================================================================
755// Helper Functions
756// ============================================================================
757
758/// Saves feature importance scores to a JSON file.
759///
760/// Creates a sorted list of features by importance score.
761fn save_feature_importance<P: AsRef<Path>>(
762 model: &GradientBooster, // Changed from GBRT to GradientBooster
763 path: P,
764 model_name: &str,
765) -> ModelIOResult<()> {
766 let importance = model.feature_importance();
767 let feature_names: Vec<String> = (0..importance.len())
768 .map(|i| format!("feature_{}", i))
769 .collect();
770
771 let mut importance_data = Vec::new();
772 for (i, (&importance, name)) in importance.iter().zip(feature_names.iter()).enumerate() {
773 importance_data.push(FeatureImportance {
774 feature_index: i,
775 feature_name: name.clone(),
776 importance,
777 });
778 }
779
780 // Sort by importance (descending)
781 importance_data.sort_by(|a, b| b.importance.partial_cmp(&a.importance).unwrap());
782
783 let file_path = path.as_ref().join(format!("{}_feature_importance.json", model_name));
784 let file = File::create(file_path)?;
785 let writer = BufWriter::new(file);
786
787 serde_json::to_writer_pretty(writer, &importance_data)
788 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
789
790 Ok(())
791}
792
793/// Saves model metadata to a JSON file.
794fn save_metadata<P: AsRef<Path>>(
795 metadata: &ModelMetadata,
796 path: P,
797 model_name: &str,
798) -> ModelIOResult<()> {
799 let file_path = path.as_ref().join(format!("{}_metadata.json", model_name));
800 let file = File::create(file_path)?;
801 let writer = BufWriter::new(file);
802
803 serde_json::to_writer_pretty(writer, metadata)
804 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
805
806 Ok(())
807}
808
809/// Verifies model integrity by comparing against metadata.
810///
811/// Checks:
812/// - Model is trained
813/// - Tree count matches metadata
814/// - Feature dimension matches metadata
815fn verify_model_integrity(model: &GradientBooster, metadata: &ModelMetadata) -> ModelIOResult<()> { // Changed parameter type
816 // Check if model is trained
817 if !model.is_trained() {
818 return Err(ModelIOError::InvalidModelFile("Model appears to be untrained".to_string()));
819 }
820
821 // Check number of trees
822 if model.n_trees() != metadata.num_trees {
823 return Err(ModelIOError::InvalidModelFile(
824 format!("Tree count mismatch: expected {}, got {}", metadata.num_trees, model.n_trees())
825 ));
826 }
827
828 // Check feature dimensions
829 if model.feature_importance().len() != metadata.num_features {
830 return Err(ModelIOError::InvalidModelFile(
831 format!("Feature dimension mismatch: expected {}, got {}", metadata.num_features, model.feature_importance().len())
832 ));
833 }
834
835 Ok(())
836}
837
838/// Internal representation of feature importance for JSON export.
839#[derive(Serialize, Deserialize)]
840struct FeatureImportance {
841 feature_index: usize,
842 feature_name: String,
843 importance: f64,
844}
845
846
847// ============================================================================
848// Model Registry for Managing Multiple Models
849// ============================================================================
850
851/// Centralized registry for managing multiple models by name.
852///
853/// The registry stores a JSON map of model names to file paths, enabling
854/// easy lookup and organization of models.
855///
856/// # Registry Structure
857///
858/// The registry is stored as `models.json` in the specified directory:
859/// ```json
860/// {
861/// "iris_classifier": "models/iris.json",
862/// "housing_predictor": "models/housing.json"
863/// }
864/// ```
865pub struct ModelRegistry {
866 registry_path: PathBuf,
867}
868
869impl ModelRegistry {
870 /// Creates a new model registry in the specified directory.
871 ///
872 /// # Parameters
873 /// - `path`: Directory where registry file (`models.json`) will be stored
874 pub fn new<P: AsRef<Path>>(path: P) -> Self {
875 Self {
876 registry_path: path.as_ref().to_path_buf(),
877 }
878 }
879
880 /// Registers a model in the registry.
881 ///
882 /// Creates or updates the `models.json` file with the model name and path.
883 ///
884 /// # Parameters
885 /// - `model_name`: Unique identifier for the model
886 /// - `model_path`: Path to the saved model file
887 ///
888 /// # Errors
889 /// - `ModelIOError::SerializationError` if registry JSON cannot be written
890 /// - `ModelIOError::IoError` if registry directory cannot be created
891 pub fn register_model(&self, model_name: &str, model_path: &Path) -> ModelIOResult<()> {
892 let registry_file = self.registry_path.join("models.json");
893 let mut registry: HashMap<String, String> = if registry_file.exists() {
894 let file = File::open(®istry_file)?;
895 let reader = BufReader::new(file);
896 serde_json::from_reader(reader).unwrap_or_default()
897 } else {
898 HashMap::new()
899 };
900
901 registry.insert(model_name.to_string(), model_path.to_string_lossy().to_string());
902
903 let file = File::create(registry_file)?;
904 let writer = BufWriter::new(file);
905 serde_json::to_writer_pretty(writer, ®istry)
906 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
907
908 Ok(())
909 }
910
911 /// Retrieves the file path for a registered model.
912 ///
913 /// # Parameters
914 /// - `model_name`: Name of the registered model
915 ///
916 /// # Returns
917 /// Path to the model file
918 ///
919 /// # Errors
920 /// - `ModelIOError::ModelNotFound` if model name is not in registry
921 /// - `ModelIOError::SerializationError` if registry JSON cannot be read
922 pub fn get_model_path(&self, model_name: &str) -> ModelIOResult<PathBuf> {
923 let registry_file = self.registry_path.join("models.json");
924 if !registry_file.exists() {
925 return Err(ModelIOError::ModelNotFound(format!("Model registry not found: {}", model_name)));
926 }
927
928 let file = File::open(registry_file)?;
929 let reader = BufReader::new(file);
930 let registry: HashMap<String, String> = serde_json::from_reader(reader)
931 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
932
933 registry
934 .get(model_name)
935 .map(|path| PathBuf::from(path))
936 .ok_or_else(|| ModelIOError::ModelNotFound(model_name.to_string()))
937 }
938
939 /// Lists all registered model names.
940 ///
941 /// # Returns
942 /// Vector of model names sorted alphabetically
943 ///
944 /// # Errors
945 /// - `ModelIOError::SerializationError` if registry file cannot be read
946 pub fn list_registered_models(&self) -> ModelIOResult<Vec<String>> {
947 let registry_file = self.registry_path.join("models.json");
948 if !registry_file.exists() {
949 return Ok(Vec::new());
950 }
951
952 let file = File::open(registry_file)?;
953 let reader = BufReader::new(file);
954 let registry: HashMap<String, String> = serde_json::from_reader(reader)
955 .map_err(|e| ModelIOError::SerializationError(e.into()))?;
956
957 Ok(registry.keys().cloned().collect())
958 }
959}
960