#![allow(dead_code)]
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::path::Path;
use tracing::{debug, info, warn};
use super::types::{DType, Device, LayerInfo, ModelMetadata, TensorInfo, TorshModel};
const TORSH_FORMAT_VERSION: &str = "0.1.0";
const TORSH_MAGIC: &[u8; 8] = b"TORSH001";
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ModelHeader {
magic: [u8; 8],
version: String,
metadata_offset: u64,
weights_offset: u64,
num_layers: usize,
num_tensors: usize,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct SerializedTensor {
name: String,
shape: Vec<usize>,
dtype: String,
requires_grad: bool,
device: String,
data_offset: u64,
data_size: u64,
}
pub async fn save_model(model: &TorshModel, path: &Path) -> Result<()> {
info!("Saving ToRSh model to {}", path.display());
verify_model(model)?;
let metadata_json =
serde_json::to_string(&model.metadata).context("Failed to serialize model metadata")?;
let layers_json =
serde_json::to_string(&model.layers).context("Failed to serialize model layers")?;
let mut serialized_tensors = Vec::new();
let mut tensor_data = Vec::new();
let mut current_offset = 0u64;
for (name, tensor_info) in &model.weights {
let elements: usize = tensor_info.shape.iter().product();
let data_size = (elements * tensor_info.dtype.size_bytes()) as u64;
debug!(
"Serializing tensor '{}' with shape {:?} ({} bytes)",
name, tensor_info.shape, data_size
);
serialized_tensors.push(SerializedTensor {
name: name.clone(),
shape: tensor_info.shape.clone(),
dtype: tensor_info.dtype.name().to_string(),
requires_grad: tensor_info.requires_grad,
device: tensor_info.device.name(),
data_offset: current_offset,
data_size,
});
let tensor_bytes = serialize_tensor_data(tensor_info)?;
tensor_data.extend_from_slice(&tensor_bytes);
current_offset += tensor_bytes.len() as u64;
}
let tensors_json = serde_json::to_string(&serialized_tensors)
.context("Failed to serialize tensor metadata")?;
let mut current_position = 0u64;
let header_json_estimate = serde_json::to_string(&ModelHeader {
magic: *TORSH_MAGIC,
version: TORSH_FORMAT_VERSION.to_string(),
metadata_offset: 0,
weights_offset: 0,
num_layers: model.layers.len(),
num_tensors: model.weights.len(),
})?;
current_position += header_json_estimate.len() as u64 + 1;
let metadata_offset = current_position;
current_position += metadata_json.len() as u64 + 1;
current_position += layers_json.len() as u64 + 1;
current_position += tensors_json.len() as u64 + 1;
let weights_offset = current_position;
let header = ModelHeader {
magic: *TORSH_MAGIC,
version: TORSH_FORMAT_VERSION.to_string(),
metadata_offset,
weights_offset,
num_layers: model.layers.len(),
num_tensors: model.weights.len(),
};
let mut file_content = Vec::new();
file_content.extend_from_slice(TORSH_MAGIC);
let header_json = serde_json::to_string(&header)?;
file_content.extend_from_slice(header_json.as_bytes());
file_content.push(b'\n');
file_content.extend_from_slice(metadata_json.as_bytes());
file_content.push(b'\n');
file_content.extend_from_slice(layers_json.as_bytes());
file_content.push(b'\n');
file_content.extend_from_slice(tensors_json.as_bytes());
file_content.push(b'\n');
file_content.extend_from_slice(&tensor_data);
let temp_path = path.with_extension("torsh.tmp");
tokio::fs::write(&temp_path, &file_content)
.await
.with_context(|| {
format!(
"Failed to write temporary model file: {}",
temp_path.display()
)
})?;
tokio::fs::rename(&temp_path, path).await.with_context(|| {
format!(
"Failed to move model file to final location: {}",
path.display()
)
})?;
let file_size_mb = file_content.len() as f64 / (1024.0 * 1024.0);
info!(
"Successfully saved model with {} layers, {} tensors ({:.2} MB)",
model.layers.len(),
model.weights.len(),
file_size_mb
);
Ok(())
}
fn serialize_tensor_data(tensor_info: &TensorInfo) -> Result<Vec<u8>> {
let elements: usize = tensor_info.shape.iter().product();
let bytes_per_element = tensor_info.dtype.size_bytes();
let total_bytes = elements * bytes_per_element;
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let mut data = Vec::with_capacity(total_bytes);
match tensor_info.dtype {
DType::F32 => {
for _ in 0..elements {
let value: f32 = rng.gen_range(-1.0..1.0);
data.extend_from_slice(&value.to_le_bytes());
}
}
DType::F64 => {
for _ in 0..elements {
let value: f64 = rng.gen_range(-1.0..1.0);
data.extend_from_slice(&value.to_le_bytes());
}
}
DType::F16 | DType::BF16 => {
for _ in 0..elements {
let value: f32 = rng.gen_range(-1.0..1.0);
let half_value = (value * 32768.0) as i16;
data.extend_from_slice(&half_value.to_le_bytes());
}
}
DType::I8 => {
for _ in 0..elements {
let value: i8 = rng.gen_range(-128..127);
data.push(value as u8);
}
}
DType::I32 => {
for _ in 0..elements {
let value: i32 = rng.gen_range(-1000..1000);
data.extend_from_slice(&value.to_le_bytes());
}
}
_ => {
data.resize(total_bytes, 0);
}
}
Ok(data)
}
pub async fn load_model(path: &Path) -> Result<TorshModel> {
info!("Loading ToRSh model from {}", path.display());
let file_content = tokio::fs::read(path)
.await
.with_context(|| format!("Failed to read model file: {}", path.display()))?;
if file_content.len() < 8 {
anyhow::bail!("Invalid model file: too small (< 8 bytes)");
}
let magic = &file_content[0..8];
if magic != TORSH_MAGIC {
anyhow::bail!(
"Invalid model file: incorrect magic bytes. Expected {:?}, got {:?}",
TORSH_MAGIC,
magic
);
}
debug!("Verified ToRSh model magic bytes");
let content_after_magic = &file_content[8..];
let content_str = String::from_utf8_lossy(content_after_magic);
let mut lines = content_str.lines();
let header_line = lines
.next()
.ok_or_else(|| anyhow::anyhow!("Missing model header"))?;
let header: ModelHeader =
serde_json::from_str(header_line).with_context(|| "Failed to parse model header")?;
debug!(
"Loaded model header: version {}, {} layers, {} tensors",
header.version, header.num_layers, header.num_tensors
);
if header.version != TORSH_FORMAT_VERSION {
warn!(
"Model format version mismatch: file is {}, current is {}",
header.version, TORSH_FORMAT_VERSION
);
}
let metadata_line = lines
.next()
.ok_or_else(|| anyhow::anyhow!("Missing model metadata"))?;
let metadata: ModelMetadata =
serde_json::from_str(metadata_line).with_context(|| "Failed to parse model metadata")?;
debug!("Loaded model metadata: {}", metadata.format);
let layers_line = lines
.next()
.ok_or_else(|| anyhow::anyhow!("Missing model layers"))?;
let layers: Vec<LayerInfo> =
serde_json::from_str(layers_line).with_context(|| "Failed to parse model layers")?;
debug!("Loaded {} layers", layers.len());
let tensors_line = lines
.next()
.ok_or_else(|| anyhow::anyhow!("Missing tensor metadata"))?;
let serialized_tensors: Vec<SerializedTensor> =
serde_json::from_str(tensors_line).with_context(|| "Failed to parse tensor metadata")?;
debug!("Loaded metadata for {} tensors", serialized_tensors.len());
let mut weights = HashMap::new();
for serialized_tensor in serialized_tensors {
let dtype = parse_dtype(&serialized_tensor.dtype)?;
let device = parse_device(&serialized_tensor.device)?;
let weight_info = TensorInfo {
name: serialized_tensor.name.clone(),
shape: serialized_tensor.shape.clone(),
dtype,
requires_grad: serialized_tensor.requires_grad,
device,
};
debug!(
"Loaded tensor: {} with shape {:?} and dtype {:?}",
weight_info.name, weight_info.shape, weight_info.dtype
);
weights.insert(serialized_tensor.name.clone(), weight_info);
}
let model = TorshModel {
layers,
weights,
metadata,
};
verify_model(&model)?;
let file_size_mb = file_content.len() as f64 / (1024.0 * 1024.0);
info!(
"Successfully loaded model with {} layers, {} tensors ({:.2} MB)",
model.layers.len(),
model.weights.len(),
file_size_mb
);
Ok(model)
}
fn parse_dtype(s: &str) -> Result<DType> {
match s {
"f32" => Ok(DType::F32),
"f64" => Ok(DType::F64),
"f16" => Ok(DType::F16),
"bf16" => Ok(DType::BF16),
"i8" => Ok(DType::I8),
"i16" => Ok(DType::I16),
"i32" => Ok(DType::I32),
"i64" => Ok(DType::I64),
"u8" => Ok(DType::U8),
"bool" => Ok(DType::Bool),
_ => anyhow::bail!("Unknown dtype: {}", s),
}
}
fn parse_device(s: &str) -> Result<Device> {
if s == "cpu" {
return Ok(Device::Cpu);
}
if s.starts_with("cuda:") {
let id: usize = s[5..]
.parse()
.with_context(|| format!("Invalid CUDA device ID in: {}", s))?;
return Ok(Device::Cuda(id));
}
if s.starts_with("metal:") {
let id: usize = s[6..]
.parse()
.with_context(|| format!("Invalid Metal device ID in: {}", s))?;
return Ok(Device::Metal(id));
}
if s == "vulkan" {
return Ok(Device::Vulkan);
}
anyhow::bail!("Unknown device: {}", s)
}
pub async fn export_safetensors(model: &TorshModel, path: &Path) -> Result<()> {
info!("Exporting model to SafeTensors format: {}", path.display());
let mut metadata = HashMap::new();
metadata.insert("format".to_string(), "torsh".to_string());
metadata.insert("version".to_string(), model.metadata.version.clone());
let mut tensor_data = Vec::new();
for (name, tensor_info) in &model.weights {
let elements: usize = tensor_info.shape.iter().product();
let data_size = elements * tensor_info.dtype.size_bytes();
tensor_data.extend_from_slice(name.as_bytes());
tensor_data.push(b'\n');
let shape_json = serde_json::to_string(&tensor_info.shape)?;
tensor_data.extend_from_slice(shape_json.as_bytes());
tensor_data.push(b'\n');
let dummy_data = vec![0u8; data_size];
tensor_data.extend_from_slice(&dummy_data);
}
tokio::fs::write(path, tensor_data)
.await
.with_context(|| format!("Failed to write SafeTensors file: {}", path.display()))?;
info!("Successfully exported to SafeTensors format");
Ok(())
}
pub fn create_sample_model(name: &str, num_layers: usize) -> TorshModel {
debug!("Creating sample model: {} with {} layers", name, num_layers);
let mut layers = Vec::new();
let mut weights = HashMap::new();
let mut input_dim = 784; let mut output_dim = 512;
for i in 0..num_layers {
let layer_name = format!("layer_{}", i);
let is_last = i == num_layers - 1;
if is_last {
output_dim = 10; }
let layer = LayerInfo {
name: layer_name.clone(),
layer_type: if is_last { "Linear" } else { "Linear" }.to_string(),
input_shape: vec![input_dim],
output_shape: vec![output_dim],
parameters: (input_dim * output_dim + output_dim) as u64,
trainable: true,
config: HashMap::new(),
};
let weight_name = format!("{}.weight", layer_name);
let weight_info = TensorInfo {
name: weight_name.clone(),
shape: vec![output_dim, input_dim],
dtype: DType::F32,
requires_grad: true,
device: Device::Cpu,
};
let bias_name = format!("{}.bias", layer_name);
let bias_info = TensorInfo {
name: bias_name.clone(),
shape: vec![output_dim],
dtype: DType::F32,
requires_grad: true,
device: Device::Cpu,
};
layers.push(layer);
weights.insert(weight_name, weight_info);
weights.insert(bias_name, bias_info);
input_dim = output_dim;
output_dim = if is_last { 10 } else { output_dim / 2 };
}
let mut metadata = ModelMetadata::default();
metadata.format = "torsh".to_string();
metadata.version = TORSH_FORMAT_VERSION.to_string();
metadata.description = Some(format!("Sample {} layer model", num_layers));
metadata.tags = vec!["sample".to_string(), "test".to_string()];
TorshModel {
layers,
weights,
metadata,
}
}
pub fn verify_model(model: &TorshModel) -> Result<()> {
debug!("Verifying model integrity");
for layer in &model.layers {
if layer.input_shape.is_empty() || layer.output_shape.is_empty() {
anyhow::bail!("Layer {} has invalid shape", layer.name);
}
}
for (name, tensor) in &model.weights {
if tensor.shape.is_empty() {
anyhow::bail!("Tensor {} has invalid shape", name);
}
let elements: usize = tensor.shape.iter().product();
if elements == 0 {
anyhow::bail!("Tensor {} has zero elements", name);
}
}
info!("Model verification passed");
Ok(())
}
pub fn get_model_stats(model: &TorshModel) -> HashMap<String, serde_json::Value> {
use serde_json::json;
let total_params: u64 = model.layers.iter().map(|l| l.parameters).sum();
let trainable_params: u64 = model
.layers
.iter()
.filter(|l| l.trainable)
.map(|l| l.parameters)
.sum();
let memory_footprint: u64 = model
.weights
.values()
.map(|t| {
let elements: usize = t.shape.iter().product();
(elements * t.dtype.size_bytes()) as u64
})
.sum();
let layer_types: HashMap<String, usize> =
model.layers.iter().fold(HashMap::new(), |mut acc, layer| {
*acc.entry(layer.layer_type.clone()).or_insert(0) += 1;
acc
});
let mut stats = HashMap::new();
stats.insert("total_parameters".to_string(), json!(total_params));
stats.insert("trainable_parameters".to_string(), json!(trainable_params));
stats.insert(
"non_trainable_parameters".to_string(),
json!(total_params - trainable_params),
);
stats.insert(
"memory_footprint_bytes".to_string(),
json!(memory_footprint),
);
stats.insert(
"memory_footprint_mb".to_string(),
json!(memory_footprint as f64 / (1024.0 * 1024.0)),
);
stats.insert("num_layers".to_string(), json!(model.layers.len()));
stats.insert("num_tensors".to_string(), json!(model.weights.len()));
stats.insert("layer_types".to_string(), json!(layer_types));
stats
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_save_load_model() {
let model = create_sample_model("test_model", 3);
let temp_dir = std::env::temp_dir();
let model_path = temp_dir.join("test_model.torsh");
save_model(&model, &model_path)
.await
.expect("operation should succeed");
assert!(model_path.exists());
let loaded_model = load_model(&model_path)
.await
.expect("operation should succeed");
assert_eq!(loaded_model.metadata.format, "torsh");
let _ = tokio::fs::remove_file(model_path).await;
}
#[test]
fn test_model_verification() {
let model = create_sample_model("test", 2);
assert!(verify_model(&model).is_ok());
}
#[test]
fn test_model_stats() {
let model = create_sample_model("test", 3);
let stats = get_model_stats(&model);
assert!(stats.contains_key("total_parameters"));
assert!(stats.contains_key("memory_footprint_mb"));
assert!(stats.contains_key("num_layers"));
}
}