use std::collections::HashMap;
use std::hash::BuildHasher;
use std::path::Path;
use crate::error::AnamnesisError;
use crate::parse::npz::{NpzDtype, NpzTensor};
fn npz_dtype_to_safetensors(dtype: NpzDtype) -> safetensors::Dtype {
match dtype {
NpzDtype::Bool => safetensors::Dtype::BOOL,
NpzDtype::U8 => safetensors::Dtype::U8,
NpzDtype::I8 => safetensors::Dtype::I8,
NpzDtype::U16 => safetensors::Dtype::U16,
NpzDtype::I16 => safetensors::Dtype::I16,
NpzDtype::U32 => safetensors::Dtype::U32,
NpzDtype::I32 => safetensors::Dtype::I32,
NpzDtype::U64 => safetensors::Dtype::U64,
NpzDtype::I64 => safetensors::Dtype::I64,
NpzDtype::F16 => safetensors::Dtype::F16,
NpzDtype::BF16 => safetensors::Dtype::BF16,
NpzDtype::F32 => safetensors::Dtype::F32,
NpzDtype::F64 => safetensors::Dtype::F64,
}
}
pub fn npz_to_safetensors<S: BuildHasher>(
tensors: &HashMap<String, NpzTensor, S>,
output: impl AsRef<Path>,
) -> crate::Result<()> {
let mut names: Vec<&str> = tensors.keys().map(String::as_str).collect();
names.sort_unstable();
let mut views: Vec<(String, safetensors::tensor::TensorView<'_>)> =
Vec::with_capacity(names.len());
for name in &names {
let tensor = tensors.get(*name).ok_or_else(|| AnamnesisError::Parse {
reason: format!("NPZ→safetensors: tensor `{name}` vanished mid-iteration"),
})?;
let st_dtype = npz_dtype_to_safetensors(tensor.dtype);
let view =
safetensors::tensor::TensorView::new(st_dtype, tensor.shape.clone(), &tensor.data)
.map_err(|e| AnamnesisError::Parse {
reason: format!("failed to create TensorView for `{name}`: {e}"),
})?;
views.push(((*name).to_owned(), view));
}
safetensors::tensor::serialize_to_file(views, &None, output.as_ref()).map_err(
#[allow(clippy::wildcard_enum_match_arm)]
|e| match e {
safetensors::SafeTensorError::IoError(io_err) => AnamnesisError::Io(io_err),
other => AnamnesisError::Parse {
reason: format!("failed to write safetensors file: {other}"),
},
},
)?;
Ok(())
}
pub fn npz_to_safetensors_bytes<S: BuildHasher>(
tensors: &HashMap<String, NpzTensor, S>,
) -> crate::Result<Vec<u8>> {
let mut names: Vec<&str> = tensors.keys().map(String::as_str).collect();
names.sort_unstable();
let mut views: Vec<(String, safetensors::tensor::TensorView<'_>)> =
Vec::with_capacity(names.len());
for name in &names {
let tensor = tensors.get(*name).ok_or_else(|| AnamnesisError::Parse {
reason: format!("NPZ→safetensors: tensor `{name}` vanished mid-iteration"),
})?;
let st_dtype = npz_dtype_to_safetensors(tensor.dtype);
let view =
safetensors::tensor::TensorView::new(st_dtype, tensor.shape.clone(), &tensor.data)
.map_err(|e| AnamnesisError::Parse {
reason: format!("failed to create TensorView for `{name}`: {e}"),
})?;
views.push(((*name).to_owned(), view));
}
#[allow(clippy::wildcard_enum_match_arm)]
safetensors::tensor::serialize(views, &None).map_err(|e| AnamnesisError::Parse {
reason: format!("failed to serialize safetensors: {e}"),
})
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::indexing_slicing,
clippy::expect_used,
clippy::as_conversions,
clippy::cast_possible_truncation
)]
mod tests {
use super::*;
fn make_tensor(name: &str, dtype: NpzDtype, shape: Vec<usize>, data: Vec<u8>) -> NpzTensor {
NpzTensor {
name: name.to_owned(),
shape,
dtype,
data,
}
}
#[test]
fn empty_archive() {
let tensors: HashMap<String, NpzTensor> = HashMap::new();
let bytes = npz_to_safetensors_bytes(&tensors).unwrap();
let parsed = safetensors::SafeTensors::deserialize(&bytes).unwrap();
assert!(parsed.names().is_empty());
}
#[test]
fn roundtrip_mixed_dtypes() {
let f32_data: Vec<u8> = (0..4u32)
.flat_map(|i| f32::from(i as u16).to_le_bytes())
.collect();
let i32_data: Vec<u8> = (0..2i32).flat_map(i32::to_le_bytes).collect();
let u8_data: Vec<u8> = vec![1, 2, 3, 4];
let mut map = HashMap::new();
map.insert(
"w".into(),
make_tensor("w", NpzDtype::F32, vec![2, 2], f32_data.clone()),
);
map.insert(
"idx".into(),
make_tensor("idx", NpzDtype::I32, vec![2], i32_data.clone()),
);
map.insert(
"bytes".into(),
make_tensor("bytes", NpzDtype::U8, vec![4], u8_data.clone()),
);
let bytes = npz_to_safetensors_bytes(&map).unwrap();
let parsed = safetensors::SafeTensors::deserialize(&bytes).unwrap();
let w = parsed.tensor("w").unwrap();
assert_eq!(w.dtype(), safetensors::Dtype::F32);
assert_eq!(w.shape(), &[2, 2]);
assert_eq!(w.data(), f32_data.as_slice());
let idx = parsed.tensor("idx").unwrap();
assert_eq!(idx.dtype(), safetensors::Dtype::I32);
assert_eq!(idx.shape(), &[2]);
assert_eq!(idx.data(), i32_data.as_slice());
let b = parsed.tensor("bytes").unwrap();
assert_eq!(b.dtype(), safetensors::Dtype::U8);
assert_eq!(b.shape(), &[4]);
assert_eq!(b.data(), u8_data.as_slice());
}
#[test]
fn bf16_dtype_roundtrips() {
let bf16_data: Vec<u8> = (0..4u16).flat_map(u16::to_le_bytes).collect();
let mut map = HashMap::new();
map.insert(
"w".into(),
make_tensor("w", NpzDtype::BF16, vec![2, 2], bf16_data.clone()),
);
let bytes = npz_to_safetensors_bytes(&map).unwrap();
let parsed = safetensors::SafeTensors::deserialize(&bytes).unwrap();
let w = parsed.tensor("w").unwrap();
assert_eq!(w.dtype(), safetensors::Dtype::BF16);
assert_eq!(w.data(), bf16_data.as_slice());
}
#[test]
fn deterministic_output() {
let mut map1 = HashMap::new();
let mut map2 = HashMap::new();
for (i, name) in ["alpha", "beta", "gamma"].iter().enumerate() {
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let bytes_val = (i as u32).to_le_bytes().to_vec();
map1.insert(
(*name).into(),
make_tensor(name, NpzDtype::U32, vec![1], bytes_val.clone()),
);
map2.insert(
(*name).into(),
make_tensor(name, NpzDtype::U32, vec![1], bytes_val),
);
}
let b1 = npz_to_safetensors_bytes(&map1).unwrap();
let b2 = npz_to_safetensors_bytes(&map2).unwrap();
assert_eq!(b1, b2, "NPZ→safetensors output must be deterministic");
}
}