use crate::arrow::{ArrowTensor, ArrowTensorStore, TensorDtype};
use bytes::Bytes;
use memmap2::Mmap;
use safetensors::tensor::{SafeTensorError, SafeTensors};
use safetensors::{Dtype, View};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::Path;
pub struct SafetensorsReader {
mmap: Option<Mmap>,
bytes: Option<Bytes>,
metadata: HashMap<String, TensorInfo>,
global_metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorInfo {
pub name: String,
pub dtype: TensorDtype,
pub shape: Vec<usize>,
pub data_offset: usize,
pub data_size: usize,
}
impl SafetensorsReader {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, SafetensorError> {
let file = File::open(path.as_ref()).map_err(SafetensorError::Io)?;
let mmap = unsafe { Mmap::map(&file).map_err(SafetensorError::Io)? };
Self::from_mmap(mmap)
}
fn from_mmap(mmap: Mmap) -> Result<Self, SafetensorError> {
let tensors = SafeTensors::deserialize(&mmap)?;
let mut metadata = HashMap::new();
let global_metadata = HashMap::new();
for (name, view) in tensors.tensors() {
let dtype = convert_safetensor_dtype(view.dtype());
let shape = view.shape().to_vec();
let data = view.data();
let info = TensorInfo {
name: name.clone(),
dtype,
shape,
data_offset: data.as_ptr() as usize - mmap.as_ptr() as usize,
data_size: data.len(),
};
metadata.insert(name, info);
}
Ok(Self {
mmap: Some(mmap),
bytes: None,
metadata,
global_metadata,
})
}
pub fn from_bytes(bytes: Bytes) -> Result<Self, SafetensorError> {
let tensors = SafeTensors::deserialize(&bytes)?;
let mut metadata = HashMap::new();
let global_metadata = HashMap::new();
for (name, view) in tensors.tensors() {
let dtype = convert_safetensor_dtype(view.dtype());
let shape = view.shape().to_vec();
let data = view.data();
let info = TensorInfo {
name: name.clone(),
dtype,
shape,
data_offset: data.as_ptr() as usize - bytes.as_ptr() as usize,
data_size: data.len(),
};
metadata.insert(name, info);
}
Ok(Self {
mmap: None,
bytes: Some(bytes),
metadata,
global_metadata,
})
}
pub fn tensor_names(&self) -> Vec<&str> {
self.metadata.keys().map(|s| s.as_str()).collect()
}
pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
self.metadata.get(name)
}
pub fn global_metadata(&self) -> &HashMap<String, String> {
&self.global_metadata
}
pub fn len(&self) -> usize {
self.metadata.len()
}
pub fn is_empty(&self) -> bool {
self.metadata.is_empty()
}
pub fn tensor_data(&self, name: &str) -> Option<&[u8]> {
let info = self.metadata.get(name)?;
let data = self.get_data()?;
Some(&data[info.data_offset..info.data_offset + info.data_size])
}
fn get_data(&self) -> Option<&[u8]> {
if let Some(ref mmap) = self.mmap {
Some(mmap.as_ref())
} else if let Some(ref bytes) = self.bytes {
Some(bytes.as_ref())
} else {
None
}
}
pub fn load_f32(&self, name: &str) -> Option<Vec<f32>> {
let info = self.tensor_info(name)?;
if info.dtype != TensorDtype::Float32 {
return None;
}
let data = self.tensor_data(name)?;
let f32_data: Vec<f32> = data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Some(f32_data)
}
pub fn load_f64(&self, name: &str) -> Option<Vec<f64>> {
let info = self.tensor_info(name)?;
if info.dtype != TensorDtype::Float64 {
return None;
}
let data = self.tensor_data(name)?;
let f64_data: Vec<f64> = data
.chunks_exact(8)
.map(|chunk| {
f64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
])
})
.collect();
Some(f64_data)
}
pub fn load_i32(&self, name: &str) -> Option<Vec<i32>> {
let info = self.tensor_info(name)?;
if info.dtype != TensorDtype::Int32 {
return None;
}
let data = self.tensor_data(name)?;
let i32_data: Vec<i32> = data
.chunks_exact(4)
.map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Some(i32_data)
}
pub fn load_i64(&self, name: &str) -> Option<Vec<i64>> {
let info = self.tensor_info(name)?;
if info.dtype != TensorDtype::Int64 {
return None;
}
let data = self.tensor_data(name)?;
let i64_data: Vec<i64> = data
.chunks_exact(8)
.map(|chunk| {
i64::from_le_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
])
})
.collect();
Some(i64_data)
}
pub fn load_as_arrow(&self, name: &str) -> Option<ArrowTensor> {
let info = self.tensor_info(name)?;
match info.dtype {
TensorDtype::Float32 => {
let data = self.load_f32(name)?;
Some(ArrowTensor::from_slice_f32(name, info.shape.clone(), &data))
}
TensorDtype::Float64 => {
let data = self.load_f64(name)?;
Some(ArrowTensor::from_slice_f64(name, info.shape.clone(), &data))
}
TensorDtype::Int32 => {
let data = self.load_i32(name)?;
Some(ArrowTensor::from_slice_i32(name, info.shape.clone(), &data))
}
TensorDtype::Int64 => {
let data = self.load_i64(name)?;
Some(ArrowTensor::from_slice_i64(name, info.shape.clone(), &data))
}
_ => None, }
}
pub fn load_all_as_arrow(&self) -> ArrowTensorStore {
let mut store = ArrowTensorStore::new();
for name in self.tensor_names() {
if let Some(tensor) = self.load_as_arrow(name) {
store.insert(tensor);
}
}
store
}
pub fn total_size_bytes(&self) -> usize {
self.metadata.values().map(|info| info.data_size).sum()
}
pub fn summary(&self) -> ModelSummary {
let mut dtype_counts: HashMap<TensorDtype, usize> = HashMap::new();
let mut total_params = 0usize;
let mut total_bytes = 0usize;
for info in self.metadata.values() {
*dtype_counts.entry(info.dtype).or_insert(0) += 1;
let numel: usize = info.shape.iter().product();
total_params += numel;
total_bytes += info.data_size;
}
ModelSummary {
num_tensors: self.metadata.len(),
total_params,
total_bytes,
dtype_distribution: dtype_counts,
metadata: self.global_metadata.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelSummary {
pub num_tensors: usize,
pub total_params: usize,
pub total_bytes: usize,
pub dtype_distribution: HashMap<TensorDtype, usize>,
pub metadata: HashMap<String, String>,
}
pub struct SafetensorsWriter {
tensors: Vec<(String, TensorData)>,
metadata: HashMap<String, String>,
}
struct TensorData {
dtype: Dtype,
shape: Vec<usize>,
data: Vec<u8>,
}
struct TensorDataRef<'a>(&'a TensorData);
impl View for TensorDataRef<'_> {
fn dtype(&self) -> Dtype {
self.0.dtype
}
fn shape(&self) -> &[usize] {
&self.0.shape
}
fn data(&self) -> std::borrow::Cow<'_, [u8]> {
std::borrow::Cow::Borrowed(&self.0.data)
}
fn data_len(&self) -> usize {
self.0.data.len()
}
}
impl SafetensorsWriter {
pub fn new() -> Self {
Self {
tensors: Vec::new(),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
pub fn add_f32(&mut self, name: &str, shape: Vec<usize>, data: &[f32]) {
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
self.tensors.push((
name.to_string(),
TensorData {
dtype: Dtype::F32,
shape,
data: bytes,
},
));
}
pub fn add_f64(&mut self, name: &str, shape: Vec<usize>, data: &[f64]) {
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
self.tensors.push((
name.to_string(),
TensorData {
dtype: Dtype::F64,
shape,
data: bytes,
},
));
}
pub fn add_i32(&mut self, name: &str, shape: Vec<usize>, data: &[i32]) {
let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
self.tensors.push((
name.to_string(),
TensorData {
dtype: Dtype::I32,
shape,
data: bytes,
},
));
}
pub fn add_i64(&mut self, name: &str, shape: Vec<usize>, data: &[i64]) {
let bytes: Vec<u8> = data.iter().flat_map(|i| i.to_le_bytes()).collect();
self.tensors.push((
name.to_string(),
TensorData {
dtype: Dtype::I64,
shape,
data: bytes,
},
));
}
pub fn add_arrow_tensor(&mut self, tensor: &ArrowTensor) {
match tensor.metadata.dtype {
TensorDtype::Float32 => {
if let Some(data) = tensor.as_slice_f32() {
self.add_f32(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
}
}
TensorDtype::Float64 => {
if let Some(data) = tensor.as_slice_f64() {
self.add_f64(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
}
}
TensorDtype::Int32 => {
if let Some(data) = tensor.as_slice_i32() {
self.add_i32(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
}
}
TensorDtype::Int64 => {
if let Some(data) = tensor.as_slice_i64() {
self.add_i64(&tensor.metadata.name, tensor.metadata.shape.clone(), data);
}
}
_ => {} }
}
pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), SafetensorError> {
let bytes = self.serialize()?;
let mut file = File::create(path).map_err(SafetensorError::Io)?;
file.write_all(&bytes).map_err(SafetensorError::Io)?;
Ok(())
}
pub fn serialize(&self) -> Result<Vec<u8>, SafetensorError> {
let tensors: Vec<(&str, TensorDataRef)> = self
.tensors
.iter()
.map(|(name, data)| (name.as_str(), TensorDataRef(data)))
.collect();
let metadata = if self.metadata.is_empty() {
None
} else {
let meta: HashMap<String, String> = self.metadata.clone();
Some(meta)
};
Ok(safetensors::tensor::serialize(
tensors.into_iter(),
metadata,
)?)
}
}
impl Default for SafetensorsWriter {
fn default() -> Self {
Self::new()
}
}
pub struct ChunkedModelStorage {
base_path: std::path::PathBuf,
chunk_size: usize,
chunks: Vec<ChunkInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkInfo {
pub index: usize,
pub path: String,
pub tensors: Vec<String>,
pub size_bytes: usize,
}
impl ChunkedModelStorage {
pub fn new<P: AsRef<Path>>(base_path: P, chunk_size: usize) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
chunk_size,
chunks: Vec::new(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn write_chunked(&mut self, store: &ArrowTensorStore) -> Result<(), SafetensorError> {
let mut current_chunk = SafetensorsWriter::new();
let mut current_size = 0usize;
let mut current_tensors = Vec::new();
for name in store.names() {
if let Some(tensor) = store.get(name) {
let tensor_size = tensor.metadata.size_bytes();
if current_size + tensor_size > self.chunk_size && !current_tensors.is_empty() {
self.write_chunk(current_chunk, ¤t_tensors, current_size)?;
current_chunk = SafetensorsWriter::new();
current_tensors = Vec::new();
current_size = 0;
}
current_chunk.add_arrow_tensor(tensor);
current_tensors.push(name.to_string());
current_size += tensor_size;
}
}
if !current_tensors.is_empty() {
self.write_chunk(current_chunk, ¤t_tensors, current_size)?;
}
Ok(())
}
fn write_chunk(
&mut self,
writer: SafetensorsWriter,
tensors: &[String],
size: usize,
) -> Result<(), SafetensorError> {
let index = self.chunks.len();
let filename = format!("chunk_{:04}.safetensors", index);
let path = self.base_path.join(&filename);
writer.write_to_file(&path)?;
self.chunks.push(ChunkInfo {
index,
path: filename,
tensors: tensors.to_vec(),
size_bytes: size,
});
Ok(())
}
pub fn write_index(&self) -> Result<(), std::io::Error> {
let index_path = self.base_path.join("model_index.json");
let json = serde_json::to_string_pretty(&self.chunks)?;
std::fs::write(index_path, json)?;
Ok(())
}
pub fn load_index<P: AsRef<Path>>(path: P) -> Result<Vec<ChunkInfo>, std::io::Error> {
let index_path = path.as_ref().join("model_index.json");
let content = std::fs::read_to_string(index_path)?;
let chunks: Vec<ChunkInfo> = serde_json::from_str(&content)?;
Ok(chunks)
}
pub fn find_tensor_chunk(&self, tensor_name: &str) -> Option<&ChunkInfo> {
self.chunks
.iter()
.find(|chunk| chunk.tensors.contains(&tensor_name.to_string()))
}
}
#[derive(Debug)]
pub enum SafetensorError {
Io(std::io::Error),
Parse(String),
Safetensors(SafeTensorError),
}
impl From<SafeTensorError> for SafetensorError {
fn from(err: SafeTensorError) -> Self {
SafetensorError::Safetensors(err)
}
}
impl std::fmt::Display for SafetensorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SafetensorError::Io(e) => write!(f, "IO error: {}", e),
SafetensorError::Parse(s) => write!(f, "Parse error: {}", s),
SafetensorError::Safetensors(e) => write!(f, "Safetensors error: {:?}", e),
}
}
}
impl std::error::Error for SafetensorError {}
fn convert_safetensor_dtype(dtype: Dtype) -> TensorDtype {
match dtype {
Dtype::F32 => TensorDtype::Float32,
Dtype::F64 => TensorDtype::Float64,
Dtype::I8 => TensorDtype::Int8,
Dtype::I16 => TensorDtype::Int16,
Dtype::I32 => TensorDtype::Int32,
Dtype::I64 => TensorDtype::Int64,
Dtype::U8 => TensorDtype::UInt8,
Dtype::U16 => TensorDtype::UInt16,
Dtype::U32 => TensorDtype::UInt32,
Dtype::U64 => TensorDtype::UInt64,
Dtype::BF16 => TensorDtype::BFloat16,
Dtype::F16 => TensorDtype::Float16,
_ => TensorDtype::Float32, }
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_writer_and_reader() {
let mut writer =
SafetensorsWriter::new().with_metadata("format".to_string(), "test".to_string());
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
writer.add_f32("test_tensor", vec![3, 4], &data);
let mut temp_file = NamedTempFile::new().unwrap();
let bytes = writer.serialize().unwrap();
temp_file.write_all(&bytes).unwrap();
temp_file.flush().unwrap();
let reader = SafetensorsReader::open(temp_file.path()).unwrap();
assert_eq!(reader.len(), 1);
assert!(reader.tensor_info("test_tensor").is_some());
let info = reader.tensor_info("test_tensor").unwrap();
assert_eq!(info.shape, vec![3, 4]);
assert_eq!(info.dtype, TensorDtype::Float32);
let loaded = reader.load_f32("test_tensor").unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_model_summary() {
let mut writer = SafetensorsWriter::new();
writer.add_f32("layer1", vec![10, 10], &[0.0; 100]);
writer.add_f32("layer2", vec![10, 5], &[0.0; 50]);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let summary = reader.summary();
assert_eq!(summary.num_tensors, 2);
assert_eq!(summary.total_params, 150);
}
#[test]
fn test_arrow_conversion() {
let mut writer = SafetensorsWriter::new();
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
writer.add_f32("weights", vec![2, 3], &data);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let tensor = reader.load_as_arrow("weights").unwrap();
assert_eq!(tensor.metadata.name, "weights");
assert_eq!(tensor.metadata.shape, vec![2, 3]);
assert_eq!(tensor.as_slice_f32().unwrap(), &data);
}
#[test]
fn test_f64_support() {
let mut writer = SafetensorsWriter::new();
let data: Vec<f64> = vec![1.5, 2.5, 3.5, 4.5];
writer.add_f64("weights_f64", vec![2, 2], &data);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let loaded = reader.load_f64("weights_f64").unwrap();
assert_eq!(loaded, data);
let tensor = reader.load_as_arrow("weights_f64").unwrap();
assert_eq!(tensor.metadata.name, "weights_f64");
assert_eq!(tensor.metadata.dtype, TensorDtype::Float64);
assert_eq!(tensor.as_slice_f64().unwrap(), &data);
}
#[test]
fn test_i32_support() {
let mut writer = SafetensorsWriter::new();
let data: Vec<i32> = vec![-10, 20, -30, 40, 50, -60];
writer.add_i32("indices", vec![2, 3], &data);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let loaded = reader.load_i32("indices").unwrap();
assert_eq!(loaded, data);
let tensor = reader.load_as_arrow("indices").unwrap();
assert_eq!(tensor.metadata.name, "indices");
assert_eq!(tensor.metadata.dtype, TensorDtype::Int32);
assert_eq!(tensor.as_slice_i32().unwrap(), &data);
}
#[test]
fn test_i64_support() {
let mut writer = SafetensorsWriter::new();
let data: Vec<i64> = vec![-1000000000, 2000000000, -3000000000, 4000000000];
writer.add_i64("large_indices", vec![2, 2], &data);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let loaded = reader.load_i64("large_indices").unwrap();
assert_eq!(loaded, data);
let tensor = reader.load_as_arrow("large_indices").unwrap();
assert_eq!(tensor.metadata.name, "large_indices");
assert_eq!(tensor.metadata.dtype, TensorDtype::Int64);
assert_eq!(tensor.as_slice_i64().unwrap(), &data);
}
#[test]
fn test_mixed_dtypes() {
let mut writer = SafetensorsWriter::new();
let f32_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let f64_data: Vec<f64> = vec![5.5, 6.5];
let i32_data: Vec<i32> = vec![10, 20, 30];
let i64_data: Vec<i64> = vec![100, 200];
writer.add_f32("layer1", vec![4], &f32_data);
writer.add_f64("layer2", vec![2], &f64_data);
writer.add_i32("layer3", vec![3], &i32_data);
writer.add_i64("layer4", vec![2], &i64_data);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
assert_eq!(reader.len(), 4);
assert_eq!(reader.load_f32("layer1").unwrap(), f32_data);
assert_eq!(reader.load_f64("layer2").unwrap(), f64_data);
assert_eq!(reader.load_i32("layer3").unwrap(), i32_data);
assert_eq!(reader.load_i64("layer4").unwrap(), i64_data);
assert!(reader.load_as_arrow("layer1").is_some());
assert!(reader.load_as_arrow("layer2").is_some());
assert!(reader.load_as_arrow("layer3").is_some());
assert!(reader.load_as_arrow("layer4").is_some());
}
#[test]
fn test_arrow_tensor_roundtrip() {
use crate::arrow::ArrowTensor;
let f64_tensor = ArrowTensor::from_slice_f64("test_f64", vec![2, 2], &[1.0, 2.0, 3.0, 4.0]);
let mut writer = SafetensorsWriter::new();
writer.add_arrow_tensor(&f64_tensor);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let loaded = reader.load_as_arrow("test_f64").unwrap();
assert_eq!(
loaded.as_slice_f64().unwrap(),
f64_tensor.as_slice_f64().unwrap()
);
let i32_tensor = ArrowTensor::from_slice_i32("test_i32", vec![3], &[10, 20, 30]);
let mut writer = SafetensorsWriter::new();
writer.add_arrow_tensor(&i32_tensor);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let loaded = reader.load_as_arrow("test_i32").unwrap();
assert_eq!(
loaded.as_slice_i32().unwrap(),
i32_tensor.as_slice_i32().unwrap()
);
let i64_tensor = ArrowTensor::from_slice_i64("test_i64", vec![2], &[100, 200]);
let mut writer = SafetensorsWriter::new();
writer.add_arrow_tensor(&i64_tensor);
let bytes = writer.serialize().unwrap();
let reader = SafetensorsReader::from_bytes(Bytes::from(bytes)).unwrap();
let loaded = reader.load_as_arrow("test_i64").unwrap();
assert_eq!(
loaded.as_slice_i64().unwrap(),
i64_tensor.as_slice_i64().unwrap()
);
}
}