pt-loader 0.1.2

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use ndarray::ArrayD;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt;
use std::path::{Path, PathBuf};

#[derive(Debug, Clone)]
pub struct LoadOptions {
  pub max_archive_bytes: u64,
  pub max_tensor_count: usize,
  pub max_tensor_bytes: usize,
  pub max_pickle_bytes: usize,
  pub strict_contiguous: bool,
}

impl Default for LoadOptions {
  fn default() -> Self {
    Self {
      max_archive_bytes: 4 * 1024 * 1024 * 1024,
      max_tensor_count: 4096,
      max_tensor_bytes: 1024 * 1024 * 1024,
      max_pickle_bytes: 64 * 1024 * 1024,
      strict_contiguous: true,
    }
  }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExportFormat {
  Safetensors,
}

impl ExportFormat {
  pub fn extension(self) -> &'static str {
    match self {
      ExportFormat::Safetensors => "safetensors",
    }
  }
}

#[derive(Debug, Clone)]
pub struct ExportOptions {
  pub format: ExportFormat,
  pub weights_filename: PathBuf,
  pub metadata_filename: PathBuf,
  pub include_metadata: bool,
  pub overwrite: bool,
}

impl ExportOptions {
  pub fn new(format: ExportFormat, input_path: Option<&Path>) -> Self {
    let weights_filename = default_weights_filename(format, input_path);
    let metadata_filename = weights_filename.with_extension("yaml");

    Self {
      format,
      weights_filename,
      metadata_filename,
      include_metadata: true,
      overwrite: false,
    }
  }
}

fn default_weights_filename(format: ExportFormat, input_path: Option<&Path>) -> PathBuf {
  let ext = format.extension();
  let Some(path) = input_path else {
    return PathBuf::from(format!("model.{ext}"));
  };

  let Some(name) = path.file_name() else {
    return PathBuf::from(format!("model.{ext}"));
  };

  Path::new(name).with_extension(ext)
}

#[derive(Debug, Clone, Serialize)]
pub struct ExportResult {
  pub weights_path: PathBuf,
  pub metadata_path: Option<PathBuf>,
  pub source_sha256: String,
  pub tensor_count: usize,
  pub total_tensor_bytes: usize,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointTensorMetadata {
  pub name: String,
  pub dtype: String,
  pub shape: Vec<usize>,
  pub sha256: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMetadata {
  pub format_version: usize,
  pub source_file: String,
  pub source_sha256: String,
  pub safetensors_file: String,
  pub created_at_unix: u64,
  pub tensor_count: usize,
  pub total_tensor_bytes: usize,
  #[serde(default)]
  pub metadata: serde_yaml::Value,
  #[serde(default)]
  pub objects: Vec<String>,
  pub tensors: Vec<CheckpointTensorMetadata>,
}

#[derive(Debug, Clone)]
pub enum ReconstructSource {
  WeightsFile(PathBuf),
  StateDict(BTreeMap<String, TensorData>),
}

#[derive(Debug, Clone)]
pub enum TensorArray {
  F32(ArrayD<f32>),
  F64(ArrayD<f64>),
  I8(ArrayD<i8>),
  I16(ArrayD<i16>),
  I32(ArrayD<i32>),
  I64(ArrayD<i64>),
  U8(ArrayD<u8>),
  Bool(ArrayD<bool>),
}

#[derive(Debug)]
pub enum ConvertError {
  Io(std::io::Error),
  Zip(zip::result::ZipError),
  Json(serde_json::Error),
  Yaml(serde_yaml::Error),
  Ndarray(ndarray::ShapeError),
  UnsupportedFormat(String),
  UnsafeOpcode { opcode: u8, offset: usize },
  InvalidStructure(String),
  ResourceLimitExceeded(String),
}

impl fmt::Display for ConvertError {
  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    match self {
      ConvertError::Io(err) => write!(f, "io error: {}", err),
      ConvertError::Zip(err) => write!(f, "zip error: {}", err),
      ConvertError::Json(err) => write!(f, "json error: {}", err),
      ConvertError::Yaml(err) => write!(f, "yaml error: {}", err),
      ConvertError::Ndarray(err) => write!(f, "ndarray error: {}", err),
      ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
      ConvertError::UnsafeOpcode { opcode, offset } => {
        write!(f, "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}")
      }
      ConvertError::InvalidStructure(msg) => {
        write!(f, "invalid checkpoint structure: {}", msg)
      }
      ConvertError::ResourceLimitExceeded(msg) => {
        write!(f, "resource limit exceeded: {}", msg)
      }
    }
  }
}

impl std::error::Error for ConvertError {}

impl From<std::io::Error> for ConvertError {
  fn from(value: std::io::Error) -> Self {
    Self::Io(value)
  }
}

impl From<zip::result::ZipError> for ConvertError {
  fn from(value: zip::result::ZipError) -> Self {
    Self::Zip(value)
  }
}

impl From<serde_json::Error> for ConvertError {
  fn from(value: serde_json::Error) -> Self {
    Self::Json(value)
  }
}

impl From<serde_yaml::Error> for ConvertError {
  fn from(value: serde_yaml::Error) -> Self {
    Self::Yaml(value)
  }
}

impl From<ndarray::ShapeError> for ConvertError {
  fn from(value: ndarray::ShapeError) -> Self {
    Self::Ndarray(value)
  }
}

pub type Result<T> = std::result::Result<T, ConvertError>;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
  F16,
  BF16,
  F32,
  F64,
  I8,
  I16,
  I32,
  I64,
  U8,
  Bool,
}

impl DType {
  pub fn elem_size(self) -> usize {
    match self {
      DType::F16 | DType::BF16 | DType::I16 => 2,
      DType::F32 | DType::I32 => 4,
      DType::F64 | DType::I64 => 8,
      DType::I8 | DType::U8 | DType::Bool => 1,
    }
  }

  pub fn as_safetensors(self) -> &'static str {
    match self {
      DType::F16 => "F16",
      DType::BF16 => "BF16",
      DType::F32 => "F32",
      DType::F64 => "F64",
      DType::I8 => "I8",
      DType::I16 => "I16",
      DType::I32 => "I32",
      DType::I64 => "I64",
      DType::U8 => "U8",
      DType::Bool => "BOOL",
    }
  }

  pub fn from_safetensors(value: &str) -> Option<Self> {
    match value {
      "F16" => Some(DType::F16),
      "BF16" => Some(DType::BF16),
      "F32" => Some(DType::F32),
      "F64" => Some(DType::F64),
      "I8" => Some(DType::I8),
      "I16" => Some(DType::I16),
      "I32" => Some(DType::I32),
      "I64" => Some(DType::I64),
      "U8" => Some(DType::U8),
      "BOOL" => Some(DType::Bool),
      _ => None,
    }
  }
}

#[derive(Debug, Clone)]
pub struct StorageRef {
  pub key: String,
  pub dtype: DType,
  pub size_elems: usize,
}

#[derive(Debug, Clone)]
pub struct TensorRef {
  pub storage: StorageRef,
  pub offset_elems: usize,
  pub shape: Vec<usize>,
  pub stride: Vec<usize>,
}

#[derive(Debug, Clone)]
pub struct TensorData {
  pub dtype: DType,
  pub shape: Vec<usize>,
  pub bytes: Vec<u8>,
}

#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum Value {
  Marker,
  None,
  Bool(bool),
  Int(i64),
  Float(f64),
  String(String),
  Bytes(Vec<u8>),
  List(Vec<Value>),
  Set(Vec<Value>),
  Tuple(Vec<Value>),
  Dict(Vec<(Value, Value)>),
  Global {
    module: String,
    name: String,
  },
  StorageRef(StorageRef),
  TensorRef(TensorRef),
  OrderedDict(Vec<(String, Value)>),
  Object {
    module: String,
    name: String,
    args: Option<Box<Value>>,
    state: Option<Box<Value>>,
  },
}

impl Value {
  pub(crate) fn as_usize(&self) -> Result<usize> {
    match self {
      Value::Int(v) if *v >= 0 => Ok(*v as usize),
      _ => Err(ConvertError::InvalidStructure(
        "expected non-negative integer".to_string(),
      )),
    }
  }

  pub(crate) fn as_string(&self) -> Result<String> {
    match self {
      Value::String(v) => Ok(v.clone()),
      Value::Int(v) => Ok(v.to_string()),
      _ => Err(ConvertError::InvalidStructure("expected string".to_string())),
    }
  }

  pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
    match self {
      Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
      _ => Err(ConvertError::InvalidStructure(
        "expected tuple/list of integers".to_string(),
      )),
    }
  }
}

pub struct ParsedCheckpoint {
  pub source_sha256: String,
  pub warnings: Vec<String>,
  pub tensors: BTreeMap<String, TensorData>,
  pub metadata: serde_yaml::Value,
  pub objects: Vec<String>,
}