use super::reader::{
from_utf8, memo_get, pop, read_exact, read_i32, read_line, read_signed_long, read_u16,
read_u32, read_u64,
};
use super::reduce::{apply_reduce, finalize_root, resolve_persistent_id};
use super::types::{PtContents, PtTensorMeta};
use crate::error::{Error, Result};
use numr::dtype::DType;
use std::collections::HashMap;
use std::io::{Cursor, Read};
#[derive(Debug, Clone)]
pub(super) enum PValue {
None,
Bool(bool),
Int(i64),
Str(String),
Tuple(Vec<PValue>),
List(Vec<PValue>),
Dict(Vec<(PValue, PValue)>),
Global(String, String),
PersistentRef {
dtype: DType,
storage_id: String,
numel: usize,
},
Tensor(PtTensorMeta),
#[allow(dead_code)]
Opaque(String, String),
}
pub(super) fn parse_pickle(bytes: &[u8]) -> Result<PtContents> {
let mut cur = Cursor::new(bytes);
let mut stack: Vec<PValue> = Vec::new();
let mut memo: HashMap<u32, PValue> = HashMap::new();
let mut marks: Vec<usize> = Vec::new();
loop {
let mut op = [0u8; 1];
cur.read_exact(&mut op).map_err(|e| Error::ModelError {
reason: format!("truncated pickle: {e}"),
})?;
match op[0] {
0x80 => {
let mut v = [0u8; 1];
read_exact(&mut cur, &mut v)?;
}
0x95 => {
let mut v = [0u8; 8];
read_exact(&mut cur, &mut v)?;
}
b'}' => stack.push(PValue::Dict(Vec::new())),
b']' => stack.push(PValue::List(Vec::new())),
b')' => stack.push(PValue::Tuple(Vec::new())),
b'(' => marks.push(stack.len()),
0x8c => {
let mut len = [0u8; 1];
read_exact(&mut cur, &mut len)?;
let mut buf = vec![0u8; len[0] as usize];
read_exact(&mut cur, &mut buf)?;
stack.push(PValue::Str(from_utf8(buf)?));
}
0x8d | b'X' => {
let len = if op[0] == 0x8d {
read_u64(&mut cur)? as usize
} else {
read_u32(&mut cur)? as usize
};
let mut buf = vec![0u8; len];
read_exact(&mut cur, &mut buf)?;
stack.push(PValue::Str(from_utf8(buf)?));
}
b'K' => {
let mut v = [0u8; 1];
read_exact(&mut cur, &mut v)?;
stack.push(PValue::Int(v[0] as i64));
}
b'M' => {
let v = read_u16(&mut cur)?;
stack.push(PValue::Int(v as i64));
}
b'J' => {
let v = read_i32(&mut cur)?;
stack.push(PValue::Int(v as i64));
}
0x8a => {
let mut len = [0u8; 1];
read_exact(&mut cur, &mut len)?;
stack.push(PValue::Int(read_signed_long(&mut cur, len[0] as usize)?));
}
0x8b => {
let len = read_u32(&mut cur)? as usize;
stack.push(PValue::Int(read_signed_long(&mut cur, len)?));
}
0x85 => {
let a = pop(&mut stack)?;
stack.push(PValue::Tuple(vec![a]));
}
0x86 => {
let b = pop(&mut stack)?;
let a = pop(&mut stack)?;
stack.push(PValue::Tuple(vec![a, b]));
}
0x87 => {
let c = pop(&mut stack)?;
let b = pop(&mut stack)?;
let a = pop(&mut stack)?;
stack.push(PValue::Tuple(vec![a, b, c]));
}
b't' => {
let mark = marks.pop().ok_or_else(|| Error::ModelError {
reason: "TUPLE opcode without MARK".into(),
})?;
let items = stack.split_off(mark);
stack.push(PValue::Tuple(items));
}
b'l' => {
let mark = marks.pop().ok_or_else(|| Error::ModelError {
reason: "LIST opcode without MARK".into(),
})?;
let items = stack.split_off(mark);
stack.push(PValue::List(items));
}
b'u' => {
let mark = marks.pop().ok_or_else(|| Error::ModelError {
reason: "SETITEMS without MARK".into(),
})?;
let items = stack.split_off(mark);
let dict = stack.last_mut().ok_or_else(|| Error::ModelError {
reason: "SETITEMS without dict on stack".into(),
})?;
let mut iter = items.into_iter();
if let PValue::Dict(entries) = dict {
while let (Some(k), Some(v)) = (iter.next(), iter.next()) {
entries.push((k, v));
}
} else {
return Err(Error::ModelError {
reason: "SETITEMS target is not a dict".into(),
});
}
}
b's' => {
let v = pop(&mut stack)?;
let k = pop(&mut stack)?;
let dict = stack.last_mut().ok_or_else(|| Error::ModelError {
reason: "SETITEM without dict on stack".into(),
})?;
if let PValue::Dict(entries) = dict {
entries.push((k, v));
} else {
return Err(Error::ModelError {
reason: "SETITEM target is not a dict".into(),
});
}
}
b'e' => {
let mark = marks.pop().ok_or_else(|| Error::ModelError {
reason: "APPENDS without MARK".into(),
})?;
let items = stack.split_off(mark);
if let Some(PValue::List(l)) = stack.last_mut() {
l.extend(items);
} else {
return Err(Error::ModelError {
reason: "APPENDS target is not a list".into(),
});
}
}
b'c' => {
let module = read_line(&mut cur)?;
let name = read_line(&mut cur)?;
stack.push(PValue::Global(module, name));
}
0x93 => {
let name = match pop(&mut stack)? {
PValue::Str(s) => s,
_ => {
return Err(Error::ModelError {
reason: "STACK_GLOBAL: name is not a string".into(),
});
}
};
let module = match pop(&mut stack)? {
PValue::Str(s) => s,
_ => {
return Err(Error::ModelError {
reason: "STACK_GLOBAL: module is not a string".into(),
});
}
};
stack.push(PValue::Global(module, name));
}
b'R' => {
let args = match pop(&mut stack)? {
PValue::Tuple(a) => a,
other => vec![other],
};
let func = pop(&mut stack)?;
let result = apply_reduce(func, args)?;
stack.push(result);
}
b'b' => {
let _state = pop(&mut stack)?;
}
b'Q' => {
let pid = pop(&mut stack)?;
stack.push(resolve_persistent_id(pid)?);
}
0x94 => {
let top = stack.last().ok_or_else(|| Error::ModelError {
reason: "MEMOIZE with empty stack".into(),
})?;
memo.insert(memo.len() as u32, top.clone());
}
b'q' => {
let mut idx = [0u8; 1];
read_exact(&mut cur, &mut idx)?;
if let Some(top) = stack.last() {
memo.insert(idx[0] as u32, top.clone());
}
}
b'r' => {
let idx = read_u32(&mut cur)?;
if let Some(top) = stack.last() {
memo.insert(idx, top.clone());
}
}
b'h' => {
let mut idx = [0u8; 1];
read_exact(&mut cur, &mut idx)?;
let v = memo_get(&memo, idx[0] as u32)?;
stack.push(v);
}
b'j' => {
let idx = read_u32(&mut cur)?;
let v = memo_get(&memo, idx)?;
stack.push(v);
}
0x89 => stack.push(PValue::Bool(false)),
0x88 => stack.push(PValue::Bool(true)),
b'N' => stack.push(PValue::None),
b'.' => break, other => {
return Err(Error::ModelError {
reason: format!("unsupported pickle opcode {:#04x}", other),
});
}
}
}
let root = stack.pop().ok_or_else(|| Error::ModelError {
reason: "pickle ended with empty stack".into(),
})?;
finalize_root(root)
}
#[cfg(test)]
mod tests {
use super::super::api::load_tensor_pt;
use super::super::reduce::{as_usize_tuple, dtype_from_global};
use super::*;
#[test]
fn rejects_non_zip() {
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
let tmp = std::env::temp_dir().join("boostr_torch_pt_non_zip.pt");
std::fs::write(&tmp, b"not a zip").unwrap();
let device = CpuDevice::new();
let res = load_tensor_pt::<CpuRuntime>(&tmp, None, &device);
assert!(res.is_err());
let _ = std::fs::remove_file(&tmp);
}
#[test]
fn rejects_missing_file() {
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
let device = CpuDevice::new();
let res = load_tensor_pt::<CpuRuntime>("/nonexistent-kokoro-voice-xyz.pt", None, &device);
assert!(res.is_err());
}
#[test]
fn dtype_from_global_maps_common_torch_types() {
let f32 = PValue::Global("torch".into(), "FloatStorage".into());
assert_eq!(dtype_from_global(&f32).unwrap(), DType::F32);
let f64 = PValue::Global("torch".into(), "DoubleStorage".into());
assert_eq!(dtype_from_global(&f64).unwrap(), DType::F64);
let bad = PValue::Global("torch".into(), "NewFangledStorage".into());
assert!(dtype_from_global(&bad).is_err());
}
#[test]
fn as_usize_tuple_rejects_non_tuple() {
assert!(as_usize_tuple(&PValue::Int(3)).is_err());
}
#[test]
fn pickle_stack_underflow_is_an_error() {
let mut stack: Vec<PValue> = Vec::new();
assert!(pop(&mut stack).is_err());
}
}