ghostflow_nn/
serialization.rs

1//! Model Serialization
2//!
3//! Save and load trained models for deployment and inference.
4
5use ghostflow_core::tensor::Tensor;
6use std::collections::HashMap;
7use std::fs::File;
8use std::io::{Read, Write, BufReader, BufWriter};
9use std::path::Path;
10
11/// Model checkpoint containing parameters and metadata
12#[derive(Clone, Debug)]
13pub struct ModelCheckpoint {
14    /// Model parameters (weights and biases)
15    pub parameters: HashMap<String, Tensor>,
16    /// Model metadata
17    pub metadata: ModelMetadata,
18    /// Optimizer state (optional)
19    pub optimizer_state: Option<HashMap<String, Vec<f32>>>,
20}
21
22/// Model metadata
23#[derive(Clone, Debug)]
24pub struct ModelMetadata {
25    /// Model name
26    pub name: String,
27    /// Model version
28    pub version: String,
29    /// Framework version
30    pub framework_version: String,
31    /// Training epoch
32    pub epoch: usize,
33    /// Training loss
34    pub loss: f32,
35    /// Additional metadata
36    pub extra: HashMap<String, String>,
37}
38
39impl Default for ModelMetadata {
40    fn default() -> Self {
41        Self {
42            name: "ghostflow_model".to_string(),
43            version: "1.0.0".to_string(),
44            framework_version: env!("CARGO_PKG_VERSION").to_string(),
45            epoch: 0,
46            loss: 0.0,
47            extra: HashMap::new(),
48        }
49    }
50}
51
52impl ModelCheckpoint {
53    /// Create a new checkpoint
54    pub fn new(parameters: HashMap<String, Tensor>) -> Self {
55        Self {
56            parameters,
57            metadata: ModelMetadata::default(),
58            optimizer_state: None,
59        }
60    }
61
62    /// Set metadata
63    pub fn with_metadata(mut self, metadata: ModelMetadata) -> Self {
64        self.metadata = metadata;
65        self
66    }
67
68    /// Set optimizer state
69    pub fn with_optimizer_state(mut self, state: HashMap<String, Vec<f32>>) -> Self {
70        self.optimizer_state = Some(state);
71        self
72    }
73
74    /// Save checkpoint to file
75    pub fn save<P: AsRef<Path>>(&self, path: P) -> std::io::Result<()> {
76        let file = File::create(path)?;
77        let mut writer = BufWriter::new(file);
78
79        // Write magic number
80        writer.write_all(b"GFCP")?; // GhostFlow CheckPoint
81
82        // Write version
83        writer.write_all(&[0, 4, 0])?; // v0.4.0
84
85        // Write metadata
86        self.write_metadata(&mut writer)?;
87
88        // Write parameters
89        self.write_parameters(&mut writer)?;
90
91        // Write optimizer state if present
92        if let Some(ref state) = self.optimizer_state {
93            writer.write_all(&[1])?; // Has optimizer state
94            self.write_optimizer_state(&mut writer, state)?;
95        } else {
96            writer.write_all(&[0])?; // No optimizer state
97        }
98
99        writer.flush()?;
100        Ok(())
101    }
102
103    /// Load checkpoint from file
104    pub fn load<P: AsRef<Path>>(path: P) -> std::io::Result<Self> {
105        let file = File::open(path)?;
106        let mut reader = BufReader::new(file);
107
108        // Read and verify magic number
109        let mut magic = [0u8; 4];
110        reader.read_exact(&mut magic)?;
111        if &magic != b"GFCP" {
112            return Err(std::io::Error::new(
113                std::io::ErrorKind::InvalidData,
114                "Invalid checkpoint file format",
115            ));
116        }
117
118        // Read version
119        let mut version = [0u8; 3];
120        reader.read_exact(&mut version)?;
121
122        // Read metadata
123        let metadata = Self::read_metadata(&mut reader)?;
124
125        // Read parameters
126        let parameters = Self::read_parameters(&mut reader)?;
127
128        // Read optimizer state if present
129        let mut has_optimizer = [0u8; 1];
130        reader.read_exact(&mut has_optimizer)?;
131        let optimizer_state = if has_optimizer[0] == 1 {
132            Some(Self::read_optimizer_state(&mut reader)?)
133        } else {
134            None
135        };
136
137        Ok(Self {
138            parameters,
139            metadata,
140            optimizer_state,
141        })
142    }
143
144    fn write_metadata<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
145        // Write name
146        self.write_string(writer, &self.metadata.name)?;
147        // Write version
148        self.write_string(writer, &self.metadata.version)?;
149        // Write framework version
150        self.write_string(writer, &self.metadata.framework_version)?;
151        // Write epoch
152        writer.write_all(&self.metadata.epoch.to_le_bytes())?;
153        // Write loss
154        writer.write_all(&self.metadata.loss.to_le_bytes())?;
155        // Write extra metadata count
156        writer.write_all(&(self.metadata.extra.len() as u32).to_le_bytes())?;
157        for (key, value) in &self.metadata.extra {
158            self.write_string(writer, key)?;
159            self.write_string(writer, value)?;
160        }
161        Ok(())
162    }
163
164    fn read_metadata<R: Read>(reader: &mut R) -> std::io::Result<ModelMetadata> {
165        let name = Self::read_string(reader)?;
166        let version = Self::read_string(reader)?;
167        let framework_version = Self::read_string(reader)?;
168        
169        let mut epoch_bytes = [0u8; 8];
170        reader.read_exact(&mut epoch_bytes)?;
171        let epoch = usize::from_le_bytes(epoch_bytes);
172        
173        let mut loss_bytes = [0u8; 4];
174        reader.read_exact(&mut loss_bytes)?;
175        let loss = f32::from_le_bytes(loss_bytes);
176        
177        let mut count_bytes = [0u8; 4];
178        reader.read_exact(&mut count_bytes)?;
179        let count = u32::from_le_bytes(count_bytes) as usize;
180        
181        let mut extra = HashMap::new();
182        for _ in 0..count {
183            let key = Self::read_string(reader)?;
184            let value = Self::read_string(reader)?;
185            extra.insert(key, value);
186        }
187
188        Ok(ModelMetadata {
189            name,
190            version,
191            framework_version,
192            epoch,
193            loss,
194            extra,
195        })
196    }
197
198    fn write_parameters<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
199        // Write parameter count
200        writer.write_all(&(self.parameters.len() as u32).to_le_bytes())?;
201
202        for (name, tensor) in &self.parameters {
203            // Write parameter name
204            self.write_string(writer, name)?;
205
206            // Write tensor shape
207            let shape = tensor.shape().dims();
208            writer.write_all(&(shape.len() as u32).to_le_bytes())?;
209            for &dim in shape {
210                writer.write_all(&(dim as u64).to_le_bytes())?;
211            }
212
213            // Write tensor data
214            let data = tensor.storage().as_slice::<f32>();
215            writer.write_all(&(data.len() as u64).to_le_bytes())?;
216            for &value in data.iter() {
217                writer.write_all(&value.to_le_bytes())?;
218            }
219        }
220
221        Ok(())
222    }
223
224    fn read_parameters<R: Read>(reader: &mut R) -> std::io::Result<HashMap<String, Tensor>> {
225        let mut count_bytes = [0u8; 4];
226        reader.read_exact(&mut count_bytes)?;
227        let count = u32::from_le_bytes(count_bytes) as usize;
228
229        let mut parameters = HashMap::new();
230
231        for _ in 0..count {
232            // Read parameter name
233            let name = Self::read_string(reader)?;
234
235            // Read tensor shape
236            let mut shape_len_bytes = [0u8; 4];
237            reader.read_exact(&mut shape_len_bytes)?;
238            let shape_len = u32::from_le_bytes(shape_len_bytes) as usize;
239
240            let mut shape = Vec::with_capacity(shape_len);
241            for _ in 0..shape_len {
242                let mut dim_bytes = [0u8; 8];
243                reader.read_exact(&mut dim_bytes)?;
244                shape.push(u64::from_le_bytes(dim_bytes) as usize);
245            }
246
247            // Read tensor data
248            let mut data_len_bytes = [0u8; 8];
249            reader.read_exact(&mut data_len_bytes)?;
250            let data_len = u64::from_le_bytes(data_len_bytes) as usize;
251
252            let mut data = Vec::with_capacity(data_len);
253            for _ in 0..data_len {
254                let mut value_bytes = [0u8; 4];
255                reader.read_exact(&mut value_bytes)?;
256                data.push(f32::from_le_bytes(value_bytes));
257            }
258
259            let tensor = Tensor::from_slice(&data, &shape)
260                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
261            
262            parameters.insert(name, tensor);
263        }
264
265        Ok(parameters)
266    }
267
268    fn write_optimizer_state<W: Write>(
269        &self,
270        writer: &mut W,
271        state: &HashMap<String, Vec<f32>>,
272    ) -> std::io::Result<()> {
273        writer.write_all(&(state.len() as u32).to_le_bytes())?;
274        
275        for (name, values) in state {
276            self.write_string(writer, name)?;
277            writer.write_all(&(values.len() as u64).to_le_bytes())?;
278            for &value in values {
279                writer.write_all(&value.to_le_bytes())?;
280            }
281        }
282
283        Ok(())
284    }
285
286    fn read_optimizer_state<R: Read>(reader: &mut R) -> std::io::Result<HashMap<String, Vec<f32>>> {
287        let mut count_bytes = [0u8; 4];
288        reader.read_exact(&mut count_bytes)?;
289        let count = u32::from_le_bytes(count_bytes) as usize;
290
291        let mut state = HashMap::new();
292
293        for _ in 0..count {
294            let name = Self::read_string(reader)?;
295            
296            let mut len_bytes = [0u8; 8];
297            reader.read_exact(&mut len_bytes)?;
298            let len = u64::from_le_bytes(len_bytes) as usize;
299
300            let mut values = Vec::with_capacity(len);
301            for _ in 0..len {
302                let mut value_bytes = [0u8; 4];
303                reader.read_exact(&mut value_bytes)?;
304                values.push(f32::from_le_bytes(value_bytes));
305            }
306
307            state.insert(name, values);
308        }
309
310        Ok(state)
311    }
312
313    fn write_string<W: Write>(&self, writer: &mut W, s: &str) -> std::io::Result<()> {
314        let bytes = s.as_bytes();
315        writer.write_all(&(bytes.len() as u32).to_le_bytes())?;
316        writer.write_all(bytes)?;
317        Ok(())
318    }
319
320    fn read_string<R: Read>(reader: &mut R) -> std::io::Result<String> {
321        let mut len_bytes = [0u8; 4];
322        reader.read_exact(&mut len_bytes)?;
323        let len = u32::from_le_bytes(len_bytes) as usize;
324
325        let mut bytes = vec![0u8; len];
326        reader.read_exact(&mut bytes)?;
327
328        String::from_utf8(bytes)
329            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
330    }
331}
332
333/// Simplified save/load functions
334pub fn save_model<P: AsRef<Path>>(
335    path: P,
336    parameters: HashMap<String, Tensor>,
337) -> std::io::Result<()> {
338    let checkpoint = ModelCheckpoint::new(parameters);
339    checkpoint.save(path)
340}
341
342pub fn load_model<P: AsRef<Path>>(path: P) -> std::io::Result<HashMap<String, Tensor>> {
343    let checkpoint = ModelCheckpoint::load(path)?;
344    Ok(checkpoint.parameters)
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use std::fs;
351
352    #[test]
353    fn test_save_load_checkpoint() {
354        let mut parameters = HashMap::new();
355        parameters.insert(
356            "weight".to_string(),
357            Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
358        );
359        parameters.insert(
360            "bias".to_string(),
361            Tensor::from_slice(&[0.5f32, 0.6], &[2]).unwrap(),
362        );
363
364        let checkpoint = ModelCheckpoint::new(parameters.clone());
365        
366        let path = "test_checkpoint.gfcp";
367        checkpoint.save(path).unwrap();
368
369        let loaded = ModelCheckpoint::load(path).unwrap();
370        
371        assert_eq!(loaded.parameters.len(), 2);
372        assert!(loaded.parameters.contains_key("weight"));
373        assert!(loaded.parameters.contains_key("bias"));
374
375        // Cleanup
376        fs::remove_file(path).ok();
377    }
378
379    #[test]
380    fn test_checkpoint_with_metadata() {
381        let mut parameters = HashMap::new();
382        parameters.insert(
383            "layer1".to_string(),
384            Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap(),
385        );
386
387        let mut metadata = ModelMetadata::default();
388        metadata.name = "test_model".to_string();
389        metadata.epoch = 10;
390        metadata.loss = 0.123;
391
392        let checkpoint = ModelCheckpoint::new(parameters)
393            .with_metadata(metadata);
394
395        let path = "test_metadata.gfcp";
396        checkpoint.save(path).unwrap();
397
398        let loaded = ModelCheckpoint::load(path).unwrap();
399        
400        assert_eq!(loaded.metadata.name, "test_model");
401        assert_eq!(loaded.metadata.epoch, 10);
402        assert!((loaded.metadata.loss - 0.123).abs() < 0.001);
403
404        // Cleanup
405        fs::remove_file(path).ok();
406    }
407
408    #[test]
409    fn test_checkpoint_with_optimizer_state() {
410        let mut parameters = HashMap::new();
411        parameters.insert(
412            "weight".to_string(),
413            Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap(),
414        );
415
416        let mut optimizer_state = HashMap::new();
417        optimizer_state.insert("momentum".to_string(), vec![0.1f32, 0.2]);
418
419        let checkpoint = ModelCheckpoint::new(parameters)
420            .with_optimizer_state(optimizer_state);
421
422        let path = "test_optimizer.gfcp";
423        checkpoint.save(path).unwrap();
424
425        let loaded = ModelCheckpoint::load(path).unwrap();
426        
427        assert!(loaded.optimizer_state.is_some());
428        let state = loaded.optimizer_state.unwrap();
429        assert!(state.contains_key("momentum"));
430
431        // Cleanup
432        fs::remove_file(path).ok();
433    }
434
435    #[test]
436    fn test_simple_save_load() {
437        let mut parameters = HashMap::new();
438        parameters.insert(
439            "test".to_string(),
440            Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap(),
441        );
442
443        let path = "test_simple.gfcp";
444        save_model(path, parameters.clone()).unwrap();
445
446        let loaded = load_model(path).unwrap();
447        
448        assert_eq!(loaded.len(), 1);
449        assert!(loaded.contains_key("test"));
450
451        // Cleanup
452        fs::remove_file(path).ok();
453    }
454}