Skip to main content

pt_loader/
types.rs

1use serde::Serialize;
2use std::collections::BTreeMap;
3use std::fmt;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct ConvertOptions {
8  pub max_archive_bytes: u64,
9  pub max_tensor_count: usize,
10  pub max_tensor_bytes: usize,
11  pub max_pickle_bytes: usize,
12  pub strict_contiguous: bool,
13}
14
15impl Default for ConvertOptions {
16  fn default() -> Self {
17    Self {
18      max_archive_bytes: 4 * 1024 * 1024 * 1024,
19      max_tensor_count: 4096,
20      max_tensor_bytes: 1024 * 1024 * 1024,
21      max_pickle_bytes: 64 * 1024 * 1024,
22      strict_contiguous: true,
23    }
24  }
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct TensorSummary {
29  pub name: String,
30  pub dtype: String,
31  pub shape: Vec<usize>,
32  pub nbytes: usize,
33}
34
35#[derive(Debug, Clone, Serialize)]
36pub struct InspectionReport {
37  pub detected_format: String,
38  pub source_file: String,
39  pub source_sha256: String,
40  pub tensor_count: usize,
41  pub total_tensor_bytes: usize,
42  pub tensors: Vec<TensorSummary>,
43  pub warnings: Vec<String>,
44}
45
46#[derive(Debug, Clone, Serialize)]
47pub struct ConvertResult {
48  pub safetensors_path: PathBuf,
49  pub model_yaml_path: PathBuf,
50  pub source_file: PathBuf,
51  pub source_sha256: String,
52  pub tensor_count: usize,
53  pub total_tensor_bytes: usize,
54}
55
56#[derive(Debug)]
57pub enum ConvertError {
58  Io(std::io::Error),
59  Zip(zip::result::ZipError),
60  Json(serde_json::Error),
61  UnsupportedFormat(String),
62  UnsafeOpcode { opcode: u8, offset: usize },
63  InvalidStructure(String),
64  ResourceLimitExceeded(String),
65}
66
67impl fmt::Display for ConvertError {
68  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69    match self {
70      ConvertError::Io(err) => write!(f, "io error: {}", err),
71      ConvertError::Zip(err) => write!(f, "zip error: {}", err),
72      ConvertError::Json(err) => write!(f, "json error: {}", err),
73      ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
74      ConvertError::UnsafeOpcode { opcode, offset } => {
75        write!(
76          f,
77          "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}"
78        )
79      }
80      ConvertError::InvalidStructure(msg) => write!(f, "invalid checkpoint structure: {}", msg),
81      ConvertError::ResourceLimitExceeded(msg) => write!(f, "resource limit exceeded: {}", msg),
82    }
83  }
84}
85
86impl std::error::Error for ConvertError {}
87
88impl From<std::io::Error> for ConvertError {
89  fn from(value: std::io::Error) -> Self {
90    Self::Io(value)
91  }
92}
93
94impl From<zip::result::ZipError> for ConvertError {
95  fn from(value: zip::result::ZipError) -> Self {
96    Self::Zip(value)
97  }
98}
99
100impl From<serde_json::Error> for ConvertError {
101  fn from(value: serde_json::Error) -> Self {
102    Self::Json(value)
103  }
104}
105
106pub type Result<T> = std::result::Result<T, ConvertError>;
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum DType {
110  F16,
111  BF16,
112  F32,
113  F64,
114  I8,
115  I16,
116  I32,
117  I64,
118  U8,
119  Bool,
120}
121
122impl DType {
123  pub fn elem_size(self) -> usize {
124    match self {
125      DType::F16 | DType::BF16 | DType::I16 => 2,
126      DType::F32 | DType::I32 => 4,
127      DType::F64 | DType::I64 => 8,
128      DType::I8 | DType::U8 | DType::Bool => 1,
129    }
130  }
131
132  pub fn as_safetensors(self) -> &'static str {
133    match self {
134      DType::F16 => "F16",
135      DType::BF16 => "BF16",
136      DType::F32 => "F32",
137      DType::F64 => "F64",
138      DType::I8 => "I8",
139      DType::I16 => "I16",
140      DType::I32 => "I32",
141      DType::I64 => "I64",
142      DType::U8 => "U8",
143      DType::Bool => "BOOL",
144    }
145  }
146}
147
148#[derive(Debug, Clone)]
149pub struct StorageRef {
150  pub key: String,
151  pub dtype: DType,
152  pub size_elems: usize,
153}
154
155#[derive(Debug, Clone)]
156pub struct TensorRef {
157  pub storage: StorageRef,
158  pub offset_elems: usize,
159  pub shape: Vec<usize>,
160  pub stride: Vec<usize>,
161}
162
163#[derive(Debug, Clone)]
164pub struct TensorData {
165  pub dtype: DType,
166  pub shape: Vec<usize>,
167  pub bytes: Vec<u8>,
168}
169
170#[allow(dead_code)]
171#[derive(Debug, Clone)]
172pub enum Value {
173  Marker,
174  None,
175  Bool(bool),
176  Int(i64),
177  Float(f64),
178  String(String),
179  Bytes(Vec<u8>),
180  List(Vec<Value>),
181  Set(Vec<Value>),
182  Tuple(Vec<Value>),
183  Dict(Vec<(Value, Value)>),
184  Global { module: String, name: String },
185  StorageRef(StorageRef),
186  TensorRef(TensorRef),
187  OrderedDict(Vec<(String, Value)>),
188  Object {
189    module: String,
190    name: String,
191    args: Option<Box<Value>>,
192    state: Option<Box<Value>>,
193  },
194}
195
196impl Value {
197  pub(crate) fn as_usize(&self) -> Result<usize> {
198    match self {
199      Value::Int(v) if *v >= 0 => Ok(*v as usize),
200      _ => Err(ConvertError::InvalidStructure(
201        "expected non-negative integer".to_string(),
202      )),
203    }
204  }
205
206  pub(crate) fn as_string(&self) -> Result<String> {
207    match self {
208      Value::String(v) => Ok(v.clone()),
209      Value::Int(v) => Ok(v.to_string()),
210      _ => Err(ConvertError::InvalidStructure(
211        "expected string".to_string(),
212      )),
213    }
214  }
215
216  pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
217    match self {
218      Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
219      _ => Err(ConvertError::InvalidStructure(
220        "expected tuple/list of integers".to_string(),
221      )),
222    }
223  }
224}
225
226pub struct ParsedCheckpoint {
227  pub source_sha256: String,
228  pub warnings: Vec<String>,
229  pub tensors: BTreeMap<String, TensorData>,
230  pub metadata: serde_yaml::Value,
231  pub objects: Vec<String>,
232}