pt-loader 0.1.4

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use pt_loader::{ExportFormat, ExportOptions, LoadOptions, PtCheckpoint};
use std::env;
use std::path::PathBuf;

fn main() {
  if let Err(error) = run() {
    eprintln!("error: {}", error);
    std::process::exit(1);
  }
}

fn run() -> Result<(), Box<dyn std::error::Error>> {
  let mut args = env::args().skip(1).collect::<Vec<_>>();
  if args.is_empty() {
    print_usage();
    return Err("missing command".into());
  }

  let command = args.remove(0);
  match command.as_str() {
    "inspect" => {
      if args.is_empty() {
        print_usage();
        return Err("inspect requires <input.pt>".into());
      }
      let input = PathBuf::from(&args[0]);
      let load_opts = parse_load_options(&args[1..])?;
      let checkpoint = PtCheckpoint::load(&input, load_opts)?;
      println!("{}", serde_json::to_string_pretty(checkpoint.metadata())?);
      Ok(())
    }
    "convert" => {
      if args.is_empty() {
        print_usage();
        return Err("convert requires <input.pt> and optional --out-dir <dir>".into());
      }
      let input = PathBuf::from(&args[0]);
      let mut out_dir = PathBuf::from("out");
      let mut load_opts = LoadOptions::default();
      let mut index = 1usize;
      while index < args.len() {
        match args[index].as_str() {
          "--out-dir" => {
            let Some(value) = args.get(index + 1) else {
              return Err("--out-dir requires a value".into());
            };
            out_dir = PathBuf::from(value);
            index += 2;
          }
          "--state-dict-root-key" => {
            let Some(value) = args.get(index + 1) else {
              return Err("--state-dict-root-key requires a value".into());
            };
            load_opts.state_dict_root_keys.push(value.clone());
            index += 2;
          }
          "--state-dict-root-strict" => {
            let Some(value) = args.get(index + 1) else {
              return Err("--state-dict-root-strict requires a value".into());
            };
            load_opts.state_dict_root_strict =
              parse_bool_flag(value).ok_or_else(|| "--state-dict-root-strict expects true|false".to_string())?;
            index += 2;
          }
          other => {
            return Err(format!("unknown argument: {}", other).into());
          }
        }
      }

      let checkpoint = PtCheckpoint::load(&input, load_opts)?;
      let result = checkpoint.export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&input)))?;
      println!("{}", serde_json::to_string_pretty(&result)?);
      Ok(())
    }
    _ => {
      print_usage();
      Err(format!("unknown command: {}", command).into())
    }
  }
}

fn print_usage() {
  eprintln!("Usage:");
  eprintln!("  pt-loader inspect <input.pt> [--state-dict-root-key <key>]... [--state-dict-root-strict <true|false>]");
  eprintln!(
    "  pt-loader convert <input.pt> [--out-dir <dir>] [--state-dict-root-key <key>]... [--state-dict-root-strict <true|false>]"
  );
}

fn parse_bool_flag(value: &str) -> Option<bool> {
  match value {
    "true" | "1" | "yes" | "on" => Some(true),
    "false" | "0" | "no" | "off" => Some(false),
    _ => None,
  }
}

fn parse_load_options(args: &[String]) -> Result<LoadOptions, Box<dyn std::error::Error>> {
  let mut opts = LoadOptions::default();
  let mut index = 0usize;
  while index < args.len() {
    match args[index].as_str() {
      "--state-dict-root-key" => {
        let Some(value) = args.get(index + 1) else {
          return Err("--state-dict-root-key requires a value".into());
        };
        opts.state_dict_root_keys.push(value.clone());
        index += 2;
      }
      "--state-dict-root-strict" => {
        let Some(value) = args.get(index + 1) else {
          return Err("--state-dict-root-strict requires a value".into());
        };
        opts.state_dict_root_strict =
          parse_bool_flag(value).ok_or_else(|| "--state-dict-root-strict expects true|false".to_string())?;
        index += 2;
      }
      other => return Err(format!("unknown argument: {}", other).into()),
    }
  }
  Ok(opts)
}