use crate::cryptotensors::{CryptoTensors, DeserializeCryptoConfig, SerializeCryptoConfig};
use crate::lib::{Cow, HashMap, String, ToString, Vec};
use crate::slice::{InvalidSlice, SliceIterator, TensorIndexer};
use core::fmt::Display;
use core::str::Utf8Error;
use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer};
use std::collections::BTreeMap;
#[cfg(feature = "std")]
use std::io::Write;
const MAX_HEADER_SIZE: usize = 100_000_000;
const N_LEN: usize = size_of::<u64>();
#[derive(Debug)]
pub enum SafeTensorError {
InvalidHeader(Utf8Error),
InvalidHeaderStart,
InvalidHeaderDeserialization(serde_json::Error),
HeaderTooLarge,
HeaderTooSmall,
InvalidHeaderLength,
TensorNotFound(String),
TensorInvalidInfo,
InvalidOffset(String),
#[cfg(feature = "std")]
IoError(std::io::Error),
JsonError(serde_json::Error),
InvalidTensorView(Dtype, Vec<usize>, usize),
MetadataIncompleteBuffer,
ValidationOverflow,
MisalignedSlice,
CryptoTensorsError(String),
}
#[cfg(feature = "std")]
impl From<std::io::Error> for SafeTensorError {
fn from(error: std::io::Error) -> SafeTensorError {
SafeTensorError::IoError(error)
}
}
impl From<serde_json::Error> for SafeTensorError {
fn from(error: serde_json::Error) -> SafeTensorError {
SafeTensorError::JsonError(error)
}
}
impl From<crate::cryptotensors::CryptoTensorsError> for SafeTensorError {
fn from(error: crate::cryptotensors::CryptoTensorsError) -> Self {
SafeTensorError::CryptoTensorsError(error.to_string())
}
}
impl Display for SafeTensorError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use SafeTensorError::*;
match self {
InvalidHeader(error) => write!(f, "invalid UTF-8 in header: {error}"),
InvalidHeaderStart => write!(f, "invalid start character in header, must be `{{`"),
InvalidHeaderDeserialization(error) => write!(f, "invalid JSON in header: {error}"),
JsonError(error) => write!(f, "JSON error: {error}"),
HeaderTooLarge => write!(f, "header too large"),
HeaderTooSmall => write!(f, "header too small"),
InvalidHeaderLength => write!(f, "invalid header length"),
TensorNotFound(name) => write!(f, "tensor `{name}` not found"),
TensorInvalidInfo => write!(f, "invalid shape, data type, or offset for tensor"),
InvalidOffset(name) => write!(f, "invalid offset for tensor `{name}`"),
#[cfg(feature = "std")]
IoError(error) => write!(f, "I/O error: {error}"),
InvalidTensorView(dtype, shape, n_bytes) => {
write!(f, "tensor of type {dtype} and shape (")?;
for (i, &dim) in shape.iter().enumerate() {
write!(f, "{sep}{dim}", sep = if i == 0 { "" } else { ", " })?;
}
write!(f, ") can't be created from {n_bytes} bytes")
}
MetadataIncompleteBuffer => write!(f, "incomplete metadata, file not fully covered"),
ValidationOverflow => write!(f, "overflow computing buffer size from shape and/or element type"),
MisalignedSlice => write!(f, "The slice is slicing for subbytes dtypes, and the slice does not end up at a byte boundary, this is invalid."),
CryptoTensorsError(error) => write!(f, "CryptoTensors error: {error}"),
}
}
}
#[cfg(not(feature = "std"))]
impl core::error::Error for SafeTensorError {
fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
match self {
SafeTensorError::InvalidHeader(source) => Some(source),
SafeTensorError::JsonError(source) => Some(source),
SafeTensorError::InvalidHeaderDeserialization(source) => Some(source),
_ => None,
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for SafeTensorError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
SafeTensorError::InvalidHeader(source) => Some(source),
SafeTensorError::JsonError(source) => Some(source),
SafeTensorError::InvalidHeaderDeserialization(source) => Some(source),
SafeTensorError::IoError(source) => Some(source),
_ => None,
}
}
}
struct PreparedData {
n: u64,
header_bytes: Vec<u8>,
offset: usize,
}
pub trait View {
fn dtype(&self) -> Dtype;
fn shape(&self) -> &[usize];
fn data(&self) -> Cow<'_, [u8]>;
fn data_len(&self) -> usize;
}
type PrepareResult<'data, V> = (PreparedData, Vec<V>, Option<CryptoTensors>, Vec<String>);
fn prepare<'data, S, V, I>(
data: I,
data_info: Option<HashMap<String, String>>,
crypto_config: Option<&SerializeCryptoConfig>,
) -> Result<PrepareResult<'data, V>, SafeTensorError>
where
S: AsRef<str> + Ord + Display,
V: View,
I: IntoIterator<Item = (S, V)>,
{
let mut data: Vec<_> = data.into_iter().collect();
data.sort_by(|(lname, left), (rname, right)| {
right.dtype().cmp(&left.dtype()).then(lname.cmp(rname))
});
let tensor_names: Vec<String> = data
.iter()
.map(|(name, _)| name.as_ref().to_string())
.collect();
let crypto = if let Some(config) = crypto_config {
CryptoTensors::from_serialize_config(tensor_names.clone(), config)?
} else {
None
};
if let Some(ref c) = crypto {
let data_by_name: HashMap<&str, &V> = data
.iter()
.map(|(name, tensor)| (name.as_ref(), tensor))
.collect();
c.encrypt_tensors(|name| data_by_name.get(name).map(|t| t.data()))?;
}
let mut tensors: Vec<V> = Vec::with_capacity(data.len());
let mut hmetadata = Vec::with_capacity(data.len());
let mut offset = 0;
for (name, tensor) in data {
let tensor_name = name.as_ref().to_string();
let n = if let Some(ref c) = crypto {
c.get_buffer(&tensor_name)
.map(|b| b.len())
.unwrap_or(tensor.data_len())
} else {
tensor.data_len()
};
let tensor_info = TensorInfo {
dtype: tensor.dtype(),
shape: tensor.shape().to_vec(),
data_offsets: (offset, offset + n),
};
offset += n;
hmetadata.push((tensor_name, tensor_info.clone()));
tensors.push(tensor);
}
let final_metadata = if let Some(ref c) = crypto {
c.generate_metadata(hmetadata.clone(), data_info)?
} else {
data_info
};
let metadata: Metadata = Metadata::new(final_metadata, hmetadata)?;
let mut metadata_buf = serde_json::to_string(&metadata)?.into_bytes();
let aligned_metadata_len = metadata_buf.len().next_multiple_of(N_LEN);
metadata_buf.resize(aligned_metadata_len, b' ');
Ok((
PreparedData {
n: aligned_metadata_len as u64,
header_bytes: metadata_buf,
offset,
},
tensors,
crypto,
tensor_names,
))
}
pub fn serialize<
S: AsRef<str> + Ord + core::fmt::Display,
V: View,
I: IntoIterator<Item = (S, V)>,
>(
data: I,
data_info: Option<HashMap<String, String>>,
crypto_config: Option<&SerializeCryptoConfig>,
) -> Result<Vec<u8>, SafeTensorError> {
let (
PreparedData {
n,
header_bytes,
offset,
},
tensors,
crypto,
tensor_names,
) = prepare(data, data_info, crypto_config)?;
if n > MAX_HEADER_SIZE as u64 {
return Err(SafeTensorError::HeaderTooLarge);
}
let expected_size = N_LEN + header_bytes.len() + offset;
let mut buffer: Vec<u8> = Vec::with_capacity(expected_size);
buffer.extend(n.to_le_bytes());
buffer.extend(header_bytes);
for (tensor, name) in tensors.iter().zip(tensor_names.iter()) {
if let Some(ref c) = crypto {
if let Some(encrypted_data) = c.get_buffer(name) {
buffer.extend(encrypted_data.as_slice());
continue;
}
}
buffer.extend(tensor.data().as_ref());
}
Ok(buffer)
}
#[cfg(feature = "std")]
pub fn serialize_to_file<S, V, I>(
data: I,
data_info: Option<HashMap<String, String>>,
filename: &std::path::Path,
crypto_config: Option<&SerializeCryptoConfig>,
) -> Result<(), SafeTensorError>
where
S: AsRef<str> + Ord + Display,
V: View,
I: IntoIterator<Item = (S, V)>,
{
let (
PreparedData {
n, header_bytes, ..
},
tensors,
crypto,
tensor_names,
) = prepare(data, data_info, crypto_config)?;
if n > MAX_HEADER_SIZE as u64 {
return Err(SafeTensorError::HeaderTooLarge);
}
let mut f = std::io::BufWriter::new(std::fs::File::create(filename)?);
f.write_all(n.to_le_bytes().as_ref())?;
f.write_all(&header_bytes)?;
for (tensor, name) in tensors.iter().zip(tensor_names.iter()) {
if let Some(ref c) = crypto {
if let Some(encrypted_data) = c.get_buffer(name) {
f.write_all(encrypted_data.as_slice())?;
continue;
}
}
f.write_all(tensor.data().as_ref())?;
}
f.flush()?;
Ok(())
}
#[cfg(feature = "std")]
pub fn rewrap_header(
buffer: &[u8],
new_config: &SerializeCryptoConfig,
old_config: Option<&DeserializeCryptoConfig>,
) -> Result<Vec<u8>, SafeTensorError> {
let (_header_size, metadata) = SafeTensors::read_metadata_header_only(buffer)?;
let mut crypto =
CryptoTensors::from_header_with_config(&metadata, old_config)?.ok_or_else(|| {
SafeTensorError::CryptoTensorsError(
"Buffer is not encrypted, cannot rewrap".to_string(),
)
})?;
crypto.rewrap(new_config)?;
let tensor_infos: Vec<_> = metadata
.offset_keys()
.into_iter()
.filter_map(|k| metadata.info(&k).map(|i| (k.clone(), i.clone())))
.collect();
let metadata_dict = metadata.metadata().as_ref().cloned().unwrap_or_default();
let new_metadata_map = crypto.generate_metadata(tensor_infos.clone(), Some(metadata_dict))?;
let new_metadata = Metadata::new(new_metadata_map, tensor_infos)?;
let header_json = serde_json::to_string(&new_metadata)?;
let header_bytes = header_json.as_bytes();
let padding = (8 - (header_bytes.len() % 8)) % 8;
let n = (header_bytes.len() + padding) as u64;
if n > MAX_HEADER_SIZE as u64 {
return Err(SafeTensorError::HeaderTooLarge);
}
let mut result = Vec::with_capacity(N_LEN + header_bytes.len() + padding);
result.extend(n.to_le_bytes());
result.extend(header_bytes);
result.extend(vec![b' '; padding]);
Ok(result)
}
#[cfg(feature = "std")]
pub fn rewrap(
buffer: &[u8],
new_config: &SerializeCryptoConfig,
old_config: Option<&DeserializeCryptoConfig>,
) -> Result<Vec<u8>, SafeTensorError> {
let (header_size, _) = SafeTensors::read_metadata(buffer)?;
let header_end = N_LEN + header_size;
let new_header = rewrap_header(&buffer[..header_end], new_config, old_config)?;
let mut result = new_header;
result.extend_from_slice(&buffer[header_end..]);
Ok(result)
}
#[cfg(feature = "std")]
pub fn rewrap_file(
filename: &std::path::Path,
new_config: &SerializeCryptoConfig,
old_config: Option<&DeserializeCryptoConfig>,
) -> Result<(), SafeTensorError> {
let buffer = std::fs::read(filename)?;
let new_buffer = rewrap(&buffer, new_config, old_config)?;
std::fs::write(filename, new_buffer)?;
Ok(())
}
#[derive(Debug)]
pub struct SafeTensors<'data> {
metadata: Metadata,
data: &'data [u8],
crypto: Option<CryptoTensors>,
}
impl<'data> SafeTensors<'data> {
pub fn read_metadata(buffer: &'data [u8]) -> Result<(usize, Metadata), SafeTensorError> {
Self::read_metadata_with_validation(buffer, true)
}
pub fn read_metadata_header_only(
buffer: &'data [u8],
) -> Result<(usize, Metadata), SafeTensorError> {
Self::read_metadata_with_validation(buffer, false)
}
fn read_metadata_with_validation(
buffer: &'data [u8],
validate_buffer: bool,
) -> Result<(usize, Metadata), SafeTensorError> {
let buffer_len = buffer.len();
let Some(header_size_bytes) = buffer.get(..N_LEN) else {
return Err(SafeTensorError::HeaderTooSmall);
};
let arr: [u8; N_LEN] = header_size_bytes
.try_into()
.expect("this can't fail due to how `header_size_bytes` is defined above");
let n: usize = u64::from_le_bytes(arr)
.try_into()
.map_err(|_| SafeTensorError::HeaderTooLarge)?;
if n > MAX_HEADER_SIZE {
return Err(SafeTensorError::HeaderTooLarge);
}
let stop = n
.checked_add(N_LEN)
.ok_or(SafeTensorError::InvalidHeaderLength)?;
let Some(header_bytes) = buffer.get(N_LEN..stop) else {
return Err(SafeTensorError::InvalidHeaderLength);
};
let string = core::str::from_utf8(header_bytes).map_err(SafeTensorError::InvalidHeader)?;
let metadata: HashMetadata =
serde_json::from_str(string).map_err(SafeTensorError::InvalidHeaderDeserialization)?;
let metadata: Metadata = metadata.try_into()?;
let buffer_end = metadata.validate()?;
if validate_buffer && buffer_end + N_LEN + n != buffer_len {
return Err(SafeTensorError::MetadataIncompleteBuffer);
}
Ok((n, metadata))
}
pub fn deserialize(buffer: &'data [u8]) -> Result<Self, SafeTensorError> {
Self::deserialize_with_config(buffer, None)
}
pub fn deserialize_with_config(
buffer: &'data [u8],
config: Option<&DeserializeCryptoConfig>,
) -> Result<Self, SafeTensorError> {
let (n, metadata) = SafeTensors::read_metadata(buffer)?;
let data = &buffer[N_LEN + n..];
let crypto = CryptoTensors::from_header_with_config(&metadata, config)?;
Ok(Self {
metadata,
data,
crypto,
})
}
pub fn tensors(&'data self) -> Vec<(String, TensorView<'data>)> {
let mut tensors = Vec::with_capacity(self.metadata.index_map.len());
for (name, &index) in &self.metadata.index_map {
let info = &self.metadata.tensors[index];
let raw_data = &self.data[info.data_offsets.0..info.data_offsets.1];
let data = if let Some(ref crypto) = self.crypto {
crypto.silent_decrypt(name, raw_data).unwrap_or(raw_data)
} else {
raw_data
};
let tensorview = TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data,
};
tensors.push((name.to_string(), tensorview));
}
tensors
}
pub fn iter(&'data self) -> impl Iterator<Item = (&'data str, TensorView<'data>)> {
self.metadata.index_map.iter().map(move |(name, &idx)| {
let info = &self.metadata.tensors[idx];
let raw_data = &self.data[info.data_offsets.0..info.data_offsets.1];
let data = if let Some(ref crypto) = self.crypto {
crypto.silent_decrypt(name, raw_data).unwrap_or(raw_data)
} else {
raw_data
};
(
name.as_str(),
TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data,
},
)
})
}
pub fn tensor(&'data self, tensor_name: &str) -> Result<TensorView<'data>, SafeTensorError> {
let &index = self
.metadata
.index_map
.get(tensor_name)
.ok_or_else(|| SafeTensorError::TensorNotFound(tensor_name.to_string()))?;
let info = self
.metadata
.tensors
.get(index)
.ok_or_else(|| SafeTensorError::TensorNotFound(tensor_name.to_string()))?;
let raw_data = &self.data[info.data_offsets.0..info.data_offsets.1];
let data = if let Some(ref crypto) = self.crypto {
crypto.silent_decrypt(tensor_name, raw_data)?
} else {
raw_data
};
Ok(TensorView {
dtype: info.dtype,
shape: info.shape.clone(),
data,
})
}
pub fn names(&self) -> Vec<&'_ str> {
self.metadata.index_map.keys().map(String::as_str).collect()
}
#[inline]
pub fn len(&self) -> usize {
self.metadata.tensors.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.metadata.tensors.is_empty()
}
pub fn metadata(&self) -> Option<HashMap<String, String>> {
self.metadata
.metadata()
.as_ref()
.map(|meta| {
meta.iter()
.filter(|(k, _)| !k.starts_with("__"))
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<HashMap<_, _>>()
})
.and_then(|meta| if meta.is_empty() { None } else { Some(meta) })
}
pub fn reserved_metadata(&self) -> Option<HashMap<String, String>> {
self.metadata
.metadata()
.as_ref()
.map(|meta| {
meta.iter()
.filter(|(k, _)| k.starts_with("__"))
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<HashMap<_, _>>()
})
.and_then(|meta| if meta.is_empty() { None } else { Some(meta) })
}
}
#[derive(Debug, Clone)]
pub struct Metadata {
metadata: Option<HashMap<String, String>>,
tensors: Vec<TensorInfo>,
index_map: HashMap<String, usize>,
}
#[derive(Serialize, Deserialize)]
struct HashMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "__metadata__")]
metadata: Option<HashMap<String, String>>,
#[serde(flatten)]
tensors: HashMap<String, TensorInfo>,
}
impl TryFrom<HashMetadata> for Metadata {
type Error = SafeTensorError;
fn try_from(hashdata: HashMetadata) -> Result<Self, Self::Error> {
let (metadata, tensors) = (hashdata.metadata, hashdata.tensors);
let mut tensors: Vec<_> = tensors.into_iter().collect();
tensors.sort_by(|(_, left), (_, right)| left.data_offsets.cmp(&right.data_offsets));
Metadata::new(metadata, tensors)
}
}
impl<'de> Deserialize<'de> for Metadata {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let hashdata: HashMetadata = HashMetadata::deserialize(deserializer)?;
let metadata: Metadata = hashdata.try_into().map_err(serde::de::Error::custom)?;
Ok(metadata)
}
}
impl Serialize for Metadata {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut names = vec![""; self.index_map.len()];
for (name, &index) in &self.index_map {
names[index] = name;
}
let length = self.metadata.as_ref().map_or(0, HashMap::len);
let mut map = serializer.serialize_map(Some(self.tensors.len() + length))?;
if let Some(metadata) = &self.metadata {
let sorted_metadata: BTreeMap<_, _> = metadata.iter().collect();
map.serialize_entry("__metadata__", &sorted_metadata)?;
}
for (name, info) in names.iter().zip(&self.tensors) {
map.serialize_entry(name, info)?;
}
map.end()
}
}
impl Metadata {
pub fn new(
metadata: Option<HashMap<String, String>>,
tensors: Vec<(String, TensorInfo)>,
) -> Result<Self, SafeTensorError> {
let mut index_map = HashMap::with_capacity(tensors.len());
let tensors: Vec<_> = tensors
.into_iter()
.enumerate()
.map(|(index, (k, tensor))| {
index_map.insert(k, index);
tensor
})
.collect();
let metadata = Self {
metadata,
tensors,
index_map,
};
metadata.validate()?;
Ok(metadata)
}
fn validate(&self) -> Result<usize, SafeTensorError> {
let mut start = 0;
for (i, info) in self.tensors.iter().enumerate() {
let (s, e) = info.data_offsets;
if s != start || e < s {
let tensor_name = self
.index_map
.iter()
.find_map(|(name, &index)| if index == i { Some(&name[..]) } else { None })
.unwrap_or("no_tensor");
return Err(SafeTensorError::InvalidOffset(tensor_name.to_string()));
}
start = e;
let nelements: usize = info
.shape
.iter()
.copied()
.try_fold(1usize, usize::checked_mul)
.ok_or(SafeTensorError::ValidationOverflow)?;
let nbits = nelements
.checked_mul(info.dtype.bitsize())
.ok_or(SafeTensorError::ValidationOverflow)?;
if nbits % 8 != 0 {
return Err(SafeTensorError::MisalignedSlice);
}
let size = nbits
.checked_div(8)
.ok_or(SafeTensorError::ValidationOverflow)?;
if e - s != size {
return Err(SafeTensorError::TensorInvalidInfo);
}
}
Ok(start)
}
pub fn info(&self, name: &str) -> Option<&TensorInfo> {
let &index = self.index_map.get(name)?;
self.tensors.get(index)
}
pub fn tensors(&self) -> HashMap<String, &TensorInfo> {
self.index_map
.iter()
.map(|(tensor_name, &index)| (tensor_name.clone(), &self.tensors[index]))
.collect()
}
pub fn offset_keys(&self) -> Vec<String> {
let mut index_vec: Vec<_> = self.index_map.iter().collect();
index_vec.sort_by_key(|a| a.1);
index_vec.into_iter().map(|a| a.0.clone()).collect()
}
pub fn data_len(&self) -> usize {
if let Some(tensor) = self.tensors.last() {
tensor.data_offsets.1
} else {
0
}
}
pub fn metadata(&self) -> &Option<HashMap<String, String>> {
&self.metadata
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct TensorView<'data> {
dtype: Dtype,
shape: Vec<usize>,
data: &'data [u8],
}
impl View for &TensorView<'_> {
fn dtype(&self) -> Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self) -> Cow<'_, [u8]> {
self.data.into()
}
fn data_len(&self) -> usize {
self.data.len()
}
}
impl View for TensorView<'_> {
fn dtype(&self) -> Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self) -> Cow<'_, [u8]> {
self.data.into()
}
fn data_len(&self) -> usize {
self.data.len()
}
}
impl<'data> TensorView<'data> {
pub fn new(
dtype: Dtype,
shape: Vec<usize>,
data: &'data [u8],
) -> Result<Self, SafeTensorError> {
let n_elements: usize = shape.iter().product();
let nbits = n_elements * dtype.bitsize();
if nbits % 8 != 0 {
return Err(SafeTensorError::MisalignedSlice);
}
let size = nbits
.checked_div(8)
.ok_or(SafeTensorError::ValidationOverflow)?;
if data.len() != size {
Err(SafeTensorError::InvalidTensorView(dtype, shape, data.len()))
} else {
Ok(Self { dtype, shape, data })
}
}
pub fn dtype(&self) -> Dtype {
self.dtype
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn data(&self) -> &'data [u8] {
self.data
}
pub fn sliced_data(
&'data self,
slices: &[TensorIndexer],
) -> Result<SliceIterator<'data>, InvalidSlice> {
SliceIterator::new(self, slices)
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct TensorInfo {
pub dtype: Dtype,
pub shape: Vec<usize>,
pub data_offsets: (usize, usize),
}
#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)]
#[non_exhaustive]
pub enum Dtype {
BOOL,
F4,
#[allow(non_camel_case_types)]
F6_E2M3,
#[allow(non_camel_case_types)]
F6_E3M2,
U8,
I8,
#[allow(non_camel_case_types)]
F8_E5M2,
#[allow(non_camel_case_types)]
F8_E4M3,
#[allow(non_camel_case_types)]
F8_E8M0,
I16,
U16,
F16,
BF16,
I32,
U32,
F32,
C64,
F64,
I64,
U64,
}
impl Dtype {
pub fn bitsize(&self) -> usize {
match self {
Dtype::F4 => 4,
Dtype::F6_E3M2 => 6,
Dtype::F6_E2M3 => 6,
Dtype::BOOL => 8,
Dtype::U8 => 8,
Dtype::I8 => 8,
Dtype::F8_E5M2 => 8,
Dtype::F8_E4M3 => 8,
Dtype::F8_E8M0 => 8,
Dtype::I16 => 16,
Dtype::U16 => 16,
Dtype::I32 => 32,
Dtype::U32 => 32,
Dtype::I64 => 64,
Dtype::U64 => 64,
Dtype::F16 => 16,
Dtype::BF16 => 16,
Dtype::F32 => 32,
Dtype::F64 => 64,
Dtype::C64 => 64,
}
}
#[deprecated(
since = "0.6.0",
note = "Use `bitsize` instead as some elements have smaller than a full byte of width"
)]
pub fn size(&self) -> usize {
self.bitsize() / 8
}
}
impl Display for Dtype {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(match *self {
Dtype::F4 => "F4",
Dtype::F6_E2M3 => "F6_E2M3",
Dtype::F6_E3M2 => "F6_E3M2",
Dtype::BOOL => "BOOL",
Dtype::I8 => "I8",
Dtype::U8 => "U8",
Dtype::F8_E5M2 => "F8_E5M2",
Dtype::F8_E4M3 => "F8_E4M3",
Dtype::F8_E8M0 => "F8_E8M0",
Dtype::I16 => "I16",
Dtype::U16 => "U16",
Dtype::I32 => "I32",
Dtype::U32 => "U32",
Dtype::I64 => "I64",
Dtype::U64 => "U64",
Dtype::F16 => "F16",
Dtype::BF16 => "BF16",
Dtype::F32 => "F32",
Dtype::F64 => "F64",
Dtype::C64 => "C64",
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::slice::IndexOp;
use proptest::prelude::*;
#[cfg(not(feature = "std"))]
extern crate std;
use std::io::Write;
const MAX_DIMENSION: usize = 8;
const MAX_SIZE: usize = 8;
const MAX_TENSORS: usize = 8;
fn arbitrary_dtype() -> impl Strategy<Value = Dtype> {
prop_oneof![
Just(Dtype::BOOL),
Just(Dtype::F4),
Just(Dtype::F6_E3M2),
Just(Dtype::F6_E2M3),
Just(Dtype::F8_E5M2),
Just(Dtype::F8_E4M3),
Just(Dtype::U8),
Just(Dtype::I8),
Just(Dtype::I16),
Just(Dtype::U16),
Just(Dtype::I32),
Just(Dtype::U32),
Just(Dtype::I64),
Just(Dtype::U64),
Just(Dtype::F16),
Just(Dtype::BF16),
Just(Dtype::F32),
Just(Dtype::F64),
Just(Dtype::C64),
]
}
fn arbitrary_shape() -> impl Strategy<Value = Vec<usize>> {
(1..MAX_DIMENSION).prop_flat_map(|length| prop::collection::vec(1..MAX_SIZE, length))
}
fn arbitrary_metadata() -> impl Strategy<Value = Metadata> {
(1..MAX_TENSORS)
.prop_flat_map(|size| {
(
prop::collection::vec(arbitrary_dtype(), size),
prop::collection::vec(arbitrary_shape(), size),
)
})
.prop_filter_map("Misaligned slices", |(dtypes, shapes)| {
let mut start = 0;
let tensors: Vec<TensorInfo> = dtypes
.iter()
.zip(shapes)
.flat_map(|(dtype, shape)| {
let bitlength: usize = shape.iter().product::<usize>() * dtype.bitsize();
if bitlength % 8 != 0 {
return None;
}
let length = bitlength.div_ceil(8);
let end = start + length;
let tensor = TensorInfo {
dtype: *dtype,
shape,
data_offsets: (start, end),
};
start = end;
Some(tensor)
})
.collect();
let index_map = (0..tensors.len())
.map(|index| (format!("t.{index}"), index))
.collect();
if tensors.is_empty() {
None
} else {
Some(Metadata {
metadata: None,
tensors,
index_map,
})
}
})
}
fn data_size(metadata: &Metadata) -> usize {
metadata.tensors.last().unwrap().data_offsets.1
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(20))]
#[test]
fn test_indexing(metadata in arbitrary_metadata()) {
let data = vec![0u8; data_size(&metadata)];
let tensors = SafeTensors { metadata, data: &data, crypto: None };
for name in tensors.names() {
assert!(tensors.tensor(name).is_ok());
}
}
#[test]
fn test_roundtrip(metadata in arbitrary_metadata()) {
let data: Vec<u8> = (0..data_size(&metadata)).map(|x| x as u8).collect();
let before = SafeTensors { metadata, data: &data, crypto: None };
let tensors = before.tensors();
let bytes = serialize(tensors.iter().map(|(name, view)| (name.to_string(), view)), None, None).unwrap();
let after = SafeTensors::deserialize(&bytes).unwrap();
assert_eq!(before.names().len(), after.names().len());
for name in before.names() {
let tensor_before = before.tensor(name).unwrap();
let tensor_after = after.tensor(name).unwrap();
assert_eq!(tensor_after.data().as_ptr() as usize % tensor_after.dtype().bitsize().div_ceil(8), 0);
assert_eq!(tensor_before, tensor_after);
}
}
}
#[test]
fn test_serialization() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let shape = vec![1, 2, 3];
let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
let metadata: HashMap<String, TensorView> =
[("attn.0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, None, None).unwrap();
assert_eq!(
out,
[
64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
116, 121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34,
58, 91, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115,
101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 0, 0, 0, 0, 0, 0, 128, 63,
0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0, 160, 64
]
);
let _parsed = SafeTensors::deserialize(&out).unwrap();
}
#[test]
fn test_serialization_fp4() {
let data: Vec<u8> = vec![0u8];
let shape = vec![1, 2];
let attn_0 = TensorView::new(Dtype::F4, shape, &data).unwrap();
let metadata: HashMap<String, TensorView> =
[("attn.0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, None, None).unwrap();
assert_eq!(
out,
[
64, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 46, 48, 34, 58, 123, 34, 100,
116, 121, 112, 101, 34, 58, 34, 70, 52, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
91, 49, 44, 50, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102, 115, 101, 116,
115, 34, 58, 91, 48, 44, 49, 93, 125, 125, 32, 32, 32, 32, 0
]
);
let parsed = SafeTensors::deserialize(&out).unwrap();
let tensors: HashMap<_, _> = parsed.tensors().into_iter().collect();
assert_eq!(tensors, metadata);
}
#[test]
fn test_serialization_fp4_misaligned() {
let data: Vec<u8> = vec![0u8, 1u8];
let shape = vec![1, 3];
let attn_0 = TensorView::new(Dtype::F4, shape, &data);
assert!(matches!(attn_0, Err(SafeTensorError::MisalignedSlice)));
}
#[test]
fn test_serialization_fp4_invalid() {
let data: Vec<u8> = vec![0u8, 1u8];
let shape = vec![1, 2];
let attn_0 = TensorView::new(Dtype::F4, shape, &data);
assert!(matches!(
attn_0,
Err(SafeTensorError::InvalidTensorView(Dtype::F4, _shape, _size))
));
}
#[test]
fn test_empty() {
let tensors: HashMap<String, TensorView> = HashMap::new();
let out = serialize(&tensors, None, None).unwrap();
assert_eq!(
out,
[8, 0, 0, 0, 0, 0, 0, 0, 123, 125, 32, 32, 32, 32, 32, 32]
);
let _parsed = SafeTensors::deserialize(&out).unwrap();
let metadata: Option<HashMap<String, String>> = Some(
[("framework".to_string(), "pt".to_string())]
.into_iter()
.collect(),
);
let out = serialize(&tensors, metadata, None).unwrap();
assert_eq!(
out,
[
40, 0, 0, 0, 0, 0, 0, 0, 123, 34, 95, 95, 109, 101, 116, 97, 100, 97, 116, 97, 95,
95, 34, 58, 123, 34, 102, 114, 97, 109, 101, 119, 111, 114, 107, 34, 58, 34, 112,
116, 34, 125, 125, 32, 32, 32, 32, 32
]
);
let _parsed = SafeTensors::deserialize(&out).unwrap();
}
#[test]
fn test_serialization_forced_alignement() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let shape = vec![1, 1, 2, 3];
let attn_0 = TensorView::new(Dtype::F32, shape, &data).unwrap();
let metadata: HashMap<String, TensorView> =
[("attn0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, None, None).unwrap();
assert_eq!(
out,
[
72, 0, 0, 0, 0, 0, 0, 0, 123, 34, 97, 116, 116, 110, 48, 34, 58, 123, 34, 100, 116,
121, 112, 101, 34, 58, 34, 70, 51, 50, 34, 44, 34, 115, 104, 97, 112, 101, 34, 58,
91, 49, 44, 49, 44, 50, 44, 51, 93, 44, 34, 100, 97, 116, 97, 95, 111, 102, 102,
115, 101, 116, 115, 34, 58, 91, 48, 44, 50, 52, 93, 125, 125, 32, 32, 32, 32, 32,
32, 32, 0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128, 64, 0, 0,
160, 64
],
);
let parsed = SafeTensors::deserialize(&out).unwrap();
let tensor = parsed.tensor("attn0").unwrap();
assert_eq!(
tensor.data().as_ptr() as usize % tensor.dtype().bitsize().div_ceil(8),
0
);
}
#[test]
fn test_slicing() {
let data: Vec<u8> = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let attn_0 = TensorView {
dtype: Dtype::F32,
shape: vec![1, 2, 3],
data: &data,
};
let metadata: HashMap<String, TensorView> =
[("attn.0".to_string(), attn_0)].into_iter().collect();
let out = serialize(&metadata, None, None).unwrap();
let parsed = SafeTensors::deserialize(&out).unwrap();
let out_buffer: Vec<u8> = parsed
.tensor("attn.0")
.unwrap()
.slice((.., ..1))
.unwrap()
.flat_map(|b| b.to_vec())
.collect();
assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64]);
assert_eq!(
out_buffer,
vec![0.0f32, 1.0, 2.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<_>>()
);
let out_buffer: Vec<u8> = parsed
.tensor("attn.0")
.unwrap()
.slice((.., .., ..1))
.unwrap()
.flat_map(|b| b.to_vec())
.collect();
assert_eq!(out_buffer, vec![0u8, 0, 0, 0, 0, 0, 64, 64]);
assert_eq!(
out_buffer,
vec![0.0f32, 3.0]
.into_iter()
.flat_map(|f| f.to_le_bytes())
.collect::<Vec<_>>()
);
}
#[test]
fn test_gpt2() {
gpt2_like(12, "gpt2");
}
#[test]
fn test_gpt2_tiny() {
gpt2_like(6, "gpt2_tiny");
}
fn gpt2_like(n_heads: usize, model_id: &str) {
let mut tensors_desc = vec![
("wte".to_string(), vec![50257, 768]),
("wpe".to_string(), vec![1024, 768]),
];
for i in 0..n_heads {
tensors_desc.push((format!("h.{i}.ln_1.weight"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_1.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.attn.bias"), vec![1, 1, 1024, 1024]));
tensors_desc.push((format!("h.{i}.attn.c_attn.weight"), vec![768, 2304]));
tensors_desc.push((format!("h.{i}.attn.c_attn.bias"), vec![2304]));
tensors_desc.push((format!("h.{i}.attn.c_proj.weight"), vec![768, 768]));
tensors_desc.push((format!("h.{i}.attn.c_proj.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_2.weight"), vec![768]));
tensors_desc.push((format!("h.{i}.ln_2.bias"), vec![768]));
tensors_desc.push((format!("h.{i}.mlp.c_fc.weight"), vec![768, 3072]));
tensors_desc.push((format!("h.{i}.mlp.c_fc.bias"), vec![3072]));
tensors_desc.push((format!("h.{i}.mlp.c_proj.weight"), vec![3072, 768]));
tensors_desc.push((format!("h.{i}.mlp.c_proj.bias"), vec![768]));
}
tensors_desc.push(("ln_f.weight".to_string(), vec![768]));
tensors_desc.push(("ln_f.bias".to_string(), vec![768]));
let dtype = Dtype::F32;
let nbits: usize = tensors_desc
.iter()
.map(|(_, shape)| shape.iter().product::<usize>())
.sum::<usize>()
* dtype.bitsize();
if nbits % 8 != 0 {
panic!("Misaligned slice");
}
let n = nbits
.checked_div(8)
.ok_or(SafeTensorError::ValidationOverflow)
.unwrap(); let all_data = vec![0; n];
let mut metadata = HashMap::with_capacity(tensors_desc.len());
let mut offset = 0;
for (name, shape) in tensors_desc {
let n: usize = shape.iter().product();
let buffer = &all_data[offset..offset + (n * dtype.bitsize()) / 8];
let tensor = TensorView::new(dtype, shape, buffer).unwrap();
metadata.insert(name, tensor);
offset += n;
}
let filename = format!("./out_{model_id}.safetensors");
let out = serialize(&metadata, None, None).unwrap();
std::fs::write(&filename, out).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
#[cfg(feature = "std")]
{
serialize_to_file(&metadata, None, std::path::Path::new(&filename), None).unwrap();
let raw = std::fs::read(&filename).unwrap();
let _deserialized = SafeTensors::deserialize(&raw).unwrap();
std::fs::remove_file(&filename).unwrap();
}
}
#[test]
fn test_empty_shapes_allowed() {
let serialized = b"8\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[],\"data_offsets\":[0,4]}}\x00\x00\x00\x00";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.names(), vec!["test"]);
let tensor = loaded.tensor("test").unwrap();
assert!(tensor.shape().is_empty());
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"\0\0\0\0");
}
#[test]
fn test_deserialization() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.len(), 1);
assert_eq!(loaded.names(), vec!["test"]);
let tensor = loaded.tensor("test").unwrap();
assert_eq!(tensor.shape(), vec![2, 2]);
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}
#[test]
fn test_lifetimes() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
let loaded = SafeTensors::deserialize(serialized).unwrap();
let tensor = loaded.tensor("test").unwrap();
assert_eq!(tensor.shape(), vec![2, 2]);
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
}
#[test]
fn test_json_attack() {
let mut tensors = HashMap::new();
let dtype = Dtype::F32;
let shape = vec![2, 2];
let data_offsets = (0, 16);
for i in 0..10 {
tensors.insert(
format!("weight_{i}"),
TensorInfo {
dtype,
shape: shape.clone(),
data_offsets,
},
);
}
let metadata = HashMetadata {
metadata: None,
tensors,
};
let serialized = serde_json::to_string(&metadata).unwrap();
let serialized = serialized.as_bytes();
let n = serialized.len();
let filename = "out.safetensors";
let mut f = std::io::BufWriter::new(std::fs::File::create(filename).unwrap());
f.write_all(n.to_le_bytes().as_ref()).unwrap();
f.write_all(serialized).unwrap();
f.write_all(b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0").unwrap();
f.flush().unwrap();
let reloaded = std::fs::read(filename).unwrap();
match SafeTensors::deserialize(&reloaded) {
Err(SafeTensorError::InvalidOffset(_)) => {
}
Err(err) => panic!("Unexpected error {err:?}"),
Ok(_) => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_metadata_incomplete_buffer() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00extra_bogus_data_for_polyglot_file";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::MetadataIncompleteBuffer) => {
}
_ => panic!("This should not be able to be deserialized"),
}
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::MetadataIncompleteBuffer) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_header_too_large() {
let serialized = b"<\x00\x00\x00\x00\xff\xff\xff{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::HeaderTooLarge) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_header_too_small() {
let serialized = b"";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::HeaderTooSmall) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_length() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeaderLength) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_non_utf8() {
let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00\xff";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeader(_)) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_not_json() {
let serialized = b"\x01\x00\x00\x00\x00\x00\x00\x00{";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::InvalidHeaderDeserialization(_)) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_whitespace_padded_header() {
let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00{}\x0D\x20\x09\x0A";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.len(), 0);
}
#[test]
fn test_zero_sized_tensor() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,0],\"data_offsets\":[0, 0]}}";
let loaded = SafeTensors::deserialize(serialized).unwrap();
assert_eq!(loaded.names(), vec!["test"]);
let tensor = loaded.tensor("test").unwrap();
assert_eq!(tensor.shape(), vec![2, 0]);
assert_eq!(tensor.dtype(), Dtype::I32);
assert_eq!(tensor.data(), b"");
}
#[test]
fn test_invalid_info() {
let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,2],\"data_offsets\":[0, 4]}}";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::TensorInvalidInfo) => {
}
something => panic!("This should not be able to be deserialized got {something:?}"),
}
}
#[test]
fn test_validation_overflow() {
let serialized = b"O\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,18446744073709551614],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::ValidationOverflow) => {
}
_ => panic!("This should not be able to be deserialized"),
}
let serialized = b"N\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,9223372036854775807],\"data_offsets\":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00";
match SafeTensors::deserialize(serialized) {
Err(SafeTensorError::ValidationOverflow) => {
}
_ => panic!("This should not be able to be deserialized"),
}
}
#[test]
fn test_invalid_header_size_serialization() {
let mut data_info = HashMap::<String, String>::new();
let tensors: HashMap<String, TensorView> = HashMap::new();
let very_large_metadata = "a".repeat(MAX_HEADER_SIZE);
data_info.insert("very_large_metadata".to_string(), very_large_metadata);
match serialize(&tensors, Some(data_info), None) {
Err(SafeTensorError::HeaderTooLarge) => {
}
_ => panic!("This should not be able to be serialized"),
}
}
#[test]
fn test_rewrap_buffer() {
use crate::key::KeyMaterial;
let old_enc_key =
KeyMaterial::new_enc_key(None, None, Some("old-enc".to_string()), None).unwrap();
let old_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("old-sign".to_string()), None)
.unwrap();
let new_enc_key =
KeyMaterial::new_enc_key(None, None, Some("new-enc".to_string()), None).unwrap();
let new_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("new-sign".to_string()), None)
.unwrap();
let data: Vec<u8> = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let shape = vec![2, 2];
let tensor = TensorView {
dtype: Dtype::F32,
shape: shape.clone(),
data: &data,
};
let old_config =
SerializeCryptoConfig::with_keys(old_enc_key.clone(), old_sign_key.clone());
let buffer = serialize([("weight", tensor)], None, Some(&old_config)).unwrap();
let (_, metadata) = SafeTensors::read_metadata(&buffer).unwrap();
assert!(metadata.metadata().is_some(), "Metadata is None");
let meta = metadata.metadata().as_ref().unwrap();
assert!(meta.contains_key("__encryption__"), "Not encrypted");
let old_deser_config =
DeserializeCryptoConfig::with_keys(old_enc_key.clone(), old_sign_key.clone());
let new_ser_config =
SerializeCryptoConfig::with_keys(new_enc_key.clone(), new_sign_key.clone());
let new_buffer = rewrap(&buffer, &new_ser_config, Some(&old_deser_config)).unwrap();
let (_, new_metadata) = SafeTensors::read_metadata(&new_buffer).unwrap();
let new_meta = new_metadata.metadata().as_ref().unwrap();
let crypto_keys = new_meta.get("__crypto_keys__").unwrap();
assert!(
crypto_keys.contains("new-enc"),
"Should contain new-enc kid"
);
let new_deser_config = DeserializeCryptoConfig::with_keys(new_enc_key, new_sign_key);
let loaded =
SafeTensors::deserialize_with_config(&new_buffer, Some(&new_deser_config)).unwrap();
let loaded_tensor = loaded.tensor("weight").unwrap();
assert_eq!(loaded_tensor.shape(), shape);
assert_eq!(loaded_tensor.dtype(), Dtype::F32);
assert_eq!(loaded_tensor.data(), data);
}
#[test]
fn test_rewrap_header_only() {
use crate::key::KeyMaterial;
let old_enc_key =
KeyMaterial::new_enc_key(None, None, Some("old-enc".to_string()), None).unwrap();
let old_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("old-sign".to_string()), None)
.unwrap();
let new_enc_key =
KeyMaterial::new_enc_key(None, None, Some("new-enc".to_string()), None).unwrap();
let new_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("new-sign".to_string()), None)
.unwrap();
let data: Vec<u8> = vec![0u8; 16];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let old_config =
SerializeCryptoConfig::with_keys(old_enc_key.clone(), old_sign_key.clone());
let buffer = serialize([("test", tensor)], None, Some(&old_config)).unwrap();
let (header_size, _) = SafeTensors::read_metadata(&buffer).unwrap();
let header_end = N_LEN + header_size;
let header_bytes = &buffer[..header_end];
let old_deser_config = DeserializeCryptoConfig::with_keys(old_enc_key, old_sign_key);
let new_ser_config = SerializeCryptoConfig::with_keys(new_enc_key, new_sign_key);
let new_header =
rewrap_header(header_bytes, &new_ser_config, Some(&old_deser_config)).unwrap();
let (_, new_metadata) = SafeTensors::read_metadata_header_only(&new_header).unwrap();
let new_meta = new_metadata.metadata().as_ref().unwrap();
let crypto_keys = new_meta.get("__crypto_keys__").unwrap();
assert!(crypto_keys.contains("new-enc"));
assert!(!crypto_keys.contains("old-enc"));
}
#[test]
fn test_rewrap_file() {
use crate::key::KeyMaterial;
use std::io::Read;
let old_enc_key =
KeyMaterial::new_enc_key(None, None, Some("old-enc-file".to_string()), None).unwrap();
let old_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("old-sign-file".to_string()), None)
.unwrap();
let new_enc_key =
KeyMaterial::new_enc_key(None, None, Some("new-enc-file".to_string()), None).unwrap();
let new_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("new-sign-file".to_string()), None)
.unwrap();
let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let old_config =
SerializeCryptoConfig::with_keys(old_enc_key.clone(), old_sign_key.clone());
let temp_file = tempfile::NamedTempFile::new().unwrap();
let filename = temp_file.path().to_path_buf();
serialize_to_file([("data", tensor)], None, &filename, Some(&old_config)).unwrap();
let old_deser_config = DeserializeCryptoConfig::with_keys(old_enc_key, old_sign_key);
let new_ser_config =
SerializeCryptoConfig::with_keys(new_enc_key.clone(), new_sign_key.clone());
rewrap_file(&filename, &new_ser_config, Some(&old_deser_config)).unwrap();
let mut buffer = Vec::new();
std::fs::File::open(&filename)
.unwrap()
.read_to_end(&mut buffer)
.unwrap();
let new_deser_config = DeserializeCryptoConfig::with_keys(new_enc_key, new_sign_key);
let loaded =
SafeTensors::deserialize_with_config(&buffer, Some(&new_deser_config)).unwrap();
let loaded_tensor = loaded.tensor("data").unwrap();
assert_eq!(loaded_tensor.data(), data);
}
#[test]
fn test_rewrap_unencrypted_fails() {
use crate::key::KeyMaterial;
let data: Vec<u8> = vec![0u8; 16];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let buffer = serialize([("test", tensor)], None, None).unwrap();
let new_enc_key =
KeyMaterial::new_enc_key(None, None, Some("new-enc".to_string()), None).unwrap();
let new_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("new-sign".to_string()), None)
.unwrap();
let new_ser_config = SerializeCryptoConfig::with_keys(new_enc_key, new_sign_key);
let result = rewrap(&buffer, &new_ser_config, None);
assert!(result.is_err());
match result {
Err(SafeTensorError::CryptoTensorsError(msg)) => {
assert!(msg.contains("not encrypted"));
}
other => panic!("Expected CryptoTensorsError, got {:?}", other),
}
}
static INDEPENDENT_PATHS_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn test_independent_paths_serialize_enc_direct_sign_registry() {
use crate::key::KeyMaterial;
use crate::registry::{
clear_providers, disable_provider, register_provider_with_priority, DirectKeyProvider,
PRIORITY_DIRECT,
};
let _guard = INDEPENDENT_PATHS_MUTEX.lock().unwrap();
clear_providers();
let enc_key =
KeyMaterial::new_enc_key(None, None, Some("indep-enc".to_string()), None).unwrap();
let sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("indep-sign".to_string()), None)
.unwrap();
let enc_jwk: serde_json::Value = serde_json::from_str(&enc_key.to_jwk().unwrap()).unwrap();
let sign_jwk: serde_json::Value =
serde_json::from_str(&sign_key.to_jwk().unwrap()).unwrap();
let provider = DirectKeyProvider::from_single_keys(enc_jwk, sign_jwk);
register_provider_with_priority(Box::new(provider), PRIORITY_DIRECT);
let mut config = SerializeCryptoConfig::new();
config.enc_key = Some(enc_key);
config.sign_kid = Some("indep-sign".to_string());
let data: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let buffer = serialize([("w", tensor)], None, Some(&config)).unwrap();
let (_, meta) = SafeTensors::read_metadata(&buffer).unwrap();
assert!(meta
.metadata()
.as_ref()
.unwrap()
.contains_key("__encryption__"));
disable_provider("DirectKeyProvider");
}
#[test]
fn test_independent_paths_serialize_enc_registry_sign_direct() {
use crate::key::KeyMaterial;
use crate::registry::{
clear_providers, disable_provider, register_provider_with_priority, DirectKeyProvider,
PRIORITY_DIRECT,
};
let _guard = INDEPENDENT_PATHS_MUTEX.lock().unwrap();
clear_providers();
let enc_key =
KeyMaterial::new_enc_key(None, None, Some("indep2-enc".to_string()), None).unwrap();
let sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("indep2-sign".to_string()), None)
.unwrap();
let enc_jwk: serde_json::Value = serde_json::from_str(&enc_key.to_jwk().unwrap()).unwrap();
let sign_jwk: serde_json::Value =
serde_json::from_str(&sign_key.to_jwk().unwrap()).unwrap();
let provider = DirectKeyProvider::from_single_keys(enc_jwk, sign_jwk);
register_provider_with_priority(Box::new(provider), PRIORITY_DIRECT);
let mut config = SerializeCryptoConfig::new();
config.enc_kid = Some("indep2-enc".to_string());
config.sign_key = Some(sign_key);
let data: Vec<u8> = vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let buffer = serialize([("w", tensor)], None, Some(&config)).unwrap();
let (_, meta) = SafeTensors::read_metadata(&buffer).unwrap();
assert!(meta
.metadata()
.as_ref()
.unwrap()
.contains_key("__encryption__"));
disable_provider("DirectKeyProvider");
}
#[test]
fn test_independent_paths_deserialize_enc_direct_sign_registry() {
use crate::key::KeyMaterial;
use crate::registry::{
clear_providers, disable_provider, register_provider_with_priority, DirectKeyProvider,
PRIORITY_DIRECT,
};
let _guard = INDEPENDENT_PATHS_MUTEX.lock().unwrap();
clear_providers();
let enc_key =
KeyMaterial::new_enc_key(None, None, Some("d-enc".to_string()), None).unwrap();
let sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("d-sign".to_string()), None).unwrap();
let enc_jwk: serde_json::Value = serde_json::from_str(&enc_key.to_jwk().unwrap()).unwrap();
let sign_jwk: serde_json::Value =
serde_json::from_str(&sign_key.to_jwk().unwrap()).unwrap();
let provider = DirectKeyProvider::from_single_keys(enc_jwk, sign_jwk);
register_provider_with_priority(Box::new(provider), PRIORITY_DIRECT);
let ser_config = SerializeCryptoConfig::with_keys(enc_key.clone(), sign_key);
let data: Vec<u8> = vec![5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let buffer = serialize([("x", tensor)], None, Some(&ser_config)).unwrap();
let mut deser_config = DeserializeCryptoConfig::new();
deser_config.enc_key = Some(enc_key);
let loaded = SafeTensors::deserialize_with_config(&buffer, Some(&deser_config)).unwrap();
let t = loaded.tensor("x").unwrap();
assert_eq!(t.shape(), &[2, 2]);
assert_eq!(t.data(), data);
disable_provider("DirectKeyProvider");
}
#[test]
fn test_independent_paths_deserialize_enc_registry_sign_direct() {
use crate::key::KeyMaterial;
use crate::registry::{
clear_providers, disable_provider, register_provider_with_priority, DirectKeyProvider,
PRIORITY_DIRECT,
};
let _guard = INDEPENDENT_PATHS_MUTEX.lock().unwrap();
clear_providers();
let enc_key =
KeyMaterial::new_enc_key(None, None, Some("d2-enc".to_string()), None).unwrap();
let sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("d2-sign".to_string()), None).unwrap();
let enc_jwk: serde_json::Value = serde_json::from_str(&enc_key.to_jwk().unwrap()).unwrap();
let sign_jwk: serde_json::Value =
serde_json::from_str(&sign_key.to_jwk().unwrap()).unwrap();
let provider = DirectKeyProvider::from_single_keys(enc_jwk, sign_jwk);
register_provider_with_priority(Box::new(provider), PRIORITY_DIRECT);
let ser_config = SerializeCryptoConfig::with_keys(enc_key, sign_key.clone());
let data: Vec<u8> = vec![
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
];
let tensor = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data,
};
let buffer = serialize([("y", tensor)], None, Some(&ser_config)).unwrap();
let mut deser_config = DeserializeCryptoConfig::new();
deser_config.sign_key = Some(sign_key);
let loaded = SafeTensors::deserialize_with_config(&buffer, Some(&deser_config)).unwrap();
let t = loaded.tensor("y").unwrap();
assert_eq!(t.shape(), &[2, 2]);
assert_eq!(t.data(), data);
disable_provider("DirectKeyProvider");
}
#[test]
fn test_partial_encryption_roundtrip() {
use crate::key::KeyMaterial;
let enc_key =
KeyMaterial::new_enc_key(None, None, Some("partial-enc".to_string()), None).unwrap();
let sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("partial-sign".to_string()), None)
.unwrap();
let data_secret: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let data_public: Vec<u8> = vec![
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
];
let tensor_secret = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data_secret,
};
let tensor_public = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data_public,
};
let config = SerializeCryptoConfig::with_keys(enc_key.clone(), sign_key.clone())
.tensor_filter(vec!["secret".to_string()]);
let buffer = serialize(
[("secret", tensor_secret), ("public", tensor_public)],
None,
Some(&config),
)
.unwrap();
let (_, metadata) = SafeTensors::read_metadata(&buffer).unwrap();
let meta = metadata.metadata().as_ref().unwrap();
let encryption_info = meta.get("__encryption__").unwrap();
assert!(
encryption_info.contains("secret"),
"secret should be in encryption info"
);
assert!(
!encryption_info.contains("public"),
"public should NOT be in encryption info"
);
let deser_config = DeserializeCryptoConfig::with_keys(enc_key, sign_key);
let loaded = SafeTensors::deserialize_with_config(&buffer, Some(&deser_config)).unwrap();
let loaded_secret = loaded.tensor("secret").unwrap();
assert_eq!(loaded_secret.shape(), &[2, 2]);
assert_eq!(loaded_secret.data(), data_secret);
let loaded_public = loaded.tensor("public").unwrap();
assert_eq!(loaded_public.shape(), &[2, 2]);
assert_eq!(loaded_public.data(), data_public);
}
#[test]
fn test_partial_encryption_rewrap() {
use crate::key::KeyMaterial;
let old_enc_key =
KeyMaterial::new_enc_key(None, None, Some("old-part-enc".to_string()), None).unwrap();
let old_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("old-part-sign".to_string()), None)
.unwrap();
let new_enc_key =
KeyMaterial::new_enc_key(None, None, Some("new-part-enc".to_string()), None).unwrap();
let new_sign_key =
KeyMaterial::new_sign_key(None, None, None, Some("new-part-sign".to_string()), None)
.unwrap();
let data_secret: Vec<u8> = vec![
10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160,
];
let data_public: Vec<u8> = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let tensor_secret = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data_secret,
};
let tensor_public = TensorView {
dtype: Dtype::F32,
shape: vec![2, 2],
data: &data_public,
};
let old_config =
SerializeCryptoConfig::with_keys(old_enc_key.clone(), old_sign_key.clone())
.tensor_filter(vec!["secret".to_string()]);
let buffer = serialize(
[("secret", tensor_secret), ("public", tensor_public)],
None,
Some(&old_config),
)
.unwrap();
let old_deser_config = DeserializeCryptoConfig::with_keys(old_enc_key, old_sign_key);
let new_ser_config =
SerializeCryptoConfig::with_keys(new_enc_key.clone(), new_sign_key.clone());
let new_buffer = rewrap(&buffer, &new_ser_config, Some(&old_deser_config)).unwrap();
let (_, new_metadata) = SafeTensors::read_metadata(&new_buffer).unwrap();
let new_meta = new_metadata.metadata().as_ref().unwrap();
let crypto_keys = new_meta.get("__crypto_keys__").unwrap();
assert!(crypto_keys.contains("new-part-enc"));
assert!(!crypto_keys.contains("old-part-enc"));
let encryption_info = new_meta.get("__encryption__").unwrap();
assert!(encryption_info.contains("secret"));
assert!(!encryption_info.contains("public"));
let new_deser_config = DeserializeCryptoConfig::with_keys(new_enc_key, new_sign_key);
let loaded =
SafeTensors::deserialize_with_config(&new_buffer, Some(&new_deser_config)).unwrap();
let loaded_secret = loaded.tensor("secret").unwrap();
assert_eq!(loaded_secret.data(), data_secret);
let loaded_public = loaded.tensor("public").unwrap();
assert_eq!(loaded_public.data(), data_public);
}
}