Skip to main content

scirs2_neural/serialization/
traits.rs

1//! Generic model serialization traits
2//!
3//! This module provides the `ModelSerialize` and `ModelDeserialize` traits
4//! that allow any neural network architecture to be saved to and loaded from disk.
5//! These traits work with multiple formats (JSON, SafeTensors, etc.) and handle
6//! nested layers, attention heads, and normalization parameters.
7
8use crate::error::Result;
9use std::path::Path;
10
11/// Supported serialization formats for model persistence
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ModelFormat {
14    /// JSON format - human-readable, larger files
15    Json,
16    /// SafeTensors format - binary, HuggingFace-compatible
17    SafeTensors,
18    /// CBOR format - binary, compact
19    Cbor,
20    /// MessagePack format - binary, compact
21    MessagePack,
22}
23
24/// Trait for serializing a model to disk
25///
26/// Any neural network architecture that implements this trait can be saved
27/// to a file in one of the supported formats. The serialization captures
28/// both the model configuration (architecture) and the learned parameters (weights).
29///
30/// # Example
31///
32/// ```rust
33/// use scirs2_neural::serialization::traits::{ModelSerialize, ModelFormat};
34///
35/// // ModelSerialize is a trait implemented by model architectures.
36/// // Example usage (with a model that implements ModelSerialize):
37/// let format = ModelFormat::SafeTensors;
38/// assert_eq!(format, ModelFormat::SafeTensors);
39/// ```
40pub trait ModelSerialize {
41    /// Save the model to the specified path in the given format
42    ///
43    /// This method serializes both the model architecture (configuration)
44    /// and all learned parameters (weights, biases, normalization stats, etc.)
45    fn save(&self, path: &Path, format: ModelFormat) -> Result<()>;
46
47    /// Serialize the model to bytes in the given format
48    ///
49    /// This is useful when you want to store the serialized model in memory
50    /// or send it over a network rather than writing to disk.
51    fn to_bytes(&self, format: ModelFormat) -> Result<Vec<u8>>;
52
53    /// Get the architecture name for this model (e.g., "ResNet", "BERT", "GPT")
54    fn architecture_name(&self) -> &str;
55
56    /// Get the model version string
57    fn model_version(&self) -> String {
58        "0.1.0".to_string()
59    }
60}
61
62/// Trait for deserializing a model from disk
63///
64/// Any neural network architecture that implements this trait can be loaded
65/// from a file that was previously saved with `ModelSerialize`.
66///
67/// # Example
68///
69/// ```rust
70/// use scirs2_neural::serialization::traits::{ModelDeserialize, ModelFormat};
71///
72/// // ModelDeserialize is a trait implemented by model architectures.
73/// // Example usage (with a model that implements ModelDeserialize):
74/// let format = ModelFormat::Json;
75/// assert_eq!(format, ModelFormat::Json);
76/// ```
77pub trait ModelDeserialize: Sized {
78    /// Load the model from the specified path in the given format
79    ///
80    /// This method deserializes both the model architecture and all
81    /// learned parameters, reconstructing a fully functional model.
82    fn load(path: &Path, format: ModelFormat) -> Result<Self>;
83
84    /// Deserialize the model from bytes in the given format
85    ///
86    /// This is useful when loading from a network stream or in-memory buffer.
87    fn from_bytes(bytes: &[u8], format: ModelFormat) -> Result<Self>;
88}
89
90/// Metadata about a serialized model, stored alongside the weights
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
92pub struct ModelMetadata {
93    /// Architecture name (e.g., "ResNet", "BERT", "GPT")
94    pub architecture: String,
95    /// Model version
96    pub version: String,
97    /// Framework version that produced this file
98    pub framework_version: String,
99    /// Number of parameters in the model
100    pub num_parameters: usize,
101    /// Data type used for parameters (e.g., "f32", "f64")
102    pub dtype: String,
103    /// Additional key-value metadata
104    pub extra: std::collections::HashMap<String, String>,
105}
106
107impl ModelMetadata {
108    /// Create new metadata for a model
109    pub fn new(architecture: &str, dtype: &str, num_parameters: usize) -> Self {
110        Self {
111            architecture: architecture.to_string(),
112            version: "0.1.0".to_string(),
113            framework_version: env!("CARGO_PKG_VERSION").to_string(),
114            num_parameters,
115            dtype: dtype.to_string(),
116            extra: std::collections::HashMap::new(),
117        }
118    }
119
120    /// Add an extra metadata key-value pair
121    pub fn with_extra(mut self, key: &str, value: &str) -> Self {
122        self.extra.insert(key.to_string(), value.to_string());
123        self
124    }
125}
126
127/// Information about a single tensor in a serialized model
128#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
129pub struct TensorInfo {
130    /// Name of the tensor (e.g., "layer1.weight", "encoder.attention.query")
131    pub name: String,
132    /// Data type (e.g., "F32", "F64")
133    pub dtype: String,
134    /// Shape of the tensor
135    pub shape: Vec<usize>,
136    /// Byte offset in the data section
137    pub data_offset: usize,
138    /// Number of bytes for this tensor
139    pub byte_length: usize,
140}
141
142impl TensorInfo {
143    /// Create a new TensorInfo
144    pub fn new(
145        name: &str,
146        dtype: &str,
147        shape: Vec<usize>,
148        data_offset: usize,
149        byte_length: usize,
150    ) -> Self {
151        Self {
152            name: name.to_string(),
153            dtype: dtype.to_string(),
154            shape,
155            data_offset,
156            byte_length,
157        }
158    }
159
160    /// Get the total number of elements in this tensor
161    pub fn num_elements(&self) -> usize {
162        if self.shape.is_empty() {
163            0
164        } else {
165            self.shape.iter().product()
166        }
167    }
168}
169
170/// A named parameter collection that can be extracted from any model
171///
172/// This provides a uniform interface for accessing model parameters
173/// regardless of the underlying architecture.
174#[derive(Debug, Clone)]
175pub struct NamedParameters {
176    /// Ordered list of (name, flattened_f64_values, shape) tuples
177    pub parameters: Vec<(String, Vec<f64>, Vec<usize>)>,
178}
179
180impl NamedParameters {
181    /// Create a new empty NamedParameters collection
182    pub fn new() -> Self {
183        Self {
184            parameters: Vec::new(),
185        }
186    }
187
188    /// Add a parameter tensor
189    pub fn add(&mut self, name: &str, values: Vec<f64>, shape: Vec<usize>) {
190        self.parameters.push((name.to_string(), values, shape));
191    }
192
193    /// Get the total number of scalar parameters
194    pub fn total_parameters(&self) -> usize {
195        self.parameters.iter().map(|(_, v, _)| v.len()).sum()
196    }
197
198    /// Find a parameter by name
199    pub fn get(&self, name: &str) -> Option<&(String, Vec<f64>, Vec<usize>)> {
200        self.parameters.iter().find(|(n, _, _)| n == name)
201    }
202
203    /// Get the number of named parameter groups
204    pub fn len(&self) -> usize {
205        self.parameters.len()
206    }
207
208    /// Check if empty
209    pub fn is_empty(&self) -> bool {
210        self.parameters.is_empty()
211    }
212}
213
214impl Default for NamedParameters {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220/// Trait for extracting named parameters from a model
221///
222/// This trait provides a standardized way to extract all named parameters
223/// from any model architecture, enabling format-agnostic serialization.
224pub trait ExtractParameters {
225    /// Extract all named parameters from the model
226    ///
227    /// Parameters are returned as named `(String, Vec<f64>, Vec<usize>)` tuples
228    /// where the first element is the name (e.g., "encoder.layer.0.attention.query.weight"),
229    /// the second is the flattened parameter values, and the third is the shape.
230    fn extract_named_parameters(&self) -> Result<NamedParameters>;
231
232    /// Load named parameters into the model
233    ///
234    /// This method takes a NamedParameters collection and sets the model's
235    /// parameters accordingly. Parameter names must match those returned
236    /// by `extract_named_parameters`.
237    fn load_named_parameters(&mut self, params: &NamedParameters) -> Result<()>;
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_model_metadata_creation() {
246        let metadata = ModelMetadata::new("ResNet", "f32", 11_000_000);
247        assert_eq!(metadata.architecture, "ResNet");
248        assert_eq!(metadata.dtype, "f32");
249        assert_eq!(metadata.num_parameters, 11_000_000);
250    }
251
252    #[test]
253    fn test_model_metadata_with_extra() {
254        let metadata = ModelMetadata::new("BERT", "f32", 110_000_000)
255            .with_extra("variant", "base-uncased")
256            .with_extra("vocab_size", "30522");
257        assert_eq!(
258            metadata.extra.get("variant"),
259            Some(&"base-uncased".to_string())
260        );
261        assert_eq!(metadata.extra.get("vocab_size"), Some(&"30522".to_string()));
262    }
263
264    #[test]
265    fn test_tensor_info() {
266        let info = TensorInfo::new("layer1.weight", "F32", vec![768, 3072], 0, 768 * 3072 * 4);
267        assert_eq!(info.num_elements(), 768 * 3072);
268        assert_eq!(info.byte_length, 768 * 3072 * 4);
269    }
270
271    #[test]
272    fn test_tensor_info_empty_shape() {
273        let info = TensorInfo::new("empty", "F32", vec![], 0, 0);
274        assert_eq!(info.num_elements(), 0);
275    }
276
277    #[test]
278    fn test_named_parameters() {
279        let mut params = NamedParameters::new();
280        assert!(params.is_empty());
281        assert_eq!(params.len(), 0);
282
283        params.add("layer1.weight", vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
284        params.add("layer1.bias", vec![0.1, 0.2], vec![2]);
285
286        assert_eq!(params.len(), 2);
287        assert!(!params.is_empty());
288        assert_eq!(params.total_parameters(), 6);
289
290        let found = params.get("layer1.weight");
291        assert!(found.is_some());
292        let (name, values, shape) = found.expect("parameter should exist");
293        assert_eq!(name, "layer1.weight");
294        assert_eq!(values, &[1.0, 2.0, 3.0, 4.0]);
295        assert_eq!(shape, &[2, 2]);
296
297        assert!(params.get("nonexistent").is_none());
298    }
299
300    #[test]
301    fn test_model_format_enum() {
302        let fmt = ModelFormat::SafeTensors;
303        assert_eq!(fmt, ModelFormat::SafeTensors);
304        assert_ne!(fmt, ModelFormat::Json);
305
306        // Test all variants exist
307        let _json = ModelFormat::Json;
308        let _st = ModelFormat::SafeTensors;
309        let _cbor = ModelFormat::Cbor;
310        let _mp = ModelFormat::MessagePack;
311    }
312}