use std::collections::BTreeMap;
use std::fs::File;
use std::io::{self, Read, Seek};
use std::path::Path;
use memmap2::{Mmap, MmapOptions};
use crate::error::Error;
use crate::models::{DType, Manifest, Object};
use crate::pickle_vm::{parse_pytorch_pickle, PtTensorInfo};
use crate::reader::{TensorData, TensorElement, TensorReader};
const ZIP_LOCAL_HEADER_MAGIC: [u8; 4] = [0x50, 0x4b, 0x03, 0x04];
enum StorageBackend {
Mmap {
mmap: Mmap,
storage_offsets: BTreeMap<String, (usize, usize)>,
},
Buffered { data: BTreeMap<String, Vec<u8>> },
}
pub struct PyTorchReader {
pub manifest: Manifest,
storage: StorageBackend,
tensor_locations: BTreeMap<String, TensorLocation>,
}
#[derive(Debug, Clone)]
struct TensorLocation {
storage_key: String,
byte_offset: usize,
byte_length: usize,
dtype: DType,
}
impl PyTorchReader {
pub fn open(path: impl AsRef<Path>) -> Result<Self, Error> {
let path = path.as_ref();
let file = File::open(path)?;
let mut archive = zip::ZipArchive::new(file).map_err(|e| {
Error::InvalidFileStructure(format!("Not a valid ZIP/PyTorch file: {}", e))
})?;
let pickle_name = find_pickle_file(&archive)?;
let tensor_infos = {
let mut pkl_file = archive.by_name(&pickle_name).map_err(|e| {
Error::InvalidFileStructure(format!("Cannot read '{}': {}", pickle_name, e))
})?;
let mut pkl_data = Vec::new();
pkl_file.read_to_end(&mut pkl_data)?;
parse_pytorch_pickle(&mut io::Cursor::new(pkl_data))?
};
if tensor_infos.is_empty() {
return Err(Error::InvalidFileStructure(
"No tensors found in PyTorch file".to_string(),
));
}
let prefix = find_data_prefix(&archive, &tensor_infos);
let (manifest, tensor_locations) = build_manifest_and_locations(&tensor_infos);
let storage_keys = unique_storage_keys(&tensor_infos);
let mut entry_headers: BTreeMap<String, (u64, u64)> = BTreeMap::new();
let mut all_stored = true;
for i in 0..archive.len() {
let entry = match archive.by_index_raw(i) {
Ok(e) => e,
Err(_) => continue,
};
let entry_name = entry.name().to_string();
let header_start = entry.header_start();
let size = entry.size();
let is_stored = entry.compression() == zip::CompressionMethod::Stored;
drop(entry);
for key in &storage_keys {
let full_path = format!("{}{}", prefix, key);
if entry_name == full_path || entry_name == *key {
if !is_stored {
all_stored = false;
}
entry_headers
.entry(key.clone())
.or_insert((header_start, size));
}
}
}
if all_stored && entry_headers.len() == storage_keys.len() {
let file = File::open(path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
let storage_offsets = compute_mmap_offsets(&mmap, &entry_headers)?;
Ok(Self {
manifest,
storage: StorageBackend::Mmap {
mmap,
storage_offsets,
},
tensor_locations,
})
} else {
let data = read_all_storage(&mut archive, &storage_keys, &prefix)?;
Ok(Self {
manifest,
storage: StorageBackend::Buffered { data },
tensor_locations,
})
}
}
pub fn from_reader<R: Read + Seek>(reader: R) -> Result<Self, Error> {
let mut archive = zip::ZipArchive::new(reader).map_err(|e| {
Error::InvalidFileStructure(format!("Not a valid ZIP/PyTorch file: {}", e))
})?;
let pickle_name = find_pickle_file(&archive)?;
let tensor_infos = {
let mut pkl_file = archive.by_name(&pickle_name).map_err(|e| {
Error::InvalidFileStructure(format!("Cannot read '{}': {}", pickle_name, e))
})?;
let mut pkl_data = Vec::new();
pkl_file.read_to_end(&mut pkl_data)?;
parse_pytorch_pickle(&mut io::Cursor::new(pkl_data))?
};
if tensor_infos.is_empty() {
return Err(Error::InvalidFileStructure(
"No tensors found in PyTorch file".to_string(),
));
}
let prefix = find_data_prefix(&archive, &tensor_infos);
let (manifest, tensor_locations) = build_manifest_and_locations(&tensor_infos);
let storage_keys = unique_storage_keys(&tensor_infos);
let data = read_all_storage(&mut archive, &storage_keys, &prefix)?;
Ok(Self {
manifest,
storage: StorageBackend::Buffered { data },
tensor_locations,
})
}
fn get_storage_slice(&self, key: &str) -> Result<&[u8], Error> {
match &self.storage {
StorageBackend::Mmap {
mmap,
storage_offsets,
} => {
let &(offset, len) = storage_offsets.get(key).ok_or_else(|| {
Error::InvalidFileStructure(format!(
"Storage '{}' not found in mmap offsets",
key
))
})?;
Ok(&mmap[offset..offset + len])
}
StorageBackend::Buffered { data } => {
data.get(key).map(Vec::as_slice).ok_or_else(|| {
Error::InvalidFileStructure(format!("Storage '{}' not found in buffer", key))
})
}
}
}
fn get_tensor_slice(&self, name: &str) -> Result<&[u8], Error> {
let loc = self
.tensor_locations
.get(name)
.ok_or_else(|| Error::ObjectNotFound(name.to_string()))?;
let storage = self.get_storage_slice(&loc.storage_key)?;
let end = loc.byte_offset + loc.byte_length;
if end > storage.len() {
return Err(Error::InvalidFileStructure(format!(
"Tensor '{}' data out of bounds: {}..{} > {}",
name,
loc.byte_offset,
end,
storage.len()
)));
}
Ok(&storage[loc.byte_offset..end])
}
pub fn read(&self, name: &str) -> Result<Vec<u8>, Error> {
Ok(self.get_tensor_slice(name)?.to_vec())
}
pub fn read_as<T: TensorElement>(&self, name: &str) -> Result<Vec<T>, Error> {
let loc = self
.tensor_locations
.get(name)
.ok_or_else(|| Error::ObjectNotFound(name.to_string()))?;
if T::DTYPE != loc.dtype {
return Err(Error::TypeMismatch {
expected: loc.dtype.as_str().to_string(),
found: std::any::type_name::<T>().to_string(),
context: format!("object '{}'", name),
});
}
crate::reader::bytes_to_typed_vec(self.get_tensor_slice(name)?)
}
}
impl TensorReader for PyTorchReader {
fn manifest(&self) -> &Manifest {
&self.manifest
}
fn read_data(&self, name: &str) -> Result<TensorData<'_>, Error> {
let slice = self.get_tensor_slice(name)?;
Ok(TensorData::Borrowed(slice))
}
}
fn build_manifest_and_locations(
tensor_infos: &[PtTensorInfo],
) -> (Manifest, BTreeMap<String, TensorLocation>) {
let mut objects = BTreeMap::new();
let mut tensor_locations = BTreeMap::new();
for info in tensor_infos {
let byte_length = info.numel * info.dtype.byte_size();
let obj = Object::dense(
info.shape.clone(),
info.dtype,
info.storage_offset as u64,
byte_length as u64,
);
tensor_locations.insert(
info.name.clone(),
TensorLocation {
storage_key: info.storage_key.clone(),
byte_offset: info.storage_offset,
byte_length,
dtype: info.dtype,
},
);
objects.insert(info.name.clone(), obj);
}
let manifest = Manifest {
version: "pytorch".to_string(),
attributes: None,
objects,
};
(manifest, tensor_locations)
}
fn unique_storage_keys(tensor_infos: &[PtTensorInfo]) -> Vec<String> {
let mut keys: Vec<String> = tensor_infos
.iter()
.map(|info| info.storage_key.clone())
.collect();
keys.sort();
keys.dedup();
keys
}
fn compute_mmap_offsets(
mmap: &Mmap,
entry_headers: &BTreeMap<String, (u64, u64)>,
) -> Result<BTreeMap<String, (usize, usize)>, Error> {
let mut storage_offsets = BTreeMap::new();
for (key, (header_start, size)) in entry_headers {
let hs = *header_start as usize;
if hs + 30 > mmap.len() {
return Err(Error::InvalidFileStructure(format!(
"ZIP local file header for storage '{}' is truncated",
key
)));
}
if mmap[hs..hs + 4] != ZIP_LOCAL_HEADER_MAGIC {
return Err(Error::InvalidFileStructure(format!(
"Invalid ZIP local file header for storage '{}'",
key
)));
}
let fname_len = u16::from_le_bytes([mmap[hs + 26], mmap[hs + 27]]) as usize;
let extra_len = u16::from_le_bytes([mmap[hs + 28], mmap[hs + 29]]) as usize;
let data_offset = hs + 30 + fname_len + extra_len;
let data_len = *size as usize;
if data_offset + data_len > mmap.len() {
return Err(Error::InvalidFileStructure(format!(
"Storage '{}' data extends beyond file",
key
)));
}
storage_offsets.insert(key.clone(), (data_offset, data_len));
}
Ok(storage_offsets)
}
fn read_all_storage<R: Read + Seek>(
archive: &mut zip::ZipArchive<R>,
storage_keys: &[String],
prefix: &str,
) -> Result<BTreeMap<String, Vec<u8>>, Error> {
let mut data = BTreeMap::new();
for key in storage_keys {
let storage_path = format!("{}{}", prefix, key);
let entry_data = read_zip_entry(archive, &storage_path)
.or_else(|| read_zip_entry(archive, key))
.ok_or_else(|| {
Error::InvalidFileStructure(format!(
"Storage file '{}' not found in PyTorch archive",
key
))
})?;
data.insert(key.clone(), entry_data);
}
Ok(data)
}
fn read_zip_entry<R: Read + Seek>(archive: &mut zip::ZipArchive<R>, name: &str) -> Option<Vec<u8>> {
let mut entry = archive.by_name(name).ok()?;
let mut data = Vec::new();
entry.read_to_end(&mut data).ok()?;
Some(data)
}
fn find_pickle_file<R: Read + Seek>(archive: &zip::ZipArchive<R>) -> Result<String, Error> {
let names: Vec<String> = (0..archive.len())
.filter_map(|i| archive.name_for_index(i).map(|s| s.to_string()))
.collect();
for pattern in &["archive/data.pkl", "data.pkl"] {
if names.iter().any(|n| n == *pattern) {
return Ok(pattern.to_string());
}
}
for name in &names {
if name.ends_with(".pkl") {
return Ok(name.clone());
}
}
Err(Error::InvalidFileStructure(
"No pickle file found in PyTorch archive".to_string(),
))
}
fn find_data_prefix<R: Read + Seek>(
archive: &zip::ZipArchive<R>,
tensor_infos: &[PtTensorInfo],
) -> String {
if tensor_infos.is_empty() {
return String::new();
}
let names: Vec<String> = (0..archive.len())
.filter_map(|i| archive.name_for_index(i).map(|s| s.to_string()))
.collect();
let first_key = &tensor_infos[0].storage_key;
for prefix in &["archive/data/", "data/", ""] {
let path = format!("{}{}", prefix, first_key);
if names.iter().any(|n| n == &path) {
return prefix.to_string();
}
}
let suffix = format!("/{}", first_key);
for name in &names {
if name.ends_with(&suffix) {
return name[..name.len() - first_key.len()].to_string();
}
}
String::new()
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_pytorch_reader_structure() {
let mut file = NamedTempFile::new().unwrap();
let path = file.path().to_path_buf();
{
let mut zip = zip::ZipWriter::new(&mut file);
let pickle_data = build_test_pickle();
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("archive/data.pkl", options).unwrap();
zip.write_all(&pickle_data).unwrap();
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let bytes: &[u8] = bytemuck::cast_slice(&data);
zip.start_file("archive/data/0", options).unwrap();
zip.write_all(bytes).unwrap();
zip.finish().unwrap();
}
let reader = PyTorchReader::open(&path);
match reader {
Ok(r) => {
assert!(!r.tensors().is_empty());
for (name, obj) in r.tensors() {
assert_eq!(obj.shape, vec![2, 3]);
let data: Vec<f32> = r.read_as(name).unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
}
Err(e) => {
eprintln!("PyTorch reader test (expected in some configs): {}", e);
}
}
}
fn build_test_pickle() -> Vec<u8> {
let mut p = Vec::new();
p.push(0x80);
p.push(2);
p.push(0x7d);
p.push(0x28);
p.push(0x8c);
p.push(6);
p.extend_from_slice(b"weight");
p.push(0x63);
p.extend_from_slice(b"torch._utils\n_rebuild_tensor_v2\n");
p.push(0x28);
p.push(0x28); p.push(0x8c);
p.push(7);
p.extend_from_slice(b"storage");
p.push(0x63);
p.extend_from_slice(b"torch\nFloatStorage\n");
p.push(0x8c);
p.push(1);
p.extend_from_slice(b"0");
p.push(0x8c);
p.push(3);
p.extend_from_slice(b"cpu");
p.push(0x4b);
p.push(6); p.push(0x74); p.push(0x51);
p.push(0x4b);
p.push(0);
p.push(0x4b);
p.push(2);
p.push(0x4b);
p.push(3);
p.push(0x86);
p.push(0x4b);
p.push(3);
p.push(0x4b);
p.push(1);
p.push(0x86);
p.push(0x89);
p.push(0x63);
p.extend_from_slice(b"collections\nOrderedDict\n");
p.push(0x29); p.push(0x52);
p.push(0x74);
p.push(0x52);
p.push(0x75);
p.push(0x2e);
p
}
}