Skip to main content

pt_loader/
types.rs

1use ndarray::ArrayD;
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::fmt;
5use std::path::{Path, PathBuf};
6
7#[derive(Debug, Clone)]
8pub struct LoadOptions {
9  pub max_archive_bytes: u64,
10  pub max_tensor_count: usize,
11  pub max_tensor_bytes: usize,
12  pub max_pickle_bytes: usize,
13  pub strict_contiguous: bool,
14  pub state_dict_root_keys: Vec<String>,
15  pub state_dict_root_strict: bool,
16}
17
18impl Default for LoadOptions {
19  fn default() -> Self {
20    Self {
21      max_archive_bytes: 4 * 1024 * 1024 * 1024,
22      max_tensor_count: 4096,
23      max_tensor_bytes: 1024 * 1024 * 1024,
24      max_pickle_bytes: 64 * 1024 * 1024,
25      strict_contiguous: true,
26      state_dict_root_keys: Vec::new(),
27      state_dict_root_strict: true,
28    }
29  }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ExportFormat {
34  Safetensors,
35}
36
37impl ExportFormat {
38  pub fn extension(self) -> &'static str {
39    match self {
40      ExportFormat::Safetensors => "safetensors",
41    }
42  }
43}
44
45#[derive(Debug, Clone)]
46pub struct ExportOptions {
47  pub format: ExportFormat,
48  pub weights_filename: PathBuf,
49  pub metadata_filename: PathBuf,
50  pub include_metadata: bool,
51  pub overwrite: bool,
52}
53
54impl ExportOptions {
55  pub fn new(format: ExportFormat, input_path: Option<&Path>) -> Self {
56    let weights_filename = default_weights_filename(format, input_path);
57    let metadata_filename = weights_filename.with_extension("yaml");
58
59    Self {
60      format,
61      weights_filename,
62      metadata_filename,
63      include_metadata: true,
64      overwrite: false,
65    }
66  }
67}
68
69fn default_weights_filename(format: ExportFormat, input_path: Option<&Path>) -> PathBuf {
70  let ext = format.extension();
71  let Some(path) = input_path else {
72    return PathBuf::from(format!("model.{ext}"));
73  };
74
75  let Some(name) = path.file_name() else {
76    return PathBuf::from(format!("model.{ext}"));
77  };
78
79  Path::new(name).with_extension(ext)
80}
81
82#[derive(Debug, Clone, Serialize)]
83pub struct ExportResult {
84  pub weights_path: PathBuf,
85  #[serde(default)]
86  pub weights_paths: BTreeMap<String, PathBuf>,
87  pub metadata_path: Option<PathBuf>,
88  pub source_sha256: String,
89  pub tensor_count: usize,
90  pub total_tensor_bytes: usize,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CheckpointTensorMetadata {
95  pub name: String,
96  pub dtype: String,
97  pub shape: Vec<usize>,
98  pub sha256: String,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct CheckpointSecurity {
103  #[serde(default)]
104  pub objects: Vec<String>,
105  #[serde(default)]
106  pub calls: Vec<String>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CheckpointMetadata {
111  pub format_version: usize,
112  pub source_file: String,
113  pub source_sha256: String,
114  pub safetensors_file: String,
115  #[serde(default)]
116  pub safetensors_files: BTreeMap<String, String>,
117  pub created_at_unix: u64,
118  pub tensor_count: usize,
119  pub total_tensor_bytes: usize,
120  #[serde(default)]
121  pub metadata: serde_yaml::Value,
122  pub security: CheckpointSecurity,
123  pub tensors: TensorManifest,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(untagged)]
128pub enum TensorManifest {
129  List(Vec<CheckpointTensorMetadata>),
130  ByRoot(BTreeMap<String, Vec<CheckpointTensorMetadata>>),
131}
132
133#[derive(Debug, Clone)]
134pub enum ReconstructSource {
135  WeightsFile(PathBuf),
136  StateDict(BTreeMap<String, TensorData>),
137}
138
139#[derive(Debug, Clone)]
140pub enum TensorArray {
141  F32(ArrayD<f32>),
142  F64(ArrayD<f64>),
143  I8(ArrayD<i8>),
144  I16(ArrayD<i16>),
145  I32(ArrayD<i32>),
146  I64(ArrayD<i64>),
147  U8(ArrayD<u8>),
148  Bool(ArrayD<bool>),
149}
150
151#[derive(Debug)]
152pub enum ConvertError {
153  Io(std::io::Error),
154  Zip(zip::result::ZipError),
155  Json(serde_json::Error),
156  Yaml(serde_yaml::Error),
157  Ndarray(ndarray::ShapeError),
158  UnsupportedFormat(String),
159  UnsafeOpcode { opcode: u8, offset: usize },
160  InvalidStructure(String),
161  ResourceLimitExceeded(String),
162}
163
164impl fmt::Display for ConvertError {
165  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166    match self {
167      ConvertError::Io(err) => write!(f, "io error: {}", err),
168      ConvertError::Zip(err) => write!(f, "zip error: {}", err),
169      ConvertError::Json(err) => write!(f, "json error: {}", err),
170      ConvertError::Yaml(err) => write!(f, "yaml error: {}", err),
171      ConvertError::Ndarray(err) => write!(f, "ndarray error: {}", err),
172      ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
173      ConvertError::UnsafeOpcode { opcode, offset } => {
174        write!(f, "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}")
175      }
176      ConvertError::InvalidStructure(msg) => {
177        write!(f, "invalid checkpoint structure: {}", msg)
178      }
179      ConvertError::ResourceLimitExceeded(msg) => {
180        write!(f, "resource limit exceeded: {}", msg)
181      }
182    }
183  }
184}
185
186impl std::error::Error for ConvertError {}
187
188impl From<std::io::Error> for ConvertError {
189  fn from(value: std::io::Error) -> Self {
190    Self::Io(value)
191  }
192}
193
194impl From<zip::result::ZipError> for ConvertError {
195  fn from(value: zip::result::ZipError) -> Self {
196    Self::Zip(value)
197  }
198}
199
200impl From<serde_json::Error> for ConvertError {
201  fn from(value: serde_json::Error) -> Self {
202    Self::Json(value)
203  }
204}
205
206impl From<serde_yaml::Error> for ConvertError {
207  fn from(value: serde_yaml::Error) -> Self {
208    Self::Yaml(value)
209  }
210}
211
212impl From<ndarray::ShapeError> for ConvertError {
213  fn from(value: ndarray::ShapeError) -> Self {
214    Self::Ndarray(value)
215  }
216}
217
218pub type Result<T> = std::result::Result<T, ConvertError>;
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum DType {
222  F16,
223  BF16,
224  F32,
225  F64,
226  I8,
227  I16,
228  I32,
229  I64,
230  U8,
231  Bool,
232}
233
234impl DType {
235  pub fn elem_size(self) -> usize {
236    match self {
237      DType::F16 | DType::BF16 | DType::I16 => 2,
238      DType::F32 | DType::I32 => 4,
239      DType::F64 | DType::I64 => 8,
240      DType::I8 | DType::U8 | DType::Bool => 1,
241    }
242  }
243
244  pub fn as_safetensors(self) -> &'static str {
245    match self {
246      DType::F16 => "F16",
247      DType::BF16 => "BF16",
248      DType::F32 => "F32",
249      DType::F64 => "F64",
250      DType::I8 => "I8",
251      DType::I16 => "I16",
252      DType::I32 => "I32",
253      DType::I64 => "I64",
254      DType::U8 => "U8",
255      DType::Bool => "BOOL",
256    }
257  }
258
259  pub fn from_safetensors(value: &str) -> Option<Self> {
260    match value {
261      "F16" => Some(DType::F16),
262      "BF16" => Some(DType::BF16),
263      "F32" => Some(DType::F32),
264      "F64" => Some(DType::F64),
265      "I8" => Some(DType::I8),
266      "I16" => Some(DType::I16),
267      "I32" => Some(DType::I32),
268      "I64" => Some(DType::I64),
269      "U8" => Some(DType::U8),
270      "BOOL" => Some(DType::Bool),
271      _ => None,
272    }
273  }
274}
275
276#[derive(Debug, Clone)]
277pub struct StorageRef {
278  pub key: String,
279  pub dtype: DType,
280  pub size_elems: usize,
281}
282
283#[derive(Debug, Clone)]
284pub struct TensorRef {
285  pub storage: StorageRef,
286  pub offset_elems: usize,
287  pub shape: Vec<usize>,
288  pub stride: Vec<usize>,
289}
290
291#[derive(Debug, Clone)]
292pub struct TensorData {
293  pub dtype: DType,
294  pub shape: Vec<usize>,
295  pub bytes: Vec<u8>,
296}
297
298#[allow(dead_code)]
299#[derive(Debug, Clone)]
300pub enum Value {
301  Marker,
302  None,
303  Bool(bool),
304  Int(i64),
305  Float(f64),
306  String(String),
307  Bytes(Vec<u8>),
308  List(Vec<Value>),
309  Set(Vec<Value>),
310  Tuple(Vec<Value>),
311  Dict(Vec<(Value, Value)>),
312  Global {
313    module: String,
314    name: String,
315  },
316  StorageRef(StorageRef),
317  TensorRef(TensorRef),
318  OrderedDict(Vec<(String, Value)>),
319  Call {
320    func: String,
321    args: Vec<Value>,
322    state: Option<Box<Value>>,
323  },
324  Object {
325    module: String,
326    name: String,
327    args: Vec<Value>,
328    state: Option<Box<Value>>,
329  },
330}
331
332impl Value {
333  pub(crate) fn as_usize(&self) -> Result<usize> {
334    match self {
335      Value::Int(v) if *v >= 0 => Ok(*v as usize),
336      _ => Err(ConvertError::InvalidStructure(
337        "expected non-negative integer".to_string(),
338      )),
339    }
340  }
341
342  pub(crate) fn as_string(&self) -> Result<String> {
343    match self {
344      Value::String(v) => Ok(v.clone()),
345      Value::Int(v) => Ok(v.to_string()),
346      _ => Err(ConvertError::InvalidStructure("expected string".to_string())),
347    }
348  }
349
350  pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
351    match self {
352      Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
353      _ => Err(ConvertError::InvalidStructure(
354        "expected tuple/list of integers".to_string(),
355      )),
356    }
357  }
358}
359
360pub struct ParsedCheckpoint {
361  pub source_sha256: String,
362  pub warnings: Vec<String>,
363  pub tensors: BTreeMap<String, TensorData>,
364  pub tensor_groups: BTreeMap<String, BTreeMap<String, TensorData>>,
365  pub metadata: serde_yaml::Value,
366  pub security: CheckpointSecurity,
367}