use crate::error::{NeuralError, Result};
use crate::models::architectures::{
BertConfig, BertModel, ResNet, ResNetBlock, ResNetConfig, ResNetLayer,
};
use crate::serialization::architecture::{
ArchitectureConfig, SerializableBertConfig, SerializableResNetConfig,
};
use crate::serialization::safetensors::{SafeTensorsReader, SafeTensorsWriter};
use crate::serialization::traits::{ModelMetadata, NamedParameters};
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use scirs2_core::simd_ops::SimdUnifiedOps;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::fs;
use std::path::Path;
pub trait ModelSerializer<
F: Float + Debug + ScalarOperand + NumAssign + ToPrimitive + FromPrimitive + 'static,
>
{
fn get_config(&self) -> Result<serde_json::Value>;
fn named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>>;
fn load_params(&mut self, params: &HashMap<String, Array<F, IxDyn>>) -> Result<()>;
fn save_model(&self, dir: &Path) -> Result<()> {
fs::create_dir_all(dir).map_err(|e| NeuralError::IOError(e.to_string()))?;
let config_value = self.get_config()?;
let config_json = serde_json::to_string_pretty(&config_value)
.map_err(|e| NeuralError::SerializationError(e.to_string()))?;
fs::write(dir.join("config.json"), config_json)
.map_err(|e| NeuralError::IOError(e.to_string()))?;
let named = self.named_params()?;
let total_params: usize = named.iter().map(|(_, a)| a.len()).sum();
let mut np = NamedParameters::new();
for (name, arr) in &named {
let shape: Vec<usize> = arr.shape().to_vec();
let values: Vec<f64> = arr
.iter()
.map(|&x| {
x.to_f64().ok_or_else(|| {
NeuralError::SerializationError(
"Cannot convert parameter to f64".to_string(),
)
})
})
.collect::<Result<Vec<f64>>>()?;
np.add(name, values, shape);
}
let metadata = ModelMetadata::new("model", "f64", total_params).with_extra(
"config",
&serde_json::to_string(&config_value)
.map_err(|e| NeuralError::SerializationError(e.to_string()))?,
);
let mut writer = SafeTensorsWriter::new();
writer.add_model_metadata(&metadata);
writer.add_named_parameters(&np)?;
writer.write_to_file(&dir.join("weights.safetensors"))?;
Ok(())
}
fn load_model(&mut self, dir: &Path) -> Result<()> {
let weights_path = dir.join("weights.safetensors");
let reader = SafeTensorsReader::from_file(&weights_path)?;
let mut params_map: HashMap<String, Array<F, IxDyn>> = HashMap::new();
for name in reader.tensor_names() {
let (values_f64, shape) = reader.read_f64_tensor(name)?;
let f_values: Vec<F> = values_f64
.iter()
.map(|&x| {
F::from(x).ok_or_else(|| {
NeuralError::DeserializationError(format!(
"Cannot convert {x} to target float type"
))
})
})
.collect::<Result<Vec<F>>>()?;
let arr = Array::from_shape_vec(IxDyn(&shape), f_values)?;
params_map.insert(name.to_string(), arr);
}
self.load_params(¶ms_map)
}
}
impl<F> ModelSerializer<F> for ResNet<F>
where
F: Float
+ Debug
+ ScalarOperand
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ Send
+ Sync
+ 'static,
{
fn get_config(&self) -> Result<serde_json::Value> {
let ser_config = SerializableResNetConfig::from(self.config());
let arch = ArchitectureConfig {
architecture: "ResNet".to_string(),
format_version: "1.0".to_string(),
config: serde_json::to_value(&ser_config)
.map_err(|e| NeuralError::SerializationError(e.to_string()))?,
};
serde_json::to_value(&arch).map_err(|e| NeuralError::SerializationError(e.to_string()))
}
fn named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
self.extract_named_params()
}
fn load_params(&mut self, params: &HashMap<String, Array<F, IxDyn>>) -> Result<()> {
self.load_named_params(params)
}
}
impl<F> ModelSerializer<F> for BertModel<F>
where
F: Float
+ Debug
+ ScalarOperand
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ Send
+ Sync
+ SimdUnifiedOps
+ 'static,
{
fn get_config(&self) -> Result<serde_json::Value> {
let ser_config = SerializableBertConfig::from(self.config());
let arch = ArchitectureConfig {
architecture: "BERT".to_string(),
format_version: "1.0".to_string(),
config: serde_json::to_value(&ser_config)
.map_err(|e| NeuralError::SerializationError(e.to_string()))?,
};
serde_json::to_value(&arch).map_err(|e| NeuralError::SerializationError(e.to_string()))
}
fn named_params(&self) -> Result<Vec<(String, Array<F, IxDyn>)>> {
self.extract_named_params()
}
fn load_params(&mut self, params: &HashMap<String, Array<F, IxDyn>>) -> Result<()> {
self.load_named_params(params)
}
}
pub fn named_parameters_to_map<F>(
params: &NamedParameters,
) -> Result<HashMap<String, Array<F, IxDyn>>>
where
F: Float + FromPrimitive + 'static,
{
let mut map = HashMap::new();
for (name, values, shape) in ¶ms.parameters {
let f_values: Vec<F> = values
.iter()
.map(|&x| {
F::from(x).ok_or_else(|| {
NeuralError::DeserializationError(format!(
"Cannot convert {x} to target float type"
))
})
})
.collect::<Result<Vec<F>>>()?;
let arr = Array::from_shape_vec(IxDyn(shape), f_values)?;
map.insert(name.clone(), arr);
}
Ok(map)
}
pub fn save_resnet<F>(model: &ResNet<F>, dir: &Path) -> Result<()>
where
F: Float
+ Debug
+ ScalarOperand
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ Send
+ Sync
+ 'static,
{
model.save_model(dir)
}
pub fn load_resnet<F>(model: &mut ResNet<F>, dir: &Path) -> Result<()>
where
F: Float
+ Debug
+ ScalarOperand
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ Send
+ Sync
+ 'static,
{
model.load_model(dir)
}
pub fn save_bert<F>(model: &BertModel<F>, dir: &Path) -> Result<()>
where
F: Float
+ Debug
+ ScalarOperand
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ Send
+ Sync
+ SimdUnifiedOps
+ 'static,
{
model.save_model(dir)
}
pub fn load_bert<F>(model: &mut BertModel<F>, dir: &Path) -> Result<()>
where
F: Float
+ Debug
+ ScalarOperand
+ NumAssign
+ ToPrimitive
+ FromPrimitive
+ Send
+ Sync
+ SimdUnifiedOps
+ 'static,
{
model.load_model(dir)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::architectures::ResNetConfig;
#[test]
fn test_resnet_extract_named_params() -> Result<()> {
let model = ResNet::<f64>::resnet18(3, 10)?;
let params = model.extract_named_params()?;
assert!(!params.is_empty(), "ResNet should have parameters");
let names: Vec<&str> = params.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"conv1.weight"),
"Should have conv1.weight, got: {:?}",
&names[..names.len().min(5)]
);
assert!(names.contains(&"fc.weight"), "Should have fc.weight");
assert!(names.contains(&"fc.bias"), "Should have fc.bias");
assert!(names.contains(&"bn1.weight"), "Should have bn1.weight");
assert!(names.contains(&"bn1.bias"), "Should have bn1.bias");
Ok(())
}
#[test]
fn test_resnet_save_load_roundtrip() -> Result<()> {
let test_dir = std::env::temp_dir().join("scirs2_resnet_serialization_test");
fs::create_dir_all(&test_dir).map_err(|e| NeuralError::IOError(e.to_string()))?;
let model = ResNet::<f64>::resnet18(3, 10)?;
model.save_model(&test_dir)?;
assert!(
test_dir.join("config.json").exists(),
"config.json should exist"
);
assert!(
test_dir.join("weights.safetensors").exists(),
"weights.safetensors should exist"
);
let mut loaded_model = ResNet::<f64>::resnet18(3, 10)?;
loaded_model.load_model(&test_dir)?;
let original_params = model.extract_named_params()?;
let loaded_params = loaded_model.extract_named_params()?;
assert_eq!(
original_params.len(),
loaded_params.len(),
"Parameter count mismatch after roundtrip"
);
for ((orig_name, orig_arr), (load_name, load_arr)) in
original_params.iter().zip(loaded_params.iter())
{
assert_eq!(orig_name, load_name, "Parameter name mismatch");
assert_eq!(
orig_arr.shape(),
load_arr.shape(),
"Shape mismatch for {orig_name}"
);
let max_diff = orig_arr
.iter()
.zip(load_arr.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
max_diff < 1e-10,
"Value mismatch for {orig_name}: max_diff = {max_diff}"
);
}
let _ = fs::remove_dir_all(&test_dir);
Ok(())
}
#[test]
fn test_resnet_get_config() -> Result<()> {
let model = ResNet::<f64>::resnet18(3, 100)?;
let config = model.get_config()?;
assert!(config.get("architecture").is_some());
let arch = config["architecture"]
.as_str()
.expect("architecture should be a string");
assert_eq!(arch, "ResNet");
Ok(())
}
#[test]
fn test_resnet_partial_load_graceful() -> Result<()> {
let mut model = ResNet::<f64>::resnet18(3, 10)?;
let empty_map: HashMap<String, Array<f64, IxDyn>> = HashMap::new();
let result = model.load_params(&empty_map);
assert!(
result.is_ok(),
"Loading empty param map should succeed gracefully"
);
Ok(())
}
#[test]
fn test_resnet_named_params_no_duplicates() -> Result<()> {
let model = ResNet::<f64>::resnet18(3, 10)?;
let params = model.extract_named_params()?;
let mut seen = std::collections::HashSet::new();
for (name, _) in ¶ms {
assert!(
seen.insert(name.clone()),
"Duplicate parameter name: {name}"
);
}
Ok(())
}
#[test]
fn test_bert_extract_named_params() -> Result<()> {
let config = BertConfig::custom(100, 32, 2, 4);
let model = BertModel::<f64>::new(config)?;
let params = model.extract_named_params()?;
assert!(!params.is_empty(), "BERT should have parameters");
let names: Vec<&str> = params.iter().map(|(n, _)| n.as_str()).collect();
assert!(
names.contains(&"embeddings.word_embeddings.weight"),
"Should have word embeddings, got names: {:?}",
&names[..names.len().min(10)]
);
assert!(names.contains(&"embeddings.position_embeddings.weight"));
assert!(names.contains(&"embeddings.token_type_embeddings.weight"));
assert!(names.contains(&"embeddings.LayerNorm.weight"));
assert!(names.contains(&"embeddings.LayerNorm.bias"));
assert!(names.contains(&"encoder.layer.0.attention.self.query.weight"));
assert!(names.contains(&"encoder.layer.0.attention.self.key.weight"));
assert!(names.contains(&"encoder.layer.0.attention.self.value.weight"));
assert!(names.contains(&"encoder.layer.0.attention.output.dense.weight"));
assert!(names.contains(&"encoder.layer.0.attention.output.LayerNorm.weight"));
assert!(names.contains(&"encoder.layer.0.intermediate.dense.weight"));
assert!(names.contains(&"encoder.layer.0.output.dense.weight"));
assert!(names.contains(&"encoder.layer.0.output.LayerNorm.weight"));
assert!(names.contains(&"pooler.dense.weight"));
assert!(names.contains(&"pooler.dense.bias"));
Ok(())
}
#[test]
fn test_bert_save_load_roundtrip() -> Result<()> {
let test_dir = std::env::temp_dir().join("scirs2_bert_serialization_test");
fs::create_dir_all(&test_dir).map_err(|e| NeuralError::IOError(e.to_string()))?;
let config = BertConfig::custom(100, 32, 2, 4);
let model = BertModel::<f64>::new(config.clone())?;
model.save_model(&test_dir)?;
assert!(test_dir.join("config.json").exists());
assert!(test_dir.join("weights.safetensors").exists());
let mut loaded_model = BertModel::<f64>::new(config)?;
loaded_model.load_model(&test_dir)?;
let original_params = model.extract_named_params()?;
let loaded_params = loaded_model.extract_named_params()?;
assert_eq!(
original_params.len(),
loaded_params.len(),
"BERT parameter count mismatch after roundtrip"
);
let loaded_map: HashMap<String, &Array<f64, IxDyn>> =
loaded_params.iter().map(|(n, a)| (n.clone(), a)).collect();
for (name, orig_arr) in &original_params {
let load_arr = loaded_map.get(name).ok_or_else(|| {
NeuralError::DeserializationError(format!("Missing parameter: {name}"))
})?;
let max_diff = orig_arr
.iter()
.zip(load_arr.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f64, f64::max);
assert!(
max_diff < 1e-10,
"Value mismatch for {name}: max_diff = {max_diff}"
);
}
let _ = fs::remove_dir_all(&test_dir);
Ok(())
}
#[test]
fn test_bert_no_duplicate_param_names() -> Result<()> {
let config = BertConfig::custom(100, 32, 2, 4);
let model = BertModel::<f64>::new(config)?;
let params = model.extract_named_params()?;
let mut seen = std::collections::HashSet::new();
for (name, _) in ¶ms {
assert!(
seen.insert(name.clone()),
"Duplicate BERT parameter name: {name}"
);
}
Ok(())
}
#[test]
fn test_bert_partial_load_graceful() -> Result<()> {
let config = BertConfig::custom(100, 32, 1, 4);
let mut model = BertModel::<f64>::new(config)?;
let empty_map: HashMap<String, Array<f64, IxDyn>> = HashMap::new();
let result = model.load_params(&empty_map);
assert!(
result.is_ok(),
"Loading empty param map should succeed gracefully"
);
Ok(())
}
#[test]
fn test_bert_cross_version_compatibility() -> Result<()> {
let test_dir = std::env::temp_dir().join("scirs2_bert_cross_version_test");
fs::create_dir_all(&test_dir).map_err(|e| NeuralError::IOError(e.to_string()))?;
let config_2_layers = BertConfig::custom(100, 32, 2, 4);
let model_2l = BertModel::<f64>::new(config_2_layers)?;
model_2l.save_model(&test_dir)?;
let config_1_layer = BertConfig::custom(100, 32, 1, 4);
let mut model_1l = BertModel::<f64>::new(config_1_layer)?;
let result = model_1l.load_model(&test_dir);
assert!(
result.is_ok(),
"Cross-version load (2-layer into 1-layer) should succeed gracefully: {:?}",
result
);
let _ = fs::remove_dir_all(&test_dir);
Ok(())
}
}