use std::{
collections::HashMap,
io::{Cursor, Read},
};
use serde::{Deserialize, Serialize};
use crate::error::{RealizarError, Result};
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
pub enum SafetensorsDtype {
F32,
F16,
BF16,
I32,
I64,
U8,
Bool,
}
#[derive(Debug, Deserialize)]
struct TensorMetadata {
dtype: SafetensorsDtype,
shape: Vec<usize>,
data_offsets: [usize; 2],
}
#[derive(Debug, Clone, PartialEq)]
pub struct SafetensorsTensorInfo {
pub name: String,
pub dtype: SafetensorsDtype,
pub shape: Vec<usize>,
pub data_offsets: [usize; 2],
}
#[derive(Debug, Clone)]
pub struct SafetensorsModel {
pub tensors: HashMap<String, SafetensorsTensorInfo>,
pub data: Vec<u8>,
}
impl SafetensorsModel {
pub fn from_bytes(data: &[u8]) -> Result<Self> {
let mut cursor = Cursor::new(data);
let metadata_len = Self::parse_header(&mut cursor)?;
let tensors = Self::parse_metadata(&mut cursor, metadata_len)?;
let data_start =
usize::try_from(8 + metadata_len).map_err(|_| RealizarError::UnsupportedOperation {
operation: "convert_data_offset".to_string(),
reason: format!(
"Data offset {} exceeds platform usize limit",
8 + metadata_len
),
})?;
let data = data[data_start..].to_vec();
Ok(Self { tensors, data })
}
pub fn get_tensor_f32(&self, name: &str) -> Result<Vec<f32>> {
let tensor = self
.tensors
.get(name)
.ok_or_else(|| RealizarError::UnsupportedOperation {
operation: "get_tensor_f32".to_string(),
reason: format!("Tensor '{name}' not found"),
})?;
if tensor.dtype != SafetensorsDtype::F32 {
let dtype = &tensor.dtype;
return Err(RealizarError::UnsupportedOperation {
operation: "get_tensor_f32".to_string(),
reason: format!("Tensor '{name}' has dtype {dtype:?}, expected F32"),
});
}
let [start, end] = tensor.data_offsets;
if end > self.data.len() {
let data_len = self.data.len();
return Err(RealizarError::UnsupportedOperation {
operation: "get_tensor_f32".to_string(),
reason: format!("Data offset {end} exceeds data size {data_len}"),
});
}
let bytes = &self.data[start..end];
if bytes.len() % 4 != 0 {
let len = bytes.len();
return Err(RealizarError::UnsupportedOperation {
operation: "get_tensor_f32".to_string(),
reason: format!("Data size {len} is not a multiple of 4"),
});
}
let values = bytes
.chunks_exact(4)
.map(|chunk| {
f32::from_le_bytes(
chunk
.try_into()
.expect("chunks_exact(4) guarantees 4-byte slices"),
)
})
.collect();
Ok(values)
}
fn parse_header(cursor: &mut Cursor<&[u8]>) -> Result<u64> {
let mut buf = [0u8; 8];
cursor
.read_exact(&mut buf)
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "read_metadata_len".to_string(),
reason: e.to_string(),
})?;
Ok(u64::from_le_bytes(buf))
}
fn parse_metadata(
cursor: &mut Cursor<&[u8]>,
len: u64,
) -> Result<HashMap<String, SafetensorsTensorInfo>> {
let len_usize = usize::try_from(len).map_err(|_| RealizarError::UnsupportedOperation {
operation: "convert_metadata_len".to_string(),
reason: format!("Metadata length {len} exceeds platform usize limit"),
})?;
let mut json_bytes = vec![0u8; len_usize];
cursor
.read_exact(&mut json_bytes)
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "read_metadata_json".to_string(),
reason: e.to_string(),
})?;
let json_map: HashMap<String, TensorMetadata> = serde_json::from_slice(&json_bytes)
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "parse_json".to_string(),
reason: e.to_string(),
})?;
let mut tensors = HashMap::new();
for (name, meta) in json_map {
tensors.insert(
name.clone(),
SafetensorsTensorInfo {
name,
dtype: meta.dtype,
shape: meta.shape,
data_offsets: meta.data_offsets,
},
);
}
Ok(tensors)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_empty_safetensors() {
let mut data = Vec::new();
data.extend_from_slice(&2u64.to_le_bytes()); data.extend_from_slice(b"{}");
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 0);
assert_eq!(model.data.len(), 0);
}
#[test]
fn test_invalid_header_truncated() {
let data = [0u8; 4];
let result = SafetensorsModel::from_bytes(&data);
assert!(result.is_err());
}
#[test]
fn test_empty_file() {
let data = &[];
let result = SafetensorsModel::from_bytes(data);
assert!(result.is_err());
}
#[test]
fn test_parse_single_tensor() {
let json = r#"{"weight":{"dtype":"F32","shape":[2,3],"data_offsets":[0,24]}}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&[0u8; 24]);
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 1);
let tensor = model.tensors.get("weight").unwrap();
assert_eq!(tensor.name, "weight");
assert_eq!(tensor.dtype, SafetensorsDtype::F32);
assert_eq!(tensor.shape, vec![2, 3]);
assert_eq!(tensor.data_offsets, [0, 24]);
}
#[test]
fn test_parse_multiple_tensors() {
let json = r#"{
"layer1.weight":{"dtype":"F32","shape":[128,256],"data_offsets":[0,131072]},
"layer1.bias":{"dtype":"F32","shape":[128],"data_offsets":[131072,131584]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&vec![0u8; 131_584]);
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 2);
let weight = model.tensors.get("layer1.weight").unwrap();
assert_eq!(weight.dtype, SafetensorsDtype::F32);
assert_eq!(weight.shape, vec![128, 256]);
assert_eq!(weight.data_offsets, [0, 131_072]);
let bias = model.tensors.get("layer1.bias").unwrap();
assert_eq!(bias.dtype, SafetensorsDtype::F32);
assert_eq!(bias.shape, vec![128]);
assert_eq!(bias.data_offsets, [131_072, 131_584]);
}
#[test]
fn test_parse_various_dtypes() {
let json = r#"{
"f32_tensor":{"dtype":"F32","shape":[2],"data_offsets":[0,8]},
"i32_tensor":{"dtype":"I32","shape":[2],"data_offsets":[8,16]},
"u8_tensor":{"dtype":"U8","shape":[4],"data_offsets":[16,20]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&[0u8; 20]);
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 3);
assert_eq!(
model.tensors.get("f32_tensor").unwrap().dtype,
SafetensorsDtype::F32
);
assert_eq!(
model.tensors.get("i32_tensor").unwrap().dtype,
SafetensorsDtype::I32
);
assert_eq!(
model.tensors.get("u8_tensor").unwrap().dtype,
SafetensorsDtype::U8
);
}
#[test]
fn test_invalid_json_error() {
let mut data = Vec::new();
data.extend_from_slice(&10u64.to_le_bytes()); data.extend_from_slice(b"not json!!");
let result = SafetensorsModel::from_bytes(&data);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RealizarError::UnsupportedOperation { .. }
));
}
#[test]
fn test_truncated_json_error() {
let mut data = Vec::new();
data.extend_from_slice(&100u64.to_le_bytes()); data.extend_from_slice(b"{}");
let result = SafetensorsModel::from_bytes(&data);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
RealizarError::UnsupportedOperation { .. }
));
}
#[test]
fn test_parse_all_dtypes() {
let json = r#"{
"f32":{"dtype":"F32","shape":[1],"data_offsets":[0,4]},
"f16":{"dtype":"F16","shape":[1],"data_offsets":[4,6]},
"bf16":{"dtype":"BF16","shape":[1],"data_offsets":[6,8]},
"i32":{"dtype":"I32","shape":[1],"data_offsets":[8,12]},
"i64":{"dtype":"I64","shape":[1],"data_offsets":[12,20]},
"u8":{"dtype":"U8","shape":[1],"data_offsets":[20,21]},
"bool":{"dtype":"Bool","shape":[1],"data_offsets":[21,22]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&[0u8; 22]);
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 7);
assert_eq!(
model.tensors.get("f32").unwrap().dtype,
SafetensorsDtype::F32
);
assert_eq!(
model.tensors.get("f16").unwrap().dtype,
SafetensorsDtype::F16
);
assert_eq!(
model.tensors.get("bf16").unwrap().dtype,
SafetensorsDtype::BF16
);
assert_eq!(
model.tensors.get("i32").unwrap().dtype,
SafetensorsDtype::I32
);
assert_eq!(
model.tensors.get("i64").unwrap().dtype,
SafetensorsDtype::I64
);
assert_eq!(model.tensors.get("u8").unwrap().dtype, SafetensorsDtype::U8);
assert_eq!(
model.tensors.get("bool").unwrap().dtype,
SafetensorsDtype::Bool
);
}
#[test]
fn test_tensor_data_preserved() {
let json = r#"{"weight":{"dtype":"F32","shape":[2],"data_offsets":[0,8]}}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&1.0f32.to_le_bytes());
data.extend_from_slice(&2.0f32.to_le_bytes());
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.data.len(), 8);
let val1 = f32::from_le_bytes(model.data[0..4].try_into().unwrap());
let val2 = f32::from_le_bytes(model.data[4..8].try_into().unwrap());
assert!((val1 - 1.0).abs() < 1e-6);
assert!((val2 - 2.0).abs() < 1e-6);
}
#[test]
fn test_multidimensional_shapes() {
let json = r#"{
"scalar":{"dtype":"F32","shape":[],"data_offsets":[0,4]},
"vector":{"dtype":"F32","shape":[10],"data_offsets":[4,44]},
"matrix":{"dtype":"F32","shape":[3,4],"data_offsets":[44,92]},
"tensor3d":{"dtype":"F32","shape":[2,3,4],"data_offsets":[92,188]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&[0u8; 188]);
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 4);
assert_eq!(
model.tensors.get("scalar").unwrap().shape,
Vec::<usize>::new()
);
assert_eq!(model.tensors.get("vector").unwrap().shape, vec![10]);
assert_eq!(model.tensors.get("matrix").unwrap().shape, vec![3, 4]);
assert_eq!(model.tensors.get("tensor3d").unwrap().shape, vec![2, 3, 4]);
}
#[test]
fn test_aprender_linear_regression_format_compatibility() {
let json = r#"{
"coefficients":{"dtype":"F32","shape":[3],"data_offsets":[0,12]},
"intercept":{"dtype":"F32","shape":[1],"data_offsets":[12,16]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&2.0f32.to_le_bytes());
data.extend_from_slice(&3.0f32.to_le_bytes());
data.extend_from_slice(&1.5f32.to_le_bytes());
data.extend_from_slice(&0.5f32.to_le_bytes());
let model = SafetensorsModel::from_bytes(&data).unwrap();
assert_eq!(model.tensors.len(), 2);
let coef = model.tensors.get("coefficients").unwrap();
assert_eq!(coef.dtype, SafetensorsDtype::F32);
assert_eq!(coef.shape, vec![3]);
assert_eq!(coef.data_offsets, [0, 12]);
let intercept = model.tensors.get("intercept").unwrap();
assert_eq!(intercept.dtype, SafetensorsDtype::F32);
assert_eq!(intercept.shape, vec![1]);
assert_eq!(intercept.data_offsets, [12, 16]);
let coef_vals: Vec<f32> = (0..3)
.map(|i| {
let offset = i * 4;
f32::from_le_bytes(model.data[offset..offset + 4].try_into().unwrap())
})
.collect();
assert!((coef_vals[0] - 2.0).abs() < 1e-6);
assert!((coef_vals[1] - 3.0).abs() < 1e-6);
assert!((coef_vals[2] - 1.5).abs() < 1e-6);
let intercept_val = f32::from_le_bytes(model.data[12..16].try_into().unwrap());
assert!((intercept_val - 0.5).abs() < 1e-6);
}
#[test]
fn test_get_tensor_f32_helper() {
let json = r#"{
"weights":{"dtype":"F32","shape":[4],"data_offsets":[0,16]},
"bias":{"dtype":"F32","shape":[2],"data_offsets":[16,24]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&1.0f32.to_le_bytes());
data.extend_from_slice(&2.0f32.to_le_bytes());
data.extend_from_slice(&3.0f32.to_le_bytes());
data.extend_from_slice(&4.0f32.to_le_bytes());
data.extend_from_slice(&0.5f32.to_le_bytes());
data.extend_from_slice(&0.25f32.to_le_bytes());
let model = SafetensorsModel::from_bytes(&data).unwrap();
let weights = model.get_tensor_f32("weights").unwrap();
assert_eq!(weights.len(), 4);
assert!((weights[0] - 1.0).abs() < 1e-6);
assert!((weights[1] - 2.0).abs() < 1e-6);
assert!((weights[2] - 3.0).abs() < 1e-6);
assert!((weights[3] - 4.0).abs() < 1e-6);
let bias = model.get_tensor_f32("bias").unwrap();
assert_eq!(bias.len(), 2);
assert!((bias[0] - 0.5).abs() < 1e-6);
assert!((bias[1] - 0.25).abs() < 1e-6);
let result = model.get_tensor_f32("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_get_tensor_f32_wrong_dtype() {
let json = r#"{
"int_tensor":{"dtype":"I32","shape":[2],"data_offsets":[0,8]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&1i32.to_le_bytes());
data.extend_from_slice(&2i32.to_le_bytes());
let model = SafetensorsModel::from_bytes(&data).unwrap();
let result = model.get_tensor_f32("int_tensor");
assert!(result.is_err());
}
#[test]
fn test_get_tensor_f32_with_aprender_model() {
let json = r#"{
"coefficients":{"dtype":"F32","shape":[3],"data_offsets":[0,12]},
"intercept":{"dtype":"F32","shape":[1],"data_offsets":[12,16]}
}"#;
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
data.extend_from_slice(&2.0f32.to_le_bytes());
data.extend_from_slice(&3.0f32.to_le_bytes());
data.extend_from_slice(&1.5f32.to_le_bytes());
data.extend_from_slice(&0.5f32.to_le_bytes());
let model = SafetensorsModel::from_bytes(&data).unwrap();
let coefficients = model.get_tensor_f32("coefficients").unwrap();
assert_eq!(coefficients, vec![2.0, 3.0, 1.5]);
let intercept = model.get_tensor_f32("intercept").unwrap();
assert_eq!(intercept, vec![0.5]);
}
}