use std::collections::BTreeMap;
use std::path::Path;
use super::module::Module;
use crate::autograd::Tensor;
use crate::serialization::safetensors::{extract_tensor, load_safetensors, save_safetensors};
pub type StateDict = BTreeMap<String, (Vec<f32>, Vec<usize>)>;
pub fn state_dict<M: Module + ?Sized>(module: &M, prefix: &str) -> StateDict {
let mut state = StateDict::new();
for (i, param) in module.parameters().iter().enumerate() {
let name = if prefix.is_empty() {
format!("{i}")
} else {
format!("{prefix}.{i}")
};
state.insert(name, (param.data().to_vec(), param.shape().to_vec()));
}
state
}
pub fn load_state_dict_into<M: Module + ?Sized>(
module: &mut M,
state: &StateDict,
prefix: &str,
) -> Result<(), String> {
let params = module.parameters_mut();
for (i, param) in params.into_iter().enumerate() {
let name = if prefix.is_empty() {
format!("{i}")
} else {
format!("{prefix}.{i}")
};
let (data, shape) = state
.get(&name)
.ok_or_else(|| format!("Missing parameter '{name}' in state dict"))?;
if param.shape() != shape.as_slice() {
return Err(format!(
"Shape mismatch for parameter '{name}': expected {:?}, got {:?}",
param.shape(),
shape
));
}
*param = Tensor::new(data, shape).requires_grad();
}
module.refresh_caches();
Ok(())
}
pub fn save_model<M: Module + ?Sized, P: AsRef<Path>>(module: &M, path: P) -> Result<(), String> {
let state = state_dict(module, "");
save_safetensors(path, &state)
}
pub fn load_state_dict<P: AsRef<Path>>(path: P) -> Result<StateDict, String> {
let (metadata, raw_data) = load_safetensors(path)?;
let mut state = StateDict::new();
for (name, tensor_meta) in metadata {
let data = extract_tensor(&raw_data, &tensor_meta)?;
state.insert(name, (data, tensor_meta.shape));
}
Ok(state)
}
pub fn load_model<M: Module + ?Sized, P: AsRef<Path>>(
module: &mut M,
path: P,
) -> Result<(), String> {
let state = load_state_dict(path)?;
load_state_dict_into(module, &state, "")
}
pub fn count_parameters<M: Module + ?Sized>(module: &M) -> usize {
module.parameters().iter().map(|p| p.numel()).sum()
}
pub fn model_size_bytes<M: Module + ?Sized>(module: &M) -> usize {
count_parameters(module) * 4
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::{Linear, ReLU, Sequential};
#[test]
fn test_state_dict_linear() {
let layer = Linear::with_seed(10, 5, Some(42));
let state = state_dict(&layer, "");
assert_eq!(state.len(), 2);
let (weight_data, weight_shape) = &state["0"];
assert_eq!(weight_shape, &[5, 10]);
assert_eq!(weight_data.len(), 50);
let (bias_data, bias_shape) = &state["1"];
assert_eq!(bias_shape, &[5]);
assert_eq!(bias_data.len(), 5);
}
#[test]
fn test_state_dict_sequential() {
let model = Sequential::new()
.add(Linear::with_seed(10, 8, Some(42)))
.add(ReLU::new())
.add(Linear::with_seed(8, 5, Some(43)));
let state = state_dict(&model, "");
assert_eq!(state.len(), 4);
}
#[test]
fn test_load_state_dict_into() {
let layer1 = Linear::with_seed(10, 5, Some(42));
let state = state_dict(&layer1, "");
let mut layer2 = Linear::with_seed(10, 5, Some(99));
assert_ne!(layer1.parameters()[0].data(), layer2.parameters()[0].data());
load_state_dict_into(&mut layer2, &state, "").expect("load_state_dict_into should succeed");
assert_eq!(layer1.parameters()[0].data(), layer2.parameters()[0].data());
}
#[test]
fn test_save_and_load_model() {
let path = "/tmp/test_nn_serialize.safetensors";
let model1 = Linear::with_seed(10, 5, Some(42));
save_model(&model1, path).expect("save_model should succeed");
let mut model2 = Linear::with_seed(10, 5, Some(99));
load_model(&mut model2, path).expect("load_model should succeed");
assert_eq!(model1.parameters()[0].data(), model2.parameters()[0].data());
assert_eq!(model1.parameters()[1].data(), model2.parameters()[1].data());
std::fs::remove_file(path).ok();
}
#[test]
fn test_save_and_load_sequential() {
let path = "/tmp/test_nn_serialize_seq.safetensors";
let model1 = Sequential::new()
.add(Linear::with_seed(10, 8, Some(42)))
.add(ReLU::new())
.add(Linear::with_seed(8, 5, Some(43)));
save_model(&model1, path).expect("save_model should succeed");
let mut model2 = Sequential::new()
.add(Linear::with_seed(10, 8, Some(99)))
.add(ReLU::new())
.add(Linear::with_seed(8, 5, Some(100)));
load_model(&mut model2, path).expect("load_model should succeed");
for (p1, p2) in model1.parameters().iter().zip(model2.parameters().iter()) {
assert_eq!(p1.data(), p2.data());
assert_eq!(p1.shape(), p2.shape());
}
std::fs::remove_file(path).ok();
}
#[test]
fn test_load_state_dict_shape_mismatch() {
let layer1 = Linear::with_seed(10, 5, Some(42));
let state = state_dict(&layer1, "");
let mut layer2 = Linear::with_seed(20, 10, Some(99));
let result = load_state_dict_into(&mut layer2, &state, "");
assert!(result.is_err());
let err = result.expect_err("Should fail with shape mismatch");
assert!(err.contains("Shape mismatch"));
}
#[test]
fn test_count_parameters() {
let model = Sequential::new()
.add(Linear::new(10, 8)) .add(Linear::new(8, 5));
assert_eq!(count_parameters(&model), 133);
}
#[test]
fn test_model_size_bytes() {
let model = Linear::new(10, 5);
assert_eq!(model_size_bytes(&model), 55 * 4); }
#[test]
fn test_model_forward_after_load() {
let path = "/tmp/test_nn_forward_after_load.safetensors";
let model1 = Linear::with_seed(10, 5, Some(42));
let x = Tensor::ones(&[2, 10]);
let y1 = model1.forward(&x);
save_model(&model1, path).expect("save_model should succeed");
let mut model2 = Linear::with_seed(10, 5, Some(99));
load_model(&mut model2, path).expect("load_model should succeed");
let y2 = model2.forward(&x);
assert_eq!(y1.data(), y2.data());
std::fs::remove_file(path).ok();
}
}