pt-loader 0.1.4

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use crate::{ExportFormat, ExportOptions, LoadOptions, PtCheckpoint, TensorData};
use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyType};
use std::collections::BTreeMap;
use std::path::PathBuf;

fn into_py_error(err: impl std::fmt::Display) -> PyErr {
  PyRuntimeError::new_err(err.to_string())
}

fn to_json<T: serde::Serialize>(value: &T) -> PyResult<String> {
  serde_json::to_string(value).map_err(into_py_error)
}

fn build_load_options(
  input_pt: &str,
  max_archive_bytes: Option<u64>,
  max_tensor_count: Option<usize>,
  max_tensor_bytes: Option<usize>,
  max_pickle_bytes: Option<usize>,
  strict_contiguous: Option<bool>,
  state_dict_root_key: Option<String>,
  state_dict_root_keys: Option<Vec<String>>,
  state_dict_root_strict: Option<bool>,
) -> PyResult<LoadOptions> {
  if input_pt.is_empty() {
    return Err(PyValueError::new_err("input_pt must not be empty"));
  }

  let mut opts = LoadOptions::default();
  if let Some(value) = max_archive_bytes {
    opts.max_archive_bytes = value;
  }
  if let Some(value) = max_tensor_count {
    opts.max_tensor_count = value;
  }
  if let Some(value) = max_tensor_bytes {
    opts.max_tensor_bytes = value;
  }
  if let Some(value) = max_pickle_bytes {
    opts.max_pickle_bytes = value;
  }
  if let Some(value) = strict_contiguous {
    opts.strict_contiguous = value;
  }
  if let Some(value) = state_dict_root_strict {
    opts.state_dict_root_strict = value;
  }
  if let Some(keys) = state_dict_root_keys {
    opts.state_dict_root_keys = keys;
  } else if let Some(key) = state_dict_root_key {
    opts.state_dict_root_keys = vec![key];
  }

  Ok(opts)
}

fn tensors_to_py_dict<'py>(py: Python<'py>, tensors: &BTreeMap<String, TensorData>) -> PyResult<Bound<'py, PyDict>> {
  let output = PyDict::new(py);
  for (name, tensor) in tensors {
    let item = PyDict::new(py);
    item.set_item("dtype", tensor.dtype.as_safetensors())?;
    item.set_item("shape", tensor.shape.clone())?;
    item.set_item("data", PyBytes::new(py, &tensor.bytes))?;
    output.set_item(name, item)?;
  }
  Ok(output)
}

#[pyclass(name = "PtCheckpointCore")]
pub struct PyPtCheckpoint {
  inner: PtCheckpoint,
}

#[pymethods]
impl PyPtCheckpoint {
  #[classmethod]
  #[pyo3(
    signature = (
      input_pt,
      *,
      max_archive_bytes = None,
      max_tensor_count = None,
      max_tensor_bytes = None,
      max_pickle_bytes = None,
      strict_contiguous = None,
      state_dict_root_key = None,
      state_dict_root_keys = None,
      state_dict_root_strict = None
    )
  )]
  fn load(
    _cls: &Bound<'_, PyType>,
    input_pt: &str,
    max_archive_bytes: Option<u64>,
    max_tensor_count: Option<usize>,
    max_tensor_bytes: Option<usize>,
    max_pickle_bytes: Option<usize>,
    strict_contiguous: Option<bool>,
    state_dict_root_key: Option<String>,
    state_dict_root_keys: Option<Vec<String>>,
    state_dict_root_strict: Option<bool>,
  ) -> PyResult<Self> {
    let opts = build_load_options(
      input_pt,
      max_archive_bytes,
      max_tensor_count,
      max_tensor_bytes,
      max_pickle_bytes,
      strict_contiguous,
      state_dict_root_key,
      state_dict_root_keys,
      state_dict_root_strict,
    )?;
    let inner = PtCheckpoint::load(input_pt, opts).map_err(into_py_error)?;
    Ok(Self { inner })
  }

  #[pyo3(
    signature = (
      out_dir,
      *,
      format = "safetensors",
      weights_filename = "model.safetensors",
      metadata_filename = "model.yaml",
      include_metadata = true,
      overwrite = false
    )
  )]
  fn export_json(
    &self,
    out_dir: &str,
    format: &str,
    weights_filename: &str,
    metadata_filename: &str,
    include_metadata: bool,
    overwrite: bool,
  ) -> PyResult<String> {
    let export_format = match format {
      "safetensors" => ExportFormat::Safetensors,
      other => {
        return Err(PyValueError::new_err(format!(
          "unsupported format {other}; expected 'safetensors'"
        )));
      }
    };

    let mut opts = ExportOptions::new(export_format, None);
    opts.weights_filename = PathBuf::from(weights_filename);
    opts.metadata_filename = PathBuf::from(metadata_filename);
    opts.include_metadata = include_metadata;
    opts.overwrite = overwrite;

    let result = self.inner.export(out_dir, opts).map_err(into_py_error)?;
    to_json(&result)
  }

  fn metadata_json(&self) -> PyResult<String> {
    to_json(self.inner.metadata())
  }

  fn state_dict_raw<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
    tensors_to_py_dict(py, self.inner.raw_tensors())
  }

  fn source_sha256(&self) -> String {
    self.inner.source_sha256().to_string()
  }

  fn tensor_count(&self) -> usize {
    self.inner.tensor_count()
  }

  fn __repr__(&self) -> String {
    format!(
      "PtCheckpointCore(tensor_count={}, source_sha256='{}')",
      self.inner.tensor_count(),
      self.inner.source_sha256()
    )
  }
}

#[pymodule]
fn _core(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
  module.add_class::<PyPtCheckpoint>()?;
  module.add("__version__", env!("CARGO_PKG_VERSION"))?;
  Ok(())
}