#[cfg(feature = "safetensors-compare")]
use safetensors::SafeTensors;
use super::WeightDiff;
#[derive(Debug)]
pub enum SafetensorsError {
FileNotFound(String),
ParseError(String),
TensorNotFound(String),
DownloadError(String),
IoError(std::io::Error),
}
impl std::fmt::Display for SafetensorsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FileNotFound(p) => write!(f, "File not found: {p}"),
Self::ParseError(e) => write!(f, "Parse error: {e}"),
Self::TensorNotFound(n) => write!(f, "Tensor not found: {n}"),
Self::DownloadError(e) => write!(f, "Download error: {e}"),
Self::IoError(e) => write!(f, "IO error: {e}"),
}
}
}
impl std::error::Error for SafetensorsError {}
impl From<std::io::Error> for SafetensorsError {
fn from(e: std::io::Error) -> Self {
Self::IoError(e)
}
}
pub type Result<T> = std::result::Result<T, SafetensorsError>;
#[derive(Debug, Clone)]
pub struct TensorData {
pub name: String,
pub shape: Vec<usize>,
pub data: Vec<f32>,
pub dtype: String,
}
impl TensorData {
#[must_use]
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[must_use]
pub fn l2_norm(&self) -> f32 {
self.data.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[must_use]
pub fn mean(&self) -> f32 {
if self.data.is_empty() {
0.0
} else {
self.data.iter().sum::<f32>() / self.data.len() as f32
}
}
}
#[derive(Debug, Clone)]
pub struct TensorComparison {
pub name: String,
pub shape_match: bool,
pub shape_a: Vec<usize>,
pub shape_b: Vec<usize>,
pub weight_diff: Option<WeightDiff>,
pub passes_threshold: bool,
}
impl TensorComparison {
#[must_use]
pub fn compare(name: &str, a: &TensorData, b: &[f32], threshold: f64) -> Self {
let shape_match = a.numel() == b.len();
let weight_diff = if shape_match {
Some(WeightDiff::from_slices(&a.data, b))
} else {
None
};
let passes_threshold = weight_diff.as_ref().is_some_and(|d| d.max_diff < threshold);
Self {
name: name.to_string(),
shape_match,
shape_a: a.shape.clone(),
shape_b: vec![b.len()],
weight_diff,
passes_threshold,
}
}
#[must_use]
pub fn is_close(&self, threshold: f64) -> bool {
self.shape_match
&& self
.weight_diff
.as_ref()
.is_some_and(|d| d.max_diff < threshold)
}
}
#[cfg(feature = "safetensors-compare")]
#[derive(Debug)]
pub struct HfSafetensors {
data: Vec<u8>,
tensor_names: Vec<String>,
}
#[cfg(feature = "safetensors-compare")]
impl HfSafetensors {
pub fn from_file(path: &std::path::Path) -> Result<Self> {
let data = std::fs::read(path)?;
Self::from_bytes(data)
}
pub fn from_bytes(data: Vec<u8>) -> Result<Self> {
let tensors = SafeTensors::deserialize(&data)
.map_err(|e| SafetensorsError::ParseError(e.to_string()))?;
let tensor_names: Vec<String> = tensors.names().into_iter().map(String::from).collect();
Ok(Self { data, tensor_names })
}
#[cfg(feature = "hf-hub-integration")]
pub fn from_hub(repo_id: &str) -> Result<Self> {
use hf_hub::api::sync::ApiBuilder;
let api = ApiBuilder::new()
.build()
.map_err(|e| SafetensorsError::DownloadError(e.to_string()))?;
let repo = api.model(repo_id.to_string());
let path = repo
.get("model.safetensors")
.or_else(|_| repo.get("pytorch_model.safetensors"))
.map_err(|e| SafetensorsError::DownloadError(format!("No safetensors file: {e}")))?;
Self::from_file(&path)
}
#[must_use]
pub fn tensor_names(&self) -> &[String] {
&self.tensor_names
}
pub fn tensor(&self, name: &str) -> Result<TensorData> {
let tensors = SafeTensors::deserialize(&self.data)
.map_err(|e| SafetensorsError::ParseError(e.to_string()))?;
let view = tensors
.tensor(name)
.map_err(|_| SafetensorsError::TensorNotFound(name.to_string()))?;
let shape: Vec<usize> = view.shape().to_vec();
let dtype = format!("{:?}", view.dtype());
let data = Self::convert_to_f32(view.data(), view.dtype())?;
Ok(TensorData {
name: name.to_string(),
shape,
data,
dtype,
})
}
fn convert_to_f32(bytes: &[u8], dtype: safetensors::Dtype) -> Result<Vec<f32>> {
use safetensors::Dtype;
match dtype {
Dtype::F32 => {
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(floats)
}
Dtype::F16 => {
let floats: Vec<f32> = bytes
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect();
Ok(floats)
}
Dtype::BF16 => {
let floats: Vec<f32> = bytes
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::bf16::from_bits(bits).to_f32()
})
.collect();
Ok(floats)
}
_ => Err(SafetensorsError::ParseError(format!(
"Unsupported dtype: {dtype:?}"
))),
}
}
pub fn compare_all<F>(&self, get_apr_tensor: F, threshold: f64) -> Vec<TensorComparison>
where
F: Fn(&str) -> Option<Vec<f32>>,
{
let mut results = Vec::new();
for name in &self.tensor_names {
if let Ok(hf_tensor) = self.tensor(name) {
if let Some(apr_data) = get_apr_tensor(name) {
results.push(TensorComparison::compare(
name, &hf_tensor, &apr_data, threshold,
));
}
}
}
results
}
}
#[derive(Debug)]
pub struct BatchComparison {
pub comparisons: Vec<TensorComparison>,
pub total_compared: usize,
pub total_passed: usize,
pub shape_mismatches: usize,
pub worst_tensor: Option<String>,
pub worst_diff: f64,
}
impl BatchComparison {
#[must_use]
pub fn from_comparisons(comparisons: Vec<TensorComparison>) -> Self {
let total_compared = comparisons.len();
let total_passed = comparisons.iter().filter(|c| c.passes_threshold).count();
let shape_mismatches = comparisons.iter().filter(|c| !c.shape_match).count();
let (worst_tensor, worst_diff) = comparisons
.iter()
.filter_map(|c| c.weight_diff.as_ref().map(|d| (&c.name, d.max_diff)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or((None, 0.0), |(name, diff)| (Some(name.clone()), diff));
Self {
comparisons,
total_compared,
total_passed,
shape_mismatches,
worst_tensor,
worst_diff,
}
}
#[must_use]
pub fn all_passed(&self) -> bool {
self.total_passed == self.total_compared && self.shape_mismatches == 0
}
#[must_use]
pub fn summary(&self) -> String {
format!(
"Compared {} tensors: {} passed, {} shape mismatches, worst diff: {:.6} ({})",
self.total_compared,
self.total_passed,
self.shape_mismatches,
self.worst_diff,
self.worst_tensor.as_deref().unwrap_or("none")
)
}
}
#[cfg(test)]
#[path = "safetensors_tests.rs"]
mod tests;