use super::io::{build_tensor_from_bytes, find_archive_layout, read_zip_entry};
use super::pickle::parse_pickle;
use super::types::PtTensorMeta;
use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use std::collections::HashMap;
use std::path::Path;
pub fn load_tensor_pt<R: Runtime<DType = DType>>(
path: impl AsRef<Path>,
key: Option<&str>,
device: &R::Device,
) -> Result<Tensor<R>> {
let path = path.as_ref();
let file = std::fs::File::open(path).map_err(|e| Error::ModelError {
reason: format!("opening {}: {e}", path.display()),
})?;
let mut archive = zip::ZipArchive::new(file).map_err(|e| Error::ModelError {
reason: format!(
"{} is not a valid PyTorch .pt (ZIP) file: {e}",
path.display()
),
})?;
let (pkl_name, data_prefix) = find_archive_layout(&mut archive)?;
let pkl_bytes = read_zip_entry(&mut archive, &pkl_name)?;
let contents = parse_pickle(&pkl_bytes)?;
let wanted_key = key.unwrap_or("");
let meta = contents.tensors.get(wanted_key).ok_or_else(|| {
let available: Vec<&str> = contents.tensors.keys().map(String::as_str).collect();
Error::ModelError {
reason: format!(
".pt file {} has no tensor at key {:?}; available: {:?}",
path.display(),
wanted_key,
available
),
}
})?;
let storage_path = format!("{data_prefix}/{}", meta.storage_id);
let storage_bytes = read_zip_entry(&mut archive, &storage_path)?;
let view = tensor_view(meta, &storage_bytes, &storage_path)?;
build_tensor_from_bytes::<R>(meta.dtype, &meta.shape, view, device)
}
pub fn load_voice_pt<R: Runtime<DType = DType>>(
path: impl AsRef<Path>,
device: &R::Device,
) -> Result<Tensor<R>> {
let path = path.as_ref();
match load_tensor_pt::<R>(path, None, device) {
Ok(t) => Ok(t),
Err(_) => load_tensor_pt::<R>(path, Some("style"), device),
}
}
fn tensor_view<'a>(
meta: &PtTensorMeta,
storage_bytes: &'a [u8],
storage_path: &str,
) -> Result<&'a [u8]> {
let expected_bytes = meta.storage_numel * meta.storage_elem_size;
if storage_bytes.len() != expected_bytes {
return Err(Error::ModelError {
reason: format!(
"storage {storage_path} size mismatch: expected {expected_bytes} bytes \
({} elements × {} B), got {}",
meta.storage_numel,
meta.storage_elem_size,
storage_bytes.len()
),
});
}
let view_numel: usize = meta.shape.iter().product();
let dtype_bytes = meta.dtype.size_in_bytes();
let byte_offset = meta.storage_offset * dtype_bytes;
let byte_len = view_numel * dtype_bytes;
if byte_offset + byte_len > storage_bytes.len() {
return Err(Error::ModelError {
reason: format!(
"tensor view exceeds storage: offset={byte_offset} len={byte_len} storage_bytes={}",
storage_bytes.len()
),
});
}
Ok(&storage_bytes[byte_offset..byte_offset + byte_len])
}
pub struct TorchStateDict {
path: std::path::PathBuf,
tensors: HashMap<String, PtTensorMeta>,
data_prefix: String,
}
impl TorchStateDict {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let file = std::fs::File::open(path).map_err(|e| Error::ModelError {
reason: format!("opening {}: {e}", path.display()),
})?;
let mut archive = zip::ZipArchive::new(file).map_err(|e| Error::ModelError {
reason: format!(
"{} is not a valid PyTorch .pt (ZIP) file: {e}",
path.display()
),
})?;
let (pkl_name, data_prefix) = find_archive_layout(&mut archive)?;
let pkl_bytes = read_zip_entry(&mut archive, &pkl_name)?;
let contents = parse_pickle(&pkl_bytes)?;
Ok(Self {
path: path.to_path_buf(),
tensors: contents.tensors,
data_prefix,
})
}
pub fn keys(&self) -> impl Iterator<Item = &str> {
self.tensors.keys().map(String::as_str)
}
pub fn has(&self, name: &str) -> bool {
self.tensors.contains_key(name)
}
pub fn load_tensor<R: Runtime<DType = DType>>(
&self,
name: &str,
device: &R::Device,
) -> Result<Tensor<R>> {
let meta = self.tensors.get(name).ok_or_else(|| Error::ModelError {
reason: format!(
"tensor {name:?} not in .pt state dict (have {} tensors)",
self.tensors.len()
),
})?;
let file = std::fs::File::open(&self.path).map_err(|e| Error::ModelError {
reason: format!("reopening {}: {e}", self.path.display()),
})?;
let mut archive = zip::ZipArchive::new(file).map_err(|e| Error::ModelError {
reason: format!("reopening archive {}: {e}", self.path.display()),
})?;
let storage_path = format!("{}/{}", self.data_prefix, meta.storage_id);
let storage_bytes = read_zip_entry(&mut archive, &storage_path)?;
let view = tensor_view(meta, &storage_bytes, &storage_path)?;
build_tensor_from_bytes::<R>(meta.dtype, &meta.shape, view, device)
}
}