pt-loader 0.1.4

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use pt_loader::{LoadOptions, PtCheckpoint};
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
use zip::write::SimpleFileOptions;
use zip::ZipWriter;

#[test]
fn rejects_non_zip_pt() {
  let dir = tempdir().expect("tempdir");
  let pt = dir.path().join("legacy.pt");
  std::fs::write(&pt, b"not a zip checkpoint").expect("write");

  let err = PtCheckpoint::load(&pt, LoadOptions::default()).expect_err("should fail");
  assert!(err.to_string().contains("torch zip checkpoints"));
}

#[test]
fn enforces_pickle_size_limit() {
  let dir = tempdir().expect("tempdir");
  let pt = dir.path().join("large.pt");

  let file = File::create(&pt).expect("create");
  let mut zip = ZipWriter::new(file);
  zip
    .start_file("archive/data.pkl", SimpleFileOptions::default())
    .expect("start");
  zip.write_all(&vec![0u8; 1024]).expect("write");
  zip
    .start_file("archive/data/0", SimpleFileOptions::default())
    .expect("start data");
  zip.write_all(&[0u8; 16]).expect("write data");
  zip.finish().expect("finish");

  let opts = LoadOptions {
    max_pickle_bytes: 32,
    ..LoadOptions::default()
  };

  let err = PtCheckpoint::load(&pt, opts).expect_err("should fail");
  assert!(err.to_string().contains("data.pkl"));
}

#[test]
fn sample_last_pt_no_longer_fails_on_call_build_state() {
  let crate_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
  let path = crate_root.join("samples").join("last.pt");
  if !path.exists() {
    return;
  }

  let err = match PtCheckpoint::load(&path, LoadOptions::default()) {
    Ok(_) => return,
    Err(err) => err,
  };
  assert!(
    !err.to_string().contains("BUILD with non-empty state is not supported"),
    "regression: parser still fails with unsupported call BUILD state"
  );
}