gbrt_rs/utils/serialization.rs
1#![allow(dead_code)]
2
3//! Model serialization and persistence for gradient boosting models.
4//!
5//! This module provides robust utilities for saving and loading trained models
6//! with metadata, versioning, and compression support. It supports multiple
7//! serialization formats and ensures model compatibility across different
8//! versions of the library.
9//!
10//! # Features
11//!
12//! - **Multiple formats**: JSON (human-readable) and Bincode (compact binary)
13//! - **Compression**: Optional gzip compression to reduce file size
14//! - **Versioning**: Automatic model version tracking and validation
15//! - **Metadata**: Rich model information (timestamps, hyperparameters, etc.)
16//! - **Validation**: Verify model files without full deserialization
17
18use serde::{Serialize, de::DeserializeOwned};
19use std::fs::{File, create_dir_all};
20use std::io::{BufReader, BufWriter, Read, Write};
21use std::path::{Path, PathBuf};
22use thiserror::Error;
23use chrono::{DateTime, Utc};
24
25/// Errors that can occur during model serialization/deserialization.
26///
27/// Covers IO errors, format-specific errors, version mismatches, and
28/// data integrity issues.
29#[derive(Error, Debug)]
30pub enum SerializationError {
31 /// Underlying filesystem or IO operation failed.
32 #[error("IO error: {0}")]
33 IoError(#[from] std::io::Error),
34
35 /// JSON serialization/deserialization failed.
36 #[error("JSON serialization error: {0}")]
37 JsonError(#[from] serde_json::Error),
38
39 /// Bincode encoding failed.
40 #[error("Bincode encoding error: {0}")]
41 BincodeEncodeError(#[from] bincode::error::EncodeError),
42
43 /// Bincode decoding failed.
44 #[error("Bincode decoding error: {0}")]
45 BincodeDecodeError(#[from] bincode::error::DecodeError),
46
47 /// Model file is corrupted or has invalid structure.
48 #[error("Invalid model file: {0}")]
49 InvalidModelFile(String),
50
51 /// Model version doesn't match current library version.
52 #[error("Version mismatch: expected {expected}, got {actual}")]
53 VersionMismatch { expected: String, actual: String },
54
55 /// General serialization failure.
56 #[error("Serialization failed: {0}")]
57 SerializationFailed(String),
58}
59
60/// Result type for serialization operations.
61pub type SerializationResult<T> = std::result::Result<T, SerializationError>;
62
63/// Metadata describing a serialized model.
64///
65/// Tracks versioning, creation info, and model characteristics to ensure
66/// reproducibility and compatibility.
67#[derive(Debug, Clone, Serialize, serde::Deserialize)]
68pub struct ModelMetadata {
69 /// Semantic version of the library that created the model.
70 pub version: String,
71 /// Timestamp when the model was first saved.
72 pub created_at: DateTime<Utc>,
73 /// Timestamp when the model was last updated.
74 pub updated_at: DateTime<Utc>,
75 /// Type of model (e.g., "gradient_booster", "decision_tree").
76 pub model_type: String,
77 /// Number of features the model expects.
78 pub num_features: usize,
79 /// Number of trees in the ensemble (if applicable).
80 pub num_trees: usize,
81 /// Model hyperparameters and configuration.
82 pub parameters: std::collections::HashMap<String, String>,
83}
84
85impl ModelMetadata {
86 /// Creates new metadata for a model.
87 ///
88 /// Automatically sets version to current crate version and timestamps to now.
89 ///
90 /// # Arguments
91 ///
92 /// * `model_type` - Identifier for the model type
93 /// * `num_features` - Number of input features
94 /// * `num_trees` - Number of trees (0 for single models)
95 pub fn new(model_type: &str, num_features: usize, num_trees: usize) -> Self {
96 let now = Utc::now();
97 Self {
98 version: env!("CARGO_PKG_VERSION").to_string(),
99 created_at: now,
100 updated_at: now,
101 model_type: model_type.to_string(),
102 num_features,
103 num_trees,
104 parameters: std::collections::HashMap::new(),
105 }
106 }
107
108 /// Adds a hyperparameter to the metadata (builder pattern).
109 ///
110 /// # Arguments
111 ///
112 /// * `key` - Parameter name
113 /// * `value` - Parameter value (converted to string)
114 ///
115 /// # Returns
116 ///
117 /// Self with updated parameters
118 pub fn with_parameter(mut self, key: &str, value: &str) -> Self {
119 self.parameters.insert(key.to_string(), value.to_string());
120 self
121 }
122
123 /// Updates the `updated_at` timestamp to current time.
124 pub fn update(&mut self) {
125 self.updated_at = Utc::now();
126 }
127
128 /// Validates that the model version matches the current library version.
129 ///
130 /// # Returns
131 ///
132 /// `Ok(())` if versions match, `Err(SerializationError::VersionMismatch)` otherwise
133 pub fn validate_version(&self) -> SerializationResult<()> {
134 let current_version = env!("CARGO_PKG_VERSION");
135 if self.version != current_version {
136 return Err(SerializationError::VersionMismatch {
137 expected: current_version.to_string(),
138 actual: self.version.clone(),
139 });
140 }
141 Ok(())
142 }
143}
144
145/// Main serializer with configurable format and compression.
146///
147/// Provides a unified interface for saving and loading models with
148/// consistent error handling and metadata management.
149pub struct ModelSerializer {
150 /// Whether to compress the output with gzip.
151 compression: bool,
152 /// Serialization format to use.
153 format: SerializationFormat,
154}
155
156/// Supported serialization formats.
157///
158/// Each format has different tradeoffs between size, speed, and readability.
159#[derive(Debug, Clone, Copy, PartialEq)]
160pub enum SerializationFormat {
161 /// JSON format (human-readable, good interoperability).
162 Json,
163 /// Bincode format (compact binary, fastest serialization).
164 Bincode,
165}
166
167impl Default for ModelSerializer {
168 fn default() -> Self {
169 Self {
170 compression: true,
171 format: SerializationFormat::Bincode,
172 }
173 }
174}
175
176impl ModelSerializer {
177 /// Creates a new serializer with default settings (Bincode + compression).
178 pub fn new() -> Self {
179 Self::default()
180 }
181
182 /// Enables or disables gzip compression.
183 ///
184 /// # Arguments
185 ///
186 /// * `compression` - `true` to compress (default), `false` for uncompressed
187 ///
188 /// # Returns
189 ///
190 /// Self with updated compression setting (builder pattern)
191 pub fn with_compression(mut self, compression: bool) -> Self {
192 self.compression = compression;
193 self
194 }
195
196 /// Sets the serialization format.
197 ///
198 /// # Arguments
199 ///
200 /// * `format` - [`SerializationFormat`] to use
201 ///
202 /// # Returns
203 ///
204 /// Self with updated format (builder pattern)
205 pub fn with_format(mut self, format: SerializationFormat) -> Self {
206 self.format = format;
207 self
208 }
209
210 /// Saves a model to disk with metadata.
211 ///
212 /// Automatically creates parent directories if they don't exist.
213 ///
214 /// # Arguments
215 ///
216 /// * `model` - Model to serialize (must implement `Serialize`)
217 /// * `metadata` - Model metadata
218 /// * `path` - Destination file path
219 ///
220 /// # Returns
221 ///
222 /// `Ok(())` on success, [`SerializationError`] on failure
223 pub fn save_model<T: Serialize>(
224 &self,
225 model: &T,
226 metadata: &ModelMetadata,
227 path: &Path,
228 ) -> SerializationResult<()> {
229 save_model(model, metadata, path, self.format, self.compression)
230 }
231
232 /// Loads a model from disk with metadata.
233 ///
234 /// Automatically detects format and compression from file content.
235 ///
236 /// # Arguments
237 ///
238 /// * `path` - Path to model file
239 ///
240 /// # Returns
241 ///
242 /// Tuple of `(deserialized_model, metadata)`
243 ///
244 /// # Errors
245 ///
246 /// Returns error if file doesn't exist, format is invalid, or version mismatch
247 pub fn load_model<T: DeserializeOwned>(&self, path: &Path) -> SerializationResult<(T, ModelMetadata)> {
248 load_model(path)
249 }
250
251 /// Validates a model file without loading the full model.
252 ///
253 /// Useful for checking version compatibility or file integrity.
254 ///
255 /// # Arguments
256 ///
257 /// * `path` - Path to model file
258 ///
259 /// # Returns
260 ///
261 /// Metadata if validation succeeds
262 pub fn validate_model_file(&self, path: &Path) -> SerializationResult<ModelMetadata> {
263 validate_model_file(path)
264 }
265}
266
267
268// Standalone serialization functions
269
270/// Saves model to file with explicit format and compression settings.
271///
272/// This is a lower-level function; prefer [`ModelSerializer::save_model`] for most cases.
273///
274/// # Arguments
275///
276/// * `model` - Model to serialize
277/// * `metadata` - Model metadata
278/// * `path` - Destination file path
279/// * `format` - Serialization format
280/// * `compression` - Whether to compress
281///
282/// # Returns
283///
284/// `Ok(())` on success
285pub fn save_model<T: Serialize>(
286 model: &T,
287 metadata: &ModelMetadata,
288 path: &Path,
289 format: SerializationFormat,
290 compression: bool,
291) -> SerializationResult<()> {
292 // Create directory if it doesn't exist
293 if let Some(parent) = path.parent() {
294 create_dir_all(parent)?;
295 }
296
297 match format {
298 SerializationFormat::Json => {
299 let file = File::create(path)?;
300 let writer = BufWriter::new(file);
301
302 if compression {
303 let mut gz_writer = flate2::write::GzEncoder::new(writer, flate2::Compression::default());
304 serde_json::to_writer(&mut gz_writer, &(model, metadata))?;
305 gz_writer.finish()?;
306 } else {
307 serde_json::to_writer(writer, &(model, metadata))?;
308 }
309 }
310 SerializationFormat::Bincode => {
311 let file = File::create(path)?;
312 let mut writer = BufWriter::new(file);
313
314 // Use bincode's serde integration with default configuration
315 let config = bincode::config::standard();
316
317 if compression {
318 let mut gz_writer = flate2::write::GzEncoder::new(writer, flate2::Compression::default());
319 bincode::serde::encode_into_std_write(&(model, metadata), &mut gz_writer, config)?;
320 gz_writer.finish()?;
321 } else {
322 bincode::serde::encode_into_std_write(&(model, metadata), &mut writer, config)?;
323 }
324 }
325 }
326
327 Ok(())
328}
329
330/// Loads model from file (auto-detects format and compression).
331///
332/// Automatically infers format from file extension and compression from magic bytes.
333///
334/// # Arguments
335///
336/// * `path` - Path to model file
337///
338/// # Returns
339///
340/// Tuple of `(model, metadata)`
341///
342/// # Errors
343///
344/// Returns error if file is corrupted, format is unknown, or version mismatch
345pub fn load_model<T: DeserializeOwned>(path: &Path) -> SerializationResult<(T, ModelMetadata)> {
346 let file = File::open(path)?;
347 let mut reader = BufReader::new(file);
348
349 // Try to detect format and compression
350 let (model, metadata) = if path.extension().map_or(false, |ext| ext == "json") {
351 // JSON format
352 if is_gzip_compressed(path)? {
353 let gz_reader = flate2::read::GzDecoder::new(reader);
354 serde_json::from_reader(gz_reader)?
355 } else {
356 serde_json::from_reader(reader)?
357 }
358 } else if path.extension().map_or(false, |ext| ext == "bin") {
359 // Bincode format
360 let config = bincode::config::standard();
361
362 if is_gzip_compressed(path)? {
363 let mut gz_reader = flate2::read::GzDecoder::new(reader);
364 bincode::serde::decode_from_std_read(&mut gz_reader, config)?
365 } else {
366 bincode::serde::decode_from_std_read(&mut reader, config)?
367 }
368 } else {
369 // Try to auto-detect
370 if let Ok(result) = load_as_bincode(&mut reader, path) {
371 result
372 } else if let Ok(result) = load_as_json(&mut reader, path) {
373 result
374 } else {
375 return Err(SerializationError::InvalidModelFile(
376 "Could not determine file format".to_string()
377 ));
378 }
379 };
380
381 // Validate metadata
382 metadata.validate_version()?;
383
384 Ok((model, metadata))
385}
386
387/// Validates a model file without loading the complete model.
388///
389/// For JSON files, this reads only the metadata portion. For bincode files,
390/// full deserialization may be required.
391///
392/// # Arguments
393///
394/// * `path` - Path to model file
395///
396/// # Returns
397///
398/// Metadata if file is valid
399pub fn validate_model_file(path: &Path) -> SerializationResult<ModelMetadata> {
400 let file = File::open(path)?;
401 let reader = BufReader::new(file);
402
403 // Try to read just the metadata
404 if path.extension().map_or(false, |ext| ext == "json") {
405 if is_gzip_compressed(path)? {
406 let gz_reader = flate2::read::GzDecoder::new(reader);
407 let value: serde_json::Value = serde_json::from_reader(gz_reader)?;
408 extract_metadata(&value)
409 } else {
410 let value: serde_json::Value = serde_json::from_reader(reader)?;
411 extract_metadata(&value)
412 }
413 } else {
414 // For bincode, we need a different approach
415 load_model::<serde::de::IgnoredAny>(path)?;
416 Err(SerializationError::SerializationFailed(
417 "Metadata validation for bincode requires full deserialization".to_string()
418 ))
419 }
420}
421
422/// Serializes a value to a pretty-printed JSON string.
423///
424/// Useful for debugging or human-readable output.
425///
426/// # Arguments
427///
428/// * `value` - Value to serialize
429///
430/// # Returns
431///
432/// JSON string
433pub fn serialize_to_json<T: Serialize>(value: &T) -> SerializationResult<String> {
434 Ok(serde_json::to_string_pretty(value)?)
435}
436
437/// Deserializes from a JSON string.
438///
439/// # Arguments
440///
441/// * `json` - JSON string
442///
443/// # Returns
444///
445/// Deserialized value
446pub fn deserialize_from_json<T: DeserializeOwned>(json: &str) -> SerializationResult<T> {
447 Ok(serde_json::from_str(json)?)
448}
449
450/// Serializes a value to bincode bytes.
451///
452/// # Arguments
453///
454/// * `value` - Value to serialize
455///
456/// # Returns
457///
458/// Binary representation
459pub fn serialize_to_bincode<T: Serialize>(value: &T) -> SerializationResult<Vec<u8>> {
460 let config = bincode::config::standard();
461 let mut bytes = Vec::new();
462 bincode::serde::encode_into_std_write(value, &mut bytes, config)?;
463 Ok(bytes)
464}
465
466/// Deserializes from bincode bytes.
467///
468/// # Arguments
469///
470/// * `bytes` - Binary data
471///
472/// # Returns
473///
474/// Deserialized value
475pub fn deserialize_from_bincode<T: DeserializeOwned>(bytes: &[u8]) -> SerializationResult<T> {
476 let config = bincode::config::standard();
477 let mut cursor = std::io::Cursor::new(bytes);
478 Ok(bincode::serde::decode_from_std_read(&mut cursor, config)?)
479}
480
481// Helper functions
482
483/// Checks if a file is gzip compressed by reading magic bytes.
484///
485/// # Arguments
486///
487/// * `path` - File path
488///
489/// # Returns
490///
491/// `true` if file starts with gzip magic number
492fn is_gzip_compressed(path: &Path) -> SerializationResult<bool> {
493 let file = File::open(path)?;
494 let mut reader = BufReader::new(file);
495 let mut buffer = [0; 2];
496
497 reader.read_exact(&mut buffer)?;
498
499 // Gzip magic number: 0x1f 0x8b
500 Ok(buffer[0] == 0x1f && buffer[1] == 0x8b)
501}
502
503/// Loads model assuming bincode format.
504fn load_as_bincode<T: DeserializeOwned>(
505 reader: &mut BufReader<File>,
506 path: &Path,
507) -> SerializationResult<(T, ModelMetadata)> {
508 let config = bincode::config::standard();
509
510 if is_gzip_compressed(path)? {
511 let mut gz_reader = flate2::read::GzDecoder::new(reader);
512 Ok(bincode::serde::decode_from_std_read(&mut gz_reader, config)?)
513 } else {
514 Ok(bincode::serde::decode_from_std_read(reader, config)?)
515 }
516}
517
518/// Loads model assuming JSON format.
519fn load_as_json<T: DeserializeOwned>(
520 reader: &mut BufReader<File>,
521 path: &Path,
522) -> SerializationResult<(T, ModelMetadata)> {
523 if is_gzip_compressed(path)? {
524 let gz_reader = flate2::read::GzDecoder::new(reader);
525 Ok(serde_json::from_reader(gz_reader)?)
526 } else {
527 Ok(serde_json::from_reader(reader)?)
528 }
529}
530
531/// Extracts metadata from a JSON value.
532fn extract_metadata(value: &serde_json::Value) -> SerializationResult<ModelMetadata> {
533 if let Some(metadata_value) = value.get("1") {
534 Ok(serde_json::from_value(metadata_value.clone())?)
535 } else {
536 Err(SerializationError::InvalidModelFile(
537 "No metadata found in JSON file".to_string()
538 ))
539 }
540}
541
542/// Utility functions for model file management.
543///
544/// Provides helpers for working with model files on disk, including
545/// path construction, file listing, and format detection.
546pub struct ModelFileUtils;
547
548impl ModelFileUtils {
549 /// Gets the appropriate file extension for a serialization format.
550 ///
551 /// # Arguments
552 ///
553 /// * `format` - Serialization format
554 /// * `compressed` - Whether compression is enabled
555 ///
556 /// # Returns
557 ///
558 /// File extension string (e.g., "json.gz", "bin")
559 pub fn get_extension(format: SerializationFormat, compressed: bool) -> String {
560 let base_ext = match format {
561 SerializationFormat::Json => "json",
562 SerializationFormat::Bincode => "bin",
563 };
564
565 if compressed {
566 format!("{}.gz", base_ext)
567 } else {
568 base_ext.to_string()
569 }
570 }
571
572 /// Constructs a full model file path with appropriate extension.
573 ///
574 /// # Arguments
575 ///
576 /// * `base_path` - Directory for model files
577 /// * `model_name` - Base name of the model
578 /// * `format` - Serialization format
579 /// * `compressed` - Whether to compress
580 ///
581 /// # Returns
582 ///
583 /// Complete file path
584 pub fn create_model_path(
585 base_path: &Path,
586 model_name: &str,
587 format: SerializationFormat,
588 compressed: bool,
589 ) -> PathBuf {
590 let ext = Self::get_extension(format, compressed);
591 base_path.join(format!("{}.{}", model_name, ext))
592 }
593
594 /// Lists all model files in a directory.
595 ///
596 /// Scans the directory for files with recognized extensions (.json, .bin, .gz).
597 ///
598 /// # Arguments
599 ///
600 /// * `dir` - Directory to scan
601 ///
602 /// # Returns
603 ///
604 /// Vector of model file paths
605 pub fn list_model_files(dir: &Path) -> SerializationResult<Vec<PathBuf>> {
606 let mut model_files = Vec::new();
607
608 for entry in std::fs::read_dir(dir)? {
609 let entry = entry?;
610 let path = entry.path();
611
612 if path.is_file() {
613 if let Some(ext) = path.extension() {
614 if ext == "json" || ext == "bin" || ext == "gz" {
615 model_files.push(path);
616 }
617 }
618 }
619 }
620
621 Ok(model_files)
622 }
623
624 /// Checks if a file is a valid model file based on extension.
625 ///
626 /// # Arguments
627 ///
628 /// * `path` - File path to check
629 ///
630 /// # Returns
631 ///
632 /// `true` if file has a recognized model extension
633 pub fn is_model_file(path: &Path) -> bool {
634 path.is_file() && path.extension().map_or(false, |ext| {
635 ext == "json" || ext == "bin" || ext == "gz"
636 })
637 }
638}
639