Skip to main content

linreg_core/serialization/
mod.rs

1//! Model serialization module for saving and loading regression models.
2//!
3//! This module provides a unified serialization framework that works across:
4//! - Native Rust (direct file I/O)
5//! - Python (PyO3 bindings)
6//! - WASM (JSON string serialization)
7//!
8//! # Format Version
9//!
10//! The current serialization format version is `1.0`. This version:
11//! - Wraps model data in metadata (format version, library version, model type, timestamp)
12//! - Uses JSON for cross-platform compatibility
13//! - Supports forward compatibility (unknown fields are ignored)
14//!
15//! # Module Structure
16//!
17//! - [`types`] — ModelType enum, ModelMetadata, SerializedModel
18//! - [`traits`] — ModelSave and ModelLoad trait definitions
19//! - [`json`] — File I/O and version validation
20
21pub mod types;
22pub mod traits;
23pub mod json;
24
25/// Current serialization format version
26///
27/// Major version changes are breaking. Minor version changes are additive.
28pub const FORMAT_VERSION: &str = "1.0";
29
30// Re-export core types for convenience
31pub use types::{ModelMetadata, ModelType, SerializedModel};
32pub use traits::{ModelLoad, ModelSave};
33
34/// Macro to generate ModelSave and ModelLoad implementations for a model type.
35///
36/// This macro eliminates the repetitive boilerplate of implementing the
37/// serialization traits. Each model type only needs to specify:
38/// - The type name
39/// - The ModelType variant
40/// - The type name string (for error messages)
41///
42/// # Example
43///
44/// ```ignore
45/// impl_serialization!(MyModel, ModelType::MyModel, "MyModel");
46/// ```
47///
48/// This expands to full implementations of both `ModelSave` and `ModelLoad`.
49#[macro_export]
50macro_rules! impl_serialization {
51    ($type_name:ty, $model_type:expr, $type_str:expr) => {
52        impl $crate::serialization::ModelSave for $type_name {
53            fn save_with_name(&self, path: &str, name: Option<String>) -> $crate::error::Result<()> {
54                use $crate::serialization::{ModelMetadata, SerializedModel};
55                use $crate::error::Error;
56
57                // Convert model to JSON value
58                let data = serde_json::to_value(self).map_err(|e| {
59                    Error::SerializationError(format!("Failed to serialize {}: {}", $type_str, e))
60                })?;
61
62                // Create metadata
63                let mut metadata = ModelMetadata::new($model_type, env!("CARGO_PKG_VERSION").to_string());
64                if let Some(n) = name {
65                    metadata = metadata.with_name(n);
66                }
67
68                // Create serialized model and save to file
69                let model = SerializedModel::new(metadata, data);
70                $crate::serialization::json::save_to_file(&model, path)
71            }
72
73            fn model_type() -> $crate::serialization::ModelType {
74                $model_type
75            }
76        }
77
78        impl $crate::serialization::ModelLoad for $type_name {
79            fn load(path: &str) -> $crate::error::Result<Self> {
80                let model = $crate::serialization::json::load_from_file(path)?;
81
82                // Validate model type
83                if model.metadata.model_type != $model_type {
84                    return Err($crate::error::Error::ModelTypeMismatch {
85                        expected: $type_str.to_string(),
86                        found: model.metadata.model_type.to_string(),
87                    });
88                }
89
90                Self::from_serialized(model)
91            }
92
93            fn from_serialized(model: $crate::serialization::SerializedModel) -> $crate::error::Result<Self> {
94                use $crate::error::Error;
95                serde_json::from_value(model.data).map_err(|e| {
96                    Error::DeserializationError(format!("Failed to deserialize {}: {}", $type_str, e))
97                })
98            }
99
100            fn model_type() -> $crate::serialization::ModelType {
101                $model_type
102            }
103        }
104    };
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use serde::{Deserialize, Serialize};
111
112    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
113    struct TestModel {
114        value: f64,
115    }
116
117    impl_serialization!(TestModel, ModelType::OLS, "TestModel");
118
119    #[test]
120    fn test_macro_generates_save() {
121        let model = TestModel { value: 42.0 };
122        assert_eq!(<TestModel as ModelSave>::model_type(), ModelType::OLS);
123    }
124
125    #[test]
126    fn test_macro_generates_load() {
127        assert_eq!(<TestModel as ModelLoad>::model_type(), ModelType::OLS);
128    }
129}