use std::collections::BTreeMap;
use std::io::Write;
use tempfile::NamedTempFile;
pub struct PtTensorSpec {
pub name: String,
pub storage_type: String,
pub storage_key: String,
pub shape: Vec<usize>,
pub stride: Vec<usize>,
pub storage_offset: usize,
pub numel: usize,
}
pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
if shape.is_empty() {
return vec![];
}
let mut strides = vec![1usize; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
fn push_int(p: &mut Vec<u8>, val: usize) {
if val <= 255 {
p.push(0x4b); p.push(val as u8);
} else if val <= 65535 {
p.push(0x4d); p.extend_from_slice(&(val as u16).to_le_bytes());
} else {
p.push(0x4a); p.extend_from_slice(&(val as i32).to_le_bytes());
}
}
fn push_int_tuple(p: &mut Vec<u8>, vals: &[usize]) {
match vals.len() {
0 => p.push(0x29), 1 => {
push_int(p, vals[0]);
p.push(0x85); }
2 => {
push_int(p, vals[0]);
push_int(p, vals[1]);
p.push(0x86); }
3 => {
push_int(p, vals[0]);
push_int(p, vals[1]);
push_int(p, vals[2]);
p.push(0x87); }
_ => {
p.push(0x28); for &v in vals {
push_int(p, v);
}
p.push(0x74); }
}
}
fn push_global(p: &mut Vec<u8>, module: &str, name: &str) {
p.push(0x63); p.extend_from_slice(module.as_bytes());
p.push(b'\n');
p.extend_from_slice(name.as_bytes());
p.push(b'\n');
}
fn push_short_binunicode(p: &mut Vec<u8>, s: &str) {
assert!(s.len() <= 255);
p.push(0x8c); p.push(s.len() as u8);
p.extend_from_slice(s.as_bytes());
}
fn build_pickle_state_dict(specs: &[PtTensorSpec]) -> Vec<u8> {
let mut p = Vec::new();
p.push(0x80);
p.push(2);
p.push(0x7d);
p.push(0x28);
for spec in specs {
push_short_binunicode(&mut p, &spec.name);
push_global(&mut p, "torch._utils", "_rebuild_tensor_v2");
p.push(0x28);
p.push(0x28); push_short_binunicode(&mut p, "storage");
push_global(&mut p, "torch", &spec.storage_type);
push_short_binunicode(&mut p, &spec.storage_key);
push_short_binunicode(&mut p, "cpu");
push_int(&mut p, spec.numel);
p.push(0x74); p.push(0x51);
push_int(&mut p, spec.storage_offset);
push_int_tuple(&mut p, &spec.shape);
push_int_tuple(&mut p, &spec.stride);
p.push(0x89);
push_global(&mut p, "collections", "OrderedDict");
p.push(0x29); p.push(0x52);
p.push(0x74); p.push(0x52); }
p.push(0x75);
p.push(0x2e);
p
}
pub fn build_pytorch_zip(
specs: &[PtTensorSpec],
storage_data: &BTreeMap<String, Vec<u8>>,
) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
{
let mut zip = zip::ZipWriter::new(&mut file);
let options = zip::write::SimpleFileOptions::default()
.compression_method(zip::CompressionMethod::Stored);
let pickle = build_pickle_state_dict(specs);
zip.start_file("archive/data.pkl", options).unwrap();
zip.write_all(&pickle).unwrap();
for (key, data) in storage_data {
zip.start_file(format!("archive/data/{}", key), options)
.unwrap();
zip.write_all(data).unwrap();
}
zip.finish().unwrap();
}
file
}