gbrt_rs/utils/
serialization.rs

1#![allow(dead_code)]
2
3//! Model serialization and persistence for gradient boosting models.
4//!
5//! This module provides robust utilities for saving and loading trained models
6//! with metadata, versioning, and compression support. It supports multiple
7//! serialization formats and ensures model compatibility across different
8//! versions of the library.
9//!
10//! # Features
11//!
12//! - **Multiple formats**: JSON (human-readable) and Bincode (compact binary)
13//! - **Compression**: Optional gzip compression to reduce file size
14//! - **Versioning**: Automatic model version tracking and validation
15//! - **Metadata**: Rich model information (timestamps, hyperparameters, etc.)
16//! - **Validation**: Verify model files without full deserialization
17
18use serde::{Serialize, de::DeserializeOwned};
19use std::fs::{File, create_dir_all};
20use std::io::{BufReader, BufWriter, Read, Write};
21use std::path::{Path, PathBuf};
22use thiserror::Error;
23use chrono::{DateTime, Utc};
24
25/// Errors that can occur during model serialization/deserialization.
26///
27/// Covers IO errors, format-specific errors, version mismatches, and
28/// data integrity issues.
29#[derive(Error, Debug)]
30pub enum SerializationError {
31    /// Underlying filesystem or IO operation failed.
32    #[error("IO error: {0}")]
33    IoError(#[from] std::io::Error),
34    
35    /// JSON serialization/deserialization failed.
36    #[error("JSON serialization error: {0}")]
37    JsonError(#[from] serde_json::Error),
38    
39    /// Bincode encoding failed.
40    #[error("Bincode encoding error: {0}")]
41    BincodeEncodeError(#[from] bincode::error::EncodeError),
42    
43    /// Bincode decoding failed.
44    #[error("Bincode decoding error: {0}")]
45    BincodeDecodeError(#[from] bincode::error::DecodeError),
46    
47    /// Model file is corrupted or has invalid structure.
48    #[error("Invalid model file: {0}")]
49    InvalidModelFile(String),
50    
51    /// Model version doesn't match current library version.
52    #[error("Version mismatch: expected {expected}, got {actual}")]
53    VersionMismatch { expected: String, actual: String },
54   
55    /// General serialization failure.
56    #[error("Serialization failed: {0}")]
57    SerializationFailed(String),
58}
59
60/// Result type for serialization operations.
61pub type SerializationResult<T> = std::result::Result<T, SerializationError>;
62
63/// Metadata describing a serialized model.
64///
65/// Tracks versioning, creation info, and model characteristics to ensure
66/// reproducibility and compatibility.
67#[derive(Debug, Clone, Serialize, serde::Deserialize)]
68pub struct ModelMetadata {
69    /// Semantic version of the library that created the model.
70    pub version: String,
71    /// Timestamp when the model was first saved.
72    pub created_at: DateTime<Utc>,
73    /// Timestamp when the model was last updated.
74    pub updated_at: DateTime<Utc>,
75    /// Type of model (e.g., "gradient_booster", "decision_tree").
76    pub model_type: String,
77    /// Number of features the model expects.
78    pub num_features: usize,
79    /// Number of trees in the ensemble (if applicable).
80    pub num_trees: usize,
81    /// Model hyperparameters and configuration.
82    pub parameters: std::collections::HashMap<String, String>,
83}
84
85impl ModelMetadata {
86    /// Creates new metadata for a model.
87    ///
88    /// Automatically sets version to current crate version and timestamps to now.
89    ///
90    /// # Arguments
91    ///
92    /// * `model_type` - Identifier for the model type
93    /// * `num_features` - Number of input features
94    /// * `num_trees` - Number of trees (0 for single models)
95    pub fn new(model_type: &str, num_features: usize, num_trees: usize) -> Self {
96        let now = Utc::now();
97        Self {
98            version: env!("CARGO_PKG_VERSION").to_string(),
99            created_at: now,
100            updated_at: now,
101            model_type: model_type.to_string(),
102            num_features,
103            num_trees,
104            parameters: std::collections::HashMap::new(),
105        }
106    }
107    
108    /// Adds a hyperparameter to the metadata (builder pattern).
109    ///
110    /// # Arguments
111    ///
112    /// * `key` - Parameter name
113    /// * `value` - Parameter value (converted to string)
114    ///
115    /// # Returns
116    ///
117    /// Self with updated parameters
118    pub fn with_parameter(mut self, key: &str, value: &str) -> Self {
119        self.parameters.insert(key.to_string(), value.to_string());
120        self
121    }
122    
123    /// Updates the `updated_at` timestamp to current time.
124    pub fn update(&mut self) {
125        self.updated_at = Utc::now();
126    }
127   
128    /// Validates that the model version matches the current library version.
129    ///
130    /// # Returns
131    ///
132    /// `Ok(())` if versions match, `Err(SerializationError::VersionMismatch)` otherwise
133    pub fn validate_version(&self) -> SerializationResult<()> {
134        let current_version = env!("CARGO_PKG_VERSION");
135        if self.version != current_version {
136            return Err(SerializationError::VersionMismatch {
137                expected: current_version.to_string(),
138                actual: self.version.clone(),
139            });
140        }
141        Ok(())
142    }
143}
144
145/// Main serializer with configurable format and compression.
146///
147/// Provides a unified interface for saving and loading models with
148/// consistent error handling and metadata management.
149pub struct ModelSerializer {
150    /// Whether to compress the output with gzip.
151    compression: bool,
152    /// Serialization format to use.
153    format: SerializationFormat,
154}
155
156/// Supported serialization formats.
157///
158/// Each format has different tradeoffs between size, speed, and readability.
159#[derive(Debug, Clone, Copy, PartialEq)]
160pub enum SerializationFormat {
161    /// JSON format (human-readable, good interoperability).
162    Json,
163    /// Bincode format (compact binary, fastest serialization).
164    Bincode,
165}
166
167impl Default for ModelSerializer {
168    fn default() -> Self {
169        Self {
170            compression: true,
171            format: SerializationFormat::Bincode,
172        }
173    }
174}
175
176impl ModelSerializer {
177    /// Creates a new serializer with default settings (Bincode + compression).
178    pub fn new() -> Self {
179        Self::default()
180    }
181    
182    /// Enables or disables gzip compression.
183    ///
184    /// # Arguments
185    ///
186    /// * `compression` - `true` to compress (default), `false` for uncompressed
187    ///
188    /// # Returns
189    ///
190    /// Self with updated compression setting (builder pattern)
191    pub fn with_compression(mut self, compression: bool) -> Self {
192        self.compression = compression;
193        self
194    }
195   
196    /// Sets the serialization format.
197    ///
198    /// # Arguments
199    ///
200    /// * `format` - [`SerializationFormat`] to use
201    ///
202    /// # Returns
203    ///
204    /// Self with updated format (builder pattern)
205    pub fn with_format(mut self, format: SerializationFormat) -> Self {
206        self.format = format;
207        self
208    }
209    
210    /// Saves a model to disk with metadata.
211    ///
212    /// Automatically creates parent directories if they don't exist.
213    ///
214    /// # Arguments
215    ///
216    /// * `model` - Model to serialize (must implement `Serialize`)
217    /// * `metadata` - Model metadata
218    /// * `path` - Destination file path
219    ///
220    /// # Returns
221    ///
222    /// `Ok(())` on success, [`SerializationError`] on failure 
223    pub fn save_model<T: Serialize>(
224        &self,
225        model: &T,
226        metadata: &ModelMetadata,
227        path: &Path,
228    ) -> SerializationResult<()> {
229        save_model(model, metadata, path, self.format, self.compression)
230    }
231    
232    /// Loads a model from disk with metadata.
233    ///
234    /// Automatically detects format and compression from file content.
235    ///
236    /// # Arguments
237    ///
238    /// * `path` - Path to model file
239    ///
240    /// # Returns
241    ///
242    /// Tuple of `(deserialized_model, metadata)`
243    ///
244    /// # Errors
245    ///
246    /// Returns error if file doesn't exist, format is invalid, or version mismatch 
247    pub fn load_model<T: DeserializeOwned>(&self, path: &Path) -> SerializationResult<(T, ModelMetadata)> {
248        load_model(path)
249    }
250    
251    /// Validates a model file without loading the full model.
252    ///
253    /// Useful for checking version compatibility or file integrity.
254    ///
255    /// # Arguments
256    ///
257    /// * `path` - Path to model file
258    ///
259    /// # Returns
260    ///
261    /// Metadata if validation succeeds
262    pub fn validate_model_file(&self, path: &Path) -> SerializationResult<ModelMetadata> {
263        validate_model_file(path)
264    }
265}
266
267
268// Standalone serialization functions
269
270/// Saves model to file with explicit format and compression settings.
271///
272/// This is a lower-level function; prefer [`ModelSerializer::save_model`] for most cases.
273///
274/// # Arguments
275///
276/// * `model` - Model to serialize
277/// * `metadata` - Model metadata
278/// * `path` - Destination file path
279/// * `format` - Serialization format
280/// * `compression` - Whether to compress
281///
282/// # Returns
283///
284/// `Ok(())` on success
285pub fn save_model<T: Serialize>(
286    model: &T,
287    metadata: &ModelMetadata,
288    path: &Path,
289    format: SerializationFormat,
290    compression: bool,
291) -> SerializationResult<()> {
292    // Create directory if it doesn't exist
293    if let Some(parent) = path.parent() {
294        create_dir_all(parent)?;
295    }
296    
297    match format {
298        SerializationFormat::Json => {
299            let file = File::create(path)?;
300            let writer = BufWriter::new(file);
301            
302            if compression {
303                let mut gz_writer = flate2::write::GzEncoder::new(writer, flate2::Compression::default());
304                serde_json::to_writer(&mut gz_writer, &(model, metadata))?;
305                gz_writer.finish()?;
306            } else {
307                serde_json::to_writer(writer, &(model, metadata))?;
308            }
309        }
310        SerializationFormat::Bincode => {
311            let file = File::create(path)?;
312            let mut writer = BufWriter::new(file);
313            
314            // Use bincode's serde integration with default configuration
315            let config = bincode::config::standard();
316            
317            if compression {
318                let mut gz_writer = flate2::write::GzEncoder::new(writer, flate2::Compression::default());
319                bincode::serde::encode_into_std_write(&(model, metadata), &mut gz_writer, config)?;
320                gz_writer.finish()?;
321            } else {
322                bincode::serde::encode_into_std_write(&(model, metadata), &mut writer, config)?;
323            }
324        }
325    }
326    
327    Ok(())
328}
329
330/// Loads model from file (auto-detects format and compression).
331///
332/// Automatically infers format from file extension and compression from magic bytes.
333///
334/// # Arguments
335///
336/// * `path` - Path to model file
337///
338/// # Returns
339///
340/// Tuple of `(model, metadata)`
341///
342/// # Errors
343///
344/// Returns error if file is corrupted, format is unknown, or version mismatch
345pub fn load_model<T: DeserializeOwned>(path: &Path) -> SerializationResult<(T, ModelMetadata)> {
346    let file = File::open(path)?;
347    let mut reader = BufReader::new(file);
348    
349    // Try to detect format and compression
350    let (model, metadata) = if path.extension().map_or(false, |ext| ext == "json") {
351        // JSON format
352        if is_gzip_compressed(path)? {
353            let gz_reader = flate2::read::GzDecoder::new(reader);
354            serde_json::from_reader(gz_reader)?
355        } else {
356            serde_json::from_reader(reader)?
357        }
358    } else if path.extension().map_or(false, |ext| ext == "bin") {
359        // Bincode format
360        let config = bincode::config::standard();
361        
362        if is_gzip_compressed(path)? {
363            let mut gz_reader = flate2::read::GzDecoder::new(reader);
364            bincode::serde::decode_from_std_read(&mut gz_reader, config)?
365        } else {
366            bincode::serde::decode_from_std_read(&mut reader, config)?
367        }
368    } else {
369        // Try to auto-detect
370        if let Ok(result) = load_as_bincode(&mut reader, path) {
371            result
372        } else if let Ok(result) = load_as_json(&mut reader, path) {
373            result
374        } else {
375            return Err(SerializationError::InvalidModelFile(
376                "Could not determine file format".to_string()
377            ));
378        }
379    };
380    
381    // Validate metadata
382    metadata.validate_version()?;
383    
384    Ok((model, metadata))
385}
386
387/// Validates a model file without loading the complete model.
388///
389/// For JSON files, this reads only the metadata portion. For bincode files,
390/// full deserialization may be required.
391///
392/// # Arguments
393///
394/// * `path` - Path to model file
395///
396/// # Returns
397///
398/// Metadata if file is valid
399pub fn validate_model_file(path: &Path) -> SerializationResult<ModelMetadata> {
400    let file = File::open(path)?;
401    let reader = BufReader::new(file);
402    
403    // Try to read just the metadata
404    if path.extension().map_or(false, |ext| ext == "json") {
405        if is_gzip_compressed(path)? {
406            let gz_reader = flate2::read::GzDecoder::new(reader);
407            let value: serde_json::Value = serde_json::from_reader(gz_reader)?;
408            extract_metadata(&value)
409        } else {
410            let value: serde_json::Value = serde_json::from_reader(reader)?;
411            extract_metadata(&value)
412        }
413    } else {
414        // For bincode, we need a different approach
415        load_model::<serde::de::IgnoredAny>(path)?;
416        Err(SerializationError::SerializationFailed(
417            "Metadata validation for bincode requires full deserialization".to_string()
418        ))
419    }
420}
421
422/// Serializes a value to a pretty-printed JSON string.
423///
424/// Useful for debugging or human-readable output.
425///
426/// # Arguments
427///
428/// * `value` - Value to serialize
429///
430/// # Returns
431///
432/// JSON string
433pub fn serialize_to_json<T: Serialize>(value: &T) -> SerializationResult<String> {
434    Ok(serde_json::to_string_pretty(value)?)
435}
436
437/// Deserializes from a JSON string.
438///
439/// # Arguments
440///
441/// * `json` - JSON string
442///
443/// # Returns
444///
445/// Deserialized value
446pub fn deserialize_from_json<T: DeserializeOwned>(json: &str) -> SerializationResult<T> {
447    Ok(serde_json::from_str(json)?)
448}
449
450/// Serializes a value to bincode bytes.
451///
452/// # Arguments
453///
454/// * `value` - Value to serialize
455///
456/// # Returns
457///
458/// Binary representation
459pub fn serialize_to_bincode<T: Serialize>(value: &T) -> SerializationResult<Vec<u8>> {
460    let config = bincode::config::standard();
461    let mut bytes = Vec::new();
462    bincode::serde::encode_into_std_write(value, &mut bytes, config)?;
463    Ok(bytes)
464}
465
466/// Deserializes from bincode bytes.
467///
468/// # Arguments
469///
470/// * `bytes` - Binary data
471///
472/// # Returns
473///
474/// Deserialized value
475pub fn deserialize_from_bincode<T: DeserializeOwned>(bytes: &[u8]) -> SerializationResult<T> {
476    let config = bincode::config::standard();
477    let mut cursor = std::io::Cursor::new(bytes);
478    Ok(bincode::serde::decode_from_std_read(&mut cursor, config)?)
479}
480
481// Helper functions
482
483/// Checks if a file is gzip compressed by reading magic bytes.
484///
485/// # Arguments
486///
487/// * `path` - File path
488///
489/// # Returns
490///
491/// `true` if file starts with gzip magic number
492fn is_gzip_compressed(path: &Path) -> SerializationResult<bool> {
493    let file = File::open(path)?;
494    let mut reader = BufReader::new(file);
495    let mut buffer = [0; 2];
496    
497    reader.read_exact(&mut buffer)?;
498    
499    // Gzip magic number: 0x1f 0x8b
500    Ok(buffer[0] == 0x1f && buffer[1] == 0x8b)
501}
502
503/// Loads model assuming bincode format.
504fn load_as_bincode<T: DeserializeOwned>(
505    reader: &mut BufReader<File>,
506    path: &Path,
507) -> SerializationResult<(T, ModelMetadata)> {
508    let config = bincode::config::standard();
509    
510    if is_gzip_compressed(path)? {
511        let mut gz_reader = flate2::read::GzDecoder::new(reader);
512        Ok(bincode::serde::decode_from_std_read(&mut gz_reader, config)?)
513    } else {
514        Ok(bincode::serde::decode_from_std_read(reader, config)?)
515    }
516}
517
518/// Loads model assuming JSON format.
519fn load_as_json<T: DeserializeOwned>(
520    reader: &mut BufReader<File>,
521    path: &Path,
522) -> SerializationResult<(T, ModelMetadata)> {
523    if is_gzip_compressed(path)? {
524        let gz_reader = flate2::read::GzDecoder::new(reader);
525        Ok(serde_json::from_reader(gz_reader)?)
526    } else {
527        Ok(serde_json::from_reader(reader)?)
528    }
529}
530
531/// Extracts metadata from a JSON value.
532fn extract_metadata(value: &serde_json::Value) -> SerializationResult<ModelMetadata> {
533    if let Some(metadata_value) = value.get("1") {
534        Ok(serde_json::from_value(metadata_value.clone())?)
535    } else {
536        Err(SerializationError::InvalidModelFile(
537            "No metadata found in JSON file".to_string()
538        ))
539    }
540}
541
542/// Utility functions for model file management.
543///
544/// Provides helpers for working with model files on disk, including
545/// path construction, file listing, and format detection.
546pub struct ModelFileUtils;
547
548impl ModelFileUtils {
549    /// Gets the appropriate file extension for a serialization format.
550    ///
551    /// # Arguments
552    ///
553    /// * `format` - Serialization format
554    /// * `compressed` - Whether compression is enabled
555    ///
556    /// # Returns
557    ///
558    /// File extension string (e.g., "json.gz", "bin") 
559    pub fn get_extension(format: SerializationFormat, compressed: bool) -> String {
560        let base_ext = match format {
561            SerializationFormat::Json => "json",
562            SerializationFormat::Bincode => "bin",
563        };
564        
565        if compressed {
566            format!("{}.gz", base_ext)
567        } else {
568            base_ext.to_string()
569        }
570    }
571    
572    /// Constructs a full model file path with appropriate extension.
573    ///
574    /// # Arguments
575    ///
576    /// * `base_path` - Directory for model files
577    /// * `model_name` - Base name of the model
578    /// * `format` - Serialization format
579    /// * `compressed` - Whether to compress
580    ///
581    /// # Returns
582    ///
583    /// Complete file path 
584    pub fn create_model_path(
585        base_path: &Path,
586        model_name: &str,
587        format: SerializationFormat,
588        compressed: bool,
589    ) -> PathBuf {
590        let ext = Self::get_extension(format, compressed);
591        base_path.join(format!("{}.{}", model_name, ext))
592    }
593    
594    /// Lists all model files in a directory.
595    ///
596    /// Scans the directory for files with recognized extensions (.json, .bin, .gz).
597    ///
598    /// # Arguments
599    ///
600    /// * `dir` - Directory to scan
601    ///
602    /// # Returns
603    ///
604    /// Vector of model file paths 
605    pub fn list_model_files(dir: &Path) -> SerializationResult<Vec<PathBuf>> {
606        let mut model_files = Vec::new();
607        
608        for entry in std::fs::read_dir(dir)? {
609            let entry = entry?;
610            let path = entry.path();
611            
612            if path.is_file() {
613                if let Some(ext) = path.extension() {
614                    if ext == "json" || ext == "bin" || ext == "gz" {
615                        model_files.push(path);
616                    }
617                }
618            }
619        }
620        
621        Ok(model_files)
622    }
623    
624    /// Checks if a file is a valid model file based on extension.
625    ///
626    /// # Arguments
627    ///
628    /// * `path` - File path to check
629    ///
630    /// # Returns
631    ///
632    /// `true` if file has a recognized model extension
633    pub fn is_model_file(path: &Path) -> bool {
634        path.is_file() && path.extension().map_or(false, |ext| {
635            ext == "json" || ext == "bin" || ext == "gz"
636        })
637    }
638}
639