Skip to main content

numrs2/new_modules/model_io/
export.rs

1//! Model Export to Various Formats
2//!
3//! Provides export functionality for NumRS2 models to different formats
4//! including JSON, MessagePack, and NPY/NPZ for interoperability.
5
6use super::format::{FormatResult, LayerData, NumRS2Model};
7use crate::error::NumRs2Error;
8use oxiarc_archive::zip::ZipCompressionLevel;
9use scirs2_core::ndarray::{Array1, Array2};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufWriter, Write};
14use std::path::Path;
15
16/// Export format enumeration
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ExportFormat {
19    /// JSON (human-readable)
20    Json,
21    /// MessagePack (compact binary)
22    MessagePack,
23    /// NumPy NPY format (single array)
24    Npy,
25    /// NumPy NPZ format (multiple arrays)
26    Npz,
27}
28
29/// Model exporter for various formats
30pub struct ModelExporter;
31
32impl ModelExporter {
33    /// Exports model to JSON format
34    ///
35    /// Creates a human-readable JSON representation of the model including
36    /// metadata, architecture, and layer configuration (but not weights).
37    ///
38    /// # Arguments
39    ///
40    /// * `model` - The model to export
41    /// * `path` - Output file path
42    pub fn export_json<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
43        let export_data = ModelExportData::from_model(model);
44
45        let file = File::create(path.as_ref())
46            .map_err(|e| NumRs2Error::IOError(format!("Failed to create JSON file: {}", e)))?;
47
48        let writer = BufWriter::new(file);
49
50        serde_json::to_writer_pretty(writer, &export_data).map_err(|e| {
51            NumRs2Error::SerializationError(format!("Failed to serialize to JSON: {}", e))
52        })?;
53
54        Ok(())
55    }
56
57    /// Exports model to MessagePack format
58    ///
59    /// Creates a compact binary representation using MessagePack.
60    ///
61    /// # Arguments
62    ///
63    /// * `model` - The model to export
64    /// * `path` - Output file path
65    #[cfg(feature = "messagepack")]
66    pub fn export_messagepack<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
67        let export_data = ModelExportData::from_model(model);
68
69        let file = File::create(path.as_ref()).map_err(|e| {
70            NumRs2Error::IOError(format!("Failed to create MessagePack file: {}", e))
71        })?;
72
73        let mut writer = BufWriter::new(file);
74
75        let bytes = rmp_serde::to_vec(&export_data).map_err(|e| {
76            NumRs2Error::SerializationError(format!("Failed to serialize to MessagePack: {}", e))
77        })?;
78
79        writer.write_all(&bytes).map_err(|e| {
80            NumRs2Error::IOError(format!("Failed to write MessagePack data: {}", e))
81        })?;
82
83        Ok(())
84    }
85
86    /// Exports model to MessagePack format (stub when feature is not enabled)
87    #[cfg(not(feature = "messagepack"))]
88    pub fn export_messagepack<P: AsRef<Path>>(_model: &NumRS2Model, _path: P) -> FormatResult<()> {
89        Err(NumRs2Error::FeatureNotEnabled(
90            "MessagePack export requires 'messagepack' feature".to_string(),
91        ))
92    }
93
94    /// Exports model weights to NPZ format
95    ///
96    /// Saves all layer weights and biases as separate arrays in a compressed NPZ file.
97    ///
98    /// # Arguments
99    ///
100    /// * `model` - The model to export
101    /// * `path` - Output file path
102    pub fn export_weights_npz<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
103        use byteorder::{ByteOrder, LittleEndian};
104        use oxiarc_archive::zip::{ZipCompressionLevel, ZipWriter};
105
106        let file = File::create(path.as_ref())
107            .map_err(|e| NumRs2Error::IOError(format!("Failed to create NPZ file: {}", e)))?;
108
109        let mut zip = ZipWriter::new(file);
110
111        // Export each layer's weights
112        for (i, layer) in model.layers.iter().enumerate() {
113            // Export weights
114            let weights_name = format!("layer_{}_weights.npy", i);
115            Self::write_npy_to_zip(&mut zip, &weights_name, &layer.weights)?;
116
117            // Export bias if present
118            if let Some(ref bias) = layer.bias {
119                let bias_name = format!("layer_{}_bias.npy", i);
120                Self::write_npy_to_zip(&mut zip, &bias_name, bias)?;
121            }
122        }
123
124        zip.finish()
125            .map_err(|e| NumRs2Error::IOError(format!("Failed to finish NPZ file: {}", e)))?;
126
127        Ok(())
128    }
129
130    /// Helper function to write NPY data to ZIP
131    fn write_npy_to_zip(
132        zip: &mut oxiarc_archive::zip::ZipWriter<File>,
133        name: &str,
134        data: &[u8],
135    ) -> FormatResult<()> {
136        zip.add_file(name, data)
137            .map_err(|e| NumRs2Error::IOError(format!("Failed to add file to ZIP: {}", e)))?;
138
139        Ok(())
140    }
141
142    /// Exports a single layer's weights to NPY format
143    ///
144    /// # Arguments
145    ///
146    /// * `weights` - Weight array to export
147    /// * `path` - Output file path
148    pub fn export_weights_npy<P: AsRef<Path>>(weights: &Array2<f64>, path: P) -> FormatResult<()> {
149        let file = File::create(path.as_ref())
150            .map_err(|e| NumRs2Error::IOError(format!("Failed to create NPY file: {}", e)))?;
151
152        let mut writer = BufWriter::new(file);
153
154        // Create NPY header
155        let header = Self::create_npy_header(weights.shape(), "f8")?;
156        writer
157            .write_all(&header)
158            .map_err(|e| NumRs2Error::IOError(format!("Failed to write NPY header: {}", e)))?;
159
160        // Write data in C-order (row-major)
161        use byteorder::{LittleEndian, WriteBytesExt};
162        for &value in weights.iter() {
163            writer
164                .write_f64::<LittleEndian>(value)
165                .map_err(|e| NumRs2Error::IOError(format!("Failed to write NPY data: {}", e)))?;
166        }
167
168        Ok(())
169    }
170
171    /// Creates NPY header
172    fn create_npy_header(shape: &[usize], dtype: &str) -> FormatResult<Vec<u8>> {
173        use byteorder::{LittleEndian, WriteBytesExt};
174
175        // Magic string
176        let magic = b"\x93NUMPY";
177
178        // Version 1.0
179        let version: [u8; 2] = [1, 0];
180
181        // Create dictionary string
182        let mut dict = format!(
183            "{{'descr': '<{}', 'fortran_order': False, 'shape': (",
184            dtype
185        );
186
187        for (i, &dim) in shape.iter().enumerate() {
188            if i > 0 {
189                dict.push_str(", ");
190            }
191            dict.push_str(&dim.to_string());
192
193            // Add trailing comma for 1D arrays
194            if shape.len() == 1 && i == shape.len() - 1 {
195                dict.push(',');
196            }
197        }
198
199        dict.push_str("), }");
200
201        // Pad to make total header length a multiple of 16
202        let header_len = 10 + dict.len(); // 6 (magic) + 2 (version) + 2 (len) + dict
203        let padding = (16 - (header_len % 16)) % 16;
204        dict.push_str(&" ".repeat(padding));
205
206        // Build header
207        let mut header = Vec::new();
208        header.extend_from_slice(magic);
209        header.extend_from_slice(&version);
210
211        // Write header length (little endian)
212        let dict_len = dict.len() as u16;
213        header.write_u16::<LittleEndian>(dict_len).map_err(|e| {
214            NumRs2Error::SerializationError(format!("Failed to write header length: {}", e))
215        })?;
216
217        header.extend_from_slice(dict.as_bytes());
218
219        Ok(header)
220    }
221
222    /// Exports model architecture description only
223    ///
224    /// Creates a JSON file with model architecture information without weights.
225    ///
226    /// # Arguments
227    ///
228    /// * `model` - The model to export
229    /// * `path` - Output file path
230    pub fn export_architecture<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
231        let arch = ArchitectureDescription::from_model(model);
232
233        let file = File::create(path.as_ref()).map_err(|e| {
234            NumRs2Error::IOError(format!("Failed to create architecture file: {}", e))
235        })?;
236
237        let writer = BufWriter::new(file);
238
239        serde_json::to_writer_pretty(writer, &arch).map_err(|e| {
240            NumRs2Error::SerializationError(format!("Failed to serialize architecture: {}", e))
241        })?;
242
243        Ok(())
244    }
245}
246
247/// Model export data (without weight values)
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ModelExportData {
250    /// Model name
251    pub name: String,
252    /// Model version
253    pub version: String,
254    /// Architecture type
255    pub architecture: String,
256    /// Description
257    pub description: Option<String>,
258    /// Hyperparameters
259    pub hyperparameters: HashMap<String, String>,
260    /// Layer information
261    pub layers: Vec<LayerExportInfo>,
262    /// Total parameters
263    pub total_parameters: usize,
264    /// Created timestamp
265    pub created_at: String,
266}
267
268impl ModelExportData {
269    /// Creates export data from a model
270    pub fn from_model(model: &NumRS2Model) -> Self {
271        let layers = model
272            .layers
273            .iter()
274            .map(LayerExportInfo::from_layer)
275            .collect();
276
277        Self {
278            name: model.metadata.name.clone(),
279            version: model.metadata.version.clone(),
280            architecture: model.metadata.architecture.clone(),
281            description: model.metadata.description.clone(),
282            hyperparameters: model.metadata.hyperparameters.clone(),
283            layers,
284            total_parameters: model.num_parameters(),
285            created_at: model.metadata.created_at.clone(),
286        }
287    }
288}
289
290/// Layer export information (without weights)
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct LayerExportInfo {
293    /// Layer name
294    pub name: String,
295    /// Layer type
296    pub layer_type: String,
297    /// Input shape
298    pub input_shape: Vec<usize>,
299    /// Output shape
300    pub output_shape: Vec<usize>,
301    /// Number of parameters
302    pub num_parameters: usize,
303    /// Activation function
304    pub activation: Option<String>,
305    /// Layer parameters
306    pub parameters: HashMap<String, String>,
307}
308
309impl LayerExportInfo {
310    /// Creates layer export info from layer data
311    pub fn from_layer(layer: &LayerData) -> Self {
312        Self {
313            name: layer.name.clone(),
314            layer_type: format!("{:?}", layer.layer_type),
315            input_shape: layer.input_shape.clone(),
316            output_shape: layer.output_shape.clone(),
317            num_parameters: layer.num_parameters(),
318            activation: layer.activation.map(|a| format!("{:?}", a)),
319            parameters: layer.parameters.clone(),
320        }
321    }
322}
323
324/// Architecture description
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct ArchitectureDescription {
327    /// Architecture name
328    pub name: String,
329    /// Layer descriptions
330    pub layers: Vec<LayerDescription>,
331    /// Total parameters
332    pub total_parameters: usize,
333}
334
335impl ArchitectureDescription {
336    /// Creates architecture description from model
337    pub fn from_model(model: &NumRS2Model) -> Self {
338        let layers = model
339            .layers
340            .iter()
341            .map(LayerDescription::from_layer)
342            .collect();
343
344        Self {
345            name: model.metadata.architecture.clone(),
346            layers,
347            total_parameters: model.num_parameters(),
348        }
349    }
350}
351
352/// Layer description
353#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct LayerDescription {
355    /// Layer type
356    pub layer_type: String,
357    /// Input shape
358    pub input_shape: Vec<usize>,
359    /// Output shape
360    pub output_shape: Vec<usize>,
361    /// Parameters
362    pub parameters: HashMap<String, String>,
363}
364
365impl LayerDescription {
366    /// Creates layer description from layer data
367    pub fn from_layer(layer: &LayerData) -> Self {
368        Self {
369            layer_type: format!("{:?}", layer.layer_type),
370            input_shape: layer.input_shape.clone(),
371            output_shape: layer.output_shape.clone(),
372            parameters: layer.parameters.clone(),
373        }
374    }
375}
376
377/// Convenience function to export model to JSON
378pub fn export_to_json<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
379    ModelExporter::export_json(model, path)
380}
381
382/// Convenience function to export model to MessagePack
383pub fn export_to_messagepack<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
384    ModelExporter::export_messagepack(model, path)
385}
386
387/// Convenience function to export weights to NPZ
388pub fn export_weights_npz<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
389    ModelExporter::export_weights_npz(model, path)
390}
391
392/// Convenience function to export weights to NPY
393pub fn export_weights_npy<P: AsRef<Path>>(weights: &Array2<f64>, path: P) -> FormatResult<()> {
394    ModelExporter::export_weights_npy(weights, path)
395}
396
397/// Convenience function to export architecture
398pub fn export_architecture<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
399    ModelExporter::export_architecture(model, path)
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::new_modules::model_io::format::{LayerData, ModelMetadata};
406    use scirs2_core::ndarray::Array2;
407    use std::env;
408    use std::fs;
409
410    #[test]
411    fn test_export_json() {
412        let temp_dir = env::temp_dir();
413        let path = temp_dir.join("test_export.json");
414
415        let metadata = ModelMetadata::builder()
416            .name("test_model")
417            .version("1.0.0")
418            .architecture("MLP")
419            .description("Test model for export")
420            .hyperparameter("hidden_size", "128")
421            .build()
422            .expect("test: valid metadata build");
423
424        let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
425        let model = NumRS2Model::new(metadata, vec![layer]);
426
427        let result = ModelExporter::export_json(&model, &path);
428        assert!(result.is_ok());
429
430        // Verify file exists and is valid JSON
431        assert!(path.exists());
432        let contents = fs::read_to_string(&path).expect("test: valid file read");
433        let parsed: serde_json::Value =
434            serde_json::from_str(&contents).expect("test: valid JSON parse");
435        assert_eq!(parsed["name"], "test_model");
436        assert_eq!(parsed["architecture"], "MLP");
437
438        // Cleanup
439        let _ = fs::remove_file(path);
440    }
441
442    #[test]
443    fn test_export_architecture() {
444        let temp_dir = env::temp_dir();
445        let path = temp_dir.join("test_architecture.json");
446
447        let metadata = ModelMetadata::builder()
448            .name("test_model")
449            .architecture("Transformer")
450            .build()
451            .expect("test: valid metadata build");
452
453        let layer1 = LayerData::dense("layer1", Array2::ones((512, 256)), None);
454        let layer2 = LayerData::dense("layer2", Array2::ones((256, 128)), None);
455        let model = NumRS2Model::new(metadata, vec![layer1, layer2]);
456
457        let result = ModelExporter::export_architecture(&model, &path);
458        assert!(result.is_ok());
459
460        // Verify file exists
461        assert!(path.exists());
462
463        // Cleanup
464        let _ = fs::remove_file(path);
465    }
466
467    #[test]
468    fn test_export_weights_npy() {
469        let temp_dir = env::temp_dir();
470        let path = temp_dir.join("test_weights.npy");
471
472        let weights = Array2::from_shape_fn((5, 3), |(i, j)| (i * 3 + j) as f64);
473
474        let result = ModelExporter::export_weights_npy(&weights, &path);
475        assert!(result.is_ok());
476
477        // Verify file exists
478        assert!(path.exists());
479
480        // Cleanup
481        let _ = fs::remove_file(path);
482    }
483
484    #[test]
485    fn test_npy_header_creation() {
486        let shape = vec![3, 4];
487        let header = ModelExporter::create_npy_header(&shape, "f8");
488        assert!(header.is_ok());
489
490        let header = header.expect("test: valid NPY header creation");
491        assert!(header.starts_with(b"\x93NUMPY"));
492        assert!(header.len().is_multiple_of(16)); // Should be aligned to 16 bytes
493    }
494
495    #[test]
496    fn test_model_export_data_creation() {
497        let metadata = ModelMetadata::builder()
498            .name("test_model")
499            .version("1.0.0")
500            .architecture("CNN")
501            .hyperparameter("kernel_size", "3")
502            .build()
503            .expect("test: valid metadata build");
504
505        let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
506        let model = NumRS2Model::new(metadata, vec![layer]);
507
508        let export_data = ModelExportData::from_model(&model);
509
510        assert_eq!(export_data.name, "test_model");
511        assert_eq!(export_data.version, "1.0.0");
512        assert_eq!(export_data.architecture, "CNN");
513        assert_eq!(export_data.layers.len(), 1);
514        assert!(export_data.total_parameters > 0);
515    }
516
517    #[test]
518    fn test_layer_export_info() {
519        let weights = Array2::ones((10, 5));
520        let layer = LayerData::dense("test_layer", weights, None);
521
522        let info = LayerExportInfo::from_layer(&layer);
523
524        assert_eq!(info.name, "test_layer");
525        assert_eq!(info.layer_type, "Dense");
526        assert_eq!(info.input_shape, vec![10]);
527        assert_eq!(info.output_shape, vec![5]);
528        assert!(info.num_parameters > 0);
529    }
530
531    #[test]
532    fn test_architecture_description() {
533        let metadata = ModelMetadata::builder()
534            .name("test_model")
535            .architecture("ResNet")
536            .build()
537            .expect("test: valid metadata build");
538
539        let layer1 = LayerData::dense("layer1", Array2::ones((256, 128)), None);
540        let layer2 = LayerData::dense("layer2", Array2::ones((128, 64)), None);
541        let model = NumRS2Model::new(metadata, vec![layer1, layer2]);
542
543        let arch = ArchitectureDescription::from_model(&model);
544
545        assert_eq!(arch.name, "ResNet");
546        assert_eq!(arch.layers.len(), 2);
547        assert!(arch.total_parameters > 0);
548    }
549
550    #[test]
551    fn test_export_format_enum() {
552        assert_ne!(ExportFormat::Json, ExportFormat::MessagePack);
553        assert_ne!(ExportFormat::Npy, ExportFormat::Npz);
554    }
555
556    #[test]
557    fn test_convenience_functions() {
558        let temp_dir = env::temp_dir();
559        let json_path = temp_dir.join("test_convenience.json");
560
561        let metadata = ModelMetadata::builder()
562            .name("test_model")
563            .build()
564            .expect("test: valid metadata build");
565
566        let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
567        let model = NumRS2Model::new(metadata, vec![layer]);
568
569        // Test JSON export
570        let result = export_to_json(&model, &json_path);
571        assert!(result.is_ok());
572
573        // Test architecture export
574        let arch_path = temp_dir.join("test_arch.json");
575        let result = export_architecture(&model, &arch_path);
576        assert!(result.is_ok());
577
578        // Cleanup
579        let _ = fs::remove_file(json_path);
580        let _ = fs::remove_file(arch_path);
581    }
582
583    #[test]
584    fn test_export_weights_npz() {
585        let temp_dir = env::temp_dir();
586        let path = temp_dir.join("test_weights.npz");
587
588        let metadata = ModelMetadata::builder()
589            .name("test_model")
590            .build()
591            .expect("test: valid metadata build");
592
593        let layer1 = LayerData::dense("layer1", Array2::ones((10, 5)), Some(Array1::zeros(5)));
594        let layer2 = LayerData::dense("layer2", Array2::ones((5, 2)), None);
595        let model = NumRS2Model::new(metadata, vec![layer1, layer2]);
596
597        let result = ModelExporter::export_weights_npz(&model, &path);
598        assert!(result.is_ok());
599
600        // Verify file exists
601        assert!(path.exists());
602
603        // Cleanup
604        let _ = fs::remove_file(path);
605    }
606}