Skip to main content

linreg_core/serialization/
traits.rs

1//! Trait definitions for model serialization.
2//!
3//! These traits provide a unified API for saving and loading models
4//! across all regression types.
5
6use crate::error::Error;
7use crate::serialization::types::{ModelType, SerializedModel};
8use serde::Serialize;
9
10/// Trait for saving models to disk.
11///
12/// This trait is implemented by all regression result types that support
13/// serialization. Models are saved as JSON with a metadata wrapper.
14///
15/// # Example
16///
17/// ```ignore
18/// # use linreg_core::core::ols_regression;
19/// # use linreg_core::serialization::ModelSave;
20/// let y = vec![2.0, 4.0, 6.0, 8.0];
21/// let x1 = vec![1.0, 2.0, 3.0, 4.0];
22/// let names = vec!["Intercept".into(), "X1".into()];
23///
24/// let model = ols_regression(&y, &[x1], &names).unwrap();
25/// model.save("my_model.json").unwrap();
26/// ```
27pub trait ModelSave: Serialize {
28    /// Save the model to a file.
29    ///
30    /// The file will contain JSON with metadata (format version, model type,
31    /// timestamp) and the model data.
32    ///
33    /// # Arguments
34    ///
35    /// * `path` - File path to save to (will be created/overwritten)
36    ///
37    /// # Returns
38    ///
39    /// Returns `Ok(())` on success, or an `Error` if serialization or file I/O fails.
40    fn save(&self, path: &str) -> Result<(), Error> {
41        self.save_with_name(path, None)
42    }
43
44    /// Save the model to a file with a custom name.
45    ///
46    /// The name is stored in the model metadata and can be used to identify
47    /// the model later.
48    ///
49    /// # Arguments
50    ///
51    /// * `path` - File path to save to
52    /// * `name` - Optional custom name for the model
53    fn save_with_name(&self, path: &str, name: Option<String>) -> Result<(), Error>;
54
55    /// Get the model type identifier.
56    ///
57    /// This is used when serializing to store the model type in metadata.
58    fn model_type() -> ModelType;
59}
60
61/// Trait for loading models from disk.
62///
63/// This trait is implemented by all regression result types that support
64/// deserialization. Loading validates the format version and model type.
65///
66/// # Example
67///
68/// ```ignore
69/// # use linreg_core::core::RegressionOutput;
70/// # use linreg_core::serialization::ModelLoad;
71/// let model: RegressionOutput = RegressionOutput::load("my_model.json").unwrap();
72/// println!("R²: {}", model.r_squared);
73/// ```
74pub trait ModelLoad: Sized {
75    /// Load a model from a file.
76    ///
77    /// This validates that:
78    /// - The file exists and contains valid JSON
79    /// - The format version is compatible
80    /// - The model type matches the expected type
81    ///
82    /// # Arguments
83    ///
84    /// * `path` - File path to load from
85    ///
86    /// # Returns
87    ///
88    /// Returns the deserialized model on success, or an `Error` if loading fails.
89    fn load(path: &str) -> Result<Self, Error>;
90
91    /// Load a model from an already-deserialized wrapper.
92    ///
93    /// This is useful when you have a `SerializedModel` and want to convert
94    /// it to a specific model type.
95    ///
96    /// # Arguments
97    ///
98    /// * `model` - The serialized model wrapper
99    ///
100    /// # Returns
101    ///
102    /// Returns the deserialized model on success, or an `Error` if conversion fails.
103    fn from_serialized(model: SerializedModel) -> Result<Self, Error>;
104
105    /// Get the model type identifier.
106    ///
107    /// This is used to validate that the loaded file contains the correct model type.
108    fn model_type() -> ModelType;
109}
110
111#[cfg(test)]
112mod tests {
113    use crate::serialization::ModelType;
114
115    // We'll implement the traits for actual model types in their respective modules
116    // This module just defines the trait interface
117
118    #[test]
119    fn test_model_type_display() {
120        // Verify ModelType works correctly
121        assert_eq!(ModelType::OLS.to_string(), "OLS");
122        assert_eq!(ModelType::Ridge.to_string(), "Ridge");
123        assert_eq!(ModelType::Lasso.to_string(), "Lasso");
124        assert_eq!(ModelType::ElasticNet.to_string(), "ElasticNet");
125        assert_eq!(ModelType::WLS.to_string(), "WLS");
126        assert_eq!(ModelType::LOESS.to_string(), "LOESS");
127    }
128}