use std::collections::HashMap;
use std::io::{self, Read, Seek};
use serde::Deserialize;
use crate::ir::DType;
#[derive(Debug, Clone)]
pub struct SafeTensorsFile {
pub tensors: Vec<SafeTensorInfo>,
pub data_offset: u64,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct SafeTensorInfo {
pub name: String,
pub dtype: DType,
pub shape: Vec<usize>,
pub data_start: usize,
pub data_end: usize,
}
impl SafeTensorInfo {
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn data_size(&self) -> usize {
self.data_end - self.data_start
}
}
#[derive(Debug, Deserialize)]
struct RawTensorEntry {
dtype: String,
shape: Vec<usize>,
data_offsets: [usize; 2],
}
#[derive(Debug, thiserror::Error)]
pub enum SafeTensorsError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("JSON parse error: {0}")]
Json(#[from] serde_json::Error),
#[error("header size {0} exceeds maximum allowed (100MB)")]
HeaderTooLarge(u64),
#[error("unknown dtype: {0}")]
UnknownDType(String),
}
const MAX_HEADER_SIZE: u64 = 100 * 1024 * 1024;
fn parse_dtype(s: &str) -> Result<DType, SafeTensorsError> {
match s {
"F32" => Ok(DType::F32),
"F16" => Ok(DType::F16),
"BF16" => Ok(DType::BF16),
"I32" => Ok(DType::I32),
"I64" => Ok(DType::I64),
"F64" => Ok(DType::I64), "U8" | "I8" | "BOOL" => Ok(DType::I32), "F8_E4M3" => Ok(DType::F8E4M3),
"F8_E5M2" => Ok(DType::F8E5M2),
_ => Err(SafeTensorsError::UnknownDType(s.to_string())),
}
}
pub fn parse<R: Read + Seek>(mut reader: R) -> Result<SafeTensorsFile, SafeTensorsError> {
let mut size_buf = [0u8; 8];
reader.read_exact(&mut size_buf)?;
let header_size = u64::from_le_bytes(size_buf);
if header_size > MAX_HEADER_SIZE {
return Err(SafeTensorsError::HeaderTooLarge(header_size));
}
let mut header_buf = vec![0u8; header_size as usize];
reader.read_exact(&mut header_buf)?;
let raw: HashMap<String, serde_json::Value> = serde_json::from_slice(&header_buf)?;
let mut tensors = Vec::new();
let mut metadata = HashMap::new();
for (key, value) in &raw {
if key == "__metadata__" {
if let Some(obj) = value.as_object() {
for (mk, mv) in obj {
if let Some(s) = mv.as_str() {
metadata.insert(mk.clone(), s.to_string());
}
}
}
continue;
}
let entry: RawTensorEntry = serde_json::from_value(value.clone())?;
let dtype = parse_dtype(&entry.dtype)?;
tensors.push(SafeTensorInfo {
name: key.clone(),
dtype,
shape: entry.shape,
data_start: entry.data_offsets[0],
data_end: entry.data_offsets[1],
});
}
tensors.sort_by_key(|t| t.data_start);
let data_offset = 8 + header_size;
Ok(SafeTensorsFile {
tensors,
data_offset,
metadata,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn build_safetensors_bytes(header_json: &str) -> Vec<u8> {
let header_bytes = header_json.as_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
buf.extend_from_slice(header_bytes);
buf.extend_from_slice(&[0u8; 256]);
buf
}
#[test]
fn parse_simple_safetensors() {
let header = r#"{
"weight": {"dtype": "F16", "shape": [768, 768], "data_offsets": [0, 1179648]}
}"#;
let bytes = build_safetensors_bytes(header);
let file = parse(Cursor::new(bytes)).unwrap();
assert_eq!(file.tensors.len(), 1);
assert_eq!(file.tensors[0].name, "weight");
assert_eq!(file.tensors[0].dtype, DType::F16);
assert_eq!(file.tensors[0].shape, vec![768, 768]);
assert_eq!(file.tensors[0].data_start, 0);
assert_eq!(file.tensors[0].data_end, 1179648);
}
#[test]
fn parse_multiple_tensors() {
let header = r#"{
"__metadata__": {"format": "pt"},
"model.embed_tokens.weight": {"dtype": "F32", "shape": [32000, 2048], "data_offsets": [0, 262144000]},
"model.layers.0.self_attn.q_proj.weight": {"dtype": "F16", "shape": [2048, 2048], "data_offsets": [262144000, 270532608]}
}"#;
let bytes = build_safetensors_bytes(header);
let file = parse(Cursor::new(bytes)).unwrap();
assert_eq!(file.tensors.len(), 2);
assert_eq!(file.metadata.get("format"), Some(&"pt".to_string()));
assert_eq!(file.tensors[0].name, "model.embed_tokens.weight");
assert_eq!(file.tensors[0].dtype, DType::F32);
assert_eq!(
file.tensors[1].name,
"model.layers.0.self_attn.q_proj.weight"
);
}
#[test]
fn parse_bf16_tensors() {
let header =
r#"{"w": {"dtype": "BF16", "shape": [1024, 512], "data_offsets": [0, 1048576]}}"#;
let bytes = build_safetensors_bytes(header);
let file = parse(Cursor::new(bytes)).unwrap();
assert_eq!(file.tensors[0].dtype, DType::BF16);
}
#[test]
fn reject_too_large_header() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&(MAX_HEADER_SIZE + 1).to_le_bytes());
let result = parse(Cursor::new(bytes));
assert!(matches!(result, Err(SafeTensorsError::HeaderTooLarge(_))));
}
#[test]
fn reject_unknown_dtype() {
let header = r#"{"w": {"dtype": "COMPLEX128", "shape": [10], "data_offsets": [0, 100]}}"#;
let bytes = build_safetensors_bytes(header);
let result = parse(Cursor::new(bytes));
assert!(matches!(result, Err(SafeTensorsError::UnknownDType(_))));
}
#[test]
fn tensor_info_helpers() {
let t = SafeTensorInfo {
name: "test".into(),
dtype: DType::F32,
shape: vec![32, 64],
data_start: 0,
data_end: 8192,
};
assert_eq!(t.numel(), 2048);
assert_eq!(t.data_size(), 8192);
}
#[test]
fn data_offset_calculation() {
let header = r#"{"w": {"dtype": "F32", "shape": [4], "data_offsets": [0, 16]}}"#;
let header_len = header.len() as u64;
let bytes = build_safetensors_bytes(header);
let file = parse(Cursor::new(bytes)).unwrap();
assert_eq!(file.data_offset, 8 + header_len);
}
}