1pub mod architecture;
23pub mod model_serializer;
24pub mod safetensors;
25pub mod traits;
26
27pub use architecture::{
29 detect_architecture, detect_architecture_from_bytes, ArchitectureConfig,
30 SerializableBertConfig, SerializableGPTConfig, SerializableMambaConfig,
31 SerializableMobileNetConfig, SerializableResNetConfig,
32};
33pub use model_serializer::{
34 load_bert, load_resnet, named_parameters_to_map, save_bert, save_resnet, ModelSerializer,
35};
36pub use safetensors::{
37 read_named_parameters, validate_safetensors_file, write_named_parameters, SafeTensorsDtype,
38 SafeTensorsHeaderEntry, SafeTensorsReader, SafeTensorsWriter,
39};
40pub use traits::{
41 ExtractParameters, ModelDeserialize, ModelFormat, ModelMetadata, ModelSerialize,
42 NamedParameters, TensorInfo,
43};
44
45use crate::activations::*;
47use crate::error::{NeuralError, Result};
48use scirs2_core::numeric::Float;
49use serde::{Deserialize, Serialize};
50use std::collections::HashMap;
51use std::fmt::Debug;
52
53#[cfg(feature = "legacy_serialization")]
55use crate::layers::conv::PaddingMode;
56#[cfg(feature = "legacy_serialization")]
57use crate::layers::*;
58#[cfg(feature = "legacy_serialization")]
59use crate::models::sequential::Sequential;
60#[cfg(feature = "legacy_serialization")]
61use scirs2_core::ndarray::{Array, ScalarOperand};
62#[cfg(feature = "legacy_serialization")]
63use scirs2_core::numeric::{FromPrimitive, NumAssign, ToPrimitive};
64#[cfg(feature = "legacy_serialization")]
65use scirs2_core::random::SeedableRng;
66#[cfg(feature = "legacy_serialization")]
67use std::fmt::Display;
68#[cfg(feature = "legacy_serialization")]
69use std::fs;
70#[cfg(feature = "legacy_serialization")]
71use std::path::Path;
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum SerializationFormat {
76 JSON,
78 CBOR,
80 MessagePack,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum LayerType {
87 Dense,
89 Conv2D,
91 LayerNorm,
93 BatchNorm,
95 Dropout,
97 MaxPool2D,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103#[serde(tag = "type")]
104pub enum LayerConfig {
105 #[serde(rename = "Dense")]
107 Dense(DenseConfig),
108 #[serde(rename = "Conv2D")]
110 Conv2D(Conv2DConfig),
111 #[serde(rename = "LayerNorm")]
113 LayerNorm(LayerNormConfig),
114 #[serde(rename = "BatchNorm")]
116 BatchNorm(BatchNormConfig),
117 #[serde(rename = "Dropout")]
119 Dropout(DropoutConfig),
120 #[serde(rename = "MaxPool2D")]
122 MaxPool2D(MaxPool2DConfig),
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct DenseConfig {
128 pub input_dim: usize,
130 pub output_dim: usize,
132 pub activation: Option<String>,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct Conv2DConfig {
139 pub in_channels: usize,
141 pub out_channels: usize,
143 pub kernel_size: usize,
145 pub stride: usize,
147 pub padding_mode: String,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct LayerNormConfig {
154 pub normalizedshape: usize,
156 pub eps: f64,
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct BatchNormConfig {
163 pub num_features: usize,
165 pub momentum: f64,
167 pub eps: f64,
169}
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct DropoutConfig {
174 pub p: f64,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct MaxPool2DConfig {
181 pub kernel_size: (usize, usize),
183 pub stride: (usize, usize),
185 pub padding: Option<(usize, usize)>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191pub struct SerializedModel {
192 pub name: String,
194 pub version: String,
196 pub layers: Vec<LayerConfig>,
198 pub parameters: Vec<Vec<Vec<f64>>>,
200}
201
202#[cfg(feature = "legacy_serialization")]
211#[allow(dead_code)]
212pub fn save_model<
213 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
214 P: AsRef<Path>,
215>(
216 model: &Sequential<F>,
217 path: P,
218 _format: SerializationFormat,
219) -> Result<()> {
220 let serialized = serialize_model(model)?;
221 let bytes = serde_json::to_vec_pretty(&serialized)
222 .map_err(|e| NeuralError::SerializationError(e.to_string()))?;
223 fs::write(path, bytes).map_err(|e| NeuralError::IOError(e.to_string()))?;
224 Ok(())
225}
226
227#[cfg(feature = "legacy_serialization")]
232#[allow(dead_code)]
233pub fn load_model<
234 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
235 P: AsRef<Path>,
236>(
237 path: P,
238 _format: SerializationFormat,
239) -> Result<Sequential<F>> {
240 let bytes = fs::read(path).map_err(|e| NeuralError::IOError(e.to_string()))?;
241 let serialized: SerializedModel = serde_json::from_slice(&bytes)
242 .map_err(|e| NeuralError::DeserializationError(e.to_string()))?;
243 deserialize_model(&serialized)
244}
245
246#[cfg(feature = "legacy_serialization")]
248#[allow(dead_code)]
249fn serialize_model<
250 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
251>(
252 model: &Sequential<F>,
253) -> Result<SerializedModel> {
254 let mut layers = Vec::new();
255 let mut parameters = Vec::new();
256
257 for layer in model.layers() {
258 if let Some(dense) = layer.as_any().downcast_ref::<Dense<F>>() {
259 let config = LayerConfig::Dense(DenseConfig {
260 input_dim: dense.input_dim(),
261 output_dim: dense.output_dim(),
262 activation: None, });
264 layers.push(config);
265 let layer_params_owned = dense.get_parameters();
266 let layer_params: Vec<&Array<F, scirs2_core::ndarray::IxDyn>> =
267 layer_params_owned.iter().collect();
268 let params = extract_parameters(layer_params)?;
269 parameters.push(params);
270 } else if let Some(dropout) = layer.as_any().downcast_ref::<Dropout<F>>() {
271 let _ = dropout; let config = LayerConfig::Dropout(DropoutConfig { p: 0.5 });
273 layers.push(config);
274 parameters.push(Vec::new());
275 } else {
276 return Err(NeuralError::SerializationError(
277 "Unsupported layer type for legacy serialization. Use SafeTensors API instead."
278 .to_string(),
279 ));
280 }
281 }
282
283 Ok(SerializedModel {
284 name: "SciRS2 Sequential Model".to_string(),
285 version: "0.1.0".to_string(),
286 layers,
287 parameters,
288 })
289}
290
291#[cfg(feature = "legacy_serialization")]
293#[allow(dead_code)]
294fn extract_parameters<F: Float + Debug + ScalarOperand + Send + Sync>(
295 params: Vec<&Array<F, scirs2_core::ndarray::IxDyn>>,
296) -> Result<Vec<Vec<f64>>> {
297 let mut result = Vec::new();
298 for param in params.iter() {
299 let f64_vec: Vec<f64> = param
300 .iter()
301 .map(|&x| {
302 x.to_f64().ok_or_else(|| {
303 NeuralError::SerializationError("Cannot convert parameter to f64".to_string())
304 })
305 })
306 .collect::<Result<Vec<f64>>>()?;
307 result.push(f64_vec);
308 }
309 Ok(result)
310}
311
312#[cfg(feature = "legacy_serialization")]
314#[allow(dead_code)]
315fn deserialize_model<
316 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
317>(
318 serialized: &SerializedModel,
319) -> Result<Sequential<F>> {
320 let empty_params: Vec<Vec<f64>> = Vec::new();
321 let mut bound_layers: Vec<Box<dyn Layer<F> + Send + Sync>> = Vec::new();
322
323 for (i, layer_config) in serialized.layers.iter().enumerate() {
324 let params = if i < serialized.parameters.len() {
325 &serialized.parameters[i]
326 } else {
327 &empty_params
328 };
329
330 match layer_config {
331 LayerConfig::Dense(config) => {
332 let layer = create_dense_layer::<F>(config, params)?;
333 bound_layers.push(Box::new(layer));
334 }
335 LayerConfig::Dropout(config) => {
336 let layer = create_dropout::<F>(config)?;
337 bound_layers.push(Box::new(layer));
338 }
339 _ => {
340 return Err(NeuralError::DeserializationError(
341 "Layer type not supported in legacy deserialization. Use SafeTensors API."
342 .to_string(),
343 ));
344 }
345 }
346 }
347
348 Ok(Sequential::from_layers(bound_layers))
349}
350
351#[cfg(feature = "legacy_serialization")]
353#[allow(dead_code)]
354fn create_dense_layer<
355 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
356>(
357 config: &DenseConfig,
358 params: &[Vec<f64>],
359) -> Result<Dense<F>> {
360 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
361 let mut layer = Dense::new(
362 config.input_dim,
363 config.output_dim,
364 config.activation.as_deref(),
365 &mut rng,
366 )?;
367
368 if params.len() >= 2 {
369 let weightsshape = [config.input_dim, config.output_dim];
370 let biasshape = [config.output_dim];
371
372 if params[0].len() == config.output_dim * config.input_dim {
373 let weights_array = match array_from_vec::<F>(¶ms[0], &weightsshape) {
374 Ok(arr) => arr,
375 Err(_) => {
376 let transposedshape = [config.output_dim, config.input_dim];
377 let transposed_arr = array_from_vec::<F>(¶ms[0], &transposedshape)?;
378 transposed_arr.t().to_owned().into_dyn()
379 }
380 };
381 let bias_array = array_from_vec::<F>(¶ms[1], &biasshape)?;
382 layer.set_parameters(vec![weights_array, bias_array])?;
383 } else {
384 return Err(NeuralError::SerializationError(format!(
385 "Weight vector length ({}) doesn't match expected shape size ({})",
386 params[0].len(),
387 config.input_dim * config.output_dim
388 )));
389 }
390 }
391 Ok(layer)
392}
393
394#[cfg(feature = "legacy_serialization")]
396#[allow(dead_code)]
397fn create_dropout<
398 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
399>(
400 config: &DropoutConfig,
401) -> Result<Dropout<F>> {
402 let mut rng = scirs2_core::random::rngs::SmallRng::from_seed([42; 32]);
403 Dropout::new(config.p, &mut rng)
404}
405
406#[cfg(feature = "legacy_serialization")]
408#[allow(dead_code)]
409fn array_from_vec<
410 F: Float + Debug + Display + ScalarOperand + FromPrimitive + NumAssign + Send + Sync + 'static,
411>(
412 vec: &[f64],
413 shape: &[usize],
414) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
415 let shape_size: usize = shape.iter().product();
416 if vec.len() != shape_size {
417 return Err(NeuralError::SerializationError(format!(
418 "Parameter vector length ({}) doesn't match expected shape size ({})",
419 vec.len(),
420 shape_size
421 )));
422 }
423 let f_vec: Vec<F> = vec
424 .iter()
425 .map(|&x| {
426 F::from(x).ok_or_else(|| {
427 NeuralError::SerializationError(format!("Cannot convert {} to target type", x))
428 })
429 })
430 .collect::<Result<Vec<F>>>()?;
431 let shape_ix = scirs2_core::ndarray::IxDyn(shape);
432 Array::from_shape_vec(shape_ix, f_vec)
433 .map_err(|e| NeuralError::SerializationError(e.to_string()))
434}
435
436#[derive(Debug, Clone, Serialize, Deserialize)]
442pub enum ActivationFunction {
443 ReLU,
445 Sigmoid,
447 Tanh,
449 Softmax,
451 LeakyReLU(f64),
453 ELU(f64),
455 GELU,
457 Swish,
459 Mish,
461}
462
463impl ActivationFunction {
464 pub fn from_name(name: &str) -> Option<Self> {
466 match name {
467 "relu" | "ReLU" => Some(ActivationFunction::ReLU),
468 "sigmoid" | "Sigmoid" => Some(ActivationFunction::Sigmoid),
469 "tanh" | "Tanh" => Some(ActivationFunction::Tanh),
470 "softmax" | "Softmax" => Some(ActivationFunction::Softmax),
471 "gelu" | "GELU" => Some(ActivationFunction::GELU),
472 "swish" | "Swish" => Some(ActivationFunction::Swish),
473 "mish" | "Mish" => Some(ActivationFunction::Mish),
474 _ => {
475 if name.starts_with("leaky_relu") || name.starts_with("LeakyReLU") {
476 let parts: Vec<&str> = name.split('(').collect();
477 if parts.len() == 2 {
478 let alpha_str = parts[1].trim_end_matches(')');
479 if let Ok(alpha) = alpha_str.parse::<f64>() {
480 return Some(ActivationFunction::LeakyReLU(alpha));
481 }
482 }
483 Some(ActivationFunction::LeakyReLU(0.01))
484 } else if name.starts_with("elu") || name.starts_with("ELU") {
485 let parts: Vec<&str> = name.split('(').collect();
486 if parts.len() == 2 {
487 let alpha_str = parts[1].trim_end_matches(')');
488 if let Ok(alpha) = alpha_str.parse::<f64>() {
489 return Some(ActivationFunction::ELU(alpha));
490 }
491 }
492 Some(ActivationFunction::ELU(1.0))
493 } else {
494 None
495 }
496 }
497 }
498 }
499
500 pub fn to_name(&self) -> String {
502 match self {
503 ActivationFunction::ReLU => "relu".to_string(),
504 ActivationFunction::Sigmoid => "sigmoid".to_string(),
505 ActivationFunction::Tanh => "tanh".to_string(),
506 ActivationFunction::Softmax => "softmax".to_string(),
507 ActivationFunction::LeakyReLU(alpha) => format!("leaky_relu({})", alpha),
508 ActivationFunction::ELU(alpha) => format!("elu({})", alpha),
509 ActivationFunction::GELU => "gelu".to_string(),
510 ActivationFunction::Swish => "swish".to_string(),
511 ActivationFunction::Mish => "mish".to_string(),
512 }
513 }
514
515 pub fn create<
519 F: Float + Debug + scirs2_core::NumAssign + scirs2_core::ndarray::ScalarOperand + Send + Sync,
520 >(
521 &self,
522 ) -> Box<dyn Activation<F>> {
523 match self {
524 ActivationFunction::ReLU => Box::new(ReLU::new()),
525 ActivationFunction::Sigmoid => Box::new(Sigmoid::new()),
526 ActivationFunction::Tanh => Box::new(Tanh::new()),
527 ActivationFunction::Softmax => Box::new(Softmax::new(1)),
528 ActivationFunction::LeakyReLU(alpha) => Box::new(LeakyReLU::new(*alpha)),
529 ActivationFunction::ELU(alpha) => Box::new(LeakyReLU::new(*alpha)),
530 ActivationFunction::GELU => Box::new(GELU::new()),
531 ActivationFunction::Swish => Box::new(Swish::new(1.0)),
532 ActivationFunction::Mish => Box::new(Mish::new()),
533 }
534 }
535}
536
537pub struct ActivationFactory;
539
540impl ActivationFactory {
541 pub fn create<
543 F: Float + Debug + scirs2_core::NumAssign + scirs2_core::ndarray::ScalarOperand + Send + Sync,
544 >(
545 name: &str,
546 ) -> Option<Box<dyn Activation<F>>> {
547 ActivationFunction::from_name(name).map(|af| af.create::<F>())
548 }
549
550 pub fn get_activation_names() -> HashMap<&'static str, &'static str> {
552 let mut names = HashMap::new();
553 names.insert("relu", "ReLU activation function");
554 names.insert("sigmoid", "Sigmoid activation function");
555 names.insert("tanh", "Tanh activation function");
556 names.insert("softmax", "Softmax activation function");
557 names.insert("leaky_relu", "Leaky ReLU activation function");
558 names.insert("elu", "ELU activation function");
559 names.insert("gelu", "GELU activation function");
560 names.insert("swish", "Swish activation function");
561 names.insert("mish", "Mish activation function");
562 names
563 }
564}
565
566#[cfg(test)]
567mod tests;