use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use ::ndarray::IxDyn;
#[cfg(feature = "serialization")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serialization")]
use serde_json;
use chrono;
use crate::array_protocol::grad::{Optimizer, SGD};
use crate::array_protocol::ml_ops::ActivationFunc;
use crate::array_protocol::neural::{
BatchNorm, Conv2D, Dropout, Layer, Linear, MaxPool2D, Sequential,
};
use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
use crate::error::{CoreError, CoreResult, ErrorContext};
pub trait Serializable {
fn serialize(&self) -> CoreResult<Vec<u8>>;
fn deserialize(bytes: &[u8]) -> CoreResult<Self>
where
Self: Sized;
fn type_name(&self) -> &str;
}
#[derive(Serialize, Deserialize)]
pub struct ModelFile {
pub metadata: ModelMetadata,
pub architecture: ModelArchitecture,
pub parameter_files: HashMap<String, String>,
pub optimizer_state: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct ModelMetadata {
pub name: String,
pub version: String,
pub framework_version: String,
pub created_at: String,
pub inputshape: Vec<usize>,
pub outputshape: Vec<usize>,
pub additional_info: HashMap<String, String>,
}
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct ModelArchitecture {
pub model_type: String,
pub layers: Vec<LayerConfig>,
}
#[cfg_attr(feature = "serialization", derive(Serialize, Deserialize))]
pub struct LayerConfig {
pub layer_type: String,
pub name: String,
#[cfg(feature = "serialization")]
pub config: serde_json::Value,
#[cfg(not(feature = "serialization"))]
pub config: HashMap<String, String>, }
pub struct ModelSerializer {
basedir: PathBuf,
}
impl ModelSerializer {
pub fn new(basedir: impl AsRef<Path>) -> Self {
Self {
basedir: basedir.as_ref().to_path_buf(),
}
}
pub fn save_model(
&self,
model: &Sequential,
name: &str,
version: &str,
optimizer: Option<&dyn Optimizer>,
) -> CoreResult<PathBuf> {
let modeldir = self.basedir.join(name).join(version);
fs::create_dir_all(&modeldir)?;
let metadata = ModelMetadata {
name: name.to_string(),
version: version.to_string(),
framework_version: "0.1.0".to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
inputshape: vec![], outputshape: vec![], additional_info: HashMap::new(),
};
let architecture = self.create_architecture(model)?;
let mut parameter_files = HashMap::new();
self.save_parameters(model, &modeldir, &mut parameter_files)?;
let optimizer_state = if let Some(optimizer) = optimizer {
let optimizerpath = self.save_optimizer(optimizer, &modeldir)?;
Some(
optimizerpath
.file_name()
.expect("Operation failed")
.to_string_lossy()
.to_string(),
)
} else {
None
};
let model_file = ModelFile {
metadata,
architecture,
parameter_files,
optimizer_state,
};
let model_file_path = modeldir.join("model.json");
let model_file_json = serde_json::to_string_pretty(&model_file)?;
let mut file = File::create(&model_file_path)?;
file.write_all(model_file_json.as_bytes())?;
Ok(model_file_path)
}
pub fn loadmodel(
&self,
name: &str,
version: &str,
) -> CoreResult<(Sequential, Option<Box<dyn Optimizer>>)> {
let modeldir = self.basedir.join(name).join(version);
let model_file_path = modeldir.join("model.json");
let mut file = File::open(&model_file_path)?;
let mut model_file_json = String::new();
file.read_to_string(&mut model_file_json)?;
let model_file: ModelFile = serde_json::from_str(&model_file_json)?;
let model = self.create_model_from_architecture(&model_file.architecture)?;
self.load_parameters(&model, &modeldir, &model_file.parameter_files)?;
let optimizer = if let Some(optimizer_state) = &model_file.optimizer_state {
let optimizerpath = modeldir.join(optimizer_state);
Some(self.load_optimizer(&optimizerpath)?)
} else {
None
};
Ok((model, optimizer))
}
fn create_architecture(&self, model: &Sequential) -> CoreResult<ModelArchitecture> {
let mut layers = Vec::new();
for layer in model.layers() {
let layer_config = self.create_layer_config(layer.as_ref())?;
layers.push(layer_config);
}
Ok(ModelArchitecture {
model_type: "Sequential".to_string(),
layers,
})
}
fn create_layer_config(&self, layer: &dyn Layer) -> CoreResult<LayerConfig> {
let layer_type = layer.layer_type();
if !["Linear", "Conv2D", "MaxPool2D", "BatchNorm", "Dropout"].contains(&layer_type) {
return Err(CoreError::NotImplementedError(ErrorContext::new(format!(
"Serialization not implemented for layer type: {}",
layer.name()
))));
};
let config = match layer_type {
"Linear" => {
serde_json::json!({
"in_features": 0,
"out_features": 0,
"bias": true,
"activation": "relu",
})
}
"Conv2D" => {
serde_json::json!({
"filter_height": 3,
"filter_width": 3,
"in_channels": 0,
"out_channels": 0,
"stride": [1, 1],
"padding": [0, 0],
"bias": true,
"activation": "relu",
})
}
"MaxPool2D" => {
serde_json::json!({
"kernel_size": [2, 2],
"stride": [2, 2],
"padding": [0, 0],
})
}
"BatchNorm" => {
serde_json::json!({
"num_features": 0,
"epsilon": 1e-5,
"momentum": 0.1,
})
}
"Dropout" => {
serde_json::json!({
"rate": 0.5,
"seed": null,
})
}
_ => serde_json::json!({}),
};
Ok(LayerConfig {
layer_type: layer_type.to_string(),
name: layer.name().to_string(),
config,
})
}
fn save_parameters(
&self,
model: &Sequential,
modeldir: &Path,
parameter_files: &mut HashMap<String, String>,
) -> CoreResult<()> {
let params_dir = modeldir.join("parameters");
fs::create_dir_all(¶ms_dir)?;
for (i, layer) in model.layers().iter().enumerate() {
for (j, param) in layer.parameters().iter().enumerate() {
let param_name = format!("layer_{i}_param_{j}");
let param_file = format!("{param_name}.npz");
let param_path = params_dir.join(¶m_file);
self.save_parameter(param.as_ref(), ¶m_path)?;
parameter_files.insert(param_name, format!("parameters/{param_file}"));
}
}
Ok(())
}
fn save_parameter(&self, param: &dyn ArrayProtocol, path: &Path) -> CoreResult<()> {
if let Some(array) = param.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>() {
let ndarray = array.as_array();
let shape: Vec<usize> = ndarray.shape().to_vec();
let data: Vec<f64> = ndarray.iter().cloned().collect();
let save_data = serde_json::json!({
"shape": shape,
"data": data,
});
let mut file = File::create(path)?;
let json_str = serde_json::to_string(&save_data)?;
file.write_all(json_str.as_bytes())?;
Ok(())
} else {
Err(CoreError::NotImplementedError(ErrorContext::new(
"Parameter serialization not implemented for this array type".to_string(),
)))
}
}
fn save_optimizer(&self, _optimizer: &dyn Optimizer, modeldir: &Path) -> CoreResult<PathBuf> {
let optimizerpath = modeldir.join("optimizer.json");
let optimizer_data = serde_json::json!({
"type": "SGD", "config": {
"learningrate": 0.01,
"momentum": null
},
"state": {} });
let mut file = File::create(&optimizerpath)?;
let json_str = serde_json::to_string_pretty(&optimizer_data)?;
file.write_all(json_str.as_bytes())?;
Ok(optimizerpath)
}
fn create_model_from_architecture(
&self,
architecture: &ModelArchitecture,
) -> CoreResult<Sequential> {
let mut model = Sequential::new(&architecture.model_type, Vec::new());
for layer_config in &architecture.layers {
let layer = self.create_layer_from_config(layer_config)?;
model.add_layer(layer);
}
Ok(model)
}
fn create_layer_from_config(&self, config: &LayerConfig) -> CoreResult<Box<dyn Layer>> {
match config.layer_type.as_str() {
"Linear" => {
let in_features = config.config["in_features"].as_u64().unwrap_or(0) as usize;
let out_features = config.config["out_features"].as_u64().unwrap_or(0) as usize;
let bias = config.config["bias"].as_bool().unwrap_or(true);
let activation = match config.config["activation"].as_str() {
Some("relu") => Some(ActivationFunc::ReLU),
Some("sigmoid") => Some(ActivationFunc::Sigmoid),
Some("tanh") => Some(ActivationFunc::Tanh),
_ => None,
};
Ok(Box::new(Linear::new_random(
&config.name,
in_features,
out_features,
bias,
activation,
)))
}
"Conv2D" => {
let filter_height = config.config["filter_height"].as_u64().unwrap_or(3) as usize;
let filter_width = config.config["filter_width"].as_u64().unwrap_or(3) as usize;
let in_channels = config.config["in_channels"].as_u64().unwrap_or(0) as usize;
let out_channels = config.config["out_channels"].as_u64().unwrap_or(0) as usize;
let stride = (
config.config["stride"][0].as_u64().unwrap_or(1) as usize,
config.config["stride"][1].as_u64().unwrap_or(1) as usize,
);
let padding = (
config.config["padding"][0].as_u64().unwrap_or(0) as usize,
config.config["padding"][1].as_u64().unwrap_or(0) as usize,
);
let bias = config.config["bias"].as_bool().unwrap_or(true);
let activation = match config.config["activation"].as_str() {
Some("relu") => Some(ActivationFunc::ReLU),
Some("sigmoid") => Some(ActivationFunc::Sigmoid),
Some("tanh") => Some(ActivationFunc::Tanh),
_ => None,
};
Ok(Box::new(Conv2D::withshape(
&config.name,
filter_height,
filter_width,
in_channels,
out_channels,
stride,
padding,
bias,
activation,
)))
}
"MaxPool2D" => {
let kernel_size = (
config.config["kernel_size"][0].as_u64().unwrap_or(2) as usize,
config.config["kernel_size"][1].as_u64().unwrap_or(2) as usize,
);
let stride = if config.config["stride"].is_array() {
Some((
config.config["stride"][0].as_u64().unwrap_or(2) as usize,
config.config["stride"][1].as_u64().unwrap_or(2) as usize,
))
} else {
None
};
let padding = (
config.config["padding"][0].as_u64().unwrap_or(0) as usize,
config.config["padding"][1].as_u64().unwrap_or(0) as usize,
);
Ok(Box::new(MaxPool2D::new(
&config.name,
kernel_size,
stride,
padding,
)))
}
"BatchNorm" => {
let num_features = config.config["num_features"].as_u64().unwrap_or(0) as usize;
let epsilon = config.config["epsilon"].as_f64().unwrap_or(1e-5);
let momentum = config.config["momentum"].as_f64().unwrap_or(0.1);
Ok(Box::new(BatchNorm::withshape(
&config.name,
num_features,
Some(epsilon),
Some(momentum),
)))
}
"Dropout" => {
let rate = config.config["rate"].as_f64().unwrap_or(0.5);
let seed = config.config["seed"].as_u64();
Ok(Box::new(Dropout::new(&config.name, rate, seed)))
}
_ => Err(CoreError::NotImplementedError(ErrorContext::new(format!(
"Deserialization not implemented for layer type: {layer_type}",
layer_type = config.layer_type
)))),
}
}
fn load_parameters(
&self,
model: &Sequential,
modeldir: &Path,
parameter_files: &HashMap<String, String>,
) -> CoreResult<()> {
for (i, layer) in model.layers().iter().enumerate() {
let params = layer.parameters();
for (j, param) in params.iter().enumerate() {
let param_name = format!("layer_{i}_param_{j}");
if let Some(param_file) = parameter_files.get(¶m_name) {
let param_path = modeldir.join(param_file);
if param_path.exists() {
let mut file = File::open(¶m_path)?;
let mut json_str = String::new();
file.read_to_string(&mut json_str)?;
let load_data: serde_json::Value = serde_json::from_str(&json_str)?;
let shape: Vec<usize> = serde_json::from_value(load_data["shape"].clone())?;
let _data: Vec<f64> = serde_json::from_value(load_data["data"].clone())?;
if let Some(_array) =
param.as_any().downcast_ref::<NdarrayWrapper<f64, IxDyn>>()
{
}
} else {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"Parameter file not found: {path}",
path = param_path.display()
))));
}
}
}
}
Ok(())
}
fn load_optimizer(&self, optimizerpath: &Path) -> CoreResult<Box<dyn Optimizer>> {
if !optimizerpath.exists() {
return Err(CoreError::InvalidArgument(ErrorContext::new(format!(
"Optimizer file not found: {path}",
path = optimizerpath.display()
))));
}
let mut file = File::open(optimizerpath)?;
let mut json_str = String::new();
file.read_to_string(&mut json_str)?;
let optimizer_data: serde_json::Value = serde_json::from_str(&json_str)?;
match optimizer_data["type"].as_str() {
Some("SGD") => {
let config = &optimizer_data["config"];
let learningrate = config["learningrate"].as_f64().unwrap_or(0.01);
let momentum = config["momentum"].as_f64();
Ok(Box::new(SGD::new(learningrate, momentum)))
}
_ => {
Ok(Box::new(SGD::new(0.01, None)))
}
}
}
}
pub struct OnnxExporter;
impl OnnxExporter {
pub fn export(
&self,
_model: &Sequential,
path: impl AsRef<Path>,
_inputshape: &[usize],
) -> CoreResult<()> {
File::create(path.as_ref())?;
Ok(())
}
}
#[allow(dead_code)]
pub fn save_checkpoint(
model: &Sequential,
optimizer: &dyn Optimizer,
path: impl AsRef<Path>,
epoch: usize,
metrics: HashMap<String, f64>,
) -> CoreResult<()> {
let checkpoint_dir = path.as_ref().parent().unwrap_or(Path::new("."));
fs::create_dir_all(checkpoint_dir)?;
let metadata = serde_json::json!({
"epoch": epoch,
"metrics": metrics,
"timestamp": chrono::Utc::now().to_rfc3339(),
});
let metadata_path = path.as_ref().with_extension("json");
let metadata_json = serde_json::to_string_pretty(&metadata)?;
let mut file = File::create(&metadata_path)?;
file.write_all(metadata_json.as_bytes())?;
let serializer = ModelSerializer::new(checkpoint_dir);
let model_name = "checkpoint";
let model_version = format!("epoch_{epoch}");
serializer.save_model(model, model_name, &model_version, Some(optimizer))?;
Ok(())
}
pub type ModelCheckpoint = (Sequential, Box<dyn Optimizer>, usize, HashMap<String, f64>);
#[cfg(feature = "serialization")]
#[allow(dead_code)]
pub fn load_checkpoint(path: impl AsRef<Path>) -> CoreResult<ModelCheckpoint> {
let metadata_path = path.as_ref().with_extension("json");
let mut file = File::open(&metadata_path)?;
let mut metadata_json = String::new();
file.read_to_string(&mut metadata_json)?;
let metadata: serde_json::Value = serde_json::from_str(&metadata_json)?;
let epoch = metadata["epoch"].as_u64().unwrap_or(0) as usize;
let metrics: HashMap<String, f64> =
serde_json::from_value(metadata["metrics"].clone()).unwrap_or_else(|_| HashMap::new());
let checkpoint_dir = path.as_ref().parent().unwrap_or(Path::new("."));
let serializer = ModelSerializer::new(checkpoint_dir);
let model_name = "checkpoint";
let model_version = format!("epoch_{epoch}");
let (model, optimizer) = serializer.loadmodel(model_name, &model_version)?;
Ok((model, optimizer.expect("Operation failed"), epoch, metrics))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::array_protocol;
use crate::array_protocol::grad::SGD;
use crate::array_protocol::ml_ops::ActivationFunc;
use crate::array_protocol::neural::{Linear, Sequential};
use tempfile::tempdir;
#[test]
fn test_model_serializer() {
array_protocol::init();
let temp_dir = match tempdir() {
Ok(dir) => dir,
Err(e) => {
println!("Skipping test_model_serializer (temp dir creation failed): {e}");
return;
}
};
let mut model = Sequential::new("test_model", Vec::new());
model.add_layer(Box::new(Linear::new_random(
"fc1",
10,
5,
true,
Some(ActivationFunc::ReLU),
)));
model.add_layer(Box::new(Linear::new_random("fc2", 5, 2, true, None)));
let optimizer = SGD::new(0.01, Some(0.9));
let serializer = ModelSerializer::new(temp_dir.path());
let model_path = serializer.save_model(&model, "test_model", "v1", Some(&optimizer));
if model_path.is_err() {
println!("Save model failed: {:?}", model_path.err());
return;
}
let (loadedmodel, loaded_optimizer) = serializer
.loadmodel("test_model", "v1")
.expect("Operation failed");
assert_eq!(loadedmodel.layers().len(), 2);
assert!(loaded_optimizer.is_some());
}
#[test]
fn test_save_load_checkpoint() {
array_protocol::init();
let temp_dir = match tempdir() {
Ok(dir) => dir,
Err(e) => {
println!("Skipping test_save_load_checkpoint (temp dir creation failed): {e}");
return;
}
};
let mut model = Sequential::new("test_model", Vec::new());
model.add_layer(Box::new(Linear::new_random(
"fc1",
10,
5,
true,
Some(ActivationFunc::ReLU),
)));
let optimizer = SGD::new(0.01, Some(0.9));
let mut metrics = HashMap::new();
metrics.insert("loss".to_string(), 0.1);
metrics.insert("accuracy".to_string(), 0.9);
let checkpoint_path = temp_dir.path().join("checkpoint");
let result = save_checkpoint(&model, &optimizer, &checkpoint_path, 10, metrics.clone());
if let Err(e) = result {
println!("Skipping test_save_load_checkpoint (save failed): {e}");
return;
}
let result = load_checkpoint(&checkpoint_path);
if let Err(e) = result {
println!("Skipping test_save_load_checkpoint (load failed): {e}");
return;
}
let (loadedmodel, loaded_optimizer, loaded_epoch, loaded_metrics) =
result.expect("Operation failed");
assert_eq!(loadedmodel.layers().len(), 1);
assert_eq!(loaded_epoch, 10);
assert_eq!(loaded_metrics.get("loss"), metrics.get("loss"));
assert_eq!(loaded_metrics.get("accuracy"), metrics.get("accuracy"));
}
}