use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::path::Path;
use crate::error::AnamnesisError;
use crate::parse::safetensors::Dtype;
use crate::parse::utils::byteswap_inplace;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum PthDtype {
F16,
BF16,
F32,
F64,
U8,
I8,
I16,
I32,
I64,
Bool,
}
impl PthDtype {
#[must_use]
pub const fn byte_size(self) -> usize {
match self {
Self::Bool | Self::U8 | Self::I8 => 1,
Self::F16 | Self::BF16 | Self::I16 => 2,
Self::F32 | Self::I32 => 4,
Self::F64 | Self::I64 => 8,
}
}
pub fn to_dtype(self) -> crate::Result<Dtype> {
match self {
Self::F16 => Ok(Dtype::F16),
Self::BF16 => Ok(Dtype::BF16),
Self::F32 => Ok(Dtype::F32),
Self::F64 => Ok(Dtype::F64),
Self::U8 => Ok(Dtype::U8),
Self::I8 => Ok(Dtype::I8),
Self::I16 => Ok(Dtype::I16),
Self::I32 => Ok(Dtype::I32),
Self::I64 => Ok(Dtype::I64),
Self::Bool => Ok(Dtype::Bool),
}
}
pub fn to_safetensors_dtype(self) -> crate::Result<safetensors::Dtype> {
match self {
Self::F16 => Ok(safetensors::Dtype::F16),
Self::BF16 => Ok(safetensors::Dtype::BF16),
Self::F32 => Ok(safetensors::Dtype::F32),
Self::F64 => Ok(safetensors::Dtype::F64),
Self::U8 => Ok(safetensors::Dtype::U8),
Self::I8 => Ok(safetensors::Dtype::I8),
Self::I16 => Ok(safetensors::Dtype::I16),
Self::I32 => Ok(safetensors::Dtype::I32),
Self::I64 => Ok(safetensors::Dtype::I64),
Self::Bool => Ok(safetensors::Dtype::BOOL),
}
}
fn from_storage_class(module: &str, name: &str) -> crate::Result<Self> {
if module != "torch" {
return Err(AnamnesisError::Parse {
reason: format!("unknown storage module `{module}.{name}`"),
});
}
match name {
"FloatStorage" => Ok(Self::F32),
"DoubleStorage" => Ok(Self::F64),
"HalfStorage" => Ok(Self::F16),
"BFloat16Storage" => Ok(Self::BF16),
"LongStorage" => Ok(Self::I64),
"IntStorage" => Ok(Self::I32),
"ShortStorage" => Ok(Self::I16),
"CharStorage" => Ok(Self::I8),
"ByteStorage" => Ok(Self::U8),
"BoolStorage" => Ok(Self::Bool),
_ => Err(AnamnesisError::Parse {
reason: format!("unknown storage class `torch.{name}`"),
}),
}
}
}
impl fmt::Display for PthDtype {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::F16 => "F16",
Self::BF16 => "BF16",
Self::F32 => "F32",
Self::F64 => "F64",
Self::U8 => "U8",
Self::I8 => "I8",
Self::I16 => "I16",
Self::I32 => "I32",
Self::I64 => "I64",
Self::Bool => "BOOL",
};
f.write_str(s)
}
}
#[derive(Debug, Clone)]
pub struct PthTensor<'a> {
pub name: String,
pub shape: Vec<usize>,
pub dtype: PthDtype,
pub data: Cow<'a, [u8]>,
}
#[derive(Debug)]
struct TensorMeta {
name: String,
shape: Vec<usize>,
dtype: PthDtype,
storage_key: String,
storage_offset: usize,
strides: Vec<usize>,
}
#[derive(Debug)]
pub struct ParsedPth {
mmap: memmap2::Mmap,
meta: Vec<TensorMeta>,
entry_index: HashMap<String, (usize, usize)>,
big_endian: bool,
}
impl ParsedPth {
pub fn tensors(&self) -> crate::Result<Vec<PthTensor<'_>>> {
let mut tensors = Vec::with_capacity(self.meta.len());
for m in &self.meta {
let storage_suffix = format!("data/{}", m.storage_key);
let &(storage_start, storage_len) = self
.entry_index
.get(storage_suffix.as_str())
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("ZIP entry `{storage_suffix}` not found"),
})?;
let storage = self
.mmap
.get(storage_start..storage_start + storage_len)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("storage `{}`: mmap slice out of bounds", m.storage_key),
})?;
let elem_size = m.dtype.byte_size();
let data: Cow<'_, [u8]> = if is_contiguous(&m.shape, &m.strides) && !self.big_endian {
let n_elements: usize = m
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: element count overflow", m.name),
})?;
let n_bytes =
n_elements
.checked_mul(elem_size)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: byte count overflow", m.name),
})?;
let end =
m.storage_offset
.checked_add(n_bytes)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: storage end offset overflow", m.name),
})?;
Cow::Borrowed(storage.get(m.storage_offset..end).ok_or_else(|| {
AnamnesisError::Parse {
reason: format!(
"tensor `{}`: storage read out of bounds \
([{}..{}], storage len = {})",
m.name,
m.storage_offset,
end,
storage.len()
),
}
})?)
} else if is_contiguous(&m.shape, &m.strides) {
let n_elements: usize = m
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: element count overflow", m.name),
})?;
let n_bytes =
n_elements
.checked_mul(elem_size)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: byte count overflow", m.name),
})?;
let end =
m.storage_offset
.checked_add(n_bytes)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: storage end offset overflow", m.name),
})?;
let mut buf = storage
.get(m.storage_offset..end)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{}`: storage read out of bounds", m.name),
})?
.to_vec();
byteswap_inplace(&mut buf, elem_size);
Cow::Owned(buf)
} else {
let mut buf =
copy_to_contiguous(storage, m.storage_offset, &m.shape, &m.strides, elem_size)?;
if self.big_endian && elem_size > 1 {
byteswap_inplace(&mut buf, elem_size);
}
Cow::Owned(buf)
};
tensors.push(PthTensor {
name: m.name.clone(),
shape: m.shape.clone(),
dtype: m.dtype,
data,
});
}
Ok(tensors)
}
#[must_use]
pub const fn len(&self) -> usize {
self.meta.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.meta.is_empty()
}
pub fn inspect(&self) -> PthInspectInfo {
let mut total_bytes: u64 = 0;
let mut dtypes: Vec<PthDtype> = Vec::new();
for m in &self.meta {
#[allow(clippy::as_conversions)]
let n_elements: u64 = m
.shape
.iter()
.copied()
.fold(1u64, |acc, d| acc.saturating_mul(d as u64));
#[allow(clippy::as_conversions)]
let byte_size = m.dtype.byte_size() as u64;
total_bytes = total_bytes.saturating_add(n_elements.saturating_mul(byte_size));
if !dtypes.contains(&m.dtype) {
dtypes.push(m.dtype);
}
}
PthInspectInfo {
tensor_count: self.meta.len(),
total_bytes,
dtypes,
big_endian: self.big_endian,
}
}
pub fn to_safetensors(&self, output: impl AsRef<std::path::Path>) -> crate::Result<()> {
let tensors = self.tensors()?;
crate::remember::pth::pth_to_safetensors(&tensors, output)
}
pub fn to_safetensors_bytes(&self) -> crate::Result<Vec<u8>> {
let tensors = self.tensors()?;
crate::remember::pth::pth_to_safetensors_bytes(&tensors)
}
#[must_use]
pub fn tensor_info(&self) -> Vec<PthTensorInfo> {
self.meta
.iter()
.map(|m| {
let n_elements: usize = m
.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.unwrap_or(usize::MAX);
PthTensorInfo {
name: m.name.clone(),
shape: m.shape.clone(),
dtype: m.dtype,
byte_len: n_elements.saturating_mul(m.dtype.byte_size()),
}
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct PthTensorInfo {
pub name: String,
pub shape: Vec<usize>,
pub dtype: PthDtype,
pub byte_len: usize,
}
#[derive(Debug, Clone)]
#[must_use]
pub struct PthInspectInfo {
pub tensor_count: usize,
pub total_bytes: u64,
pub dtypes: Vec<PthDtype>,
pub big_endian: bool,
}
impl fmt::Display for PthInspectInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Format: PyTorch state_dict (.pth)")?;
write!(f, "\nTensors: {}", self.tensor_count)?;
write!(
f,
"\nTotal size: {}",
crate::inspect::format_bytes(self.total_bytes)
)?;
let dtype_list: String = self
.dtypes
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(", ");
write!(f, "\nDtypes: {dtype_list}")?;
let endian = if self.big_endian {
"big-endian"
} else {
"little-endian"
};
write!(f, "\nByte order: {endian}")?;
Ok(())
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
#[allow(clippy::wildcard_enum_match_arm)]
enum PickleValue {
None,
Bool(bool),
Int(i64),
String(String),
Bytes(Vec<u8>),
Tuple(Vec<PickleValue>),
List(Vec<PickleValue>),
Dict(Vec<(PickleValue, PickleValue)>),
Global {
module: String,
name: String,
},
PersistentId(Box<PickleValue>),
Reduced {
callable: Box<PickleValue>,
args: Box<PickleValue>,
},
Built {
obj: Box<PickleValue>,
state: Box<PickleValue>,
},
}
fn is_allowed_global(module: &str, name: &str) -> bool {
matches!(
(module, name),
(
"torch._utils",
"_rebuild_tensor_v2" | "_rebuild_parameter" | "_rebuild_parameter_with_state"
) | (
"torch",
"FloatStorage"
| "DoubleStorage"
| "HalfStorage"
| "BFloat16Storage"
| "LongStorage"
| "IntStorage"
| "ShortStorage"
| "CharStorage"
| "ByteStorage"
| "BoolStorage"
) | ("collections", "OrderedDict")
| ("torch.nn.parameter", "Parameter")
)
}
fn is_ordered_dict_constructor(callable: &PickleValue, args: &PickleValue) -> bool {
if let PickleValue::Global { module, name } = callable {
if module == "collections" && name == "OrderedDict" {
if let PickleValue::Tuple(items) = args {
return items.is_empty();
}
}
}
false
}
struct PickleVm<'a> {
data: &'a [u8],
pos: usize,
stack: Vec<PickleValue>,
mark_stack: Vec<usize>,
memo: HashMap<u32, PickleValue>,
next_memo_id: u32,
}
impl<'a> PickleVm<'a> {
fn new(data: &'a [u8]) -> Self {
Self {
data,
pos: 0,
stack: Vec::new(),
mark_stack: Vec::new(),
memo: HashMap::new(),
next_memo_id: 0,
}
}
fn read_u8(&mut self) -> crate::Result<u8> {
let b = self
.data
.get(self.pos)
.copied()
.ok_or_else(|| AnamnesisError::Parse {
reason: "unexpected end of pickle stream".into(),
})?;
self.pos += 1;
Ok(b)
}
fn read_u16_le(&mut self) -> crate::Result<u16> {
let bytes: [u8; 2] = self.read_fixed()?;
Ok(u16::from_le_bytes(bytes))
}
fn read_i32_le(&mut self) -> crate::Result<i32> {
let bytes: [u8; 4] = self.read_fixed()?;
Ok(i32::from_le_bytes(bytes))
}
fn read_u32_le(&mut self) -> crate::Result<u32> {
let bytes: [u8; 4] = self.read_fixed()?;
Ok(u32::from_le_bytes(bytes))
}
fn read_u64_le(&mut self) -> crate::Result<u64> {
let bytes: [u8; 8] = self.read_fixed()?;
Ok(u64::from_le_bytes(bytes))
}
fn read_fixed<const N: usize>(&mut self) -> crate::Result<[u8; N]> {
let hi = self
.pos
.checked_add(N)
.ok_or_else(|| AnamnesisError::Parse {
reason: "pickle offset overflow".into(),
})?;
let slice = self
.data
.get(self.pos..hi)
.ok_or_else(|| AnamnesisError::Parse {
reason: "unexpected end of pickle stream".into(),
})?;
self.pos = hi;
let arr: [u8; N] = slice.try_into().map_err(|_| AnamnesisError::Parse {
reason: "internal: slice-to-array conversion failed".into(),
})?;
Ok(arr)
}
fn read_bytes(&mut self, n: usize) -> crate::Result<&'a [u8]> {
let hi = self
.pos
.checked_add(n)
.ok_or_else(|| AnamnesisError::Parse {
reason: "pickle offset overflow".into(),
})?;
let slice = self
.data
.get(self.pos..hi)
.ok_or_else(|| AnamnesisError::Parse {
reason: "unexpected end of pickle stream".into(),
})?;
self.pos = hi;
Ok(slice)
}
fn read_line(&mut self) -> crate::Result<&'a str> {
let start = self.pos;
loop {
let b = self.read_u8()?;
if b == b'\n' {
let line =
self.data
.get(start..self.pos - 1)
.ok_or_else(|| AnamnesisError::Parse {
reason: "pickle line read out of bounds".into(),
})?;
return std::str::from_utf8(line).map_err(|e| AnamnesisError::Parse {
reason: format!("non-UTF-8 pickle string: {e}"),
});
}
}
}
fn pop(&mut self) -> crate::Result<PickleValue> {
self.stack.pop().ok_or_else(|| AnamnesisError::Parse {
reason: "pickle stack underflow".into(),
})
}
fn pop_mark(&mut self) -> crate::Result<Vec<PickleValue>> {
let mark_pos = self.mark_stack.pop().ok_or_else(|| AnamnesisError::Parse {
reason: "pickle mark stack underflow".into(),
})?;
let items = self.stack.split_off(mark_pos);
Ok(items)
}
fn execute(&mut self) -> crate::Result<PickleValue> {
loop {
let opcode = self.read_u8()?;
match opcode {
0x80 => {
let _version = self.read_u8()?;
}
0x95 => {
let _frame_len = self.read_u64_le()?;
}
b'.' => return self.pop(),
b'N' => self.stack.push(PickleValue::None),
0x88 => self.stack.push(PickleValue::Bool(true)),
0x89 => self.stack.push(PickleValue::Bool(false)),
b'J' => {
let v = self.read_i32_le()?;
self.stack.push(PickleValue::Int(i64::from(v)));
}
b'K' => {
let v = self.read_u8()?;
self.stack.push(PickleValue::Int(i64::from(v)));
}
b'M' => {
let v = self.read_u16_le()?;
self.stack.push(PickleValue::Int(i64::from(v)));
}
0x8a => {
let n = self.read_u8()?;
let bytes = self.read_bytes(usize::from(n))?;
let val = long1_to_i64(bytes)?;
self.stack.push(PickleValue::Int(val));
}
0x8c => {
let n = self.read_u8()?;
let bytes = self.read_bytes(usize::from(n))?;
let s = std::str::from_utf8(bytes).map_err(|e| AnamnesisError::Parse {
reason: format!("non-UTF-8 pickle string: {e}"),
})?;
self.stack.push(PickleValue::String(s.to_owned()));
}
b'X' => {
let n = self.read_u32_le()?;
let len = usize::try_from(n).map_err(|_| AnamnesisError::Parse {
reason: "BINUNICODE length overflow".into(),
})?;
let bytes = self.read_bytes(len)?;
let s = std::str::from_utf8(bytes).map_err(|e| AnamnesisError::Parse {
reason: format!("non-UTF-8 pickle string: {e}"),
})?;
self.stack.push(PickleValue::String(s.to_owned()));
}
b'U' => {
let n = self.read_u8()?;
let bytes = self.read_bytes(usize::from(n))?;
match std::str::from_utf8(bytes) {
Ok(s) => self.stack.push(PickleValue::String(s.to_owned())),
Err(_) => self.stack.push(PickleValue::Bytes(bytes.to_vec())),
}
}
b'T' => {
let n = self.read_i32_le()?;
if n < 0 {
return Err(AnamnesisError::Parse {
reason: "negative BINSTRING length".into(),
});
}
let len = usize::try_from(n).map_err(|_| AnamnesisError::Parse {
reason: "BINSTRING length overflow".into(),
})?;
let bytes = self.read_bytes(len)?;
match std::str::from_utf8(bytes) {
Ok(s) => self.stack.push(PickleValue::String(s.to_owned())),
Err(_) => self.stack.push(PickleValue::Bytes(bytes.to_vec())),
}
}
b'B' => {
let n = self.read_u32_le()?;
let len = usize::try_from(n).map_err(|_| AnamnesisError::Parse {
reason: "BINBYTES length overflow".into(),
})?;
let bytes = self.read_bytes(len)?;
self.stack.push(PickleValue::Bytes(bytes.to_vec()));
}
b'C' => {
let n = self.read_u8()?;
let bytes = self.read_bytes(usize::from(n))?;
self.stack.push(PickleValue::Bytes(bytes.to_vec()));
}
b'}' => self.stack.push(PickleValue::Dict(Vec::new())),
b']' => self.stack.push(PickleValue::List(Vec::new())),
b')' => self.stack.push(PickleValue::Tuple(Vec::new())),
b'(' => self.mark_stack.push(self.stack.len()),
b't' => {
let items = self.pop_mark()?;
self.stack.push(PickleValue::Tuple(items));
}
0x85 => {
let a = self.pop()?;
self.stack.push(PickleValue::Tuple(vec![a]));
}
0x86 => {
let b = self.pop()?;
let a = self.pop()?;
self.stack.push(PickleValue::Tuple(vec![a, b]));
}
0x87 => {
let c = self.pop()?;
let b = self.pop()?;
let a = self.pop()?;
self.stack.push(PickleValue::Tuple(vec![a, b, c]));
}
b'u' => {
let items = self.pop_mark()?;
if items.len() % 2 != 0 {
return Err(AnamnesisError::Parse {
reason: "SETITEMS: odd number of items on stack".into(),
});
}
let dict = self.stack.last_mut().ok_or_else(|| AnamnesisError::Parse {
reason: "SETITEMS: empty stack (no dict)".into(),
})?;
if let PickleValue::Dict(ref mut pairs) = *dict {
let mut iter = items.into_iter();
while let Some(key) = iter.next() {
let val = iter.next().ok_or_else(|| AnamnesisError::Parse {
reason: "SETITEMS: missing value for key".into(),
})?;
pairs.push((key, val));
}
} else {
return Err(AnamnesisError::Parse {
reason: "SETITEMS: top of stack is not a dict".into(),
});
}
}
b's' => {
let value = self.pop()?;
let key = self.pop()?;
let dict = self.stack.last_mut().ok_or_else(|| AnamnesisError::Parse {
reason: "SETITEM: empty stack (no dict)".into(),
})?;
if let PickleValue::Dict(ref mut pairs) = *dict {
pairs.push((key, value));
} else {
return Err(AnamnesisError::Parse {
reason: "SETITEM: top of stack is not a dict".into(),
});
}
}
b'a' => {
let item = self.pop()?;
let list = self.stack.last_mut().ok_or_else(|| AnamnesisError::Parse {
reason: "APPEND: empty stack (no list)".into(),
})?;
if let PickleValue::List(ref mut items) = *list {
items.push(item);
} else {
return Err(AnamnesisError::Parse {
reason: "APPEND: top of stack is not a list".into(),
});
}
}
b'e' => {
let new_items = self.pop_mark()?;
let list = self.stack.last_mut().ok_or_else(|| AnamnesisError::Parse {
reason: "APPENDS: empty stack (no list)".into(),
})?;
if let PickleValue::List(ref mut items) = *list {
items.extend(new_items);
} else {
return Err(AnamnesisError::Parse {
reason: "APPENDS: top of stack is not a list".into(),
});
}
}
b'c' => {
let module = self.read_line()?.to_owned();
let name = self.read_line()?.to_owned();
if !is_allowed_global(&module, &name) {
return Err(AnamnesisError::Parse {
reason: format!(
"disallowed pickle global `{module}.{name}` \
(potential code execution)"
),
});
}
self.stack.push(PickleValue::Global { module, name });
}
0x93 => {
let name_val = self.pop()?;
let module_val = self.pop()?;
let (module, name) = match (&module_val, &name_val) {
(PickleValue::String(m), PickleValue::String(n)) => {
(m.as_str(), n.as_str())
}
_ => {
return Err(AnamnesisError::Parse {
reason: "STACK_GLOBAL: module/name are not strings".into(),
})
}
};
if !is_allowed_global(module, name) {
return Err(AnamnesisError::Parse {
reason: format!(
"disallowed pickle global `{module}.{name}` \
(potential code execution)"
),
});
}
self.stack.push(PickleValue::Global {
module: module.to_owned(),
name: name.to_owned(),
});
}
b'R' | 0x81 => {
let args = self.pop()?;
let callable = self.pop()?;
if is_ordered_dict_constructor(&callable, &args) {
self.stack.push(PickleValue::Dict(Vec::new()));
} else {
self.stack.push(PickleValue::Reduced {
callable: Box::new(callable),
args: Box::new(args),
});
}
}
b'b' => {
let state = self.pop()?;
let obj = self.pop()?;
self.stack.push(PickleValue::Built {
obj: Box::new(obj),
state: Box::new(state),
});
}
b'Q' => {
let pid = self.pop()?;
self.stack.push(PickleValue::PersistentId(Box::new(pid)));
}
b'q' => {
let key = self.read_u8()?;
let val = self
.stack
.last()
.ok_or_else(|| AnamnesisError::Parse {
reason: "BINPUT: empty stack".into(),
})?
.clone();
self.memo.insert(u32::from(key), val);
}
b'r' => {
let key = self.read_u32_le()?;
let val = self
.stack
.last()
.ok_or_else(|| AnamnesisError::Parse {
reason: "LONG_BINPUT: empty stack".into(),
})?
.clone();
self.memo.insert(key, val);
}
b'h' => {
let key = self.read_u8()?;
let val = self
.memo
.get(&u32::from(key))
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("BINGET: memo key {key} not found"),
})?
.clone();
self.stack.push(val);
}
b'j' => {
let key = self.read_u32_le()?;
let val = self
.memo
.get(&key)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("LONG_BINGET: memo key {key} not found"),
})?
.clone();
self.stack.push(val);
}
0x94 => {
let val = self
.stack
.last()
.ok_or_else(|| AnamnesisError::Parse {
reason: "MEMOIZE: empty stack".into(),
})?
.clone();
self.memo.insert(self.next_memo_id, val);
self.next_memo_id =
self.next_memo_id
.checked_add(1)
.ok_or_else(|| AnamnesisError::Parse {
reason: "pickle memo table overflow (>2^32 MEMOIZE opcodes)".into(),
})?;
}
_ => {
return Err(AnamnesisError::Parse {
reason: format!("unsupported pickle opcode 0x{opcode:02x}"),
});
}
}
}
}
}
fn long1_to_i64(bytes: &[u8]) -> crate::Result<i64> {
if bytes.is_empty() {
return Ok(0);
}
if bytes.len() > 8 {
return Err(AnamnesisError::Parse {
reason: format!(
"LONG1 value too large ({} bytes, max 8 for i64)",
bytes.len()
),
});
}
let last = bytes.last().copied().ok_or_else(|| AnamnesisError::Parse {
reason: "LONG1 empty bytes".into(),
})?;
let pad = if last & 0x80 != 0 { 0xFF } else { 0x00 };
let mut buf = [pad; 8];
let dest = buf
.get_mut(..bytes.len())
.ok_or_else(|| AnamnesisError::Parse {
reason: "LONG1 internal: slice bounds exceeded".into(),
})?;
dest.copy_from_slice(bytes);
Ok(i64::from_le_bytes(buf))
}
struct TensorRef {
name: String,
storage_key: String,
dtype: PthDtype,
storage_offset: usize,
shape: Vec<usize>,
strides: Vec<usize>,
}
fn as_i64(val: &PickleValue) -> crate::Result<i64> {
if let PickleValue::Int(v) = val {
Ok(*v)
} else {
Err(AnamnesisError::Parse {
reason: format!("expected int, got {val:?}"),
})
}
}
fn as_usize(val: &PickleValue) -> crate::Result<usize> {
let v = as_i64(val)?;
usize::try_from(v).map_err(|_| AnamnesisError::Parse {
reason: format!("integer {v} does not fit in usize"),
})
}
fn as_str(val: &PickleValue) -> crate::Result<&str> {
if let PickleValue::String(s) = val {
Ok(s.as_str())
} else {
Err(AnamnesisError::Parse {
reason: format!("expected string, got {val:?}"),
})
}
}
fn tuple_to_usize_vec(val: &PickleValue) -> crate::Result<Vec<usize>> {
if let PickleValue::Tuple(items) = val {
items.iter().map(as_usize).collect()
} else {
Err(AnamnesisError::Parse {
reason: format!("expected tuple, got {val:?}"),
})
}
}
#[allow(clippy::wildcard_enum_match_arm)]
fn parse_rebuild_args(name: &str, args: &PickleValue) -> crate::Result<TensorRef> {
let PickleValue::Tuple(items) = args else {
return Err(AnamnesisError::Parse {
reason: format!("tensor `{name}`: expected tuple args for _rebuild_tensor_v2"),
});
};
if items.len() < 4 {
return Err(AnamnesisError::Parse {
reason: format!(
"tensor `{name}`: _rebuild_tensor_v2 needs ≥4 args, got {}",
items.len()
),
});
}
let persistent_id = items.first().ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: missing args[0]"),
})?;
let storage_tuple = match persistent_id {
PickleValue::PersistentId(inner) => match inner.as_ref() {
PickleValue::Tuple(t) => t,
other => {
return Err(AnamnesisError::Parse {
reason: format!(
"tensor `{name}`: PersistentId payload is not a tuple: {other:?}"
),
})
}
},
other => {
return Err(AnamnesisError::Parse {
reason: format!("tensor `{name}`: expected PersistentId, got {other:?}"),
})
}
};
if storage_tuple.len() < 5 {
return Err(AnamnesisError::Parse {
reason: format!(
"tensor `{name}`: storage tuple needs ≥5 items, got {}",
storage_tuple.len()
),
});
}
let st1 = storage_tuple.get(1).ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: missing storage_tuple[1]"),
})?;
let dtype = match st1 {
PickleValue::Global { module, name: cls } => PthDtype::from_storage_class(module, cls)?,
other => {
return Err(AnamnesisError::Parse {
reason: format!("tensor `{name}`: expected storage Global, got {other:?}"),
})
}
};
let st2 = storage_tuple.get(2).ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: missing storage_tuple[2]"),
})?;
let storage_key = as_str(st2)?.to_owned();
let it1 = items.get(1).ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: missing args[1]"),
})?;
let storage_offset_elements = as_usize(it1)?;
let storage_offset = storage_offset_elements
.checked_mul(dtype.byte_size())
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: storage offset overflow"),
})?;
let it2 = items.get(2).ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: missing args[2]"),
})?;
let it3 = items.get(3).ok_or_else(|| AnamnesisError::Parse {
reason: format!("tensor `{name}`: missing args[3]"),
})?;
let shape = tuple_to_usize_vec(it2)?;
let strides = tuple_to_usize_vec(it3)?;
if shape.len() != strides.len() {
return Err(AnamnesisError::Parse {
reason: format!(
"tensor `{name}`: shape ndim {} != strides ndim {}",
shape.len(),
strides.len()
),
});
}
Ok(TensorRef {
name: name.to_owned(),
storage_key,
dtype,
storage_offset,
shape,
strides,
})
}
const MAX_PICKLE_NESTING: u32 = 32;
#[allow(clippy::wildcard_enum_match_arm)]
fn unwrap_to_rebuild(val: &PickleValue, depth: u32) -> Option<(&PickleValue, &PickleValue)> {
if depth > MAX_PICKLE_NESTING {
return None;
}
match val {
PickleValue::Reduced { callable, args, .. } => {
if let PickleValue::Global { module, name } = callable.as_ref() {
if module == "torch._utils" && name == "_rebuild_tensor_v2" {
return Some((callable, args));
}
if module == "torch._utils"
&& (name == "_rebuild_parameter" || name == "_rebuild_parameter_with_state")
{
if let PickleValue::Tuple(items) = args.as_ref() {
if let Some(first) = items.first() {
return unwrap_to_rebuild(first, depth + 1);
}
}
}
}
None
}
PickleValue::Built { obj, .. } => unwrap_to_rebuild(obj, depth + 1),
_ => None,
}
}
#[allow(clippy::wildcard_enum_match_arm)]
fn extract_dict_pairs(
root: &PickleValue,
depth: u32,
) -> crate::Result<&[(PickleValue, PickleValue)]> {
if depth > MAX_PICKLE_NESTING {
return Err(AnamnesisError::Parse {
reason: "pickle nesting limit exceeded in extract_dict_pairs".into(),
});
}
match root {
PickleValue::Dict(pairs) => Ok(pairs),
PickleValue::Reduced { callable, args: _ } => {
if let PickleValue::Global { module, name } = callable.as_ref() {
if module == "collections" && name == "OrderedDict" {
return Err(AnamnesisError::Parse {
reason: "OrderedDict arrived as Reduced (expected Dict \
after REDUCE rewrite); possible pickle VM bug"
.into(),
});
}
}
Err(AnamnesisError::Parse {
reason: format!("top-level pickle value is not a dict: {root:?}"),
})
}
PickleValue::Built { obj, state: _ } => {
extract_dict_pairs(obj, depth + 1)
}
_ => Err(AnamnesisError::Parse {
reason: format!("top-level pickle value is not a dict or OrderedDict: {root:?}"),
}),
}
}
fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
if let (Some(&prev), Some(&dim)) = (strides.get(i + 1), shape.get(i + 1)) {
if let Some(s) = strides.get_mut(i) {
*s = prev.saturating_mul(dim);
}
}
}
strides
}
fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
if shape.len() != strides.len() {
return false;
}
let expected = contiguous_strides(shape);
strides == expected
}
fn copy_to_contiguous(
storage: &[u8],
offset: usize,
shape: &[usize],
strides: &[usize],
elem_size: usize,
) -> crate::Result<Vec<u8>> {
if shape.len() != strides.len() {
return Err(AnamnesisError::Parse {
reason: format!(
"shape ndim {} != strides ndim {}",
shape.len(),
strides.len()
),
});
}
let n_elements: usize = shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.ok_or_else(|| AnamnesisError::Parse {
reason: "element count overflow".into(),
})?;
let out_bytes = n_elements
.checked_mul(elem_size)
.ok_or_else(|| AnamnesisError::Parse {
reason: "output size overflow".into(),
})?;
let max_elem_offset: usize = shape
.iter()
.zip(strides.iter())
.try_fold(0usize, |acc, (&dim, &stride)| {
dim.checked_sub(1)
.and_then(|d| d.checked_mul(stride))
.and_then(|ds| acc.checked_add(ds))
})
.ok_or_else(|| AnamnesisError::Parse {
reason: "max stride offset overflow".into(),
})?;
let max_src_end = offset
.checked_add(max_elem_offset.checked_mul(elem_size).ok_or_else(|| {
AnamnesisError::Parse {
reason: "max source byte offset overflow".into(),
}
})?)
.and_then(|b| b.checked_add(elem_size))
.ok_or_else(|| AnamnesisError::Parse {
reason: "max source end offset overflow".into(),
})?;
if max_src_end > storage.len() {
return Err(AnamnesisError::Parse {
reason: format!(
"non-contiguous tensor: max source byte [{max_src_end}] \
exceeds storage len {}",
storage.len()
),
});
}
let mut out = vec![0u8; out_bytes];
let ndim = shape.len();
let mut coords = vec![0usize; ndim];
for flat_idx in 0..n_elements {
let src_elem_offset: usize = coords
.iter()
.zip(strides.iter())
.map(|(&c, &s)| c * s)
.sum();
let src_byte = offset + src_elem_offset * elem_size;
let dst_byte = flat_idx * elem_size;
#[allow(clippy::indexing_slicing)]
out[dst_byte..dst_byte + elem_size]
.copy_from_slice(&storage[src_byte..src_byte + elem_size]);
for d in (0..ndim).rev() {
if let (Some(c), Some(&s)) = (coords.get_mut(d), shape.get(d)) {
*c += 1;
if *c < s {
break;
}
*c = 0;
}
}
}
Ok(out)
}
#[allow(unsafe_code)]
pub fn parse_pth(path: impl AsRef<Path>) -> crate::Result<ParsedPth> {
let file = std::fs::File::open(path.as_ref())?;
let raw =
unsafe { memmap2::MmapOptions::new().populate().map(&file) }.map_err(AnamnesisError::Io)?;
let magic = raw.get(..4).ok_or_else(|| AnamnesisError::Parse {
reason: "file too small to be a .pth archive".into(),
})?;
if magic.first() == Some(&0x80) && magic.get(1).is_some_and(|&b| b <= 0x05) {
return Err(AnamnesisError::Unsupported {
format: "pth".into(),
detail: "legacy .pth format (pre-PyTorch 1.6) is not supported; \
re-save with torch.save()"
.into(),
});
}
if magic != b"PK\x03\x04" {
return Err(AnamnesisError::Parse {
reason: "file is not a ZIP archive (missing PK\\x03\\x04 magic)".into(),
});
}
let cursor = std::io::Cursor::new(&raw[..]);
let mut archive = zip::ZipArchive::new(cursor)?;
let entry_index = build_entry_index(&mut archive, &raw)?;
let big_endian = match entry_index.get("byteorder") {
Some(&(start, len)) => {
let bytes = raw
.get(start..start + len)
.ok_or_else(|| AnamnesisError::Parse {
reason: "byteorder entry out of bounds".into(),
})?;
let text = std::str::from_utf8(bytes).map_err(|e| AnamnesisError::Parse {
reason: format!("byteorder entry is not UTF-8: {e}"),
})?;
match text.trim() {
"little" => false,
"big" => true,
other => {
return Err(AnamnesisError::Parse {
reason: format!(
"unknown byte order `{other}` (expected `little` or `big`)"
),
})
}
}
}
None => false, };
let &(pkl_start, pkl_len) =
entry_index
.get("data.pkl")
.ok_or_else(|| AnamnesisError::Parse {
reason: "ZIP entry `data.pkl` not found".into(),
})?;
let pkl_data =
raw.get(pkl_start..pkl_start + pkl_len)
.ok_or_else(|| AnamnesisError::Parse {
reason: "data.pkl slice out of bounds".into(),
})?;
let mut vm = PickleVm::new(pkl_data);
let root = vm.execute()?;
let dict_pairs = extract_dict_pairs(&root, 0)?;
let mut meta = Vec::new();
for (key, value) in dict_pairs {
let name = as_str(key)?;
if let Some((_callable, args)) = unwrap_to_rebuild(value, 0) {
let tref = parse_rebuild_args(name, args)?;
meta.push(TensorMeta {
name: tref.name,
shape: tref.shape,
dtype: tref.dtype,
storage_key: tref.storage_key,
storage_offset: tref.storage_offset,
strides: tref.strides,
});
}
}
Ok(ParsedPth {
mmap: raw,
meta,
entry_index,
big_endian,
})
}
fn build_entry_index(
archive: &mut zip::ZipArchive<std::io::Cursor<&[u8]>>,
raw: &[u8],
) -> crate::Result<HashMap<String, (usize, usize)>> {
let mut index = HashMap::with_capacity(archive.len());
for i in 0..archive.len() {
let entry = archive.by_index(i).map_err(|e| AnamnesisError::Parse {
reason: format!("failed to read ZIP entry {i}: {e}"),
})?;
if entry.compression() != zip::CompressionMethod::Stored {
continue;
}
let full_name = entry.name().to_owned();
let data_start =
usize::try_from(entry.data_start()).map_err(|_| AnamnesisError::Parse {
reason: format!("ZIP entry `{full_name}`: data_start overflows usize"),
})?;
let data_len = usize::try_from(entry.size()).map_err(|_| AnamnesisError::Parse {
reason: format!("ZIP entry `{full_name}`: size overflows usize"),
})?;
let data_end = data_start
.checked_add(data_len)
.ok_or_else(|| AnamnesisError::Parse {
reason: format!("ZIP entry `{full_name}`: data range overflow"),
})?;
if data_end > raw.len() {
return Err(AnamnesisError::Parse {
reason: format!(
"ZIP entry `{full_name}`: data range [{data_start}..{data_end}] \
exceeds file size {}",
raw.len()
),
});
}
let suffix = full_name
.find('/')
.map_or(full_name.as_str(), |pos| {
full_name.get(pos + 1..).unwrap_or(&full_name)
})
.to_owned();
if !suffix.is_empty() {
index.insert(suffix, (data_start, data_len));
}
}
Ok(index)
}
#[cfg(test)]
#[allow(
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_used,
clippy::as_conversions,
clippy::wildcard_enum_match_arm
)]
mod tests {
use std::io::Write;
use super::*;
#[test]
fn dtype_byte_sizes() {
assert_eq!(PthDtype::Bool.byte_size(), 1);
assert_eq!(PthDtype::U8.byte_size(), 1);
assert_eq!(PthDtype::I8.byte_size(), 1);
assert_eq!(PthDtype::F16.byte_size(), 2);
assert_eq!(PthDtype::BF16.byte_size(), 2);
assert_eq!(PthDtype::I16.byte_size(), 2);
assert_eq!(PthDtype::F32.byte_size(), 4);
assert_eq!(PthDtype::I32.byte_size(), 4);
assert_eq!(PthDtype::F64.byte_size(), 8);
assert_eq!(PthDtype::I64.byte_size(), 8);
}
#[test]
fn dtype_display() {
assert_eq!(PthDtype::F32.to_string(), "F32");
assert_eq!(PthDtype::BF16.to_string(), "BF16");
assert_eq!(PthDtype::Bool.to_string(), "BOOL");
}
#[test]
fn dtype_to_dtype_roundtrip() {
assert_eq!(PthDtype::F32.to_dtype().unwrap(), Dtype::F32);
assert_eq!(PthDtype::F16.to_dtype().unwrap(), Dtype::F16);
assert_eq!(PthDtype::BF16.to_dtype().unwrap(), Dtype::BF16);
assert_eq!(PthDtype::I64.to_dtype().unwrap(), Dtype::I64);
assert_eq!(PthDtype::Bool.to_dtype().unwrap(), Dtype::Bool);
}
#[test]
fn dtype_from_storage_class() {
assert_eq!(
PthDtype::from_storage_class("torch", "FloatStorage").unwrap(),
PthDtype::F32
);
assert_eq!(
PthDtype::from_storage_class("torch", "BFloat16Storage").unwrap(),
PthDtype::BF16
);
assert!(PthDtype::from_storage_class("torch", "UnknownStorage").is_err());
assert!(PthDtype::from_storage_class("numpy", "FloatStorage").is_err());
}
#[test]
fn long1_zero() {
assert_eq!(long1_to_i64(&[]).unwrap(), 0);
}
#[test]
fn long1_positive() {
assert_eq!(long1_to_i64(&[0xFF, 0x00]).unwrap(), 255);
assert_eq!(long1_to_i64(&[0x01]).unwrap(), 1);
assert_eq!(long1_to_i64(&[0x80, 0x00]).unwrap(), 128);
}
#[test]
fn long1_negative() {
assert_eq!(long1_to_i64(&[0xFF]).unwrap(), -1);
assert_eq!(long1_to_i64(&[0x80]).unwrap(), -128);
}
#[test]
fn long1_too_large() {
let big = vec![0x01; 9]; assert!(long1_to_i64(&big).is_err());
}
#[test]
fn contiguous_strides_2d() {
assert_eq!(contiguous_strides(&[3, 4]), vec![4, 1]);
assert_eq!(contiguous_strides(&[16, 10]), vec![10, 1]);
}
#[test]
fn contiguous_strides_1d() {
assert_eq!(contiguous_strides(&[5]), vec![1]);
}
#[test]
fn contiguous_strides_scalar() {
assert_eq!(contiguous_strides(&[]), Vec::<usize>::new());
}
#[test]
fn is_contiguous_true() {
assert!(is_contiguous(&[3, 4], &[4, 1]));
assert!(is_contiguous(&[5], &[1]));
}
#[test]
fn is_contiguous_transposed() {
assert!(!is_contiguous(&[3, 4], &[1, 3]));
}
#[test]
fn vm_simple_int() {
let pkl = &[0x80, 0x02, b'K', 42, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::Int(42)));
}
#[test]
fn vm_string() {
let pkl = &[0x80, 0x02, 0x8c, 0x02, b'h', b'i', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::String(ref s) if s == "hi"));
}
#[test]
fn vm_tuple2() {
let pkl = &[0x80, 0x02, b'K', 1, b'K', 2, 0x86, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert_eq!(items.len(), 2);
} else {
panic!("expected Tuple");
}
}
#[test]
fn vm_empty_dict() {
let pkl = &[0x80, 0x02, b'}', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::Dict(ref d) if d.is_empty()));
}
#[test]
fn vm_dict_with_setitem() {
let pkl = &[0x80, 0x02, b'}', 0x8c, 1, b'k', b'K', 7, b's', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Dict(pairs) = result {
assert_eq!(pairs.len(), 1);
} else {
panic!("expected Dict");
}
}
#[test]
fn vm_memo_roundtrip() {
let pkl = &[
0x80, 0x02, b'K', 99, b'q', 0, b'K', 0, b'h', 0, 0x86, b'.', ];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert_eq!(items.len(), 2);
assert!(matches!(&items[0], PickleValue::Int(0)));
assert!(matches!(&items[1], PickleValue::Int(99)));
} else {
panic!("expected Tuple");
}
}
#[test]
fn vm_rejects_disallowed_global() {
let pkl = b"\x80\x02cos\nsystem\n.";
let mut vm = PickleVm::new(pkl);
let err = vm.execute().unwrap_err();
let msg = err.to_string();
assert!(msg.contains("disallowed pickle global"), "got: {msg}");
assert!(msg.contains("os.system"), "got: {msg}");
}
#[test]
fn vm_rejects_unknown_opcode() {
let pkl = &[0x80, 0x02, 0xFF, b'.'];
let mut vm = PickleVm::new(pkl);
let err = vm.execute().unwrap_err();
assert!(err.to_string().contains("unsupported pickle opcode 0xff"));
}
#[test]
fn vm_allows_torch_global() {
let pkl = b"\x80\x02ctorch._utils\n_rebuild_tensor_v2\n.";
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(
result,
PickleValue::Global { ref module, ref name }
if module == "torch._utils" && name == "_rebuild_tensor_v2"
));
}
#[test]
fn reject_legacy_pth() {
let data = vec![0x80, 0x02, 0x00, 0x00, 0x00];
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), &data).unwrap();
let err = parse_pth(tmp.path()).unwrap_err();
assert!(err.to_string().contains("legacy .pth format"));
}
#[test]
fn reject_non_zip() {
let data = vec![0x00, 0x01, 0x02, 0x03, 0x04];
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), &data).unwrap();
let err = parse_pth(tmp.path()).unwrap_err();
assert!(err.to_string().contains("not a ZIP archive"));
}
#[test]
fn reject_too_small() {
let data = vec![0x50, 0x4B]; let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), &data).unwrap();
let err = parse_pth(tmp.path()).unwrap_err();
assert!(err.to_string().contains("too small"));
}
#[test]
fn vm_frame_opcode() {
let pkl: &[u8] = &[
0x80, 0x04, 0x95, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, b'K', 42, b'.', ];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::Int(42)));
}
#[test]
fn vm_none() {
let pkl: &[u8] = &[0x80, 0x02, b'N', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::None));
}
#[test]
fn vm_newtrue_newfalse() {
let pkl: &[u8] = &[0x80, 0x02, 0x88, 0x89, 0x86, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert!(matches!(items[0], PickleValue::Bool(true)));
assert!(matches!(items[1], PickleValue::Bool(false)));
} else {
panic!("expected Tuple");
}
}
#[test]
fn vm_binint() {
let pkl: &[u8] = &[0x80, 0x02, b'J', 0x04, 0x03, 0x02, 0x01, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::Int(0x0102_0304)));
}
#[test]
fn vm_binint_negative() {
let pkl: &[u8] = &[0x80, 0x02, b'J', 0xFF, 0xFF, 0xFF, 0xFF, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::Int(-1)));
}
#[test]
fn vm_binint2() {
let pkl: &[u8] = &[0x80, 0x02, b'M', 0x00, 0x01, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(result, PickleValue::Int(256)));
}
#[test]
fn vm_binunicode() {
let pkl: &[u8] = &[
0x80, 0x02, b'X', 0x03, 0x00, 0x00, 0x00, b'a', b'b', b'c', b'.',
];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::String(s) = result {
assert_eq!(s, "abc");
} else {
panic!("expected String, got {result:?}");
}
}
#[test]
fn vm_short_binstring() {
let pkl: &[u8] = &[0x80, 0x02, b'U', 0x02, b'x', b'y', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::String(s) = result {
assert_eq!(s, "xy");
} else {
panic!("expected String, got {result:?}");
}
}
#[test]
fn vm_short_binbytes() {
let pkl: &[u8] = &[0x80, 0x02, b'C', 0x02, 0xDE, 0xAD, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Bytes(b) = result {
assert_eq!(b, vec![0xDE, 0xAD]);
} else {
panic!("expected Bytes, got {result:?}");
}
}
#[test]
fn vm_empty_list() {
let pkl: &[u8] = &[0x80, 0x02, b']', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::List(items) = result {
assert!(items.is_empty());
} else {
panic!("expected List, got {result:?}");
}
}
#[test]
fn vm_empty_tuple() {
let pkl: &[u8] = &[0x80, 0x02, b')', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert!(items.is_empty());
} else {
panic!("expected Tuple, got {result:?}");
}
}
#[test]
fn vm_tuple1() {
let pkl: &[u8] = &[0x80, 0x02, b'K', 7, 0x85, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert_eq!(items.len(), 1);
assert!(matches!(items[0], PickleValue::Int(7)));
} else {
panic!("expected Tuple, got {result:?}");
}
}
#[test]
fn vm_tuple3() {
let pkl: &[u8] = &[0x80, 0x02, b'K', 1, b'K', 2, b'K', 3, 0x87, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert_eq!(items.len(), 3);
assert!(matches!(items[0], PickleValue::Int(1)));
assert!(matches!(items[1], PickleValue::Int(2)));
assert!(matches!(items[2], PickleValue::Int(3)));
} else {
panic!("expected Tuple, got {result:?}");
}
}
#[test]
fn vm_setitems() {
let pkl: &[u8] = &[
0x80, 0x02, b'}', b'(', 0x8C, 0x01, b'a', b'K', 1, 0x8C, 0x01, b'b', b'K', 2, b'u',
b'.',
];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Dict(pairs) = result {
assert_eq!(pairs.len(), 2);
} else {
panic!("expected Dict, got {result:?}");
}
}
#[test]
fn vm_append() {
let pkl: &[u8] = &[0x80, 0x02, b']', b'K', 42, b'a', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::List(items) = result {
assert_eq!(items.len(), 1);
assert!(matches!(items[0], PickleValue::Int(42)));
} else {
panic!("expected List, got {result:?}");
}
}
#[test]
fn vm_appends() {
let pkl: &[u8] = &[0x80, 0x02, b']', b'(', b'K', 1, b'K', 2, b'e', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::List(items) = result {
assert_eq!(items.len(), 2);
assert!(matches!(items[0], PickleValue::Int(1)));
assert!(matches!(items[1], PickleValue::Int(2)));
} else {
panic!("expected List, got {result:?}");
}
}
#[test]
fn vm_long_memo_roundtrip() {
let pkl: &[u8] = &[
0x80, 0x02, b'K', 77, b'r', 0x01, 0x00, 0x00, 0x00, b'K', 0, b'j', 0x01, 0x00, 0x00, 0x00, 0x86, b'.', ];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert_eq!(items.len(), 2);
assert!(matches!(items[0], PickleValue::Int(0)));
assert!(matches!(items[1], PickleValue::Int(77)));
} else {
panic!("expected Tuple, got {result:?}");
}
}
#[test]
fn vm_memoize() {
let pkl: &[u8] = &[
0x80, 0x04, b'K', 99, 0x94, b'K', 0, b'h', 0x00, 0x86, b'.', ];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Tuple(items) = result {
assert_eq!(items.len(), 2);
assert!(matches!(items[0], PickleValue::Int(0)));
assert!(matches!(items[1], PickleValue::Int(99)));
} else {
panic!("expected Tuple, got {result:?}");
}
}
#[test]
fn long1_8byte_negative() {
let result = long1_to_i64(&[0xFF; 8]).unwrap();
assert_eq!(result, -1);
}
#[test]
fn long1_8byte_max_positive() {
let result = long1_to_i64(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]).unwrap();
assert_eq!(result, i64::MAX);
}
#[test]
fn long1_8byte_min_negative() {
let result = long1_to_i64(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80]).unwrap();
assert_eq!(result, i64::MIN);
}
#[test]
fn unwrap_to_rebuild_rejects_deep_nesting() {
let leaf = PickleValue::Int(0);
let result = unwrap_to_rebuild(&leaf, MAX_PICKLE_NESTING + 1);
assert!(result.is_none(), "should reject nesting beyond limit");
}
#[test]
fn copy_to_contiguous_transposed_2x3() {
let values: [f32; 6] = [0.0, 3.0, 1.0, 4.0, 2.0, 5.0];
let mut storage = Vec::new();
for v in &values {
storage.extend_from_slice(&v.to_le_bytes());
}
let result = copy_to_contiguous(&storage, 0, &[2, 3], &[1, 2], 4).unwrap();
let expected: [f32; 6] = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0];
let mut expected_bytes = Vec::new();
for v in &expected {
expected_bytes.extend_from_slice(&v.to_le_bytes());
}
assert_eq!(result, expected_bytes);
}
#[test]
fn vm_binstring() {
let pkl: &[u8] = &[0x80, 0x02, b'T', 0x02, 0x00, 0x00, 0x00, b'a', b'b', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::String(s) = result {
assert_eq!(s, "ab");
} else {
panic!("expected String, got {result:?}");
}
}
#[test]
fn vm_binbytes() {
let pkl: &[u8] = &[0x80, 0x03, b'B', 0x02, 0x00, 0x00, 0x00, 0xCA, 0xFE, b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Bytes(b) = result {
assert_eq!(b, vec![0xCA, 0xFE]);
} else {
panic!("expected Bytes, got {result:?}");
}
}
#[test]
fn vm_stack_global() {
let pkl: &[u8] = &[
0x80, 0x04, 0x8C, 0x0C, b't', b'o', b'r', b'c', b'h', b'.', b'_', b'u', b't', b'i', b'l', b's', 0x8C,
0x12, b'_', b'r', b'e', b'b', b'u', b'i', b'l', b'd', b'_', b't', b'e', b'n', b's', b'o',
b'r', b'_', b'v', b'2', 0x93, b'.',
];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(matches!(
result,
PickleValue::Global { ref module, ref name }
if module == "torch._utils" && name == "_rebuild_tensor_v2"
));
}
#[test]
fn vm_reduce() {
let pkl = b"\x80\x02ctorch._utils\n_rebuild_tensor_v2\n)R.";
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(
matches!(result, PickleValue::Reduced { .. }),
"expected Reduced, got {result:?}"
);
}
#[test]
fn vm_newobj() {
let pkl: &[u8] = b"\x80\x02ctorch._utils\n_rebuild_tensor_v2\n)\x81.";
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
assert!(
matches!(result, PickleValue::Reduced { .. }),
"expected Reduced, got {result:?}"
);
}
#[test]
fn vm_build() {
let pkl: &[u8] = &[0x80, 0x02, b'K', 1, b'K', 2, b'b', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::Built { obj, state } = result {
assert!(matches!(*obj, PickleValue::Int(1)));
assert!(matches!(*state, PickleValue::Int(2)));
} else {
panic!("expected Built, got {result:?}");
}
}
#[test]
fn vm_binpersid() {
let pkl: &[u8] = &[0x80, 0x02, b'K', 42, b'Q', b'.'];
let mut vm = PickleVm::new(pkl);
let result = vm.execute().unwrap();
if let PickleValue::PersistentId(inner) = result {
assert!(matches!(*inner, PickleValue::Int(42)));
} else {
panic!("expected PersistentId, got {result:?}");
}
}
#[test]
fn vm_memoize_overflow() {
let pkl: &[u8] = &[0x80, 0x04, b'K', 0, 0x94, b'.'];
let mut vm = PickleVm::new(pkl);
vm.next_memo_id = u32::MAX;
let err = vm.execute().unwrap_err();
assert!(
err.to_string().contains("memo table overflow"),
"got: {err}"
);
}
#[test]
fn copy_to_contiguous_zero_elements_errors() {
let storage = vec![0u8; 16];
let result = copy_to_contiguous(&storage, 0, &[0, 4], &[4, 1], 4);
assert!(
result.is_err(),
"zero-dim in shape should error in max_elem_offset"
);
}
#[test]
fn copy_to_contiguous_element_count_overflow() {
let storage = vec![0u8; 8];
let result = copy_to_contiguous(&storage, 0, &[usize::MAX, 2], &[2, 1], 1);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("overflow"),
"expected overflow error"
);
}
#[test]
fn copy_to_contiguous_offset_overflow() {
let storage = vec![0u8; 8];
let result = copy_to_contiguous(&storage, usize::MAX, &[1], &[1], 1);
assert!(result.is_err());
}
#[test]
fn copy_to_contiguous_offset_at_boundary() {
let storage = vec![0x01, 0x02, 0x03, 0x04];
let result = copy_to_contiguous(&storage, 0, &[1], &[1], 4).unwrap();
assert_eq!(result, vec![0x01, 0x02, 0x03, 0x04]);
}
#[test]
fn copy_to_contiguous_one_past_boundary() {
let storage = vec![0x01, 0x02, 0x03, 0x04];
let result = copy_to_contiguous(&storage, 1, &[1], &[1], 4);
assert!(result.is_err());
}
#[test]
fn is_contiguous_mismatched_dims() {
assert!(!is_contiguous(&[2, 3, 4], &[12, 4]));
}
#[test]
fn copy_to_contiguous_mismatched_ndim() {
let storage = vec![0u8; 96]; let result = copy_to_contiguous(&storage, 0, &[2, 3, 4], &[12, 4], 4);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("ndim"),
"expected ndim mismatch error, got: {msg}"
);
}
#[test]
fn copy_to_contiguous_zero_stride_broadcast() {
let storage: Vec<u8> = vec![10, 20, 30];
let result = copy_to_contiguous(&storage, 0, &[2, 3], &[0, 1], 1).unwrap();
assert_eq!(result, vec![10, 20, 30, 10, 20, 30]);
}
#[test]
fn reject_zip_missing_data_pkl() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("archive/not_a_pkl.txt", options).unwrap();
zip.write_all(b"hello").unwrap();
zip.finish().unwrap();
}
let err = parse_pth(tmp.path()).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("data.pkl") && msg.contains("not found"),
"expected 'data.pkl not found', got: {msg}"
);
}
#[test]
fn reject_compressed_data_pkl() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Deflated);
zip.start_file("archive/data.pkl", options).unwrap();
zip.write_all(b"\x80\x02}.").unwrap(); zip.finish().unwrap();
}
let err = parse_pth(tmp.path()).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("data.pkl") && msg.contains("not found"),
"expected 'data.pkl not found' (compressed entries are skipped), got: {msg}"
);
}
#[test]
fn zip_zero_length_entry_accepted() {
let tmp = tempfile::NamedTempFile::new().unwrap();
{
let file = std::fs::File::create(tmp.path()).unwrap();
let mut zip = zip::ZipWriter::new(file);
let opts = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
zip.start_file("archive/data.pkl", opts).unwrap();
zip.write_all(b"\x80\x02}.").unwrap();
zip.start_file("archive/data/0", opts).unwrap();
zip.finish().unwrap();
}
let parsed = parse_pth(tmp.path()).unwrap();
assert!(
parsed.tensors().unwrap().is_empty(),
"empty state_dict should produce no tensors"
);
}
#[test]
fn extract_dict_pairs_rejects_deep_nesting() {
let dict = PickleValue::Dict(Vec::new());
let result = extract_dict_pairs(&dict, MAX_PICKLE_NESTING + 1);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("nesting limit exceeded"),
"expected nesting limit error"
);
}
}