pt-loader 0.1.1

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use crate::{
  convert_pt_to_safetensors, inspect_pt, parse_checkpoint, ConvertOptions, ConvertResult,
  InspectionReport,
};
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use std::path::Path;

fn into_py_error(err: impl std::fmt::Display) -> PyErr {
  PyRuntimeError::new_err(err.to_string())
}

fn to_json<T: serde::Serialize>(value: &T) -> PyResult<String> {
  serde_json::to_string(value).map_err(into_py_error)
}

#[pyfunction]
fn inspect_json(input_pt: &str) -> PyResult<String> {
  let report: InspectionReport = inspect_pt(Path::new(input_pt)).map_err(into_py_error)?;
  to_json(&report)
}

#[pyfunction(
  signature = (
    input_pt,
    out_dir = "out",
    *,
    max_archive_bytes = None,
    max_tensor_count = None,
    max_tensor_bytes = None,
    max_pickle_bytes = None,
    strict_contiguous = None
  )
)]
fn convert_json(
  input_pt: &str,
  out_dir: &str,
  max_archive_bytes: Option<u64>,
  max_tensor_count: Option<usize>,
  max_tensor_bytes: Option<usize>,
  max_pickle_bytes: Option<usize>,
  strict_contiguous: Option<bool>,
) -> PyResult<String> {
  let opts = build_options(
    input_pt,
    max_archive_bytes,
    max_tensor_count,
    max_tensor_bytes,
    max_pickle_bytes,
    strict_contiguous,
  )?;

  let result: ConvertResult =
    convert_pt_to_safetensors(Path::new(input_pt), Path::new(out_dir), opts).map_err(into_py_error)?;
  to_json(&result)
}

#[pyfunction(
  signature = (
    input_pt,
    *,
    max_archive_bytes = None,
    max_tensor_count = None,
    max_tensor_bytes = None,
    max_pickle_bytes = None,
    strict_contiguous = None
  )
)]
fn load_pt_raw<'py>(
  py: Python<'py>,
  input_pt: &str,
  max_archive_bytes: Option<u64>,
  max_tensor_count: Option<usize>,
  max_tensor_bytes: Option<usize>,
  max_pickle_bytes: Option<usize>,
  strict_contiguous: Option<bool>,
) -> PyResult<Bound<'py, PyDict>> {
  let opts = build_options(
    input_pt,
    max_archive_bytes,
    max_tensor_count,
    max_tensor_bytes,
    max_pickle_bytes,
    strict_contiguous,
  )?;

  let parsed = parse_checkpoint(Path::new(input_pt), &opts).map_err(into_py_error)?;
  let output = PyDict::new(py);
  for (name, tensor) in parsed.tensors {
    let item = PyDict::new(py);
    item.set_item("dtype", tensor.dtype.as_safetensors())?;
    item.set_item("shape", tensor.shape)?;
    item.set_item("data", PyBytes::new(py, &tensor.bytes))?;
    output.set_item(name, item)?;
  }

  Ok(output)
}

fn build_options(
  input_pt: &str,
  max_archive_bytes: Option<u64>,
  max_tensor_count: Option<usize>,
  max_tensor_bytes: Option<usize>,
  max_pickle_bytes: Option<usize>,
  strict_contiguous: Option<bool>,
) -> PyResult<ConvertOptions> {
  if input_pt.is_empty() {
    return Err(PyValueError::new_err("input_pt must not be empty"));
  }

  let mut opts = ConvertOptions::default();
  if let Some(value) = max_archive_bytes {
    opts.max_archive_bytes = value;
  }
  if let Some(value) = max_tensor_count {
    opts.max_tensor_count = value;
  }
  if let Some(value) = max_tensor_bytes {
    opts.max_tensor_bytes = value;
  }
  if let Some(value) = max_pickle_bytes {
    opts.max_pickle_bytes = value;
  }
  if let Some(value) = strict_contiguous {
    opts.strict_contiguous = value;
  }

  Ok(opts)
}

#[pymodule]
fn _core(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
  module.add_function(wrap_pyfunction!(inspect_json, module)?)?;
  module.add_function(wrap_pyfunction!(convert_json, module)?)?;
  module.add_function(wrap_pyfunction!(load_pt_raw, module)?)?;
  module.add("__version__", env!("CARGO_PKG_VERSION"))?;
  Ok(())
}