use crate::TensorSnapshot;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer};
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::path::Path;
use super::lazy_data::LazyDataSource;
use super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data};
use std::sync::Arc;
#[derive(Debug)]
pub enum PytorchError {
Io(std::io::Error),
Pickle(PickleError),
Zip(zip::result::ZipError),
Tar(std::io::Error),
InvalidFormat(String),
KeyNotFound(String),
Serde(burn_core::record::serde::error::Error),
}
impl From<std::io::Error> for PytorchError {
fn from(e: std::io::Error) -> Self {
PytorchError::Io(e)
}
}
impl From<PickleError> for PytorchError {
fn from(e: PickleError) -> Self {
PytorchError::Pickle(e)
}
}
impl From<zip::result::ZipError> for PytorchError {
fn from(e: zip::result::ZipError) -> Self {
PytorchError::Zip(e)
}
}
impl From<burn_core::record::serde::error::Error> for PytorchError {
fn from(e: burn_core::record::serde::error::Error) -> Self {
PytorchError::Serde(e)
}
}
impl std::fmt::Display for PytorchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PytorchError::Io(e) => write!(f, "IO error: {}", e),
PytorchError::Pickle(e) => write!(
f,
"Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.",
e
),
PytorchError::Zip(e) => write!(f, "Zip archive error: {}", e),
PytorchError::Tar(e) => write!(f, "TAR archive error: {}", e),
PytorchError::InvalidFormat(msg) => write!(f, "Invalid PyTorch file format: {}", msg),
PytorchError::KeyNotFound(key) => write!(
f,
"Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.",
key
),
PytorchError::Serde(e) => write!(f, "Serde deserialization error: {}", e),
}
}
}
impl std::error::Error for PytorchError {}
type Result<T> = std::result::Result<T, PytorchError>;
#[derive(Debug, Clone)]
pub struct PytorchMetadata {
pub format_version: Option<String>,
pub format_type: FileFormat,
pub byte_order: ByteOrder,
pub has_storage_alignment: bool,
pub pytorch_version: Option<String>,
pub tensor_count: usize,
pub total_data_size: Option<usize>,
}
impl PytorchMetadata {
pub fn is_modern_format(&self) -> bool {
matches!(self.format_type, FileFormat::Zip)
}
pub fn is_legacy_format(&self) -> bool {
matches!(self.format_type, FileFormat::Legacy)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FileFormat {
Zip,
Tar,
Legacy,
Pickle,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ByteOrder {
LittleEndian,
BigEndian,
}
pub struct PytorchReader {
tensors: HashMap<String, TensorSnapshot>,
metadata: PytorchMetadata,
}
impl PytorchReader {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?;
Ok(Self { tensors, metadata })
}
pub fn with_top_level_key<P: AsRef<Path>>(path: P, key: &str) -> Result<Self> {
let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?;
Ok(Self { tensors, metadata })
}
pub fn from_reader<R: Read>(reader: R, top_level_key: Option<&str>) -> Result<Self> {
let tensors = load_from_reader(reader, top_level_key)?;
let metadata = PytorchMetadata {
format_version: None,
format_type: FileFormat::Pickle, byte_order: ByteOrder::LittleEndian,
has_storage_alignment: false,
pytorch_version: None,
tensor_count: tensors.len(),
total_data_size: None,
};
Ok(Self { tensors, metadata })
}
pub fn keys(&self) -> Vec<String> {
self.tensors.keys().cloned().collect()
}
pub fn get(&self, name: &str) -> Option<&TensorSnapshot> {
self.tensors.get(name)
}
pub fn tensors(&self) -> &HashMap<String, TensorSnapshot> {
&self.tensors
}
pub fn into_tensors(self) -> HashMap<String, TensorSnapshot> {
self.tensors
}
pub fn metadata(&self) -> &PytorchMetadata {
&self.metadata
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn read_pickle_data<P: AsRef<Path>>(
path: P,
top_level_key: Option<&str>,
) -> Result<PickleValue> {
read_pickle_as_value(path.as_ref(), top_level_key)
}
pub fn load_config<D, P>(path: P, top_level_key: Option<&str>) -> Result<D>
where
D: DeserializeOwned,
P: AsRef<Path>,
{
let pickle_value = Self::read_pickle_data(path, top_level_key)?;
let nested_value = convert_pickle_to_nested_value(pickle_value)?;
let deserializer = Deserializer::<DefaultAdapter>::new(nested_value, false);
let value = D::deserialize(deserializer)?;
Ok(value)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PickleValue {
None,
Bool(bool),
Int(i64),
Float(f64),
String(String),
List(Vec<PickleValue>),
Dict(HashMap<String, PickleValue>),
Bytes(Vec<u8>),
}
fn load_pytorch_file_with_metadata(
path: &Path,
top_level_key: Option<&str>,
) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
if let Ok(file) = File::open(path)
&& let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
{
let mut pickle_data = Vec::new();
let mut pickle_found = false;
let possible_pickle_paths = [
"data.pkl",
"archive/data.pkl",
];
for pickle_path in &possible_pickle_paths {
if archive.by_name(pickle_path).is_ok() {
let mut pickle_file = archive.by_name(pickle_path)?;
pickle_file.read_to_end(&mut pickle_data)?;
pickle_found = true;
break;
}
}
if !pickle_found {
for i in 0..archive.len() {
let file = archive.by_index(i)?;
let name = file.name().to_string();
drop(file);
if name.ends_with("data.pkl") {
let mut file = archive.by_index(i)?;
file.read_to_end(&mut pickle_data)?;
pickle_found = true;
break;
}
}
}
if !pickle_found {
return Err(PytorchError::InvalidFormat(
"No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl".to_string(),
));
}
let format_version = if let Ok(mut version_file) = archive.by_name(".format_version") {
let mut version_data = Vec::new();
version_file.read_to_end(&mut version_data)?;
let version_str = String::from_utf8_lossy(&version_data);
let version = version_str.trim().to_string();
Some(version)
} else {
None
};
let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name("byteorder") {
let mut byteorder_data = Vec::new();
byteorder_file.read_to_end(&mut byteorder_data)?;
let byteorder_str = String::from_utf8_lossy(&byteorder_data);
byteorder_str.trim() == "big"
} else {
false };
if is_big_endian {
return Err(PytorchError::InvalidFormat(
"Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.".to_string()
));
}
let has_storage_alignment = archive.by_name(".storage_alignment").is_ok();
let pytorch_version = if let Ok(mut version_file) = archive.by_name("version") {
let mut version_data = Vec::new();
version_file.read_to_end(&mut version_data)?;
Some(String::from_utf8_lossy(&version_data).trim().to_string())
} else {
None
};
let data_source = Arc::new(LazyDataSource::from_zip(path)?);
let mut total_data_size = 0usize;
for i in 0..archive.len() {
let file = archive.by_index(i)?;
let name = file.name();
let is_data_file = (name.contains("/data/")
|| name.starts_with("data/")
|| name.starts_with("archive/data/"))
&& !name.ends_with(".pkl")
&& !name.ends_with("/");
if is_data_file {
total_data_size += file.size() as usize;
}
}
let mut pickle_reader = BufReader::new(pickle_data.as_slice());
let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
let tensors = extract_tensors_with_data(obj, top_level_key)?;
let metadata = PytorchMetadata {
format_version,
format_type: FileFormat::Zip,
byte_order: if is_big_endian {
ByteOrder::BigEndian
} else {
ByteOrder::LittleEndian
},
has_storage_alignment,
pytorch_version,
tensor_count: tensors.len(),
total_data_size: Some(total_data_size),
};
return Ok((tensors, metadata));
}
if is_tar_file(path) {
return load_tar_pytorch_file_with_metadata(path, top_level_key);
}
let mut file = File::open(path)?;
let mut header = [0u8; 15];
let bytes_read = file.read(&mut header)?;
file.seek(std::io::SeekFrom::Start(0))?;
let is_legacy_format = bytes_read >= 15
&& header[0] == 0x80 && header[1] == 0x02 && header[2] == 0x8a && header[3] == 0x0a && header[4] == 0x6c
&& header[5] == 0xfc
&& header[6] == 0x9c
&& header[7] == 0x46
&& header[8] == 0xf9
&& header[9] == 0x20
&& header[10] == 0x6a
&& header[11] == 0xa8
&& header[12] == 0x50
&& header[13] == 0x19
&& header[14] == 0x2e;
if is_legacy_format {
return load_legacy_pytorch_file_with_metadata(path, top_level_key);
}
let file = File::open(path)?;
let mut reader = BufReader::new(file);
match read_pickle(&mut reader) {
Ok(obj) => {
let tensors = extract_tensors_with_data(obj, top_level_key)?;
let tensor_count = tensors.len();
Ok((
tensors,
PytorchMetadata {
format_version: None,
format_type: FileFormat::Pickle,
byte_order: ByteOrder::LittleEndian,
has_storage_alignment: false,
pytorch_version: None,
tensor_count,
total_data_size: None,
},
))
}
Err(e)
if e.to_string()
.contains("Cannot load tensor data without a data source") =>
{
Err(PytorchError::InvalidFormat(
"Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.".to_string()
))
}
Err(e) => Err(PytorchError::Pickle(e)),
}
}
fn load_from_reader<R: Read>(
reader: R,
top_level_key: Option<&str>,
) -> Result<HashMap<String, TensorSnapshot>> {
let mut buf_reader = BufReader::new(reader);
match read_pickle(&mut buf_reader) {
Ok(obj) => extract_tensors_with_data(obj, top_level_key),
Err(e)
if e.to_string()
.contains("Cannot load tensor data without a data source") =>
{
Err(PytorchError::InvalidFormat(
"Reader contains tensor data but no data source is available. Use file-based loading instead.".to_string()
))
}
Err(e) => Err(PytorchError::Pickle(e)),
}
}
fn extract_tensors_with_data(
obj: Object,
top_level_key: Option<&str>,
) -> Result<HashMap<String, TensorSnapshot>> {
let dict = match obj {
Object::Dict(dict) => {
if let Some(key) = top_level_key {
match dict.get(key) {
Some(Object::Dict(nested)) => nested.clone(),
_ => {
return Err(PytorchError::KeyNotFound(format!(
"Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}",
key,
dict.keys().collect::<Vec<_>>()
)));
}
}
} else {
dict
}
}
_ => {
return Err(PytorchError::InvalidFormat(
"Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.".to_string(),
));
}
};
let mut tensors = HashMap::new();
let mut path = Vec::new();
extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors);
Ok(tensors)
}
fn extract_tensors_recursive<'a>(
obj: &'a Object,
path: &mut Vec<&'a str>,
tensors: &mut HashMap<String, TensorSnapshot>,
) {
match obj {
Object::Dict(dict) => {
for (key, value) in dict {
path.push(key);
extract_tensors_recursive(value, path, tensors);
path.pop();
}
}
Object::TorchParam(snapshot) => {
tensors.insert(path.join("."), snapshot.clone());
}
_ => {}
}
}
fn load_legacy_pytorch_file_with_metadata(
path: &Path,
top_level_key: Option<&str>,
) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let _ = read_pickle(&mut reader).map_err(|e| {
PytorchError::InvalidFormat(format!(
"Failed to read magic number from legacy format: {}",
e
))
})?;
let _ = read_pickle(&mut reader).map_err(|e| {
PytorchError::InvalidFormat(format!(
"Failed to read protocol version from legacy format: {}",
e
))
})?;
let _ = read_pickle(&mut reader).map_err(|e| {
PytorchError::InvalidFormat(format!(
"Failed to read system info from legacy format: {}",
e
))
})?;
let main_pickle_pos = reader.stream_position()?;
use crate::pytorch::pickle_reader::skip_pickle;
skip_pickle(&mut reader).map_err(|e| {
PytorchError::InvalidFormat(format!(
"Failed to skip main object in legacy format: {}",
e
))
})?;
let storage_keys = match read_pickle(&mut reader) {
Ok(Object::List(keys)) => keys
.into_iter()
.filter_map(|obj| match obj {
Object::String(s) => Some(s),
_ => None,
})
.collect::<Vec<_>>(),
_ => vec![],
};
let data_start_pos = reader.stream_position()?;
let file_size = reader.seek(SeekFrom::End(0))?;
let data_size = file_size - data_start_pos;
let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage(
path,
data_start_pos,
data_size,
));
if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source
&& !storage_keys.is_empty()
{
let source = source
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
source.set_storage_keys(storage_keys.clone());
}
reader.seek(SeekFrom::Start(main_pickle_pos))?;
let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?;
let tensors = extract_tensors_with_data(main_obj, top_level_key)?;
let metadata = PytorchMetadata {
format_version: None, format_type: FileFormat::Legacy,
byte_order: ByteOrder::LittleEndian, has_storage_alignment: false,
pytorch_version: None, tensor_count: tensors.len(),
total_data_size: Some(data_size as usize),
};
Ok((tensors, metadata))
}
fn is_tar_file(path: &Path) -> bool {
if let Ok(mut file) = File::open(path) {
let mut header = [0u8; 263];
if file.read_exact(&mut header).is_ok() {
return &header[257..262] == b"ustar";
}
}
false
}
fn load_tar_pytorch_file_with_metadata(
path: &Path,
top_level_key: Option<&str>,
) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
use tar::Archive;
let file = File::open(path)?;
let mut archive = Archive::new(BufReader::new(file));
let mut sys_info_data: Option<Vec<u8>> = None;
let mut pickle_data: Option<Vec<u8>> = None;
let mut storages_data: Option<Vec<u8>> = None;
for entry in archive.entries().map_err(PytorchError::Tar)? {
let mut entry = entry.map_err(PytorchError::Tar)?;
let entry_path = entry
.path()
.map_err(PytorchError::Tar)?
.to_string_lossy()
.to_string();
if entry_path.contains("@PaxHeader") {
continue;
}
let normalized = entry_path.trim_start_matches("./");
match normalized {
"sys_info" => {
let mut data = Vec::new();
entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
sys_info_data = Some(data);
}
"pickle" => {
let mut data = Vec::new();
entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
pickle_data = Some(data);
}
"storages" => {
let mut data = Vec::new();
entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
storages_data = Some(data);
}
_ => {}
}
}
let pickle_data = pickle_data.ok_or_else(|| {
PytorchError::InvalidFormat("TAR file missing 'pickle' entry".to_string())
})?;
let storages_data = storages_data.ok_or_else(|| {
PytorchError::InvalidFormat("TAR file missing 'storages' entry".to_string())
})?;
let is_little_endian = if let Some(ref data) = sys_info_data {
parse_tar_sys_info(data)?
} else {
true };
if !is_little_endian {
return Err(PytorchError::InvalidFormat(
"Big-endian TAR PyTorch files are not supported".to_string(),
));
}
let data_source = Arc::new(LazyDataSource::from_tar(&storages_data)?);
let mut pickle_reader = BufReader::new(pickle_data.as_slice());
let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
let tensors = extract_tensors_with_data(obj, top_level_key)?;
let metadata = PytorchMetadata {
format_version: None,
format_type: FileFormat::Tar,
byte_order: ByteOrder::LittleEndian,
has_storage_alignment: false,
pytorch_version: None,
tensor_count: tensors.len(),
total_data_size: Some(storages_data.len()),
};
Ok((tensors, metadata))
}
fn parse_tar_sys_info(data: &[u8]) -> Result<bool> {
let mut reader = BufReader::new(data);
let obj = read_pickle(&mut reader)?;
if let Object::Dict(dict) = obj
&& let Some(Object::Bool(little_endian)) = dict.get("little_endian")
{
return Ok(*little_endian);
}
Ok(true) }
fn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result<PickleValue> {
use crate::pytorch::lazy_data::LazyDataSource;
use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data};
use std::sync::Arc;
if let Ok(file) = File::open(path)
&& let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
{
let mut pickle_data = Vec::new();
for pickle_path in &["data.pkl", "archive/data.pkl"] {
if let Ok(mut pickle_file) = archive.by_name(pickle_path) {
pickle_file.read_to_end(&mut pickle_data)?;
break;
}
}
if pickle_data.is_empty() {
for i in 0..archive.len() {
let file = archive.by_index(i)?;
let name = file.name().to_string();
drop(file);
if name.ends_with("data.pkl") {
let mut file = archive.by_index(i)?;
file.read_to_end(&mut pickle_data)?;
break;
}
}
}
if !pickle_data.is_empty() {
let data_source = LazyDataSource::from_zip(path)?;
let data_source_arc = Arc::new(data_source);
let mut reader = BufReader::new(pickle_data.as_slice());
let obj = read_pickle_with_data(&mut reader, data_source_arc)?;
return convert_object_to_value(obj, top_level_key);
}
}
let file = File::open(path)?;
let mut reader = BufReader::new(file);
match read_pickle(&mut reader) {
Ok(obj) => convert_object_to_value(obj, top_level_key),
Err(e)
if e.to_string()
.contains("Cannot load tensor data without a data source") =>
{
let reader = PytorchReader::new(path)?;
let mut result = std::collections::HashMap::new();
for key in reader.keys() {
result.insert(
key.clone(),
PickleValue::String(format!("<Tensor:{}>", key)),
);
}
if let Some(key) = top_level_key {
Ok(PickleValue::Dict(
[(key.to_string(), PickleValue::Dict(result))]
.into_iter()
.collect(),
))
} else {
Ok(PickleValue::Dict(result))
}
}
Err(e) => Err(PytorchError::Pickle(e)),
}
}
fn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result<PickleValue> {
use crate::pytorch::pickle_reader::Object;
if let Some(key) = top_level_key
&& let Object::Dict(dict) = obj
{
if let Some(value) = dict.get(key) {
return object_to_pickle_value(value.clone());
} else {
return Err(PytorchError::KeyNotFound(format!(
"Key '{}' not found in pickle data",
key
)));
}
}
object_to_pickle_value(obj)
}
fn object_to_pickle_value(obj: Object) -> Result<PickleValue> {
use crate::pytorch::pickle_reader::Object;
Ok(match obj {
Object::None => PickleValue::None,
Object::Bool(b) => PickleValue::Bool(b),
Object::Int(i) => PickleValue::Int(i),
Object::Float(f) => PickleValue::Float(f),
Object::String(s) => PickleValue::String(s),
Object::Persistent(data) => {
PickleValue::Bytes(data)
}
Object::PersistentTuple(tuple) => {
let mut values = Vec::new();
for item in tuple {
values.push(object_to_pickle_value(item)?);
}
PickleValue::List(values)
}
Object::List(list) => {
let mut values = Vec::new();
for item in list {
values.push(object_to_pickle_value(item)?);
}
PickleValue::List(values)
}
Object::Dict(dict) => {
let mut map = HashMap::new();
for (k, v) in dict {
map.insert(k, object_to_pickle_value(v)?);
}
PickleValue::Dict(map)
}
Object::Tuple(tuple) => {
let mut values = Vec::new();
for item in tuple {
values.push(object_to_pickle_value(item)?);
}
PickleValue::List(values)
}
Object::TorchParam(_) => {
PickleValue::None
}
Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => {
PickleValue::None
}
})
}
fn convert_pickle_to_nested_value(value: PickleValue) -> Result<NestedValue> {
Ok(match value {
PickleValue::None => NestedValue::Default(None),
PickleValue::Bool(b) => NestedValue::Bool(b),
PickleValue::Int(i) => NestedValue::I64(i),
PickleValue::Float(f) => NestedValue::F64(f),
PickleValue::String(s) => NestedValue::String(s),
PickleValue::List(list) => {
let mut vec = Vec::new();
for item in list {
vec.push(convert_pickle_to_nested_value(item)?);
}
NestedValue::Vec(vec)
}
PickleValue::Dict(dict) => {
let mut map = HashMap::new();
for (k, v) in dict {
map.insert(k, convert_pickle_to_nested_value(v)?);
}
NestedValue::Map(map)
}
PickleValue::Bytes(data) => {
let vec: Vec<NestedValue> = data.into_iter().map(NestedValue::U8).collect();
NestedValue::Vec(vec)
}
})
}