rust_lstm/
persistence.rs

1use serde::{Serialize, Deserialize};
2use ndarray::{Array2, Dimension};
3use std::fs::File;
4use std::io::{Write, Read};
5use std::path::Path;
6
7use crate::models::lstm_network::LSTMNetwork;
8use crate::layers::lstm_cell::LSTMCell;
9
10/// Serializable version of Array2<f64> for persistence
11#[derive(Serialize, Deserialize, Debug, Clone)]
12struct SerializableArray2 {
13    data: Vec<f64>,
14    shape: (usize, usize),
15}
16
17impl From<&Array2<f64>> for SerializableArray2 {
18    fn from(array: &Array2<f64>) -> Self {
19        Self {
20            data: array.iter().cloned().collect(),
21            shape: array.raw_dim().into_pattern(),
22        }
23    }
24}
25
26impl Into<Array2<f64>> for SerializableArray2 {
27    fn into(self) -> Array2<f64> {
28        Array2::from_shape_vec(self.shape, self.data)
29            .expect("Failed to reconstruct Array2 from serialized data")
30    }
31}
32
33/// Serializable LSTM cell parameters
34#[derive(Serialize, Deserialize, Debug, Clone)]
35pub struct SerializableLSTMCell {
36    w_ih: SerializableArray2,
37    w_hh: SerializableArray2,
38    b_ih: SerializableArray2,
39    b_hh: SerializableArray2,
40    hidden_size: usize,
41}
42
43impl From<&LSTMCell> for SerializableLSTMCell {
44    fn from(cell: &LSTMCell) -> Self {
45        Self {
46            w_ih: (&cell.w_ih).into(),
47            w_hh: (&cell.w_hh).into(),
48            b_ih: (&cell.b_ih).into(),
49            b_hh: (&cell.b_hh).into(),
50            hidden_size: cell.hidden_size,
51        }
52    }
53}
54
55impl Into<LSTMCell> for SerializableLSTMCell {
56    fn into(self) -> LSTMCell {
57        LSTMCell {
58            w_ih: self.w_ih.into(),
59            w_hh: self.w_hh.into(),
60            b_ih: self.b_ih.into(),
61            b_hh: self.b_hh.into(),
62            hidden_size: self.hidden_size,
63            input_dropout: None,
64            recurrent_dropout: None,
65            output_dropout: None,
66            zoneout: None,
67            is_training: true,
68        }
69    }
70}
71
72/// Serializable LSTM network
73#[derive(Serialize, Deserialize, Debug, Clone)]
74pub struct SerializableLSTMNetwork {
75    cells: Vec<SerializableLSTMCell>,
76    input_size: usize,
77    hidden_size: usize,
78    num_layers: usize,
79}
80
81impl From<&LSTMNetwork> for SerializableLSTMNetwork {
82    fn from(network: &LSTMNetwork) -> Self {
83        Self {
84            cells: network.get_cells().iter().map(|cell| cell.into()).collect(),
85            input_size: network.input_size,
86            hidden_size: network.hidden_size,
87            num_layers: network.num_layers,
88        }
89    }
90}
91
92impl Into<LSTMNetwork> for SerializableLSTMNetwork {
93    fn into(self) -> LSTMNetwork {
94        LSTMNetwork::from_cells(
95            self.cells.into_iter().map(|cell| cell.into()).collect(),
96            self.input_size,
97            self.hidden_size,
98            self.num_layers,
99        )
100    }
101}
102
103/// Model metadata for tracking training information
104#[derive(Serialize, Deserialize, Clone)]
105pub struct ModelMetadata {
106    pub model_name: String,
107    pub version: String,
108    pub created_at: String,
109    pub input_size: usize,
110    pub hidden_size: usize,
111    pub num_layers: usize,
112    pub total_epochs: usize,
113    pub final_loss: Option<f64>,
114    pub description: Option<String>,
115}
116
117/// Complete saved model including network and metadata
118#[derive(Serialize, Deserialize)]
119pub struct SavedModel {
120    pub network: SerializableLSTMNetwork,
121    pub metadata: ModelMetadata,
122}
123
124/// Errors that can occur during model persistence operations
125#[derive(Debug)]
126pub enum PersistenceError {
127    IoError(std::io::Error),
128    SerializationError(String),
129}
130
131impl std::fmt::Display for PersistenceError {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match self {
134            PersistenceError::IoError(err) => write!(f, "IO error: {}", err),
135            PersistenceError::SerializationError(err) => write!(f, "Serialization error: {}", err),
136        }
137    }
138}
139
140impl std::error::Error for PersistenceError {
141    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
142        match self {
143            PersistenceError::IoError(err) => Some(err),
144            PersistenceError::SerializationError(_) => None,
145        }
146    }
147}
148
149impl From<std::io::Error> for PersistenceError {
150    fn from(error: std::io::Error) -> Self {
151        PersistenceError::IoError(error)
152    }
153}
154
155impl From<serde_json::Error> for PersistenceError {
156    fn from(error: serde_json::Error) -> Self {
157        PersistenceError::SerializationError(error.to_string())
158    }
159}
160
161impl From<bincode::Error> for PersistenceError {
162    fn from(error: bincode::Error) -> Self {
163        PersistenceError::SerializationError(error.to_string())
164    }
165}
166
167/// Model persistence operations
168pub struct ModelPersistence;
169
170impl ModelPersistence {
171    /// Save model to JSON format (human-readable)
172    pub fn save_to_json<P: AsRef<Path>>(
173        model: &SavedModel,
174        path: P,
175    ) -> Result<(), PersistenceError> {
176        let json = serde_json::to_string_pretty(model)?;
177        let mut file = File::create(path)?;
178        file.write_all(json.as_bytes())?;
179        Ok(())
180    }
181
182    /// Load model from JSON format
183    pub fn load_from_json<P: AsRef<Path>>(
184        path: P,
185    ) -> Result<SavedModel, PersistenceError> {
186        let mut file = File::open(path)?;
187        let mut contents = String::new();
188        file.read_to_string(&mut contents)?;
189        let model = serde_json::from_str(&contents)?;
190        Ok(model)
191    }
192
193    /// Save model to binary format (compact and fast)
194    pub fn save_to_binary<P: AsRef<Path>>(
195        model: &SavedModel,
196        path: P,
197    ) -> Result<(), PersistenceError> {
198        let encoded = bincode::serialize(model)?;
199        let mut file = File::create(path)?;
200        file.write_all(&encoded)?;
201        Ok(())
202    }
203
204    /// Load model from binary format
205    pub fn load_from_binary<P: AsRef<Path>>(
206        path: P,
207    ) -> Result<SavedModel, PersistenceError> {
208        let mut file = File::open(path)?;
209        let mut contents = Vec::new();
210        file.read_to_end(&mut contents)?;
211        let model = bincode::deserialize(&contents)?;
212        Ok(model)
213    }
214
215    /// Create a model with metadata
216    pub fn create_saved_model(
217        network: &LSTMNetwork,
218        model_name: String,
219        total_epochs: usize,
220        final_loss: Option<f64>,
221        description: Option<String>,
222    ) -> SavedModel {
223        let metadata = ModelMetadata {
224            model_name,
225            version: env!("CARGO_PKG_VERSION").to_string(),
226            created_at: chrono::Utc::now().to_rfc3339(),
227            input_size: network.input_size,
228            hidden_size: network.hidden_size,
229            num_layers: network.num_layers,
230            total_epochs,
231            final_loss,
232            description,
233        };
234
235        SavedModel {
236            network: network.into(),
237            metadata,
238        }
239    }
240}
241
242/// Convenience trait for easy model saving/loading
243pub trait PersistentModel {
244    /// Save model to file (format determined by file extension)
245    fn save<P: AsRef<Path>>(&self, path: P, metadata: ModelMetadata) -> Result<(), PersistenceError>;
246    
247    /// Load model from file (format determined by file extension)
248    fn load<P: AsRef<Path>>(path: P) -> Result<(Self, ModelMetadata), PersistenceError>
249    where
250        Self: Sized;
251}
252
253impl PersistentModel for LSTMNetwork {
254    fn save<P: AsRef<Path>>(&self, path: P, metadata: ModelMetadata) -> Result<(), PersistenceError> {
255        let saved_model = SavedModel {
256            network: self.into(),
257            metadata,
258        };
259
260        let path_ref = path.as_ref();
261        match path_ref.extension().and_then(|s| s.to_str()) {
262            Some("json") => ModelPersistence::save_to_json(&saved_model, path),
263            Some("bin") | Some("model") => ModelPersistence::save_to_binary(&saved_model, path),
264            _ => ModelPersistence::save_to_binary(&saved_model, path), // Default to binary
265        }
266    }
267
268    fn load<P: AsRef<Path>>(path: P) -> Result<(Self, ModelMetadata), PersistenceError> {
269        let path_ref = path.as_ref();
270        let saved_model = match path_ref.extension().and_then(|s| s.to_str()) {
271            Some("json") => ModelPersistence::load_from_json(path)?,
272            Some("bin") | Some("model") => ModelPersistence::load_from_binary(path)?,
273            _ => ModelPersistence::load_from_binary(path)?, // Default to binary
274        };
275
276        Ok((saved_model.network.into(), saved_model.metadata))
277    }
278}