use safetensors::SafeTensors;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub shape: Vec<usize>,
pub dtype: safetensors::Dtype,
}
pub struct SafeTensorsModel {
data: Vec<u8>,
info: HashMap<String, TensorInfo>,
}
impl SafeTensorsModel {
pub fn download(repo_id: &str) -> Result<Self, Box<dyn std::error::Error>> {
Self::download_file(repo_id, "model.safetensors")
}
pub fn download_file(
repo_id: &str,
filename: &str,
) -> Result<Self, Box<dyn std::error::Error>> {
log::info!(
"downloading {}/{} from HuggingFace Hub...",
repo_id,
filename
);
let api = hf_hub::api::sync::Api::new()?;
let repo = api.model(repo_id.to_string());
let path = repo.get(filename)?;
log::info!("cached at: {}", path.display());
Self::load(path)
}
pub fn load(path: PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let data = std::fs::read(&path)?;
let tensors = SafeTensors::deserialize(&data)?;
let mut info = HashMap::new();
for (name, view) in tensors.iter() {
info.insert(
name.to_string(),
TensorInfo {
shape: view.shape().to_vec(),
dtype: view.dtype(),
},
);
}
log::info!("loaded {} tensors from {}", info.len(), path.display());
Ok(Self { data, info })
}
pub fn tensor_info(&self) -> &HashMap<String, TensorInfo> {
&self.info
}
pub fn tensor_f32(&self, name: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let tensors = SafeTensors::deserialize(&self.data)?;
let view = tensors
.tensor(name)
.map_err(|e| format!("tensor '{}': {}", name, e))?;
if view.dtype() != safetensors::Dtype::F32 {
return Err(format!(
"tensor '{}' has dtype {:?}, expected F32",
name,
view.dtype()
)
.into());
}
let bytes = view.data();
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(floats)
}
pub fn tensor_f32_auto(&self, name: &str) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let tensors = SafeTensors::deserialize(&self.data)?;
let view = tensors
.tensor(name)
.map_err(|e| format!("tensor '{}': {}", name, e))?;
match view.dtype() {
safetensors::Dtype::F32 => {
let bytes = view.data();
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(floats)
}
safetensors::Dtype::BF16 => {
let bytes = view.data();
let floats: Vec<f32> = bytes
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
f32::from_bits((bits as u32) << 16)
})
.collect();
Ok(floats)
}
other => Err(format!(
"tensor '{}' has dtype {:?}, expected F32 or BF16",
name, other
)
.into()),
}
}
pub fn tensor_f32_auto_transposed(
&self,
name: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let info = self
.info
.get(name)
.ok_or_else(|| format!("tensor '{}' not found", name))?;
if info.shape.len() != 2 {
return Err(format!(
"tensor '{}' has {} dims, expected 2 for transpose",
name,
info.shape.len()
)
.into());
}
let data = self.tensor_f32_auto(name)?;
let rows = info.shape[0];
let cols = info.shape[1];
let mut transposed = vec![0.0_f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
transposed[c * rows + r] = data[r * cols + c];
}
}
Ok(transposed)
}
pub fn tensor_f32_transposed(
&self,
name: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error>> {
let info = self
.info
.get(name)
.ok_or_else(|| format!("tensor '{}' not found", name))?;
if info.shape.len() != 2 {
return Err(format!(
"tensor '{}' has {} dims, expected 2 for transpose",
name,
info.shape.len()
)
.into());
}
let data = self.tensor_f32(name)?;
let rows = info.shape[0];
let cols = info.shape[1];
let mut transposed = vec![0.0_f32; rows * cols];
for r in 0..rows {
for c in 0..cols {
transposed[c * rows + r] = data[r * cols + c];
}
}
Ok(transposed)
}
}