use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use memmap2::Mmap;
use serde_json::Value;
use super::{SafeTensorsError, SafeTensorsResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SafeTensorsDtype {
F32,
F16,
BF16,
F64,
I8,
I16,
I32,
I64,
U8,
Bool,
}
impl SafeTensorsDtype {
pub fn from_str(s: &str) -> SafeTensorsResult<Self> {
match s {
"F32" => Ok(Self::F32),
"F16" => Ok(Self::F16),
"BF16" => Ok(Self::BF16),
"F64" => Ok(Self::F64),
"I8" => Ok(Self::I8),
"I16" => Ok(Self::I16),
"I32" => Ok(Self::I32),
"I64" => Ok(Self::I64),
"U8" => Ok(Self::U8),
"BOOL" => Ok(Self::Bool),
_ => Err(SafeTensorsError::UnsupportedDtype(s.to_string())),
}
}
pub fn element_size(&self) -> usize {
match self {
Self::F32 | Self::I32 => 4,
Self::F16 | Self::BF16 | Self::I16 => 2,
Self::F64 | Self::I64 => 8,
Self::I8 | Self::U8 | Self::Bool => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct SafeTensorInfo {
pub dtype: SafeTensorsDtype,
pub shape: Vec<usize>,
pub data_start: usize, pub data_end: usize, }
impl SafeTensorInfo {
pub fn byte_size(&self) -> usize {
self.data_end - self.data_start
}
pub fn n_elements(&self) -> usize {
self.shape.iter().product()
}
}
pub struct SafeTensorsFile {
mmap: Mmap,
tensors: HashMap<String, SafeTensorInfo>,
data_offset: usize, }
impl SafeTensorsFile {
pub fn open(path: impl AsRef<Path>) -> SafeTensorsResult<Self> {
let file = File::open(path.as_ref())?;
let mmap = unsafe { Mmap::map(&file)? };
if mmap.len() < 8 {
return Err(SafeTensorsError::InvalidFormat(
"File too small for header".to_string(),
));
}
let header_size = u64::from_le_bytes(mmap[0..8].try_into().unwrap()) as usize;
if header_size > mmap.len() - 8 {
return Err(SafeTensorsError::InvalidFormat(format!(
"Header size {} exceeds file size {}",
header_size,
mmap.len() - 8
)));
}
let data_offset = 8 + header_size;
let header_bytes = &mmap[8..data_offset];
let header_json: HashMap<String, Value> = serde_json::from_slice(header_bytes)?;
let mut tensors = HashMap::new();
for (name, value) in header_json.iter() {
if name == "__metadata__" {
continue;
}
let obj = value.as_object().ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!("Tensor {} is not an object", name))
})?;
let dtype_str = obj
.get("dtype")
.and_then(|v| v.as_str())
.ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!("Missing dtype for tensor {}", name))
})?;
let dtype = SafeTensorsDtype::from_str(dtype_str)?;
let shape_array = obj
.get("shape")
.and_then(|v| v.as_array())
.ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!("Missing shape for tensor {}", name))
})?;
let shape: Vec<usize> = shape_array
.iter()
.map(|v| {
v.as_u64()
.ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!(
"Invalid shape value for tensor {}",
name
))
})
.map(|x| x as usize)
})
.collect::<SafeTensorsResult<_>>()?;
let data_offsets = obj
.get("data_offsets")
.and_then(|v| v.as_array())
.ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!(
"Missing data_offsets for tensor {}",
name
))
})?;
if data_offsets.len() != 2 {
return Err(SafeTensorsError::InvalidFormat(format!(
"Expected 2 data_offsets for tensor {}, got {}",
name,
data_offsets.len()
)));
}
let data_start = data_offsets[0].as_u64().ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!(
"Invalid data_start for tensor {}",
name
))
})? as usize;
let data_end = data_offsets[1].as_u64().ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!("Invalid data_end for tensor {}", name))
})? as usize;
tensors.insert(
name.clone(),
SafeTensorInfo {
dtype,
shape,
data_start,
data_end,
},
);
}
Ok(Self {
mmap,
tensors,
data_offset,
})
}
pub fn tensor_data(&self, name: &str) -> Option<&[u8]> {
let info = self.tensors.get(name)?;
let start = self.data_offset + info.data_start;
let end = self.data_offset + info.data_end;
if end > self.mmap.len() {
return None;
}
Some(&self.mmap[start..end])
}
pub fn tensor_info(&self, name: &str) -> Option<&SafeTensorInfo> {
self.tensors.get(name)
}
pub fn tensor_names(&self) -> impl Iterator<Item = &str> {
self.tensors.keys().map(|s| s.as_str())
}
pub fn num_tensors(&self) -> usize {
self.tensors.len()
}
}
pub struct ShardedSafeTensors {
shards: Vec<SafeTensorsFile>,
tensor_to_shard: HashMap<String, usize>, }
impl ShardedSafeTensors {
pub fn open(dir: impl AsRef<Path>) -> SafeTensorsResult<Self> {
let dir = dir.as_ref();
let index_path = dir.join("model.safetensors.index.json");
let single_path = dir.join("model.safetensors");
if index_path.exists() {
let index_file = File::open(&index_path)?;
let index_json: Value = serde_json::from_reader(index_file)?;
let weight_map = index_json
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| {
SafeTensorsError::InvalidFormat(
"Missing or invalid weight_map in index.json".to_string(),
)
})?;
let mut shard_filenames: Vec<String> = weight_map
.values()
.filter_map(|v| v.as_str())
.map(|s| s.to_string())
.collect();
shard_filenames.sort();
shard_filenames.dedup();
let mut shards = Vec::new();
let mut shard_name_to_idx = HashMap::new();
for (idx, filename) in shard_filenames.iter().enumerate() {
let shard_path = dir.join(filename);
let shard = SafeTensorsFile::open(&shard_path)?;
shards.push(shard);
shard_name_to_idx.insert(filename.clone(), idx);
}
let mut tensor_to_shard = HashMap::new();
for (tensor_name, shard_filename) in weight_map.iter() {
let shard_filename_str = shard_filename.as_str().ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!(
"Invalid shard filename for tensor {}",
tensor_name
))
})?;
let shard_idx = shard_name_to_idx.get(shard_filename_str).ok_or_else(|| {
SafeTensorsError::InvalidFormat(format!(
"Shard {} not found for tensor {}",
shard_filename_str, tensor_name
))
})?;
tensor_to_shard.insert(tensor_name.clone(), *shard_idx);
}
Ok(Self {
shards,
tensor_to_shard,
})
} else if single_path.exists() {
let shard = SafeTensorsFile::open(&single_path)?;
let tensor_to_shard: HashMap<String, usize> = shard
.tensor_names()
.map(|name| (name.to_string(), 0))
.collect();
Ok(Self {
shards: vec![shard],
tensor_to_shard,
})
} else {
Err(SafeTensorsError::InvalidFormat(
"No model.safetensors or model.safetensors.index.json found".to_string(),
))
}
}
pub fn tensor_data(&self, name: &str) -> Option<&[u8]> {
let shard_idx = self.tensor_to_shard.get(name)?;
let shard = self.shards.get(*shard_idx)?;
shard.tensor_data(name)
}
pub fn tensor_info(&self, name: &str) -> Option<&SafeTensorInfo> {
let shard_idx = self.tensor_to_shard.get(name)?;
let shard = self.shards.get(*shard_idx)?;
shard.tensor_info(name)
}
pub fn tensor_names(&self) -> Vec<String> {
self.tensor_to_shard.keys().cloned().collect()
}
pub fn num_tensors(&self) -> usize {
self.tensor_to_shard.len()
}
}
pub fn bf16_to_f32(data: &[u8]) -> Vec<f32> {
let bf16s: &[u16] = bytemuck::cast_slice(data);
bf16s
.iter()
.map(|&bits| f32::from_bits((bits as u32) << 16))
.collect()
}
pub fn f16_to_f32(data: &[u8]) -> Vec<f32> {
let f16s: &[u16] = bytemuck::cast_slice(data);
f16s
.iter()
.map(|&bits| half::f16::from_bits(bits).to_f32())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_dtype_from_str() {
assert_eq!(SafeTensorsDtype::from_str("F32").unwrap(), SafeTensorsDtype::F32);
assert_eq!(SafeTensorsDtype::from_str("F16").unwrap(), SafeTensorsDtype::F16);
assert_eq!(SafeTensorsDtype::from_str("BF16").unwrap(), SafeTensorsDtype::BF16);
assert_eq!(SafeTensorsDtype::from_str("F64").unwrap(), SafeTensorsDtype::F64);
assert_eq!(SafeTensorsDtype::from_str("I8").unwrap(), SafeTensorsDtype::I8);
assert_eq!(SafeTensorsDtype::from_str("I16").unwrap(), SafeTensorsDtype::I16);
assert_eq!(SafeTensorsDtype::from_str("I32").unwrap(), SafeTensorsDtype::I32);
assert_eq!(SafeTensorsDtype::from_str("I64").unwrap(), SafeTensorsDtype::I64);
assert_eq!(SafeTensorsDtype::from_str("U8").unwrap(), SafeTensorsDtype::U8);
assert_eq!(SafeTensorsDtype::from_str("BOOL").unwrap(), SafeTensorsDtype::Bool);
assert!(SafeTensorsDtype::from_str("INVALID").is_err());
}
#[test]
fn test_parse_single_file() {
let mut tmpfile = NamedTempFile::new().unwrap();
let tensor1_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let tensor2_data: Vec<f32> = vec![10.0, 20.0, 30.0, 40.0];
let tensor1_bytes: Vec<u8> = tensor1_data
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let tensor2_bytes: Vec<u8> = tensor2_data
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let header_json = serde_json::json!({
"__metadata__": {"format": "pt"},
"tensor1": {
"dtype": "F32",
"shape": [2, 3],
"data_offsets": [0, tensor1_bytes.len()]
},
"tensor2": {
"dtype": "F32",
"shape": [4],
"data_offsets": [tensor1_bytes.len(), tensor1_bytes.len() + tensor2_bytes.len()]
}
});
let header_str = serde_json::to_string(&header_json).unwrap();
let header_bytes = header_str.as_bytes();
tmpfile.write_all(&(header_bytes.len() as u64).to_le_bytes()).unwrap();
tmpfile.write_all(header_bytes).unwrap();
tmpfile.write_all(&tensor1_bytes).unwrap();
tmpfile.write_all(&tensor2_bytes).unwrap();
tmpfile.flush().unwrap();
let st = SafeTensorsFile::open(tmpfile.path()).unwrap();
assert_eq!(st.num_tensors(), 2);
let info1 = st.tensor_info("tensor1").unwrap();
assert_eq!(info1.dtype, SafeTensorsDtype::F32);
assert_eq!(info1.shape, vec![2, 3]);
assert_eq!(info1.n_elements(), 6);
assert_eq!(info1.byte_size(), 24);
let data1 = st.tensor_data("tensor1").unwrap();
let f32_data1: Vec<f32> = data1
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
assert_eq!(f32_data1, tensor1_data);
let info2 = st.tensor_info("tensor2").unwrap();
assert_eq!(info2.dtype, SafeTensorsDtype::F32);
assert_eq!(info2.shape, vec![4]);
assert_eq!(info2.n_elements(), 4);
assert_eq!(info2.byte_size(), 16);
let data2 = st.tensor_data("tensor2").unwrap();
let f32_data2: Vec<f32> = data2
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
assert_eq!(f32_data2, tensor2_data);
}
#[test]
fn test_bf16_to_f32() {
let bf16_bytes: Vec<u8> = vec![0x80, 0x3F, 0x00, 0x40]; let f32_vec = bf16_to_f32(&bf16_bytes);
assert_eq!(f32_vec.len(), 2);
assert!((f32_vec[0] - 1.0).abs() < 1e-6);
assert!((f32_vec[1] - 2.0).abs() < 1e-6);
}
#[test]
fn test_f16_to_f32() {
let one_f16 = half::f16::from_f32(1.0);
let two_f16 = half::f16::from_f32(2.0);
let f16_bytes: Vec<u8> = vec![
one_f16.to_bits().to_le_bytes()[0],
one_f16.to_bits().to_le_bytes()[1],
two_f16.to_bits().to_le_bytes()[0],
two_f16.to_bits().to_le_bytes()[1],
];
let f32_vec = f16_to_f32(&f16_bytes);
assert_eq!(f32_vec.len(), 2);
assert!((f32_vec[0] - 1.0).abs() < 1e-3);
assert!((f32_vec[1] - 2.0).abs() < 1e-3);
}
#[test]
fn test_tensor_not_found() {
let mut tmpfile = NamedTempFile::new().unwrap();
let header_json = serde_json::json!({
"tensor1": {
"dtype": "F32",
"shape": [2],
"data_offsets": [0, 8]
}
});
let header_str = serde_json::to_string(&header_json).unwrap();
let header_bytes = header_str.as_bytes();
tmpfile.write_all(&(header_bytes.len() as u64).to_le_bytes()).unwrap();
tmpfile.write_all(header_bytes).unwrap();
tmpfile.write_all(&[0u8; 8]).unwrap(); tmpfile.flush().unwrap();
let st = SafeTensorsFile::open(tmpfile.path()).unwrap();
assert!(st.tensor_info("nonexistent").is_none());
assert!(st.tensor_data("nonexistent").is_none());
}
}