use crate::bundle::MappedFile;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SafeTensorsDType {
F32,
F16,
BF16,
}
impl SafeTensorsDType {
#[must_use]
pub fn bytes_per_element(self) -> usize {
match self {
Self::F32 => 4,
Self::F16 | Self::BF16 => 2,
}
}
}
#[derive(Debug, Clone)]
pub struct RawTensorData {
pub dtype: SafeTensorsDType,
pub shape: Vec<usize>,
pub bytes: Vec<u8>,
}
impl RawTensorData {
pub fn to_f32(&self) -> Result<Vec<f32>, String> {
match self.dtype {
SafeTensorsDType::F32 => extract_f32(&self.bytes),
SafeTensorsDType::F16 => extract_f16_to_f32(&self.bytes),
SafeTensorsDType::BF16 => extract_bf16_to_f32(&self.bytes),
}
}
#[must_use]
pub fn is_f16(&self) -> bool {
self.dtype == SafeTensorsDType::F16
}
#[must_use]
pub fn is_bf16(&self) -> bool {
self.dtype == SafeTensorsDType::BF16
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorMetadata {
pub dtype: String,
pub shape: Vec<usize>,
pub data_offsets: [usize; 2],
}
pub type SafeTensorsMetadata = BTreeMap<String, TensorMetadata>;
pub type UserMetadata = BTreeMap<String, String>;
pub fn save_safetensors<P: AsRef<Path>>(
path: P,
tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
) -> Result<(), String> {
let mut metadata = SafeTensorsMetadata::new();
let mut raw_data = Vec::new();
let mut current_offset = 0;
for (name, (data, shape)) in tensors {
let start_offset = current_offset;
let data_size = data.len() * 4; let end_offset = current_offset + data_size;
metadata.insert(
name.clone(),
TensorMetadata {
dtype: "F32".to_string(),
shape: shape.clone(),
data_offsets: [start_offset, end_offset],
},
);
for &value in data {
raw_data.extend_from_slice(&value.to_le_bytes());
}
current_offset = end_offset;
}
let metadata_json =
serde_json::to_string(&metadata).map_err(|e| format!("JSON serialization failed: {e}"))?;
let metadata_bytes = metadata_json.as_bytes();
let metadata_len = metadata_bytes.len() as u64;
let mut output = Vec::new();
output.extend_from_slice(&metadata_len.to_le_bytes());
output.extend_from_slice(metadata_bytes);
output.extend_from_slice(&raw_data);
fs::write(path, output).map_err(|e| format!("File write failed: {e}"))?;
Ok(())
}
pub fn save_safetensors_with_metadata<P: AsRef<Path>>(
path: P,
tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
user_metadata: &UserMetadata,
) -> Result<(), String> {
let mut header = serde_json::Map::new();
if !user_metadata.is_empty() {
let meta_obj: serde_json::Map<String, serde_json::Value> = user_metadata
.iter()
.map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
.collect();
header.insert(
"__metadata__".to_string(),
serde_json::Value::Object(meta_obj),
);
}
let mut raw_data = Vec::new();
let mut current_offset = 0;
for (name, (data, shape)) in tensors {
let start_offset = current_offset;
let data_size = data.len() * 4;
let end_offset = current_offset + data_size;
#[allow(clippy::disallowed_methods)] let tensor_meta = serde_json::json!({
"dtype": "F32",
"shape": shape,
"data_offsets": [start_offset, end_offset]
});
header.insert(name.clone(), tensor_meta);
for &value in data {
raw_data.extend_from_slice(&value.to_le_bytes());
}
current_offset = end_offset;
}
let metadata_json =
serde_json::to_string(&header).map_err(|e| format!("JSON serialization failed: {e}"))?;
let metadata_bytes = metadata_json.as_bytes();
let metadata_len = metadata_bytes.len() as u64;
let mut output = Vec::new();
output.extend_from_slice(&metadata_len.to_le_bytes());
output.extend_from_slice(metadata_bytes);
output.extend_from_slice(&raw_data);
fs::write(path, output).map_err(|e| format!("File write failed: {e}"))?;
Ok(())
}
fn f32_slice_to_bf16_bytes(data: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(data.len() * 2);
for &value in data {
let bits = value.to_bits();
let bf16 = (bits >> 16) as u16;
bytes.extend_from_slice(&bf16.to_le_bytes());
}
bytes
}
fn f32_slice_to_f16_bytes(data: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(data.len() * 2);
for &value in data {
let bits = value.to_bits();
let sign = (bits >> 16) & 0x8000;
let exponent = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x007F_FFFF;
let f16_bits = if exponent == 0xFF {
sign | 0x7C00 | if mantissa != 0 { 0x0200 } else { 0 }
} else if exponent > 142 {
sign | 0x7C00
} else if exponent < 113 {
sign
} else {
let e = (exponent - 112) as u32;
let m = mantissa >> 13;
sign | (e << 10) | (m & 0x3FF)
};
bytes.extend_from_slice(&(f16_bits as u16).to_le_bytes());
}
bytes
}
fn encode_tensor_for_dtype(data: &[f32], original_dtype: Option<&str>) -> (&'static str, Vec<u8>) {
match original_dtype {
Some("BF16") => ("BF16", f32_slice_to_bf16_bytes(data)),
Some("F16") => ("F16", f32_slice_to_f16_bytes(data)),
Some("F32") | None => ("F32", data.iter().flat_map(|f| f.to_le_bytes()).collect()),
Some(unknown) => {
eprintln!(
"[GH-439] encode_tensor_for_dtype: unknown dtype '{}' — \
falling back to F32. This may produce incorrect output.",
unknown
);
("F32", data.iter().flat_map(|f| f.to_le_bytes()).collect())
}
}
}
pub fn save_safetensors_typed<P: AsRef<Path>>(
path: P,
tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
original_dtypes: &BTreeMap<String, String>,
) -> Result<(), String> {
let mut metadata = SafeTensorsMetadata::new();
let mut raw_data = Vec::new();
let mut current_offset = 0;
for (name, (data, shape)) in tensors {
let orig = original_dtypes.get(name).map(String::as_str);
let (dtype_str, tensor_bytes) = encode_tensor_for_dtype(data, orig);
let start_offset = current_offset;
let end_offset = current_offset + tensor_bytes.len();
metadata.insert(
name.clone(),
TensorMetadata {
dtype: dtype_str.to_string(),
shape: shape.clone(),
data_offsets: [start_offset, end_offset],
},
);
raw_data.extend_from_slice(&tensor_bytes);
current_offset = end_offset;
}
let metadata_json =
serde_json::to_string(&metadata).map_err(|e| format!("JSON serialization failed: {e}"))?;
let metadata_bytes = metadata_json.as_bytes();
let metadata_len = metadata_bytes.len() as u64;
let mut output = Vec::new();
output.extend_from_slice(&metadata_len.to_le_bytes());
output.extend_from_slice(metadata_bytes);
output.extend_from_slice(&raw_data);
fs::write(path, output).map_err(|e| format!("File write failed: {e}"))?;
Ok(())
}
pub fn save_safetensors_with_metadata_typed<P: AsRef<Path>>(
path: P,
tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
user_metadata: &UserMetadata,
original_dtypes: &BTreeMap<String, String>,
) -> Result<(), String> {
let mut header = serde_json::Map::new();
if !user_metadata.is_empty() {
let meta_obj: serde_json::Map<String, serde_json::Value> = user_metadata
.iter()
.map(|(k, v)| (k.clone(), serde_json::Value::String(v.clone())))
.collect();
header.insert(
"__metadata__".to_string(),
serde_json::Value::Object(meta_obj),
);
}
let mut raw_data = Vec::new();
let mut current_offset = 0;
for (name, (data, shape)) in tensors {
let orig = original_dtypes.get(name).map(String::as_str);
let (dtype_str, tensor_bytes) = encode_tensor_for_dtype(data, orig);
let start_offset = current_offset;
let end_offset = current_offset + tensor_bytes.len();
#[allow(clippy::disallowed_methods)]
let tensor_meta = serde_json::json!({
"dtype": dtype_str,
"shape": shape,
"data_offsets": [start_offset, end_offset]
});
header.insert(name.clone(), tensor_meta);
raw_data.extend_from_slice(&tensor_bytes);
current_offset = end_offset;
}
let metadata_json =
serde_json::to_string(&header).map_err(|e| format!("JSON serialization failed: {e}"))?;
let metadata_bytes = metadata_json.as_bytes();
let metadata_len = metadata_bytes.len() as u64;
let mut output = Vec::new();
output.extend_from_slice(&metadata_len.to_le_bytes());
output.extend_from_slice(metadata_bytes);
output.extend_from_slice(&raw_data);
fs::write(path, output).map_err(|e| format!("File write failed: {e}"))?;
Ok(())
}
pub fn load_safetensors<P: AsRef<Path>>(path: P) -> Result<(SafeTensorsMetadata, Vec<u8>), String> {
let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
let metadata_len = validate_and_read_header(&bytes)?;
let (metadata, _user_metadata) = parse_metadata(&bytes, metadata_len)?;
let raw_data = bytes[8 + metadata_len..].to_vec();
Ok((metadata, raw_data))
}
#[derive(Debug)]
pub struct MappedSafeTensors {
mmap: MappedFile,
metadata: SafeTensorsMetadata,
user_metadata: UserMetadata,
data_offset: usize,
}
include!("safetensors_include_01.rs");