Skip to main content

pt_loader/
types.rs

1use ndarray::ArrayD;
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::fmt;
5use std::path::{Path, PathBuf};
6
7#[derive(Debug, Clone)]
8pub struct LoadOptions {
9  pub max_archive_bytes: u64,
10  pub max_tensor_count: usize,
11  pub max_tensor_bytes: usize,
12  pub max_pickle_bytes: usize,
13  pub strict_contiguous: bool,
14}
15
16impl Default for LoadOptions {
17  fn default() -> Self {
18    Self {
19      max_archive_bytes: 4 * 1024 * 1024 * 1024,
20      max_tensor_count: 4096,
21      max_tensor_bytes: 1024 * 1024 * 1024,
22      max_pickle_bytes: 64 * 1024 * 1024,
23      strict_contiguous: true,
24    }
25  }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ExportFormat {
30  Safetensors,
31}
32
33impl ExportFormat {
34  pub fn extension(self) -> &'static str {
35    match self {
36      ExportFormat::Safetensors => "safetensors",
37    }
38  }
39}
40
41#[derive(Debug, Clone)]
42pub struct ExportOptions {
43  pub format: ExportFormat,
44  pub weights_filename: PathBuf,
45  pub metadata_filename: PathBuf,
46  pub include_metadata: bool,
47  pub overwrite: bool,
48}
49
50impl ExportOptions {
51  pub fn new(format: ExportFormat, input_path: Option<&Path>) -> Self {
52    let weights_filename = default_weights_filename(format, input_path);
53    let metadata_filename = weights_filename.with_extension("yaml");
54
55    Self {
56      format,
57      weights_filename,
58      metadata_filename,
59      include_metadata: true,
60      overwrite: false,
61    }
62  }
63}
64
65fn default_weights_filename(format: ExportFormat, input_path: Option<&Path>) -> PathBuf {
66  let ext = format.extension();
67  let Some(path) = input_path else {
68    return PathBuf::from(format!("model.{ext}"));
69  };
70
71  let Some(name) = path.file_name() else {
72    return PathBuf::from(format!("model.{ext}"));
73  };
74
75  Path::new(name).with_extension(ext)
76}
77
78#[derive(Debug, Clone, Serialize)]
79pub struct ExportResult {
80  pub weights_path: PathBuf,
81  pub metadata_path: Option<PathBuf>,
82  pub source_sha256: String,
83  pub tensor_count: usize,
84  pub total_tensor_bytes: usize,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CheckpointTensorMetadata {
89  pub name: String,
90  pub dtype: String,
91  pub shape: Vec<usize>,
92  pub sha256: String,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct CheckpointMetadata {
97  pub format_version: usize,
98  pub source_file: String,
99  pub source_sha256: String,
100  pub safetensors_file: String,
101  pub created_at_unix: u64,
102  pub tensor_count: usize,
103  pub total_tensor_bytes: usize,
104  #[serde(default)]
105  pub metadata: serde_yaml::Value,
106  #[serde(default)]
107  pub objects: Vec<String>,
108  pub tensors: Vec<CheckpointTensorMetadata>,
109}
110
111#[derive(Debug, Clone)]
112pub enum ReconstructSource {
113  WeightsFile(PathBuf),
114  StateDict(BTreeMap<String, TensorData>),
115}
116
117#[derive(Debug, Clone)]
118pub enum TensorArray {
119  F32(ArrayD<f32>),
120  F64(ArrayD<f64>),
121  I8(ArrayD<i8>),
122  I16(ArrayD<i16>),
123  I32(ArrayD<i32>),
124  I64(ArrayD<i64>),
125  U8(ArrayD<u8>),
126  Bool(ArrayD<bool>),
127}
128
129#[derive(Debug)]
130pub enum ConvertError {
131  Io(std::io::Error),
132  Zip(zip::result::ZipError),
133  Json(serde_json::Error),
134  Yaml(serde_yaml::Error),
135  Ndarray(ndarray::ShapeError),
136  UnsupportedFormat(String),
137  UnsafeOpcode { opcode: u8, offset: usize },
138  InvalidStructure(String),
139  ResourceLimitExceeded(String),
140}
141
142impl fmt::Display for ConvertError {
143  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144    match self {
145      ConvertError::Io(err) => write!(f, "io error: {}", err),
146      ConvertError::Zip(err) => write!(f, "zip error: {}", err),
147      ConvertError::Json(err) => write!(f, "json error: {}", err),
148      ConvertError::Yaml(err) => write!(f, "yaml error: {}", err),
149      ConvertError::Ndarray(err) => write!(f, "ndarray error: {}", err),
150      ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
151      ConvertError::UnsafeOpcode { opcode, offset } => {
152        write!(f, "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}")
153      }
154      ConvertError::InvalidStructure(msg) => {
155        write!(f, "invalid checkpoint structure: {}", msg)
156      }
157      ConvertError::ResourceLimitExceeded(msg) => {
158        write!(f, "resource limit exceeded: {}", msg)
159      }
160    }
161  }
162}
163
164impl std::error::Error for ConvertError {}
165
166impl From<std::io::Error> for ConvertError {
167  fn from(value: std::io::Error) -> Self {
168    Self::Io(value)
169  }
170}
171
172impl From<zip::result::ZipError> for ConvertError {
173  fn from(value: zip::result::ZipError) -> Self {
174    Self::Zip(value)
175  }
176}
177
178impl From<serde_json::Error> for ConvertError {
179  fn from(value: serde_json::Error) -> Self {
180    Self::Json(value)
181  }
182}
183
184impl From<serde_yaml::Error> for ConvertError {
185  fn from(value: serde_yaml::Error) -> Self {
186    Self::Yaml(value)
187  }
188}
189
190impl From<ndarray::ShapeError> for ConvertError {
191  fn from(value: ndarray::ShapeError) -> Self {
192    Self::Ndarray(value)
193  }
194}
195
196pub type Result<T> = std::result::Result<T, ConvertError>;
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum DType {
200  F16,
201  BF16,
202  F32,
203  F64,
204  I8,
205  I16,
206  I32,
207  I64,
208  U8,
209  Bool,
210}
211
212impl DType {
213  pub fn elem_size(self) -> usize {
214    match self {
215      DType::F16 | DType::BF16 | DType::I16 => 2,
216      DType::F32 | DType::I32 => 4,
217      DType::F64 | DType::I64 => 8,
218      DType::I8 | DType::U8 | DType::Bool => 1,
219    }
220  }
221
222  pub fn as_safetensors(self) -> &'static str {
223    match self {
224      DType::F16 => "F16",
225      DType::BF16 => "BF16",
226      DType::F32 => "F32",
227      DType::F64 => "F64",
228      DType::I8 => "I8",
229      DType::I16 => "I16",
230      DType::I32 => "I32",
231      DType::I64 => "I64",
232      DType::U8 => "U8",
233      DType::Bool => "BOOL",
234    }
235  }
236
237  pub fn from_safetensors(value: &str) -> Option<Self> {
238    match value {
239      "F16" => Some(DType::F16),
240      "BF16" => Some(DType::BF16),
241      "F32" => Some(DType::F32),
242      "F64" => Some(DType::F64),
243      "I8" => Some(DType::I8),
244      "I16" => Some(DType::I16),
245      "I32" => Some(DType::I32),
246      "I64" => Some(DType::I64),
247      "U8" => Some(DType::U8),
248      "BOOL" => Some(DType::Bool),
249      _ => None,
250    }
251  }
252}
253
254#[derive(Debug, Clone)]
255pub struct StorageRef {
256  pub key: String,
257  pub dtype: DType,
258  pub size_elems: usize,
259}
260
261#[derive(Debug, Clone)]
262pub struct TensorRef {
263  pub storage: StorageRef,
264  pub offset_elems: usize,
265  pub shape: Vec<usize>,
266  pub stride: Vec<usize>,
267}
268
269#[derive(Debug, Clone)]
270pub struct TensorData {
271  pub dtype: DType,
272  pub shape: Vec<usize>,
273  pub bytes: Vec<u8>,
274}
275
276#[allow(dead_code)]
277#[derive(Debug, Clone)]
278pub enum Value {
279  Marker,
280  None,
281  Bool(bool),
282  Int(i64),
283  Float(f64),
284  String(String),
285  Bytes(Vec<u8>),
286  List(Vec<Value>),
287  Set(Vec<Value>),
288  Tuple(Vec<Value>),
289  Dict(Vec<(Value, Value)>),
290  Global {
291    module: String,
292    name: String,
293  },
294  StorageRef(StorageRef),
295  TensorRef(TensorRef),
296  OrderedDict(Vec<(String, Value)>),
297  Object {
298    module: String,
299    name: String,
300    args: Option<Box<Value>>,
301    state: Option<Box<Value>>,
302  },
303}
304
305impl Value {
306  pub(crate) fn as_usize(&self) -> Result<usize> {
307    match self {
308      Value::Int(v) if *v >= 0 => Ok(*v as usize),
309      _ => Err(ConvertError::InvalidStructure(
310        "expected non-negative integer".to_string(),
311      )),
312    }
313  }
314
315  pub(crate) fn as_string(&self) -> Result<String> {
316    match self {
317      Value::String(v) => Ok(v.clone()),
318      Value::Int(v) => Ok(v.to_string()),
319      _ => Err(ConvertError::InvalidStructure("expected string".to_string())),
320    }
321  }
322
323  pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
324    match self {
325      Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
326      _ => Err(ConvertError::InvalidStructure(
327        "expected tuple/list of integers".to_string(),
328      )),
329    }
330  }
331}
332
333pub struct ParsedCheckpoint {
334  pub source_sha256: String,
335  pub warnings: Vec<String>,
336  pub tensors: BTreeMap<String, TensorData>,
337  pub metadata: serde_yaml::Value,
338  pub objects: Vec<String>,
339}