linreg_core/serialization/traits.rs
1//! Trait definitions for model serialization.
2//!
3//! These traits provide a unified API for saving and loading models
4//! across all regression types.
5
6use crate::error::Error;
7use crate::serialization::types::{ModelType, SerializedModel};
8use serde::Serialize;
9
10/// Trait for saving models to disk.
11///
12/// This trait is implemented by all regression result types that support
13/// serialization. Models are saved as JSON with a metadata wrapper.
14///
15/// # Example
16///
17/// ```ignore
18/// # use linreg_core::core::ols_regression;
19/// # use linreg_core::serialization::ModelSave;
20/// let y = vec![2.0, 4.0, 6.0, 8.0];
21/// let x1 = vec![1.0, 2.0, 3.0, 4.0];
22/// let names = vec!["Intercept".into(), "X1".into()];
23///
24/// let model = ols_regression(&y, &[x1], &names).unwrap();
25/// model.save("my_model.json").unwrap();
26/// ```
27pub trait ModelSave: Serialize {
28 /// Save the model to a file.
29 ///
30 /// The file will contain JSON with metadata (format version, model type,
31 /// timestamp) and the model data.
32 ///
33 /// # Arguments
34 ///
35 /// * `path` - File path to save to (will be created/overwritten)
36 ///
37 /// # Returns
38 ///
39 /// Returns `Ok(())` on success, or an `Error` if serialization or file I/O fails.
40 fn save(&self, path: &str) -> Result<(), Error> {
41 self.save_with_name(path, None)
42 }
43
44 /// Save the model to a file with a custom name.
45 ///
46 /// The name is stored in the model metadata and can be used to identify
47 /// the model later.
48 ///
49 /// # Arguments
50 ///
51 /// * `path` - File path to save to
52 /// * `name` - Optional custom name for the model
53 fn save_with_name(&self, path: &str, name: Option<String>) -> Result<(), Error>;
54
55 /// Get the model type identifier.
56 ///
57 /// This is used when serializing to store the model type in metadata.
58 fn model_type() -> ModelType;
59}
60
61/// Trait for loading models from disk.
62///
63/// This trait is implemented by all regression result types that support
64/// deserialization. Loading validates the format version and model type.
65///
66/// # Example
67///
68/// ```ignore
69/// # use linreg_core::core::RegressionOutput;
70/// # use linreg_core::serialization::ModelLoad;
71/// let model: RegressionOutput = RegressionOutput::load("my_model.json").unwrap();
72/// println!("R²: {}", model.r_squared);
73/// ```
74pub trait ModelLoad: Sized {
75 /// Load a model from a file.
76 ///
77 /// This validates that:
78 /// - The file exists and contains valid JSON
79 /// - The format version is compatible
80 /// - The model type matches the expected type
81 ///
82 /// # Arguments
83 ///
84 /// * `path` - File path to load from
85 ///
86 /// # Returns
87 ///
88 /// Returns the deserialized model on success, or an `Error` if loading fails.
89 fn load(path: &str) -> Result<Self, Error>;
90
91 /// Load a model from an already-deserialized wrapper.
92 ///
93 /// This is useful when you have a `SerializedModel` and want to convert
94 /// it to a specific model type.
95 ///
96 /// # Arguments
97 ///
98 /// * `model` - The serialized model wrapper
99 ///
100 /// # Returns
101 ///
102 /// Returns the deserialized model on success, or an `Error` if conversion fails.
103 fn from_serialized(model: SerializedModel) -> Result<Self, Error>;
104
105 /// Get the model type identifier.
106 ///
107 /// This is used to validate that the loaded file contains the correct model type.
108 fn model_type() -> ModelType;
109}
110
111#[cfg(test)]
112mod tests {
113 use crate::serialization::ModelType;
114
115 // We'll implement the traits for actual model types in their respective modules
116 // This module just defines the trait interface
117
118 #[test]
119 fn test_model_type_display() {
120 // Verify ModelType works correctly
121 assert_eq!(ModelType::OLS.to_string(), "OLS");
122 assert_eq!(ModelType::Ridge.to_string(), "Ridge");
123 assert_eq!(ModelType::Lasso.to_string(), "Lasso");
124 assert_eq!(ModelType::ElasticNet.to_string(), "ElasticNet");
125 assert_eq!(ModelType::WLS.to_string(), "WLS");
126 assert_eq!(ModelType::LOESS.to_string(), "LOESS");
127 }
128}