1use serde::Serialize;
2use std::collections::BTreeMap;
3use std::fmt;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
7pub struct ConvertOptions {
8 pub max_archive_bytes: u64,
9 pub max_tensor_count: usize,
10 pub max_tensor_bytes: usize,
11 pub max_pickle_bytes: usize,
12 pub strict_contiguous: bool,
13}
14
15impl Default for ConvertOptions {
16 fn default() -> Self {
17 Self {
18 max_archive_bytes: 4 * 1024 * 1024 * 1024,
19 max_tensor_count: 4096,
20 max_tensor_bytes: 1024 * 1024 * 1024,
21 max_pickle_bytes: 64 * 1024 * 1024,
22 strict_contiguous: true,
23 }
24 }
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct TensorSummary {
29 pub name: String,
30 pub dtype: String,
31 pub shape: Vec<usize>,
32 pub nbytes: usize,
33}
34
35#[derive(Debug, Clone, Serialize)]
36pub struct InspectionReport {
37 pub detected_format: String,
38 pub source_file: String,
39 pub source_sha256: String,
40 pub tensor_count: usize,
41 pub total_tensor_bytes: usize,
42 pub tensors: Vec<TensorSummary>,
43 pub warnings: Vec<String>,
44}
45
46#[derive(Debug, Clone, Serialize)]
47pub struct ConvertResult {
48 pub safetensors_path: PathBuf,
49 pub model_yaml_path: PathBuf,
50 pub source_file: PathBuf,
51 pub source_sha256: String,
52 pub tensor_count: usize,
53 pub total_tensor_bytes: usize,
54}
55
56#[derive(Debug)]
57pub enum ConvertError {
58 Io(std::io::Error),
59 Zip(zip::result::ZipError),
60 Json(serde_json::Error),
61 UnsupportedFormat(String),
62 UnsafeOpcode { opcode: u8, offset: usize },
63 InvalidStructure(String),
64 ResourceLimitExceeded(String),
65}
66
67impl fmt::Display for ConvertError {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 match self {
70 ConvertError::Io(err) => write!(f, "io error: {}", err),
71 ConvertError::Zip(err) => write!(f, "zip error: {}", err),
72 ConvertError::Json(err) => write!(f, "json error: {}", err),
73 ConvertError::UnsupportedFormat(msg) => write!(f, "unsupported format: {}", msg),
74 ConvertError::UnsafeOpcode { opcode, offset } => {
75 write!(
76 f,
77 "unsafe/unsupported pickle opcode 0x{opcode:02x} at offset {offset}"
78 )
79 }
80 ConvertError::InvalidStructure(msg) => write!(f, "invalid checkpoint structure: {}", msg),
81 ConvertError::ResourceLimitExceeded(msg) => write!(f, "resource limit exceeded: {}", msg),
82 }
83 }
84}
85
86impl std::error::Error for ConvertError {}
87
88impl From<std::io::Error> for ConvertError {
89 fn from(value: std::io::Error) -> Self {
90 Self::Io(value)
91 }
92}
93
94impl From<zip::result::ZipError> for ConvertError {
95 fn from(value: zip::result::ZipError) -> Self {
96 Self::Zip(value)
97 }
98}
99
100impl From<serde_json::Error> for ConvertError {
101 fn from(value: serde_json::Error) -> Self {
102 Self::Json(value)
103 }
104}
105
106pub type Result<T> = std::result::Result<T, ConvertError>;
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
109pub enum DType {
110 F16,
111 BF16,
112 F32,
113 F64,
114 I8,
115 I16,
116 I32,
117 I64,
118 U8,
119 Bool,
120}
121
122impl DType {
123 pub fn elem_size(self) -> usize {
124 match self {
125 DType::F16 | DType::BF16 | DType::I16 => 2,
126 DType::F32 | DType::I32 => 4,
127 DType::F64 | DType::I64 => 8,
128 DType::I8 | DType::U8 | DType::Bool => 1,
129 }
130 }
131
132 pub fn as_safetensors(self) -> &'static str {
133 match self {
134 DType::F16 => "F16",
135 DType::BF16 => "BF16",
136 DType::F32 => "F32",
137 DType::F64 => "F64",
138 DType::I8 => "I8",
139 DType::I16 => "I16",
140 DType::I32 => "I32",
141 DType::I64 => "I64",
142 DType::U8 => "U8",
143 DType::Bool => "BOOL",
144 }
145 }
146}
147
148#[derive(Debug, Clone)]
149pub struct StorageRef {
150 pub key: String,
151 pub dtype: DType,
152 pub size_elems: usize,
153}
154
155#[derive(Debug, Clone)]
156pub struct TensorRef {
157 pub storage: StorageRef,
158 pub offset_elems: usize,
159 pub shape: Vec<usize>,
160 pub stride: Vec<usize>,
161}
162
163#[derive(Debug, Clone)]
164pub struct TensorData {
165 pub dtype: DType,
166 pub shape: Vec<usize>,
167 pub bytes: Vec<u8>,
168}
169
170#[allow(dead_code)]
171#[derive(Debug, Clone)]
172pub enum Value {
173 Marker,
174 None,
175 Bool(bool),
176 Int(i64),
177 Float(f64),
178 String(String),
179 Bytes(Vec<u8>),
180 List(Vec<Value>),
181 Set(Vec<Value>),
182 Tuple(Vec<Value>),
183 Dict(Vec<(Value, Value)>),
184 Global { module: String, name: String },
185 StorageRef(StorageRef),
186 TensorRef(TensorRef),
187 OrderedDict(Vec<(String, Value)>),
188 Object {
189 module: String,
190 name: String,
191 args: Option<Box<Value>>,
192 state: Option<Box<Value>>,
193 },
194}
195
196impl Value {
197 pub(crate) fn as_usize(&self) -> Result<usize> {
198 match self {
199 Value::Int(v) if *v >= 0 => Ok(*v as usize),
200 _ => Err(ConvertError::InvalidStructure(
201 "expected non-negative integer".to_string(),
202 )),
203 }
204 }
205
206 pub(crate) fn as_string(&self) -> Result<String> {
207 match self {
208 Value::String(v) => Ok(v.clone()),
209 Value::Int(v) => Ok(v.to_string()),
210 _ => Err(ConvertError::InvalidStructure(
211 "expected string".to_string(),
212 )),
213 }
214 }
215
216 pub(crate) fn as_usize_vec(&self) -> Result<Vec<usize>> {
217 match self {
218 Value::Tuple(items) | Value::List(items) => items.iter().map(Value::as_usize).collect(),
219 _ => Err(ConvertError::InvalidStructure(
220 "expected tuple/list of integers".to_string(),
221 )),
222 }
223 }
224}
225
226pub struct ParsedCheckpoint {
227 pub source_sha256: String,
228 pub warnings: Vec<String>,
229 pub tensors: BTreeMap<String, TensorData>,
230 pub metadata: serde_yaml::Value,
231 pub objects: Vec<String>,
232}