1use super::format::{FormatResult, LayerData, NumRS2Model};
7use crate::error::NumRs2Error;
8use oxiarc_archive::zip::ZipCompressionLevel;
9use scirs2_core::ndarray::{Array1, Array2};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufWriter, Write};
14use std::path::Path;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ExportFormat {
19 Json,
21 MessagePack,
23 Npy,
25 Npz,
27}
28
29pub struct ModelExporter;
31
32impl ModelExporter {
33 pub fn export_json<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
43 let export_data = ModelExportData::from_model(model);
44
45 let file = File::create(path.as_ref())
46 .map_err(|e| NumRs2Error::IOError(format!("Failed to create JSON file: {}", e)))?;
47
48 let writer = BufWriter::new(file);
49
50 serde_json::to_writer_pretty(writer, &export_data).map_err(|e| {
51 NumRs2Error::SerializationError(format!("Failed to serialize to JSON: {}", e))
52 })?;
53
54 Ok(())
55 }
56
57 #[cfg(feature = "messagepack")]
66 pub fn export_messagepack<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
67 let export_data = ModelExportData::from_model(model);
68
69 let file = File::create(path.as_ref()).map_err(|e| {
70 NumRs2Error::IOError(format!("Failed to create MessagePack file: {}", e))
71 })?;
72
73 let mut writer = BufWriter::new(file);
74
75 let bytes = rmp_serde::to_vec(&export_data).map_err(|e| {
76 NumRs2Error::SerializationError(format!("Failed to serialize to MessagePack: {}", e))
77 })?;
78
79 writer.write_all(&bytes).map_err(|e| {
80 NumRs2Error::IOError(format!("Failed to write MessagePack data: {}", e))
81 })?;
82
83 Ok(())
84 }
85
86 #[cfg(not(feature = "messagepack"))]
88 pub fn export_messagepack<P: AsRef<Path>>(_model: &NumRS2Model, _path: P) -> FormatResult<()> {
89 Err(NumRs2Error::FeatureNotEnabled(
90 "MessagePack export requires 'messagepack' feature".to_string(),
91 ))
92 }
93
94 pub fn export_weights_npz<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
103 use byteorder::{ByteOrder, LittleEndian};
104 use oxiarc_archive::zip::{ZipCompressionLevel, ZipWriter};
105
106 let file = File::create(path.as_ref())
107 .map_err(|e| NumRs2Error::IOError(format!("Failed to create NPZ file: {}", e)))?;
108
109 let mut zip = ZipWriter::new(file);
110
111 for (i, layer) in model.layers.iter().enumerate() {
113 let weights_name = format!("layer_{}_weights.npy", i);
115 Self::write_npy_to_zip(&mut zip, &weights_name, &layer.weights)?;
116
117 if let Some(ref bias) = layer.bias {
119 let bias_name = format!("layer_{}_bias.npy", i);
120 Self::write_npy_to_zip(&mut zip, &bias_name, bias)?;
121 }
122 }
123
124 zip.finish()
125 .map_err(|e| NumRs2Error::IOError(format!("Failed to finish NPZ file: {}", e)))?;
126
127 Ok(())
128 }
129
130 fn write_npy_to_zip(
132 zip: &mut oxiarc_archive::zip::ZipWriter<File>,
133 name: &str,
134 data: &[u8],
135 ) -> FormatResult<()> {
136 zip.add_file(name, data)
137 .map_err(|e| NumRs2Error::IOError(format!("Failed to add file to ZIP: {}", e)))?;
138
139 Ok(())
140 }
141
142 pub fn export_weights_npy<P: AsRef<Path>>(weights: &Array2<f64>, path: P) -> FormatResult<()> {
149 let file = File::create(path.as_ref())
150 .map_err(|e| NumRs2Error::IOError(format!("Failed to create NPY file: {}", e)))?;
151
152 let mut writer = BufWriter::new(file);
153
154 let header = Self::create_npy_header(weights.shape(), "f8")?;
156 writer
157 .write_all(&header)
158 .map_err(|e| NumRs2Error::IOError(format!("Failed to write NPY header: {}", e)))?;
159
160 use byteorder::{LittleEndian, WriteBytesExt};
162 for &value in weights.iter() {
163 writer
164 .write_f64::<LittleEndian>(value)
165 .map_err(|e| NumRs2Error::IOError(format!("Failed to write NPY data: {}", e)))?;
166 }
167
168 Ok(())
169 }
170
171 fn create_npy_header(shape: &[usize], dtype: &str) -> FormatResult<Vec<u8>> {
173 use byteorder::{LittleEndian, WriteBytesExt};
174
175 let magic = b"\x93NUMPY";
177
178 let version: [u8; 2] = [1, 0];
180
181 let mut dict = format!(
183 "{{'descr': '<{}', 'fortran_order': False, 'shape': (",
184 dtype
185 );
186
187 for (i, &dim) in shape.iter().enumerate() {
188 if i > 0 {
189 dict.push_str(", ");
190 }
191 dict.push_str(&dim.to_string());
192
193 if shape.len() == 1 && i == shape.len() - 1 {
195 dict.push(',');
196 }
197 }
198
199 dict.push_str("), }");
200
201 let header_len = 10 + dict.len(); let padding = (16 - (header_len % 16)) % 16;
204 dict.push_str(&" ".repeat(padding));
205
206 let mut header = Vec::new();
208 header.extend_from_slice(magic);
209 header.extend_from_slice(&version);
210
211 let dict_len = dict.len() as u16;
213 header.write_u16::<LittleEndian>(dict_len).map_err(|e| {
214 NumRs2Error::SerializationError(format!("Failed to write header length: {}", e))
215 })?;
216
217 header.extend_from_slice(dict.as_bytes());
218
219 Ok(header)
220 }
221
222 pub fn export_architecture<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
231 let arch = ArchitectureDescription::from_model(model);
232
233 let file = File::create(path.as_ref()).map_err(|e| {
234 NumRs2Error::IOError(format!("Failed to create architecture file: {}", e))
235 })?;
236
237 let writer = BufWriter::new(file);
238
239 serde_json::to_writer_pretty(writer, &arch).map_err(|e| {
240 NumRs2Error::SerializationError(format!("Failed to serialize architecture: {}", e))
241 })?;
242
243 Ok(())
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct ModelExportData {
250 pub name: String,
252 pub version: String,
254 pub architecture: String,
256 pub description: Option<String>,
258 pub hyperparameters: HashMap<String, String>,
260 pub layers: Vec<LayerExportInfo>,
262 pub total_parameters: usize,
264 pub created_at: String,
266}
267
268impl ModelExportData {
269 pub fn from_model(model: &NumRS2Model) -> Self {
271 let layers = model
272 .layers
273 .iter()
274 .map(LayerExportInfo::from_layer)
275 .collect();
276
277 Self {
278 name: model.metadata.name.clone(),
279 version: model.metadata.version.clone(),
280 architecture: model.metadata.architecture.clone(),
281 description: model.metadata.description.clone(),
282 hyperparameters: model.metadata.hyperparameters.clone(),
283 layers,
284 total_parameters: model.num_parameters(),
285 created_at: model.metadata.created_at.clone(),
286 }
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct LayerExportInfo {
293 pub name: String,
295 pub layer_type: String,
297 pub input_shape: Vec<usize>,
299 pub output_shape: Vec<usize>,
301 pub num_parameters: usize,
303 pub activation: Option<String>,
305 pub parameters: HashMap<String, String>,
307}
308
309impl LayerExportInfo {
310 pub fn from_layer(layer: &LayerData) -> Self {
312 Self {
313 name: layer.name.clone(),
314 layer_type: format!("{:?}", layer.layer_type),
315 input_shape: layer.input_shape.clone(),
316 output_shape: layer.output_shape.clone(),
317 num_parameters: layer.num_parameters(),
318 activation: layer.activation.map(|a| format!("{:?}", a)),
319 parameters: layer.parameters.clone(),
320 }
321 }
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct ArchitectureDescription {
327 pub name: String,
329 pub layers: Vec<LayerDescription>,
331 pub total_parameters: usize,
333}
334
335impl ArchitectureDescription {
336 pub fn from_model(model: &NumRS2Model) -> Self {
338 let layers = model
339 .layers
340 .iter()
341 .map(LayerDescription::from_layer)
342 .collect();
343
344 Self {
345 name: model.metadata.architecture.clone(),
346 layers,
347 total_parameters: model.num_parameters(),
348 }
349 }
350}
351
352#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct LayerDescription {
355 pub layer_type: String,
357 pub input_shape: Vec<usize>,
359 pub output_shape: Vec<usize>,
361 pub parameters: HashMap<String, String>,
363}
364
365impl LayerDescription {
366 pub fn from_layer(layer: &LayerData) -> Self {
368 Self {
369 layer_type: format!("{:?}", layer.layer_type),
370 input_shape: layer.input_shape.clone(),
371 output_shape: layer.output_shape.clone(),
372 parameters: layer.parameters.clone(),
373 }
374 }
375}
376
377pub fn export_to_json<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
379 ModelExporter::export_json(model, path)
380}
381
382pub fn export_to_messagepack<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
384 ModelExporter::export_messagepack(model, path)
385}
386
387pub fn export_weights_npz<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
389 ModelExporter::export_weights_npz(model, path)
390}
391
392pub fn export_weights_npy<P: AsRef<Path>>(weights: &Array2<f64>, path: P) -> FormatResult<()> {
394 ModelExporter::export_weights_npy(weights, path)
395}
396
397pub fn export_architecture<P: AsRef<Path>>(model: &NumRS2Model, path: P) -> FormatResult<()> {
399 ModelExporter::export_architecture(model, path)
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::new_modules::model_io::format::{LayerData, ModelMetadata};
406 use scirs2_core::ndarray::Array2;
407 use std::env;
408 use std::fs;
409
410 #[test]
411 fn test_export_json() {
412 let temp_dir = env::temp_dir();
413 let path = temp_dir.join("test_export.json");
414
415 let metadata = ModelMetadata::builder()
416 .name("test_model")
417 .version("1.0.0")
418 .architecture("MLP")
419 .description("Test model for export")
420 .hyperparameter("hidden_size", "128")
421 .build()
422 .expect("test: valid metadata build");
423
424 let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
425 let model = NumRS2Model::new(metadata, vec![layer]);
426
427 let result = ModelExporter::export_json(&model, &path);
428 assert!(result.is_ok());
429
430 assert!(path.exists());
432 let contents = fs::read_to_string(&path).expect("test: valid file read");
433 let parsed: serde_json::Value =
434 serde_json::from_str(&contents).expect("test: valid JSON parse");
435 assert_eq!(parsed["name"], "test_model");
436 assert_eq!(parsed["architecture"], "MLP");
437
438 let _ = fs::remove_file(path);
440 }
441
442 #[test]
443 fn test_export_architecture() {
444 let temp_dir = env::temp_dir();
445 let path = temp_dir.join("test_architecture.json");
446
447 let metadata = ModelMetadata::builder()
448 .name("test_model")
449 .architecture("Transformer")
450 .build()
451 .expect("test: valid metadata build");
452
453 let layer1 = LayerData::dense("layer1", Array2::ones((512, 256)), None);
454 let layer2 = LayerData::dense("layer2", Array2::ones((256, 128)), None);
455 let model = NumRS2Model::new(metadata, vec![layer1, layer2]);
456
457 let result = ModelExporter::export_architecture(&model, &path);
458 assert!(result.is_ok());
459
460 assert!(path.exists());
462
463 let _ = fs::remove_file(path);
465 }
466
467 #[test]
468 fn test_export_weights_npy() {
469 let temp_dir = env::temp_dir();
470 let path = temp_dir.join("test_weights.npy");
471
472 let weights = Array2::from_shape_fn((5, 3), |(i, j)| (i * 3 + j) as f64);
473
474 let result = ModelExporter::export_weights_npy(&weights, &path);
475 assert!(result.is_ok());
476
477 assert!(path.exists());
479
480 let _ = fs::remove_file(path);
482 }
483
484 #[test]
485 fn test_npy_header_creation() {
486 let shape = vec![3, 4];
487 let header = ModelExporter::create_npy_header(&shape, "f8");
488 assert!(header.is_ok());
489
490 let header = header.expect("test: valid NPY header creation");
491 assert!(header.starts_with(b"\x93NUMPY"));
492 assert!(header.len().is_multiple_of(16)); }
494
495 #[test]
496 fn test_model_export_data_creation() {
497 let metadata = ModelMetadata::builder()
498 .name("test_model")
499 .version("1.0.0")
500 .architecture("CNN")
501 .hyperparameter("kernel_size", "3")
502 .build()
503 .expect("test: valid metadata build");
504
505 let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
506 let model = NumRS2Model::new(metadata, vec![layer]);
507
508 let export_data = ModelExportData::from_model(&model);
509
510 assert_eq!(export_data.name, "test_model");
511 assert_eq!(export_data.version, "1.0.0");
512 assert_eq!(export_data.architecture, "CNN");
513 assert_eq!(export_data.layers.len(), 1);
514 assert!(export_data.total_parameters > 0);
515 }
516
517 #[test]
518 fn test_layer_export_info() {
519 let weights = Array2::ones((10, 5));
520 let layer = LayerData::dense("test_layer", weights, None);
521
522 let info = LayerExportInfo::from_layer(&layer);
523
524 assert_eq!(info.name, "test_layer");
525 assert_eq!(info.layer_type, "Dense");
526 assert_eq!(info.input_shape, vec![10]);
527 assert_eq!(info.output_shape, vec![5]);
528 assert!(info.num_parameters > 0);
529 }
530
531 #[test]
532 fn test_architecture_description() {
533 let metadata = ModelMetadata::builder()
534 .name("test_model")
535 .architecture("ResNet")
536 .build()
537 .expect("test: valid metadata build");
538
539 let layer1 = LayerData::dense("layer1", Array2::ones((256, 128)), None);
540 let layer2 = LayerData::dense("layer2", Array2::ones((128, 64)), None);
541 let model = NumRS2Model::new(metadata, vec![layer1, layer2]);
542
543 let arch = ArchitectureDescription::from_model(&model);
544
545 assert_eq!(arch.name, "ResNet");
546 assert_eq!(arch.layers.len(), 2);
547 assert!(arch.total_parameters > 0);
548 }
549
550 #[test]
551 fn test_export_format_enum() {
552 assert_ne!(ExportFormat::Json, ExportFormat::MessagePack);
553 assert_ne!(ExportFormat::Npy, ExportFormat::Npz);
554 }
555
556 #[test]
557 fn test_convenience_functions() {
558 let temp_dir = env::temp_dir();
559 let json_path = temp_dir.join("test_convenience.json");
560
561 let metadata = ModelMetadata::builder()
562 .name("test_model")
563 .build()
564 .expect("test: valid metadata build");
565
566 let layer = LayerData::dense("layer1", Array2::ones((10, 5)), None);
567 let model = NumRS2Model::new(metadata, vec![layer]);
568
569 let result = export_to_json(&model, &json_path);
571 assert!(result.is_ok());
572
573 let arch_path = temp_dir.join("test_arch.json");
575 let result = export_architecture(&model, &arch_path);
576 assert!(result.is_ok());
577
578 let _ = fs::remove_file(json_path);
580 let _ = fs::remove_file(arch_path);
581 }
582
583 #[test]
584 fn test_export_weights_npz() {
585 let temp_dir = env::temp_dir();
586 let path = temp_dir.join("test_weights.npz");
587
588 let metadata = ModelMetadata::builder()
589 .name("test_model")
590 .build()
591 .expect("test: valid metadata build");
592
593 let layer1 = LayerData::dense("layer1", Array2::ones((10, 5)), Some(Array1::zeros(5)));
594 let layer2 = LayerData::dense("layer2", Array2::ones((5, 2)), None);
595 let model = NumRS2Model::new(metadata, vec![layer1, layer2]);
596
597 let result = ModelExporter::export_weights_npz(&model, &path);
598 assert!(result.is_ok());
599
600 assert!(path.exists());
602
603 let _ = fs::remove_file(path);
605 }
606}