pt-loader 0.1.0

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
from __future__ import annotations

import json
from typing import Any

from ._core import __version__, convert_json, inspect_json, load_pt_raw

_DTYPE_TO_NUMPY: dict[str, str] = {
  "F16": "float16",
  "BF16": "float32",
  "F32": "float32",
  "F64": "float64",
  "I8": "int8",
  "I16": "int16",
  "I32": "int32",
  "I64": "int64",
  "U8": "uint8",
  "BOOL": "bool",
}


def inspect(input_pt: str) -> dict[str, Any]:
  return json.loads(inspect_json(input_pt))


def convert(
  input_pt: str,
  out_dir: str = "out",
  *,
  max_archive_bytes: int | None = None,
  max_tensor_count: int | None = None,
  max_tensor_bytes: int | None = None,
  max_pickle_bytes: int | None = None,
  strict_contiguous: bool | None = None,
) -> dict[str, Any]:
  return json.loads(
    convert_json(
      input_pt,
      out_dir,
      max_archive_bytes=max_archive_bytes,
      max_tensor_count=max_tensor_count,
      max_tensor_bytes=max_tensor_bytes,
      max_pickle_bytes=max_pickle_bytes,
      strict_contiguous=strict_contiguous,
    )
  )


def load_pt(
  input_pt: str,
  *,
  max_archive_bytes: int | None = None,
  max_tensor_count: int | None = None,
  max_tensor_bytes: int | None = None,
  max_pickle_bytes: int | None = None,
  strict_contiguous: bool | None = None,
) -> dict[str, "Any"]:
  import numpy as np

  raw = load_pt_raw(
    input_pt,
    max_archive_bytes=max_archive_bytes,
    max_tensor_count=max_tensor_count,
    max_tensor_bytes=max_tensor_bytes,
    max_pickle_bytes=max_pickle_bytes,
    strict_contiguous=strict_contiguous,
  )

  tensors: dict[str, Any] = {}
  for name, payload in raw.items():
    dtype_name = _DTYPE_TO_NUMPY[payload["dtype"]]
    shape = tuple(payload["shape"])
    data = payload["data"]
    array = np.frombuffer(data, dtype=dtype_name).reshape(shape)
    tensors[name] = array.copy()

  return tensors


__all__ = ["__version__", "inspect", "convert", "load_pt"]