#![allow(dead_code)]
#[derive(Debug, Clone, PartialEq)]
pub enum StDtype {
F32,
F16,
Bf16,
F64,
I64,
I32,
I16,
I8,
Bool,
}
impl StDtype {
pub fn byte_width(&self) -> usize {
match self {
StDtype::F64 | StDtype::I64 => 8,
StDtype::F32 | StDtype::I32 => 4,
StDtype::F16 | StDtype::Bf16 | StDtype::I16 => 2,
StDtype::I8 | StDtype::Bool => 1,
}
}
}
#[derive(Debug, Clone)]
pub struct StTensorEntry {
pub name: String,
pub dtype: StDtype,
pub shape: Vec<u64>,
pub data_offsets: (u64, u64),
}
impl StTensorEntry {
pub fn element_count(&self) -> u64 {
self.shape.iter().product::<u64>().max(1)
}
}
#[derive(Debug, Clone, Default)]
pub struct SafeTensorsExport {
pub tensors: Vec<StTensorEntry>,
pub metadata: Vec<(String, String)>,
}
pub fn new_safetensors_export() -> SafeTensorsExport {
SafeTensorsExport::default()
}
pub fn add_st_tensor(export: &mut SafeTensorsExport, entry: StTensorEntry) {
export.tensors.push(entry);
}
pub fn add_st_metadata(export: &mut SafeTensorsExport, key: &str, value: &str) {
export.metadata.push((key.to_string(), value.to_string()));
}
pub fn find_st_tensor<'a>(export: &'a SafeTensorsExport, name: &str) -> Option<&'a StTensorEntry> {
export.tensors.iter().find(|t| t.name == name)
}
pub fn st_tensor_count(export: &SafeTensorsExport) -> usize {
export.tensors.len()
}
pub fn validate_safetensors(export: &SafeTensorsExport) -> bool {
!export.tensors.is_empty()
}
pub fn st_data_size_estimate(export: &SafeTensorsExport) -> usize {
let data_bytes: usize = export
.tensors
.iter()
.map(|t| t.element_count() as usize * t.dtype.byte_width())
.sum();
let header_bytes: usize = export
.tensors
.iter()
.map(|t| t.name.len() + 64)
.sum::<usize>()
+ export
.metadata
.iter()
.map(|(k, v)| k.len() + v.len() + 4)
.sum::<usize>()
+ 8;
data_bytes + header_bytes
}
pub fn st_header_json(export: &SafeTensorsExport) -> String {
format!(
"{{\"tensors\":{},\"metadata_keys\":{}}}",
export.tensors.len(),
export.metadata.len()
)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_export() -> SafeTensorsExport {
let mut e = new_safetensors_export();
add_st_tensor(
&mut e,
StTensorEntry {
name: "model.embed.weight".into(),
dtype: StDtype::F16,
shape: vec![32000, 4096],
data_offsets: (0, 32000 * 4096 * 2),
},
);
add_st_metadata(&mut e, "format", "pt");
e
}
#[test]
fn tensor_count() {
let e = sample_export();
assert_eq!(st_tensor_count(&e), 1);
}
#[test]
fn validate_complete() {
let e = sample_export();
assert!(validate_safetensors(&e));
}
#[test]
fn validate_empty_false() {
let e = new_safetensors_export();
assert!(!validate_safetensors(&e));
}
#[test]
fn find_tensor_found() {
let e = sample_export();
assert!(find_st_tensor(&e, "model.embed.weight").is_some());
}
#[test]
fn find_tensor_missing() {
let e = sample_export();
assert!(find_st_tensor(&e, "nonexistent").is_none());
}
#[test]
fn dtype_byte_width_f32() {
assert_eq!(StDtype::F32.byte_width(), 4);
}
#[test]
fn dtype_byte_width_f16() {
assert_eq!(StDtype::F16.byte_width(), 2);
}
#[test]
fn size_estimate_positive() {
let e = sample_export();
assert!(st_data_size_estimate(&e) > 0);
}
#[test]
fn header_json_has_tensors() {
let e = sample_export();
let json = st_header_json(&e);
assert!(json.contains("tensors"));
}
}