1use ndarray::ArrayD;
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::fmt;
5use std::path::{Path, PathBuf};
6
7#[derive(Debug, Clone)]
8pub struct LoadOptions {
9 pub max_archive_bytes: u64,
10 pub max_tensor_count: usize,
11 pub max_tensor_bytes: usize,
12 pub max_pickle_bytes: usize,
13 pub strict_contiguous: bool,
14 pub state_dict_root_keys: Vec<String>,
15 pub state_dict_root_strict: bool,
16}
17
18impl Default for LoadOptions {
19 fn default() -> Self {
20 Self {
21 max_archive_bytes: 4 * 1024 * 1024 * 1024,
22 max_tensor_count: 4096,
23 max_tensor_bytes: 1024 * 1024 * 1024,
24 max_pickle_bytes: 64 * 1024 * 1024,
25 strict_contiguous: true,
26 state_dict_root_keys: Vec::new(),
27 state_dict_root_strict: true,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum ExportFormat {
34 Safetensors,
35}
36
37impl ExportFormat {
38 pub fn extension(self) -> &'static str {
39 match self {
40 ExportFormat::Safetensors => "safetensors",
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
46pub struct ExportOptions {
47 pub format: ExportFormat,
48 pub weights_filename: PathBuf,
49 pub metadata_filename: PathBuf,
50 pub include_metadata: bool,
51 pub overwrite: bool,
52}
53
54impl ExportOptions {
55 pub fn new(format: ExportFormat, input_path: Option<&Path>) -> Self {
56 let weights_filename = default_weights_filename(format, input_path);
57 let metadata_filename = weights_filename.with_extension("yaml");
58
59 Self {
60 format,
61 weights_filename,
62 metadata_filename,
63 include_metadata: true,
64 overwrite: false,
65 }
66 }
67}
68
69fn default_weights_filename(format: ExportFormat, input_path: Option<&Path>) -> PathBuf {
70 let ext = format.extension();
71 let Some(path) = input_path else {
72 return PathBuf::from(format!("model.{ext}"));
73 };
74
75 let Some(name) = path.file_name() else {
76 return PathBuf::from(format!("model.{ext}"));
77 };
78
79 Path::new(name).with_extension(ext)
80}
81
82#[derive(Debug, Clone, Serialize)]
83pub struct ExportResult {
84 pub weights_path: PathBuf,
85 #[serde(default)]
86 pub weights_paths: BTreeMap<String, PathBuf>,
87 pub metadata_path: Option<PathBuf>,
88 pub source_sha256: String,
89 pub tensor_count: usize,
90 pub total_tensor_bytes: usize,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CheckpointTensorMetadata {
95 pub name: String,
96 pub dtype: String,
97 pub shape: Vec<usize>,
98 pub sha256: String,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct CheckpointSecurity {
103 #[serde(default)]
104 pub objects: Vec<String>,
105 #[serde(default)]
106 pub calls: Vec<String>,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct CheckpointMetadata {
111 pub format_version: usize,
112 pub source_file: String,
113 pub source_sha256: String,
114 pub safetensors_file: String,
115 #[serde(default)]
116 pub safetensors_files: BTreeMap<String, String>,
117 pub created_at_unix: u64,
118 pub tensor_count: usize,
119 pub total_tensor_bytes: usize,
120 #[serde(default)]
121 pub metadata: serde_yaml::Value,
122 pub security: CheckpointSecurity,
123 pub tensors: TensorManifest,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127#[serde(untagged)]
128pub enum TensorManifest {
129 List(Vec<CheckpointTensorMetadata>),
130 ByRoot(BTreeMap<String, Vec<CheckpointTensorMetadata>>),
131}
132
133#[derive(Debug, Clone)]
134pub enum ReconstructSource {
135 WeightsFile(PathBuf),
136 StateDict(BTreeMap<String, TensorData>),
137}
138
139#[derive(Debug, Clone)]
140pub enum TensorArray {
141 F32(ArrayD<f32>),
142 F64(ArrayD<f64>),
143 I8(ArrayD<i8>),
144 I16(ArrayD<i16>),
145 I32(ArrayD<i32>),
146 I64(ArrayD<i64>),
147 U8(ArrayD<u8>),
148 Bool(ArrayD<bool>),
149}
150
151#[derive(Debug)]
152pub enum ConvertError {
153 Io(std::io::Error),
154 Zip(zip::result::ZipError),
155 Json(serde_json::Error),
156 Yaml(serde_yaml::Error),
157 Ndarray(ndarray::ShapeError),
158 UnsupportedFormat(String),
159 UnsafeOpcode { opcode: u8, offset: usize },
160 InvalidStructure(String),
161 ResourceLimitExceeded(String),
162}
163
164impl fmt::Display for ConvertError {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 match self {
167 ConvertError::Io(err) => write!(f, "io error: {}", err),
168 ConvertError::Zip(err) => write!(f, "zip error: {}", err),
169 ConvertError::Json(err) => write!(f, "json error: {}", err),
170 ConvertError::Yaml(err) => write!(f, "yaml error: {}", err),
171 ConvertError::Ndarray(err) => write!(f, "ndarray error: {}", err),
172 ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
173 ConvertError::UnsafeOpcode { opcode, offset } => {
174 write!(f, "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}")
175 }
176 ConvertError::InvalidStructure(msg) => {
177 write!(f, "invalid checkpoint structure: {}", msg)
178 }
179 ConvertError::ResourceLimitExceeded(msg) => {
180 write!(f, "resource limit exceeded: {}", msg)
181 }
182 }
183 }
184}
185
186impl std::error::Error for ConvertError {}
187
188impl From<std::io::Error> for ConvertError {
189 fn from(value: std::io::Error) -> Self {
190 Self::Io(value)
191 }
192}
193
194impl From<zip::result::ZipError> for ConvertError {
195 fn from(value: zip::result::ZipError) -> Self {
196 Self::Zip(value)
197 }
198}
199
200impl From<serde_json::Error> for ConvertError {
201 fn from(value: serde_json::Error) -> Self {
202 Self::Json(value)
203 }
204}
205
206impl From<serde_yaml::Error> for ConvertError {
207 fn from(value: serde_yaml::Error) -> Self {
208 Self::Yaml(value)
209 }
210}
211
212impl From<ndarray::ShapeError> for ConvertError {
213 fn from(value: ndarray::ShapeError) -> Self {
214 Self::Ndarray(value)
215 }
216}
217
218pub type Result<T> = std::result::Result<T, ConvertError>;
219
220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
221pub enum DType {
222 F16,
223 BF16,
224 F32,
225 F64,
226 I8,
227 I16,
228 I32,
229 I64,
230 U8,
231 Bool,
232}
233
234impl DType {
235 pub fn elem_size(self) -> usize {
236 match self {
237 DType::F16 | DType::BF16 | DType::I16 => 2,
238 DType::F32 | DType::I32 => 4,
239 DType::F64 | DType::I64 => 8,
240 DType::I8 | DType::U8 | DType::Bool => 1,
241 }
242 }
243
244 pub fn as_safetensors(self) -> &'static str {
245 match self {
246 DType::F16 => "F16",
247 DType::BF16 => "BF16",
248 DType::F32 => "F32",
249 DType::F64 => "F64",
250 DType::I8 => "I8",
251 DType::I16 => "I16",
252 DType::I32 => "I32",
253 DType::I64 => "I64",
254 DType::U8 => "U8",
255 DType::Bool => "BOOL",
256 }
257 }
258
259 pub fn from_safetensors(value: &str) -> Option<Self> {
260 match value {
261 "F16" => Some(DType::F16),
262 "BF16" => Some(DType::BF16),
263 "F32" => Some(DType::F32),
264 "F64" => Some(DType::F64),
265 "I8" => Some(DType::I8),
266 "I16" => Some(DType::I16),
267 "I32" => Some(DType::I32),
268 "I64" => Some(DType::I64),
269 "U8" => Some(DType::U8),
270 "BOOL" => Some(DType::Bool),
271 _ => None,
272 }
273 }
274}
275
276#[derive(Debug, Clone, PartialEq)]
277pub enum NumpyScalarData {
278 F32(f32),
279 F64(f64),
280 I8(i8),
281 I16(i16),
282 I32(i32),
283 I64(i64),
284 U8(u8),
285 Bool(bool),
286}
287
288#[derive(Debug, Clone, Copy, PartialEq, Eq)]
289pub enum NumpyEndian {
290 Little,
291 Big,
292 NotApplicable,
293 Native,
294 NotSpecified,
295}
296
297#[derive(Debug, Clone)]
298pub struct StorageRef {
299 pub key: String,
300 pub dtype: DType,
301 pub size_elems: usize,
302}
303
304#[derive(Debug, Clone)]
305pub struct TensorRef {
306 pub storage: StorageRef,
307 pub offset_elems: usize,
308 pub shape: Vec<usize>,
309 pub stride: Vec<usize>,
310}
311
312#[derive(Debug, Clone)]
313pub struct TensorData {
314 pub dtype: DType,
315 pub shape: Vec<usize>,
316 pub bytes: Vec<u8>,
317}
318
319#[allow(dead_code)]
320#[derive(Debug, Clone)]
321pub enum Value {
322 Marker,
323 None,
324 Bool(bool),
325 Int(i64),
326 Float(f64),
327 String(String),
328 Bytes(Vec<u8>),
329 List(Vec<Value>),
330 Set(Vec<Value>),
331 Tuple(Vec<Value>),
332 Dict(Vec<(Value, Value)>),
333 Global {
334 module: String,
335 name: String,
336 },
337 StorageRef(StorageRef),
338 TensorRef(TensorRef),
339 OrderedDict(Vec<(String, Value)>),
340 NumpyScalar {
341 dtype: DType,
342 endian: NumpyEndian,
343 shape: Vec<usize>,
344 data: NumpyScalarData,
345 },
346 Call {
347 func: String,
348 args: Vec<Value>,
349 state: Option<Box<Value>>,
350 },
351 Object {
352 module: String,
353 name: String,
354 args: Vec<Value>,
355 state: Option<Box<Value>>,
356 },
357}
358
359impl Value {
360 pub(crate) fn as_usize(&self) -> Result<usize> {
361 match self {
362 Value::Int(v) if *v >= 0 => Ok(*v as usize),
363 _ => Err(ConvertError::InvalidStructure(
364 "expected non-negative integer".to_string(),
365 )),
366 }
367 }
368
369 pub(crate) fn as_string(&self) -> Result<String> {
370 match self {
371 Value::String(v) => Ok(v.clone()),
372 Value::Int(v) => Ok(v.to_string()),
373 _ => Err(ConvertError::InvalidStructure("expected string".to_string())),
374 }
375 }
376
377 pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
378 match self {
379 Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
380 _ => Err(ConvertError::InvalidStructure(
381 "expected tuple/list of integers".to_string(),
382 )),
383 }
384 }
385}
386
387pub struct ParsedCheckpoint {
388 pub source_sha256: String,
389 pub warnings: Vec<String>,
390 pub tensors: BTreeMap<String, TensorData>,
391 pub tensor_groups: BTreeMap<String, BTreeMap<String, TensorData>>,
392 pub metadata: serde_yaml::Value,
393 pub security: CheckpointSecurity,
394}