pt-loader 0.1.1

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use serde::Serialize;
use std::collections::BTreeMap;
use std::fmt;
use std::path::PathBuf;

#[derive(Debug, Clone)]
pub struct ConvertOptions {
  pub max_archive_bytes: u64,
  pub max_tensor_count: usize,
  pub max_tensor_bytes: usize,
  pub max_pickle_bytes: usize,
  pub strict_contiguous: bool,
}

impl Default for ConvertOptions {
  fn default() -> Self {
    Self {
      max_archive_bytes: 4 * 1024 * 1024 * 1024,
      max_tensor_count: 4096,
      max_tensor_bytes: 1024 * 1024 * 1024,
      max_pickle_bytes: 64 * 1024 * 1024,
      strict_contiguous: true,
    }
  }
}

#[derive(Debug, Clone, Serialize)]
pub struct TensorSummary {
  pub name: String,
  pub dtype: String,
  pub shape: Vec<usize>,
  pub nbytes: usize,
}

#[derive(Debug, Clone, Serialize)]
pub struct InspectionReport {
  pub detected_format: String,
  pub source_file: String,
  pub source_sha256: String,
  pub tensor_count: usize,
  pub total_tensor_bytes: usize,
  pub tensors: Vec<TensorSummary>,
  pub warnings: Vec<String>,
}

#[derive(Debug, Clone, Serialize)]
pub struct ConvertResult {
  pub safetensors_path: PathBuf,
  pub model_yaml_path: PathBuf,
  pub source_file: PathBuf,
  pub source_sha256: String,
  pub tensor_count: usize,
  pub total_tensor_bytes: usize,
}

#[derive(Debug)]
pub enum ConvertError {
  Io(std::io::Error),
  Zip(zip::result::ZipError),
  Json(serde_json::Error),
  UnsupportedFormat(String),
  UnsafeOpcode { opcode: u8, offset: usize },
  InvalidStructure(String),
  ResourceLimitExceeded(String),
}

impl fmt::Display for ConvertError {
  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    match self {
      ConvertError::Io(err) => write!(f, "io error: {}", err),
      ConvertError::Zip(err) => write!(f, "zip error: {}", err),
      ConvertError::Json(err) => write!(f, "json error: {}", err),
      ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
      ConvertError::UnsafeOpcode { opcode, offset } => {
        write!(
          f,
          "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}"
        )
      }
      ConvertError::InvalidStructure(msg) => write!(f, "invalid checkpoint structure: {}", msg),
      ConvertError::ResourceLimitExceeded(msg) => write!(f, "resource limit exceeded: {}", msg),
    }
  }
}

impl std::error::Error for ConvertError {}

impl From<std::io::Error> for ConvertError {
  fn from(value: std::io::Error) -> Self {
    Self::Io(value)
  }
}

impl From<zip::result::ZipError> for ConvertError {
  fn from(value: zip::result::ZipError) -> Self {
    Self::Zip(value)
  }
}

impl From<serde_json::Error> for ConvertError {
  fn from(value: serde_json::Error) -> Self {
    Self::Json(value)
  }
}

pub type Result<T> = std::result::Result<T, ConvertError>;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
  F16,
  BF16,
  F32,
  F64,
  I8,
  I16,
  I32,
  I64,
  U8,
  Bool,
}

impl DType {
  pub fn elem_size(self) -> usize {
    match self {
      DType::F16 | DType::BF16 | DType::I16 => 2,
      DType::F32 | DType::I32 => 4,
      DType::F64 | DType::I64 => 8,
      DType::I8 | DType::U8 | DType::Bool => 1,
    }
  }

  pub fn as_safetensors(self) -> &'static str {
    match self {
      DType::F16 => "F16",
      DType::BF16 => "BF16",
      DType::F32 => "F32",
      DType::F64 => "F64",
      DType::I8 => "I8",
      DType::I16 => "I16",
      DType::I32 => "I32",
      DType::I64 => "I64",
      DType::U8 => "U8",
      DType::Bool => "BOOL",
    }
  }
}

#[derive(Debug, Clone)]
pub struct StorageRef {
  pub key: String,
  pub dtype: DType,
  pub size_elems: usize,
}

#[derive(Debug, Clone)]
pub struct TensorRef {
  pub storage: StorageRef,
  pub offset_elems: usize,
  pub shape: Vec<usize>,
  pub stride: Vec<usize>,
}

#[derive(Debug, Clone)]
pub struct TensorData {
  pub dtype: DType,
  pub shape: Vec<usize>,
  pub bytes: Vec<u8>,
}

#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum Value {
  Marker,
  None,
  Bool(bool),
  Int(i64),
  Float(f64),
  String(String),
  Bytes(Vec<u8>),
  List(Vec<Value>),
  Set(Vec<Value>),
  Tuple(Vec<Value>),
  Dict(Vec<(Value, Value)>),
  Global { module: String, name: String },
  StorageRef(StorageRef),
  TensorRef(TensorRef),
  OrderedDict(Vec<(String, Value)>),
  Object {
    module: String,
    name: String,
    args: Option<Box<Value>>,
    state: Option<Box<Value>>,
  },
}

impl Value {
  pub(crate) fn as_usize(&self) -> Result<usize> {
    match self {
      Value::Int(v) if *v >= 0 => Ok(*v as usize),
      _ => Err(ConvertError::InvalidStructure(
        "expected non-negative integer".to_string(),
      )),
    }
  }

  pub(crate) fn as_string(&self) -> Result<String> {
    match self {
      Value::String(v) => Ok(v.clone()),
      Value::Int(v) => Ok(v.to_string()),
      _ => Err(ConvertError::InvalidStructure(
        "expected string".to_string(),
      )),
    }
  }

  pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
    match self {
      Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
      _ => Err(ConvertError::InvalidStructure(
        "expected tuple/list of integers".to_string(),
      )),
    }
  }
}

pub struct ParsedCheckpoint {
  pub source_sha256: String,
  pub warnings: Vec<String>,
  pub tensors: BTreeMap<String, TensorData>,
  pub metadata: serde_yaml::Value,
  pub objects: Vec<String>,
}