1use serde::{Serialize, Deserialize};
2use ndarray::{Array2, Dimension};
3use std::fs::File;
4use std::io::{Write, Read};
5use std::path::Path;
6
7use crate::models::lstm_network::LSTMNetwork;
8use crate::layers::lstm_cell::LSTMCell;
9
10#[derive(Serialize, Deserialize, Debug, Clone)]
12struct SerializableArray2 {
13 data: Vec<f64>,
14 shape: (usize, usize),
15}
16
17impl From<&Array2<f64>> for SerializableArray2 {
18 fn from(array: &Array2<f64>) -> Self {
19 Self {
20 data: array.iter().cloned().collect(),
21 shape: array.raw_dim().into_pattern(),
22 }
23 }
24}
25
26impl Into<Array2<f64>> for SerializableArray2 {
27 fn into(self) -> Array2<f64> {
28 Array2::from_shape_vec(self.shape, self.data)
29 .expect("Failed to reconstruct Array2 from serialized data")
30 }
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone)]
35pub struct SerializableLSTMCell {
36 w_ih: SerializableArray2,
37 w_hh: SerializableArray2,
38 b_ih: SerializableArray2,
39 b_hh: SerializableArray2,
40 hidden_size: usize,
41}
42
43impl From<&LSTMCell> for SerializableLSTMCell {
44 fn from(cell: &LSTMCell) -> Self {
45 Self {
46 w_ih: (&cell.w_ih).into(),
47 w_hh: (&cell.w_hh).into(),
48 b_ih: (&cell.b_ih).into(),
49 b_hh: (&cell.b_hh).into(),
50 hidden_size: cell.hidden_size,
51 }
52 }
53}
54
55impl Into<LSTMCell> for SerializableLSTMCell {
56 fn into(self) -> LSTMCell {
57 LSTMCell {
58 w_ih: self.w_ih.into(),
59 w_hh: self.w_hh.into(),
60 b_ih: self.b_ih.into(),
61 b_hh: self.b_hh.into(),
62 hidden_size: self.hidden_size,
63 input_dropout: None,
64 recurrent_dropout: None,
65 output_dropout: None,
66 zoneout: None,
67 is_training: true,
68 }
69 }
70}
71
72#[derive(Serialize, Deserialize, Debug, Clone)]
74pub struct SerializableLSTMNetwork {
75 cells: Vec<SerializableLSTMCell>,
76 input_size: usize,
77 hidden_size: usize,
78 num_layers: usize,
79}
80
81impl From<&LSTMNetwork> for SerializableLSTMNetwork {
82 fn from(network: &LSTMNetwork) -> Self {
83 Self {
84 cells: network.get_cells().iter().map(|cell| cell.into()).collect(),
85 input_size: network.input_size,
86 hidden_size: network.hidden_size,
87 num_layers: network.num_layers,
88 }
89 }
90}
91
92impl Into<LSTMNetwork> for SerializableLSTMNetwork {
93 fn into(self) -> LSTMNetwork {
94 LSTMNetwork::from_cells(
95 self.cells.into_iter().map(|cell| cell.into()).collect(),
96 self.input_size,
97 self.hidden_size,
98 self.num_layers,
99 )
100 }
101}
102
103#[derive(Serialize, Deserialize, Clone)]
105pub struct ModelMetadata {
106 pub model_name: String,
107 pub version: String,
108 pub created_at: String,
109 pub input_size: usize,
110 pub hidden_size: usize,
111 pub num_layers: usize,
112 pub total_epochs: usize,
113 pub final_loss: Option<f64>,
114 pub description: Option<String>,
115}
116
117#[derive(Serialize, Deserialize)]
119pub struct SavedModel {
120 pub network: SerializableLSTMNetwork,
121 pub metadata: ModelMetadata,
122}
123
124#[derive(Debug)]
126pub enum PersistenceError {
127 IoError(std::io::Error),
128 SerializationError(String),
129}
130
131impl std::fmt::Display for PersistenceError {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 PersistenceError::IoError(err) => write!(f, "IO error: {}", err),
135 PersistenceError::SerializationError(err) => write!(f, "Serialization error: {}", err),
136 }
137 }
138}
139
140impl std::error::Error for PersistenceError {
141 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
142 match self {
143 PersistenceError::IoError(err) => Some(err),
144 PersistenceError::SerializationError(_) => None,
145 }
146 }
147}
148
149impl From<std::io::Error> for PersistenceError {
150 fn from(error: std::io::Error) -> Self {
151 PersistenceError::IoError(error)
152 }
153}
154
155impl From<serde_json::Error> for PersistenceError {
156 fn from(error: serde_json::Error) -> Self {
157 PersistenceError::SerializationError(error.to_string())
158 }
159}
160
161impl From<bincode::Error> for PersistenceError {
162 fn from(error: bincode::Error) -> Self {
163 PersistenceError::SerializationError(error.to_string())
164 }
165}
166
167pub struct ModelPersistence;
169
170impl ModelPersistence {
171 pub fn save_to_json<P: AsRef<Path>>(
173 model: &SavedModel,
174 path: P,
175 ) -> Result<(), PersistenceError> {
176 let json = serde_json::to_string_pretty(model)?;
177 let mut file = File::create(path)?;
178 file.write_all(json.as_bytes())?;
179 Ok(())
180 }
181
182 pub fn load_from_json<P: AsRef<Path>>(
184 path: P,
185 ) -> Result<SavedModel, PersistenceError> {
186 let mut file = File::open(path)?;
187 let mut contents = String::new();
188 file.read_to_string(&mut contents)?;
189 let model = serde_json::from_str(&contents)?;
190 Ok(model)
191 }
192
193 pub fn save_to_binary<P: AsRef<Path>>(
195 model: &SavedModel,
196 path: P,
197 ) -> Result<(), PersistenceError> {
198 let encoded = bincode::serialize(model)?;
199 let mut file = File::create(path)?;
200 file.write_all(&encoded)?;
201 Ok(())
202 }
203
204 pub fn load_from_binary<P: AsRef<Path>>(
206 path: P,
207 ) -> Result<SavedModel, PersistenceError> {
208 let mut file = File::open(path)?;
209 let mut contents = Vec::new();
210 file.read_to_end(&mut contents)?;
211 let model = bincode::deserialize(&contents)?;
212 Ok(model)
213 }
214
215 pub fn create_saved_model(
217 network: &LSTMNetwork,
218 model_name: String,
219 total_epochs: usize,
220 final_loss: Option<f64>,
221 description: Option<String>,
222 ) -> SavedModel {
223 let metadata = ModelMetadata {
224 model_name,
225 version: env!("CARGO_PKG_VERSION").to_string(),
226 created_at: chrono::Utc::now().to_rfc3339(),
227 input_size: network.input_size,
228 hidden_size: network.hidden_size,
229 num_layers: network.num_layers,
230 total_epochs,
231 final_loss,
232 description,
233 };
234
235 SavedModel {
236 network: network.into(),
237 metadata,
238 }
239 }
240}
241
242pub trait PersistentModel {
244 fn save<P: AsRef<Path>>(&self, path: P, metadata: ModelMetadata) -> Result<(), PersistenceError>;
246
247 fn load<P: AsRef<Path>>(path: P) -> Result<(Self, ModelMetadata), PersistenceError>
249 where
250 Self: Sized;
251}
252
253impl PersistentModel for LSTMNetwork {
254 fn save<P: AsRef<Path>>(&self, path: P, metadata: ModelMetadata) -> Result<(), PersistenceError> {
255 let saved_model = SavedModel {
256 network: self.into(),
257 metadata,
258 };
259
260 let path_ref = path.as_ref();
261 match path_ref.extension().and_then(|s| s.to_str()) {
262 Some("json") => ModelPersistence::save_to_json(&saved_model, path),
263 Some("bin") | Some("model") => ModelPersistence::save_to_binary(&saved_model, path),
264 _ => ModelPersistence::save_to_binary(&saved_model, path), }
266 }
267
268 fn load<P: AsRef<Path>>(path: P) -> Result<(Self, ModelMetadata), PersistenceError> {
269 let path_ref = path.as_ref();
270 let saved_model = match path_ref.extension().and_then(|s| s.to_str()) {
271 Some("json") => ModelPersistence::load_from_json(path)?,
272 Some("bin") | Some("model") => ModelPersistence::load_from_binary(path)?,
273 _ => ModelPersistence::load_from_binary(path)?, };
275
276 Ok((saved_model.network.into(), saved_model.metadata))
277 }
278}