1use ndarray::ArrayD;
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4use std::fmt;
5use std::path::{Path, PathBuf};
6
7#[derive(Debug, Clone)]
8pub struct LoadOptions {
9 pub max_archive_bytes: u64,
10 pub max_tensor_count: usize,
11 pub max_tensor_bytes: usize,
12 pub max_pickle_bytes: usize,
13 pub strict_contiguous: bool,
14}
15
16impl Default for LoadOptions {
17 fn default() -> Self {
18 Self {
19 max_archive_bytes: 4 * 1024 * 1024 * 1024,
20 max_tensor_count: 4096,
21 max_tensor_bytes: 1024 * 1024 * 1024,
22 max_pickle_bytes: 64 * 1024 * 1024,
23 strict_contiguous: true,
24 }
25 }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ExportFormat {
30 Safetensors,
31}
32
33impl ExportFormat {
34 pub fn extension(self) -> &'static str {
35 match self {
36 ExportFormat::Safetensors => "safetensors",
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct ExportOptions {
43 pub format: ExportFormat,
44 pub weights_filename: PathBuf,
45 pub metadata_filename: PathBuf,
46 pub include_metadata: bool,
47 pub overwrite: bool,
48}
49
50impl ExportOptions {
51 pub fn new(format: ExportFormat, input_path: Option<&Path>) -> Self {
52 let weights_filename = default_weights_filename(format, input_path);
53 let metadata_filename = weights_filename.with_extension("yaml");
54
55 Self {
56 format,
57 weights_filename,
58 metadata_filename,
59 include_metadata: true,
60 overwrite: false,
61 }
62 }
63}
64
65fn default_weights_filename(format: ExportFormat, input_path: Option<&Path>) -> PathBuf {
66 let ext = format.extension();
67 let Some(path) = input_path else {
68 return PathBuf::from(format!("model.{ext}"));
69 };
70
71 let Some(name) = path.file_name() else {
72 return PathBuf::from(format!("model.{ext}"));
73 };
74
75 Path::new(name).with_extension(ext)
76}
77
78#[derive(Debug, Clone, Serialize)]
79pub struct ExportResult {
80 pub weights_path: PathBuf,
81 pub metadata_path: Option<PathBuf>,
82 pub source_sha256: String,
83 pub tensor_count: usize,
84 pub total_tensor_bytes: usize,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct CheckpointTensorMetadata {
89 pub name: String,
90 pub dtype: String,
91 pub shape: Vec<usize>,
92 pub sha256: String,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct CheckpointMetadata {
97 pub format_version: usize,
98 pub source_file: String,
99 pub source_sha256: String,
100 pub safetensors_file: String,
101 pub created_at_unix: u64,
102 pub tensor_count: usize,
103 pub total_tensor_bytes: usize,
104 #[serde(default)]
105 pub metadata: serde_yaml::Value,
106 #[serde(default)]
107 pub objects: Vec<String>,
108 pub tensors: Vec<CheckpointTensorMetadata>,
109}
110
111#[derive(Debug, Clone)]
112pub enum ReconstructSource {
113 WeightsFile(PathBuf),
114 StateDict(BTreeMap<String, TensorData>),
115}
116
117#[derive(Debug, Clone)]
118pub enum TensorArray {
119 F32(ArrayD<f32>),
120 F64(ArrayD<f64>),
121 I8(ArrayD<i8>),
122 I16(ArrayD<i16>),
123 I32(ArrayD<i32>),
124 I64(ArrayD<i64>),
125 U8(ArrayD<u8>),
126 Bool(ArrayD<bool>),
127}
128
129#[derive(Debug)]
130pub enum ConvertError {
131 Io(std::io::Error),
132 Zip(zip::result::ZipError),
133 Json(serde_json::Error),
134 Yaml(serde_yaml::Error),
135 Ndarray(ndarray::ShapeError),
136 UnsupportedFormat(String),
137 UnsafeOpcode { opcode: u8, offset: usize },
138 InvalidStructure(String),
139 ResourceLimitExceeded(String),
140}
141
142impl fmt::Display for ConvertError {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 match self {
145 ConvertError::Io(err) => write!(f, "io error: {}", err),
146 ConvertError::Zip(err) => write!(f, "zip error: {}", err),
147 ConvertError::Json(err) => write!(f, "json error: {}", err),
148 ConvertError::Yaml(err) => write!(f, "yaml error: {}", err),
149 ConvertError::Ndarray(err) => write!(f, "ndarray error: {}", err),
150 ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
151 ConvertError::UnsafeOpcode { opcode, offset } => {
152 write!(f, "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}")
153 }
154 ConvertError::InvalidStructure(msg) => {
155 write!(f, "invalid checkpoint structure: {}", msg)
156 }
157 ConvertError::ResourceLimitExceeded(msg) => {
158 write!(f, "resource limit exceeded: {}", msg)
159 }
160 }
161 }
162}
163
164impl std::error::Error for ConvertError {}
165
166impl From<std::io::Error> for ConvertError {
167 fn from(value: std::io::Error) -> Self {
168 Self::Io(value)
169 }
170}
171
172impl From<zip::result::ZipError> for ConvertError {
173 fn from(value: zip::result::ZipError) -> Self {
174 Self::Zip(value)
175 }
176}
177
178impl From<serde_json::Error> for ConvertError {
179 fn from(value: serde_json::Error) -> Self {
180 Self::Json(value)
181 }
182}
183
184impl From<serde_yaml::Error> for ConvertError {
185 fn from(value: serde_yaml::Error) -> Self {
186 Self::Yaml(value)
187 }
188}
189
190impl From<ndarray::ShapeError> for ConvertError {
191 fn from(value: ndarray::ShapeError) -> Self {
192 Self::Ndarray(value)
193 }
194}
195
196pub type Result<T> = std::result::Result<T, ConvertError>;
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq)]
199pub enum DType {
200 F16,
201 BF16,
202 F32,
203 F64,
204 I8,
205 I16,
206 I32,
207 I64,
208 U8,
209 Bool,
210}
211
212impl DType {
213 pub fn elem_size(self) -> usize {
214 match self {
215 DType::F16 | DType::BF16 | DType::I16 => 2,
216 DType::F32 | DType::I32 => 4,
217 DType::F64 | DType::I64 => 8,
218 DType::I8 | DType::U8 | DType::Bool => 1,
219 }
220 }
221
222 pub fn as_safetensors(self) -> &'static str {
223 match self {
224 DType::F16 => "F16",
225 DType::BF16 => "BF16",
226 DType::F32 => "F32",
227 DType::F64 => "F64",
228 DType::I8 => "I8",
229 DType::I16 => "I16",
230 DType::I32 => "I32",
231 DType::I64 => "I64",
232 DType::U8 => "U8",
233 DType::Bool => "BOOL",
234 }
235 }
236
237 pub fn from_safetensors(value: &str) -> Option<Self> {
238 match value {
239 "F16" => Some(DType::F16),
240 "BF16" => Some(DType::BF16),
241 "F32" => Some(DType::F32),
242 "F64" => Some(DType::F64),
243 "I8" => Some(DType::I8),
244 "I16" => Some(DType::I16),
245 "I32" => Some(DType::I32),
246 "I64" => Some(DType::I64),
247 "U8" => Some(DType::U8),
248 "BOOL" => Some(DType::Bool),
249 _ => None,
250 }
251 }
252}
253
254#[derive(Debug, Clone)]
255pub struct StorageRef {
256 pub key: String,
257 pub dtype: DType,
258 pub size_elems: usize,
259}
260
261#[derive(Debug, Clone)]
262pub struct TensorRef {
263 pub storage: StorageRef,
264 pub offset_elems: usize,
265 pub shape: Vec<usize>,
266 pub stride: Vec<usize>,
267}
268
269#[derive(Debug, Clone)]
270pub struct TensorData {
271 pub dtype: DType,
272 pub shape: Vec<usize>,
273 pub bytes: Vec<u8>,
274}
275
276#[allow(dead_code)]
277#[derive(Debug, Clone)]
278pub enum Value {
279 Marker,
280 None,
281 Bool(bool),
282 Int(i64),
283 Float(f64),
284 String(String),
285 Bytes(Vec<u8>),
286 List(Vec<Value>),
287 Set(Vec<Value>),
288 Tuple(Vec<Value>),
289 Dict(Vec<(Value, Value)>),
290 Global {
291 module: String,
292 name: String,
293 },
294 StorageRef(StorageRef),
295 TensorRef(TensorRef),
296 OrderedDict(Vec<(String, Value)>),
297 Object {
298 module: String,
299 name: String,
300 args: Option<Box<Value>>,
301 state: Option<Box<Value>>,
302 },
303}
304
305impl Value {
306 pub(crate) fn as_usize(&self) -> Result<usize> {
307 match self {
308 Value::Int(v) if *v >= 0 => Ok(*v as usize),
309 _ => Err(ConvertError::InvalidStructure(
310 "expected non-negative integer".to_string(),
311 )),
312 }
313 }
314
315 pub(crate) fn as_string(&self) -> Result<String> {
316 match self {
317 Value::String(v) => Ok(v.clone()),
318 Value::Int(v) => Ok(v.to_string()),
319 _ => Err(ConvertError::InvalidStructure("expected string".to_string())),
320 }
321 }
322
323 pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
324 match self {
325 Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
326 _ => Err(ConvertError::InvalidStructure(
327 "expected tuple/list of integers".to_string(),
328 )),
329 }
330 }
331}
332
333pub struct ParsedCheckpoint {
334 pub source_sha256: String,
335 pub warnings: Vec<String>,
336 pub tensors: BTreeMap<String, TensorData>,
337 pub metadata: serde_yaml::Value,
338 pub objects: Vec<String>,
339}