use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateDict {
pub tensors: HashMap<String, TensorData>,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorData {
pub shape: Vec<usize>,
pub data: Vec<f64>, pub dtype: String,
}
impl StateDict {
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
metadata: None,
}
}
pub fn add_tensor<T: Float + 'static>(&mut self, name: String, tensor: &Tensor<T>) {
let tensor_data = TensorData {
shape: tensor.shape().to_vec(),
data: tensor.data.iter().map(|&x| x.to_f64().unwrap()).collect(),
dtype: std::any::type_name::<T>().to_string(),
};
self.tensors.insert(name, tensor_data);
}
pub fn get_tensor<T: Float + 'static>(&self, name: &str) -> Option<Tensor<T>> {
self.tensors.get(name).map(|tensor_data| {
let data: Vec<T> = tensor_data
.data
.iter()
.map(|&x| T::from(x).unwrap())
.collect();
Tensor::from_vec(data, tensor_data.shape.clone())
})
}
pub fn tensor_names(&self) -> Vec<&String> {
self.tensors.keys().collect()
}
pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> RusTorchResult<()> {
let file = File::create(path)?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, self)?;
Ok(())
}
pub fn load_from_file<P: AsRef<Path>>(path: P) -> RusTorchResult<Self> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let state_dict = serde_json::from_reader(reader)?;
Ok(state_dict)
}
pub fn add_metadata(&mut self, key: String, value: String) {
if self.metadata.is_none() {
self.metadata = Some(HashMap::new());
}
self.metadata.as_mut().unwrap().insert(key, value);
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.as_ref()?.get(key)
}
}
impl Default for StateDict {
fn default() -> Self {
Self::new()
}
}
pub struct PyTorchModel {
pub state_dict: StateDict,
pub architecture: Option<String>,
}
impl PyTorchModel {
pub fn new() -> Self {
Self {
state_dict: StateDict::new(),
architecture: None,
}
}
pub fn from_state_dict(state_dict: StateDict) -> Self {
Self {
state_dict,
architecture: None,
}
}
pub fn set_architecture(&mut self, architecture: String) {
self.architecture = Some(architecture);
}
pub fn layer_names(&self) -> Vec<&String> {
self.state_dict.tensor_names()
}
pub fn get_layer_weights<T: Float + 'static>(&self, layer_name: &str) -> Option<Tensor<T>> {
let weight_key = format!("{}.weight", layer_name);
self.state_dict
.get_tensor(&weight_key)
.or_else(|| self.state_dict.get_tensor(layer_name))
}
pub fn get_layer_bias<T: Float + 'static>(&self, layer_name: &str) -> Option<Tensor<T>> {
let bias_key = format!("{}.bias", layer_name);
self.state_dict.get_tensor(&bias_key)
}
pub fn set_layer_weights<T: Float + 'static>(&mut self, layer_name: &str, weights: &Tensor<T>) {
let weight_key = format!("{}.weight", layer_name);
self.state_dict.add_tensor(weight_key, weights);
}
pub fn set_layer_bias<T: Float + 'static>(&mut self, layer_name: &str, bias: &Tensor<T>) {
let bias_key = format!("{}.bias", layer_name);
self.state_dict.add_tensor(bias_key, bias);
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> RusTorchResult<()> {
self.state_dict.save_to_file(path)
}
pub fn load<P: AsRef<Path>>(path: P) -> RusTorchResult<Self> {
let state_dict = StateDict::load_from_file(path)?;
Ok(Self::from_state_dict(state_dict))
}
}
impl Default for PyTorchModel {
fn default() -> Self {
Self::new()
}
}
pub mod utils {
use super::*;
pub fn convert_layer_name(pytorch_name: &str) -> String {
pytorch_name
.replace(".", "_")
.replace("weight", "w")
.replace("bias", "b")
}
pub fn to_pytorch_layer_name(rustorch_name: &str) -> String {
rustorch_name
.replace("_w", ".weight")
.replace("_b", ".bias")
.replace("_", ".")
}
pub fn model_statistics(model: &PyTorchModel) -> HashMap<String, usize> {
let mut stats = HashMap::new();
let mut total_params = 0;
let mut layer_count = 0;
for name in model.layer_names() {
if let Some(tensor_data) = model.state_dict.tensors.get(name) {
let param_count: usize = tensor_data.shape.iter().product();
total_params += param_count;
layer_count += 1;
}
}
stats.insert("total_parameters".to_string(), total_params);
stats.insert("layer_count".to_string(), layer_count);
stats
}
pub fn validate_model(model: &PyTorchModel) -> Result<(), String> {
if model.state_dict.tensors.is_empty() {
return Err("Model has no tensors".to_string());
}
for (name, tensor_data) in &model.state_dict.tensors {
if tensor_data.shape.is_empty() {
return Err(format!("Tensor '{}' has empty shape", name));
}
if tensor_data.data.is_empty() {
return Err(format!("Tensor '{}' has no data", name));
}
let expected_size: usize = tensor_data.shape.iter().product();
if tensor_data.data.len() != expected_size {
return Err(format!(
"Tensor '{}' data size mismatch: expected {}, got {}",
name,
expected_size,
tensor_data.data.len()
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_state_dict_operations() {
let mut state_dict = StateDict::new();
let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
state_dict.add_tensor("test_layer.weight".to_string(), &tensor);
let loaded_tensor: Tensor<f32> = state_dict.get_tensor("test_layer.weight").unwrap();
assert_eq!(loaded_tensor.shape(), &[2, 2]);
assert_eq!(
loaded_tensor.data.as_slice().unwrap(),
&[1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_pytorch_model_operations() {
let mut model = PyTorchModel::new();
let weights = Tensor::<f32>::from_vec(vec![0.1, 0.2, 0.3, 0.4], vec![2, 2]);
let bias = Tensor::<f32>::from_vec(vec![0.1, 0.2], vec![2]);
model.set_layer_weights("linear1", &weights);
model.set_layer_bias("linear1", &bias);
let loaded_weights: Tensor<f32> = model.get_layer_weights("linear1").unwrap();
let loaded_bias: Tensor<f32> = model.get_layer_bias("linear1").unwrap();
assert_eq!(
loaded_weights.data.as_slice().unwrap(),
weights.data.as_slice().unwrap()
);
assert_eq!(
loaded_bias.data.as_slice().unwrap(),
bias.data.as_slice().unwrap()
);
}
#[test]
fn test_save_load_state_dict() {
let mut state_dict = StateDict::new();
let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
state_dict.add_tensor("test".to_string(), &tensor);
state_dict.add_metadata("version".to_string(), "1.0".to_string());
let temp_file = NamedTempFile::new().unwrap();
state_dict.save_to_file(temp_file.path()).unwrap();
let loaded_state_dict = StateDict::load_from_file(temp_file.path()).unwrap();
let loaded_tensor: Tensor<f32> = loaded_state_dict.get_tensor("test").unwrap();
assert_eq!(loaded_tensor.data.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
assert_eq!(
loaded_state_dict.get_metadata("version"),
Some(&"1.0".to_string())
);
}
#[test]
fn test_model_statistics() {
let mut model = PyTorchModel::new();
let weights1 = Tensor::<f32>::from_vec(vec![1.0; 12], vec![3, 4]); let weights2 = Tensor::<f32>::from_vec(vec![1.0; 8], vec![2, 4]);
model.set_layer_weights("layer1", &weights1);
model.set_layer_weights("layer2", &weights2);
let stats = utils::model_statistics(&model);
assert_eq!(stats["total_parameters"], 20);
assert_eq!(stats["layer_count"], 2);
}
#[test]
fn test_model_validation() {
let mut model = PyTorchModel::new();
let weights = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
model.set_layer_weights("linear", &weights);
assert!(utils::validate_model(&model).is_ok());
let empty_model = PyTorchModel::new();
assert!(utils::validate_model(&empty_model).is_err());
}
}