use serde::Serialize;
use std::collections::BTreeMap;
use std::fmt;
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct ConvertOptions {
pub max_archive_bytes: u64,
pub max_tensor_count: usize,
pub max_tensor_bytes: usize,
pub max_pickle_bytes: usize,
pub strict_contiguous: bool,
}
impl Default for ConvertOptions {
fn default() -> Self {
Self {
max_archive_bytes: 4 * 1024 * 1024 * 1024,
max_tensor_count: 4096,
max_tensor_bytes: 1024 * 1024 * 1024,
max_pickle_bytes: 64 * 1024 * 1024,
strict_contiguous: true,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct TensorSummary {
pub name: String,
pub dtype: String,
pub shape: Vec<usize>,
pub nbytes: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct InspectionReport {
pub detected_format: String,
pub source_file: String,
pub source_sha256: String,
pub tensor_count: usize,
pub total_tensor_bytes: usize,
pub tensors: Vec<TensorSummary>,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ConvertResult {
pub safetensors_path: PathBuf,
pub model_yaml_path: PathBuf,
pub source_file: PathBuf,
pub source_sha256: String,
pub tensor_count: usize,
pub total_tensor_bytes: usize,
}
#[derive(Debug)]
pub enum ConvertError {
Io(std::io::Error),
Zip(zip::result::ZipError),
Json(serde_json::Error),
UnsupportedFormat(String),
UnsafeOpcode { opcode: u8, offset: usize },
InvalidStructure(String),
ResourceLimitExceeded(String),
}
impl fmt::Display for ConvertError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConvertError::Io(err) => write!(f, "io error: {}", err),
ConvertError::Zip(err) => write!(f, "zip error: {}", err),
ConvertError::Json(err) => write!(f, "json error: {}", err),
ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
ConvertError::UnsafeOpcode { opcode, offset } => {
write!(
f,
"unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}"
)
}
ConvertError::InvalidStructure(msg) => write!(f, "invalid checkpoint structure: {}", msg),
ConvertError::ResourceLimitExceeded(msg) => write!(f, "resource limit exceeded: {}", msg),
}
}
}
impl std::error::Error for ConvertError {}
impl From<std::io::Error> for ConvertError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<zip::result::ZipError> for ConvertError {
fn from(value: zip::result::ZipError) -> Self {
Self::Zip(value)
}
}
impl From<serde_json::Error> for ConvertError {
fn from(value: serde_json::Error) -> Self {
Self::Json(value)
}
}
pub type Result<T> = std::result::Result<T, ConvertError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DType {
F16,
BF16,
F32,
F64,
I8,
I16,
I32,
I64,
U8,
Bool,
}
impl DType {
pub fn elem_size(self) -> usize {
match self {
DType::F16 | DType::BF16 | DType::I16 => 2,
DType::F32 | DType::I32 => 4,
DType::F64 | DType::I64 => 8,
DType::I8 | DType::U8 | DType::Bool => 1,
}
}
pub fn as_safetensors(self) -> &'static str {
match self {
DType::F16 => "F16",
DType::BF16 => "BF16",
DType::F32 => "F32",
DType::F64 => "F64",
DType::I8 => "I8",
DType::I16 => "I16",
DType::I32 => "I32",
DType::I64 => "I64",
DType::U8 => "U8",
DType::Bool => "BOOL",
}
}
}
#[derive(Debug, Clone)]
pub struct StorageRef {
pub key: String,
pub dtype: DType,
pub size_elems: usize,
}
#[derive(Debug, Clone)]
pub struct TensorRef {
pub storage: StorageRef,
pub offset_elems: usize,
pub shape: Vec<usize>,
pub stride: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct TensorData {
pub dtype: DType,
pub shape: Vec<usize>,
pub bytes: Vec<u8>,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum Value {
Marker,
None,
Bool(bool),
Int(i64),
Float(f64),
String(String),
Bytes(Vec<u8>),
List(Vec<Value>),
Set(Vec<Value>),
Tuple(Vec<Value>),
Dict(Vec<(Value, Value)>),
Global { module: String, name: String },
StorageRef(StorageRef),
TensorRef(TensorRef),
OrderedDict(Vec<(String, Value)>),
Object {
module: String,
name: String,
args: Option<Box<Value>>,
state: Option<Box<Value>>,
},
}
impl Value {
pub(crate) fn as_usize(&self) -> Result<usize> {
match self {
Value::Int(v) if *v >= 0 => Ok(*v as usize),
_ => Err(ConvertError::InvalidStructure(
"expected non-negative integer".to_string(),
)),
}
}
pub(crate) fn as_string(&self) -> Result<String> {
match self {
Value::String(v) => Ok(v.clone()),
Value::Int(v) => Ok(v.to_string()),
_ => Err(ConvertError::InvalidStructure(
"expected string".to_string(),
)),
}
}
pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
match self {
Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
_ => Err(ConvertError::InvalidStructure(
"expected tuple/list of integers".to_string(),
)),
}
}
}
pub struct ParsedCheckpoint {
pub source_sha256: String,
pub warnings: Vec<String>,
pub tensors: BTreeMap<String, TensorData>,
pub metadata: serde_yaml::Value,
pub objects: Vec<String>,
}