use std::path::Path;
use crate::error::AnamnesisError;
use crate::parse::pth::PthTensor;
pub fn pth_to_safetensors(
tensors: &[PthTensor<'_>],
output: impl AsRef<Path>,
) -> crate::Result<()> {
let mut views: Vec<(String, safetensors::tensor::TensorView<'_>)> =
Vec::with_capacity(tensors.len());
for tensor in tensors {
let st_dtype = tensor.dtype.to_safetensors_dtype()?;
let view =
safetensors::tensor::TensorView::new(st_dtype, tensor.shape.clone(), &tensor.data)
.map_err(|e| AnamnesisError::Parse {
reason: format!("failed to create TensorView for `{}`: {e}", tensor.name),
})?;
views.push((tensor.name.clone(), view));
}
safetensors::tensor::serialize_to_file(views, &None, output.as_ref()).map_err(
#[allow(clippy::wildcard_enum_match_arm)]
|e| match e {
safetensors::SafeTensorError::IoError(io_err) => AnamnesisError::Io(io_err),
other => AnamnesisError::Parse {
reason: format!("failed to write safetensors file: {other}"),
},
},
)?;
Ok(())
}
pub fn pth_to_safetensors_bytes(tensors: &[PthTensor<'_>]) -> crate::Result<Vec<u8>> {
let mut views: Vec<(String, safetensors::tensor::TensorView<'_>)> =
Vec::with_capacity(tensors.len());
for tensor in tensors {
let st_dtype = tensor.dtype.to_safetensors_dtype()?;
let view =
safetensors::tensor::TensorView::new(st_dtype, tensor.shape.clone(), &tensor.data)
.map_err(|e| AnamnesisError::Parse {
reason: format!("failed to create TensorView for `{}`: {e}", tensor.name),
})?;
views.push((tensor.name.clone(), view));
}
#[allow(clippy::wildcard_enum_match_arm)]
safetensors::tensor::serialize(views, &None).map_err(|e| AnamnesisError::Parse {
reason: format!("failed to serialize safetensors: {e}"),
})
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::parse::pth::PthDtype;
use std::borrow::Cow;
#[test]
fn roundtrip_simple() {
let weight_data: Vec<u8> = vec![
0x00, 0x00, 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, ];
let bias_data: Vec<u8> = vec![
0x00, 0x00, 0x00, 0x3F, 0x00, 0x00, 0x00, 0xBF, ];
let tensors = vec![
PthTensor {
name: "weight".into(),
shape: vec![2, 2],
dtype: PthDtype::F32,
data: Cow::Borrowed(&weight_data),
},
PthTensor {
name: "bias".into(),
shape: vec![2],
dtype: PthDtype::F32,
data: Cow::Borrowed(&bias_data),
},
];
let tmp = tempfile::NamedTempFile::new().unwrap();
pth_to_safetensors(&tensors, tmp.path()).unwrap();
let data = std::fs::read(tmp.path()).unwrap();
let st = safetensors::SafeTensors::deserialize(&data).unwrap();
assert_eq!(st.len(), 2);
let w = st.tensor("weight").unwrap();
assert_eq!(w.shape(), &[2, 2]);
assert_eq!(w.dtype(), safetensors::Dtype::F32);
assert_eq!(w.data(), weight_data.as_slice());
let b = st.tensor("bias").unwrap();
assert_eq!(b.shape(), &[2]);
assert_eq!(b.dtype(), safetensors::Dtype::F32);
assert_eq!(b.data(), bias_data.as_slice());
}
#[test]
fn empty_tensors() {
let tensors: Vec<PthTensor<'_>> = vec![];
let tmp = tempfile::NamedTempFile::new().unwrap();
pth_to_safetensors(&tensors, tmp.path()).unwrap();
let data = std::fs::read(tmp.path()).unwrap();
let st = safetensors::SafeTensors::deserialize(&data).unwrap();
assert_eq!(st.len(), 0);
}
#[test]
fn roundtrip_bytes() {
let weight_data: Vec<u8> = vec![
0x00, 0x00, 0x80, 0x3F, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, ];
let bias_data: Vec<u8> = vec![
0x00, 0x00, 0x00, 0x3F, 0x00, 0x00, 0x00, 0xBF, ];
let tensors = vec![
PthTensor {
name: "weight".into(),
shape: vec![2, 2],
dtype: PthDtype::F32,
data: Cow::Borrowed(&weight_data),
},
PthTensor {
name: "bias".into(),
shape: vec![2],
dtype: PthDtype::F32,
data: Cow::Borrowed(&bias_data),
},
];
let bytes = pth_to_safetensors_bytes(&tensors).unwrap();
let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
assert_eq!(st.len(), 2);
let w = st.tensor("weight").unwrap();
assert_eq!(w.shape(), &[2, 2]);
assert_eq!(w.dtype(), safetensors::Dtype::F32);
assert_eq!(w.data(), weight_data.as_slice());
let b = st.tensor("bias").unwrap();
assert_eq!(b.shape(), &[2]);
assert_eq!(b.dtype(), safetensors::Dtype::F32);
assert_eq!(b.data(), bias_data.as_slice());
}
#[test]
fn empty_tensors_bytes() {
let tensors: Vec<PthTensor<'_>> = vec![];
let bytes = pth_to_safetensors_bytes(&tensors).unwrap();
let st = safetensors::SafeTensors::deserialize(&bytes).unwrap();
assert_eq!(st.len(), 0);
}
}