Skip to main content

scirs2_neural/serialization/
mod.rs

1//! Module for model serialization and deserialization
2//!
3//! This module provides comprehensive serialization support for all neural network
4//! architectures in scirs2-neural, including:
5//!
6//! - **Generic traits**: `ModelSerialize`, `ModelDeserialize`, `ExtractParameters`
7//! - **SafeTensors format**: HuggingFace-compatible binary format (safe, no pickle)
8//! - **Architecture-specific serialization**: ResNet, BERT, GPT, Mamba, EfficientNet, MobileNet
9//! - **Legacy support**: JSON for Sequential models (via `legacy_serialization` feature)
10//!
11//! ## Quick Start
12//!
13//! ```rust
14//! use scirs2_neural::serialization::{ModelSerialize, ModelDeserialize, ModelFormat};
15//!
16//! // ModelFormat enumerates the supported serialization formats
17//! let format = ModelFormat::SafeTensors;
18//! assert_eq!(format, ModelFormat::SafeTensors);
19//! ```
20
21// Sub-modules
22pub mod architecture;
23pub mod model_serializer;
24pub mod safetensors;
25pub mod traits;
26
27// Re-export key types from sub-modules
28pub use architecture::{
29    detect_architecture, detect_architecture_from_bytes, ArchitectureConfig,
30    SerializableBertConfig, SerializableGPTConfig, SerializableMambaConfig,
31    SerializableMobileNetConfig, SerializableResNetConfig,
32};
33pub use model_serializer::{
34    load_bert, load_resnet, named_parameters_to_map, save_bert, save_resnet, ModelSerializer,
35};
36pub use safetensors::{
37    read_named_parameters, validate_safetensors_file, write_named_parameters, SafeTensorsDtype,
38    SafeTensorsHeaderEntry, SafeTensorsReader, SafeTensorsWriter,
39};
40pub use traits::{
41    ExtractParameters, ModelDeserialize, ModelFormat, ModelMetadata, ModelSerialize,
42    NamedParameters, TensorInfo,
43};
44
45// Legacy imports for existing serialization code
46use crate::activations::*;
47use crate::error::{NeuralError, Result};
48use scirs2_core::numeric::Float;
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use std::fmt::Debug;
52
53// Imports needed by legacy feature
54#[cfg(feature = "legacy_serialization")]
55use crate::layers::conv::PaddingMode;
56#[cfg(feature = "legacy_serialization")]
57use crate::layers::*;
58#[cfg(feature = "legacy_serialization")]
59use crate::models::sequential::Sequential;
60#[cfg(feature = "legacy_serialization")]
61use scirs2_core::ndarray::{Array, ScalarOperand};
62#[cfg(feature = "legacy_serialization")]
63use scirs2_core::numeric::{FromPrimitive, NumAssign, ToPrimitive};
64#[cfg(feature = "legacy_serialization")]
65use scirs2_core::random::SeedableRng;
66#[cfg(feature = "legacy_serialization")]
67use std::fmt::Display;
68#[cfg(feature = "legacy_serialization")]
69use std::fs;
70#[cfg(feature = "legacy_serialization")]
71use std::path::Path;
72
73/// Model serialization format (legacy)
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum SerializationFormat {
76    /// JSON serialization format
77    JSON,
78    /// CBOR serialization format (serialized as JSON in legacy mode)
79    CBOR,
80    /// MessagePack serialization format (serialized as JSON in legacy mode)
81    MessagePack,
82}
83
84/// Layer type
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum LayerType {
87    /// Dense (fully connected) layer
88    Dense,
89    /// Convolutional 2D layer
90    Conv2D,
91    /// Layer normalization
92    LayerNorm,
93    /// Batch normalization
94    BatchNorm,
95    /// Dropout layer
96    Dropout,
97    /// Max pooling 2D layer
98    MaxPool2D,
99}
100
101/// Layer configuration for serialization
102#[derive(Debug, Clone, Serialize, Deserialize)]
103#[serde(tag = "type")]
104pub enum LayerConfig {
105    /// Dense layer configuration
106    #[serde(rename = "Dense")]
107    Dense(DenseConfig),
108    /// Conv2D layer configuration
109    #[serde(rename = "Conv2D")]
110    Conv2D(Conv2DConfig),
111    /// LayerNorm layer configuration
112    #[serde(rename = "LayerNorm")]
113    LayerNorm(LayerNormConfig),
114    /// BatchNorm layer configuration
115    #[serde(rename = "BatchNorm")]
116    BatchNorm(BatchNormConfig),
117    /// Dropout layer configuration
118    #[serde(rename = "Dropout")]
119    Dropout(DropoutConfig),
120    /// MaxPool2D layer configuration
121    #[serde(rename = "MaxPool2D")]
122    MaxPool2D(MaxPool2DConfig),
123}
124
125/// Dense layer configuration
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct DenseConfig {
128    /// Input dimension
129    pub input_dim: usize,
130    /// Output dimension
131    pub output_dim: usize,
132    /// Activation function name
133    pub activation: Option<String>,
134}
135
136/// Conv2D layer configuration
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct Conv2DConfig {
139    /// Number of input channels
140    pub in_channels: usize,
141    /// Number of output channels
142    pub out_channels: usize,
143    /// Kernel size (square)
144    pub kernel_size: usize,
145    /// Stride
146    pub stride: usize,
147    /// Padding mode
148    pub padding_mode: String,
149}
150
151/// LayerNorm layer configuration
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct LayerNormConfig {
154    /// Normalized shape
155    pub normalizedshape: usize,
156    /// Epsilon for numerical stability
157    pub eps: f64,
158}
159
160/// BatchNorm layer configuration
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct BatchNormConfig {
163    /// Number of features
164    pub num_features: usize,
165    /// Momentum
166    pub momentum: f64,
167    /// Epsilon for numerical stability
168    pub eps: f64,
169}
170
171/// Dropout layer configuration
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct DropoutConfig {
174    /// Dropout probability
175    pub p: f64,
176}
177
178/// MaxPool2D layer configuration
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct MaxPool2DConfig {
181    /// Kernel size
182    pub kernel_size: (usize, usize),
183    /// Stride
184    pub stride: (usize, usize),
185    /// Padding
186    pub padding: Option<(usize, usize)>,
187}
188
189/// Serialized model
190#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct SerializedModel {
192    /// Model name
193    pub name: String,
194    /// Model version
195    pub version: String,
196    /// Model layers configuration
197    pub layers: Vec<LayerConfig>,
198    /// Model parameters (weights and biases)
199    pub parameters: Vec<Vec<Vec<f64>>>,
200}
201
202// =============================================================================
203// Legacy serialization functions (available under legacy_serialization feature)
204// =============================================================================
205
206/// Save model to file
207///
208/// Legacy function for saving `Sequential` models to JSON.
209/// For new code, prefer the `SafeTensors`-based API via `ModelSerialize::save()`.
210#[cfg(feature = "legacy_serialization")]
211#[allow(dead_code)]
212pub fn save_model<
213    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
214    P: AsRef<Path>,
215>(
216    model: &Sequential<F>,
217    path: P,
218    _format: SerializationFormat,
219) -> Result<()> {
220    let serialized = serialize_model(model)?;
221    let bytes = serde_json::to_vec_pretty(&serialized)
222        .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
223    fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
224    Ok(())
225}
226
227/// Load model from file
228///
229/// Legacy function for loading `Sequential` models from JSON.
230/// For new code, prefer the `SafeTensors`-based API via `ModelDeserialize::load()`.
231#[cfg(feature = "legacy_serialization")]
232#[allow(dead_code)]
233pub fn load_model<
234    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
235    P: AsRef<Path>,
236>(
237    path: P,
238    _format: SerializationFormat,
239) -> Result<Sequential<F>> {
240    let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
241    let serialized: SerializedModel = serde_json::from_slice(&bytes)
242        .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
243    deserialize_model(&serialized)
244}
245
246/// Serialize model to SerializedModel
247#[cfg(feature = "legacy_serialization")]
248#[allow(dead_code)]
249fn serialize_model<
250    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
251>(
252    model: &Sequential<F>,
253) -> Result<SerializedModel> {
254    let mut layers = Vec::new();
255    let mut parameters = Vec::new();
256
257    for layer in model.layers() {
258        if let Some(dense) = layer.as_any().downcast_ref::<Dense<F>>() {
259            let config = LayerConfig::Dense(DenseConfig {
260                input_dim: dense.input_dim(),
261                output_dim: dense.output_dim(),
262                activation: None, // Dense::activation_name() not available
263            });
264            layers.push(config);
265            let layer_params_owned = dense.get_parameters();
266            let layer_params: Vec<&Array<F, scirs2_core::ndarray::IxDyn>> =
267                layer_params_owned.iter().collect();
268            let params = extract_parameters(layer_params)?;
269            parameters.push(params);
270        } else if let Some(dropout) = layer.as_any().downcast_ref::<Dropout<F>>() {
271            let _ = dropout; // p() not available on Dropout
272            let config = LayerConfig::Dropout(DropoutConfig { p: 0.5 });
273            layers.push(config);
274            parameters.push(Vec::new());
275        } else {
276            return Err(NeuralError::SerializationError(
277                "Unsupported layer type for legacy serialization. Use SafeTensors API instead."
278                    .to_string(),
279            ));
280        }
281    }
282
283    Ok(SerializedModel {
284        name: "SciRS2 Sequential Model".to_string(),
285        version: "0.1.0".to_string(),
286        layers,
287        parameters,
288    })
289}
290
291/// Extract parameters from layer
292#[cfg(feature = "legacy_serialization")]
293#[allow(dead_code)]
294fn extract_parameters<F: Float + Debug + ScalarOperand + Send + Sync>(
295    params: Vec<&Array<F, scirs2_core::ndarray::IxDyn>>,
296) -> Result<Vec<Vec<f64>>> {
297    let mut result = Vec::new();
298    for param in params.iter() {
299        let f64_vec: Vec<f64> = param
300            .iter()
301            .map(|&x| {
302                x.to_f64().ok_or_else(|| {
303                    NeuralError::SerializationError("Cannot convert parameter to f64".to_string())
304                })
305            })
306            .collect::<Result<Vec<f64>>>()?;
307        result.push(f64_vec);
308    }
309    Ok(result)
310}
311
312/// Deserialize model from SerializedModel
313#[cfg(feature = "legacy_serialization")]
314#[allow(dead_code)]
315fn deserialize_model<
316    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
317>(
318    serialized: &SerializedModel,
319) -> Result<Sequential<F>> {
320    let empty_params: Vec<Vec<f64>> = Vec::new();
321    let mut bound_layers: Vec<Box<dyn Layer<F> + Send + Sync>> = Vec::new();
322
323    for (i, layer_config) in serialized.layers.iter().enumerate() {
324        let params = if i < serialized.parameters.len() {
325            &serialized.parameters[i]
326        } else {
327            &empty_params
328        };
329
330        match layer_config {
331            LayerConfig::Dense(config) => {
332                let layer = create_dense_layer::<F>(config, params)?;
333                bound_layers.push(Box::new(layer));
334            }
335            LayerConfig::Dropout(config) => {
336                let layer = create_dropout::<F>(config)?;
337                bound_layers.push(Box::new(layer));
338            }
339            _ => {
340                return Err(NeuralError::DeserializationError(
341                    "Layer type not supported in legacy deserialization. Use SafeTensors API."
342                        .to_string(),
343                ));
344            }
345        }
346    }
347
348    Ok(Sequential::from_layers(bound_layers))
349}
350
351/// Create a Dense layer from configuration and parameters
352#[cfg(feature = "legacy_serialization")]
353#[allow(dead_code)]
354fn create_dense_layer<
355    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
356>(
357    config: &DenseConfig,
358    params: &[Vec<f64>],
359) -> Result<Dense<F>> {
360    let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
361    let mut layer = Dense::new(
362        config.input_dim,
363        config.output_dim,
364        config.activation.as_deref(),
365        &mut rng,
366    )?;
367
368    if params.len() >= 2 {
369        let weightsshape = [config.input_dim, config.output_dim];
370        let biasshape = [config.output_dim];
371
372        if params[0].len() == config.output_dim * config.input_dim {
373            let weights_array = match array_from_vec::<F>(&params[0], &weightsshape) {
374                Ok(arr) => arr,
375                Err(_) => {
376                    let transposedshape = [config.output_dim, config.input_dim];
377                    let transposed_arr = array_from_vec::<F>(&params[0], &transposedshape)?;
378                    transposed_arr.t().to_owned().into_dyn()
379                }
380            };
381            let bias_array = array_from_vec::<F>(&params[1], &biasshape)?;
382            layer.set_parameters(vec![weights_array, bias_array])?;
383        } else {
384            return Err(NeuralError::SerializationError(format!(
385                "Weight vector length ({}) doesn't match expected shape size ({})",
386                params[0].len(),
387                config.input_dim * config.output_dim
388            )));
389        }
390    }
391    Ok(layer)
392}
393
394/// Create a Dropout layer from configuration
395#[cfg(feature = "legacy_serialization")]
396#[allow(dead_code)]
397fn create_dropout<
398    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
399>(
400    config: &DropoutConfig,
401) -> Result<Dropout<F>> {
402    let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
403    Dropout::new(config.p, &mut rng)
404}
405
406/// Convert a vector of f64 values to an ndarray with the given shape
407#[cfg(feature = "legacy_serialization")]
408#[allow(dead_code)]
409fn array_from_vec<
410    F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
411>(
412    vec: &[f64],
413    shape: &[usize],
414) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
415    let shape_size: usize = shape.iter().product();
416    if vec.len() != shape_size {
417        return Err(NeuralError::SerializationError(format!(
418            "Parameter vector length ({}) doesn't match expected shape size ({})",
419            vec.len(),
420            shape_size
421        )));
422    }
423    let f_vec: Vec<F> = vec
424        .iter()
425        .map(|&x| {
426            F::from(x).ok_or_else(|| {
427                NeuralError::SerializationError(format!("Cannot convert {} to target type", x))
428            })
429        })
430        .collect::<Result<Vec<F>>>()?;
431    let shape_ix = scirs2_core::ndarray::IxDyn(shape);
432    Array::from_shape_vec(shape_ix, f_vec)
433        .map_err(|e| NeuralError::SerializationError(e.to_string()))
434}
435
436// =============================================================================
437// Activation function utilities (always available)
438// =============================================================================
439
440/// Serializable activation function
441#[derive(Debug, Clone, Serialize, Deserialize)]
442pub enum ActivationFunction {
443    /// ReLU activation
444    ReLU,
445    /// Sigmoid activation
446    Sigmoid,
447    /// Tanh activation
448    Tanh,
449    /// Softmax activation
450    Softmax,
451    /// LeakyReLU activation
452    LeakyReLU(f64),
453    /// ELU activation (serialized; implemented as LeakyReLU for forward compat)
454    ELU(f64),
455    /// GELU activation
456    GELU,
457    /// Swish activation
458    Swish,
459    /// Mish activation
460    Mish,
461}
462
463impl ActivationFunction {
464    /// Convert activation function name to ActivationFunction enum
465    pub fn from_name(name: &str) -> Option<Self> {
466        match name {
467            "relu" | "ReLU" => Some(ActivationFunction::ReLU),
468            "sigmoid" | "Sigmoid" => Some(ActivationFunction::Sigmoid),
469            "tanh" | "Tanh" => Some(ActivationFunction::Tanh),
470            "softmax" | "Softmax" => Some(ActivationFunction::Softmax),
471            "gelu" | "GELU" => Some(ActivationFunction::GELU),
472            "swish" | "Swish" => Some(ActivationFunction::Swish),
473            "mish" | "Mish" => Some(ActivationFunction::Mish),
474            _ => {
475                if name.starts_with("leaky_relu") || name.starts_with("LeakyReLU") {
476                    let parts: Vec<&str> = name.split('(').collect();
477                    if parts.len() == 2 {
478                        let alpha_str = parts[1].trim_end_matches(')');
479                        if let Ok(alpha) = alpha_str.parse::<f64>() {
480                            return Some(ActivationFunction::LeakyReLU(alpha));
481                        }
482                    }
483                    Some(ActivationFunction::LeakyReLU(0.01))
484                } else if name.starts_with("elu") || name.starts_with("ELU") {
485                    let parts: Vec<&str> = name.split('(').collect();
486                    if parts.len() == 2 {
487                        let alpha_str = parts[1].trim_end_matches(')');
488                        if let Ok(alpha) = alpha_str.parse::<f64>() {
489                            return Some(ActivationFunction::ELU(alpha));
490                        }
491                    }
492                    Some(ActivationFunction::ELU(1.0))
493                } else {
494                    None
495                }
496            }
497        }
498    }
499
500    /// Convert ActivationFunction enum to activation function name
501    pub fn to_name(&self) -> String {
502        match self {
503            ActivationFunction::ReLU => "relu".to_string(),
504            ActivationFunction::Sigmoid => "sigmoid".to_string(),
505            ActivationFunction::Tanh => "tanh".to_string(),
506            ActivationFunction::Softmax => "softmax".to_string(),
507            ActivationFunction::LeakyReLU(alpha) => format!("leaky_relu({})", alpha),
508            ActivationFunction::ELU(alpha) => format!("elu({})", alpha),
509            ActivationFunction::GELU => "gelu".to_string(),
510            ActivationFunction::Swish => "swish".to_string(),
511            ActivationFunction::Mish => "mish".to_string(),
512        }
513    }
514
515    /// Create activation function from enum
516    ///
517    /// Note: ELU is not currently implemented; it falls back to LeakyReLU.
518    pub fn create<
519        F: Float + Debug + scirs2_core::NumAssign + scirs2_core::ndarray::ScalarOperand + Send + Sync,
520    >(
521        &self,
522    ) -> Box<dyn Activation<F>> {
523        match self {
524            ActivationFunction::ReLU => Box::new(ReLU::new()),
525            ActivationFunction::Sigmoid => Box::new(Sigmoid::new()),
526            ActivationFunction::Tanh => Box::new(Tanh::new()),
527            ActivationFunction::Softmax => Box::new(Softmax::new(1)),
528            ActivationFunction::LeakyReLU(alpha) => Box::new(LeakyReLU::new(*alpha)),
529            ActivationFunction::ELU(alpha) => Box::new(LeakyReLU::new(*alpha)),
530            ActivationFunction::GELU => Box::new(GELU::new()),
531            ActivationFunction::Swish => Box::new(Swish::new(1.0)),
532            ActivationFunction::Mish => Box::new(Mish::new()),
533        }
534    }
535}
536
537/// Activation function factory
538pub struct ActivationFactory;
539
540impl ActivationFactory {
541    /// Create activation function from name
542    pub fn create<
543        F: Float + Debug + scirs2_core::NumAssign + scirs2_core::ndarray::ScalarOperand + Send + Sync,
544    >(
545        name: &str,
546    ) -> Option<Box<dyn Activation<F>>> {
547        ActivationFunction::from_name(name).map(|af| af.create::<F>())
548    }
549
550    /// Get activation function names
551    pub fn get_activation_names() -> HashMap<&'static str, &'static str> {
552        let mut names = HashMap::new();
553        names.insert("relu", "ReLU activation function");
554        names.insert("sigmoid", "Sigmoid activation function");
555        names.insert("tanh", "Tanh activation function");
556        names.insert("softmax", "Softmax activation function");
557        names.insert("leaky_relu", "Leaky ReLU activation function");
558        names.insert("elu", "ELU activation function");
559        names.insert("gelu", "GELU activation function");
560        names.insert("swish", "Swish activation function");
561        names.insert("mish", "Mish activation function");
562        names
563    }
564}
565
566#[cfg(test)]
567mod tests;