1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8mod writer;
9
10pub use types::{
11 ConvertError, ConvertOptions, ConvertResult, DType, InspectionReport, Result, StorageRef,
12 TensorRef, TensorSummary, Value,
13};
14
15use std::collections::{BTreeMap, HashMap};
16use std::fs;
17use std::fs::File;
18use std::io::Read;
19use std::path::Path;
20use zip::read::ZipArchive;
21
22use extract::{contiguous_stride, extract_state_dict_tensors, numel};
23use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
24use metadata::{collect_constructor_types, project_root_metadata};
25use parser::parse_pickle;
26use types::{ParsedCheckpoint, TensorData};
27use writer::{write_model_yaml, write_safetensors};
28
29pub fn inspect_pt(input_pt: &Path) -> Result<InspectionReport> {
30 let parsed = parse_checkpoint(input_pt, &ConvertOptions::default())?;
31 let mut tensors = Vec::with_capacity(parsed.tensors.len());
32 let mut total_tensor_bytes = 0usize;
33 for (name, tensor) in &parsed.tensors {
34 let nbytes = tensor.bytes.len();
35 total_tensor_bytes += nbytes;
36 tensors.push(TensorSummary {
37 name: name.clone(),
38 dtype: tensor.dtype.as_safetensors().to_string(),
39 shape: tensor.shape.clone(),
40 nbytes,
41 });
42 }
43
44 Ok(InspectionReport {
45 detected_format: "torch_zip_pickle".to_string(),
46 source_file: input_pt.display().to_string(),
47 source_sha256: parsed.source_sha256,
48 tensor_count: tensors.len(),
49 total_tensor_bytes,
50 tensors,
51 warnings: parsed.warnings,
52 })
53}
54
55pub fn convert_pt_to_safetensors(
56 input_pt: &Path,
57 out_dir: &Path,
58 opts: ConvertOptions,
59) -> Result<ConvertResult> {
60 let parsed = parse_checkpoint(input_pt, &opts)?;
61 fs::create_dir_all(out_dir)?;
62
63 let safetensors_path = out_dir.join("model.safetensors");
64 write_safetensors(&safetensors_path, &parsed.tensors, &parsed.source_sha256)?;
65
66 let mut total_tensor_bytes = 0usize;
67 let mut tensor_summaries = Vec::new();
68 for (name, tensor) in &parsed.tensors {
69 total_tensor_bytes += tensor.bytes.len();
70 tensor_summaries.push((
71 name.clone(),
72 tensor.dtype.as_safetensors().to_string(),
73 tensor.shape.clone(),
74 sha256_hex(&tensor.bytes),
75 ));
76 }
77
78 let model_yaml_path = out_dir.join("model.yaml");
79 write_model_yaml(
80 &model_yaml_path,
81 input_pt,
82 &parsed.source_sha256,
83 parsed.tensors.len(),
84 total_tensor_bytes,
85 &parsed.metadata,
86 &parsed.objects,
87 &tensor_summaries,
88 )?;
89
90 Ok(ConvertResult {
91 safetensors_path,
92 model_yaml_path,
93 source_file: input_pt.to_path_buf(),
94 source_sha256: parsed.source_sha256,
95 tensor_count: parsed.tensors.len(),
96 total_tensor_bytes,
97 })
98}
99
100pub fn parse_checkpoint(path: &Path, opts: &ConvertOptions) -> Result<ParsedCheckpoint> {
101 let file = File::open(path)?;
102 let metadata = file.metadata()?;
103 if metadata.len() > opts.max_archive_bytes {
104 return Err(ConvertError::ResourceLimitExceeded(format!(
105 "archive is {} bytes, limit is {}",
106 metadata.len(),
107 opts.max_archive_bytes
108 )));
109 }
110
111 let mut magic = [0u8; 4];
112 let mut fh = File::open(path)?;
113 fh.read_exact(&mut magic)?;
114 if magic != [0x50, 0x4b, 0x03, 0x04] {
115 return Err(ConvertError::UnsupportedFormat(
116 "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
117 ));
118 }
119
120 let source_sha256 = sha256_file(path)?;
121 let mut archive = ZipArchive::new(file)?;
122 let data_pkl_name = find_data_pkl_name(&mut archive)?;
123 let prefix = data_pkl_name
124 .strip_suffix("data.pkl")
125 .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
126 .to_string();
127 let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
128 if pickle_bytes.len() > opts.max_pickle_bytes {
129 return Err(ConvertError::ResourceLimitExceeded(format!(
130 "data.pkl is {} bytes, limit is {}",
131 pickle_bytes.len(),
132 opts.max_pickle_bytes
133 )));
134 }
135
136 let root = parse_pickle(&pickle_bytes, opts)?;
137 let metadata = project_root_metadata(&root);
138 let objects = collect_constructor_types(&root);
139 let tensor_refs = extract_state_dict_tensors(&root)?;
140 if tensor_refs.is_empty() {
141 return Err(ConvertError::InvalidStructure(
142 "no tensors found in checkpoint state_dict".to_string(),
143 ));
144 }
145 if tensor_refs.len() > opts.max_tensor_count {
146 return Err(ConvertError::ResourceLimitExceeded(format!(
147 "tensor count {} exceeds limit {}",
148 tensor_refs.len(),
149 opts.max_tensor_count
150 )));
151 }
152
153 let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
154 for tensor in tensor_refs.values() {
155 let key = &tensor.storage.key;
156 if storage_blobs.contains_key(key) {
157 continue;
158 }
159 let blob = read_storage_blob(&mut archive, &prefix, key)?;
160 let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
161 if blob.len() < required_bytes {
162 return Err(ConvertError::InvalidStructure(format!(
163 "storage {} has {} bytes, expected at least {}",
164 key,
165 blob.len(),
166 required_bytes
167 )));
168 }
169 storage_blobs.insert(key.clone(), blob);
170 }
171
172 let mut tensors = BTreeMap::new();
173 for (name, tensor_ref) in tensor_refs {
174 if opts.strict_contiguous {
175 let expected = contiguous_stride(&tensor_ref.shape);
176 if expected != tensor_ref.stride {
177 return Err(ConvertError::InvalidStructure(format!(
178 "tensor {} has non-contiguous stride {:?}, expected {:?}",
179 name, tensor_ref.stride, expected
180 )));
181 }
182 }
183
184 let elem_size = tensor_ref.storage.dtype.elem_size();
185 let numel = numel(&tensor_ref.shape)?;
186 let start = tensor_ref
187 .offset_elems
188 .checked_mul(elem_size)
189 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
190 let byte_len = numel
191 .checked_mul(elem_size)
192 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
193 if byte_len > opts.max_tensor_bytes {
194 return Err(ConvertError::ResourceLimitExceeded(format!(
195 "tensor {} is {} bytes, limit is {}",
196 name, byte_len, opts.max_tensor_bytes
197 )));
198 }
199 let end = start
200 .checked_add(byte_len)
201 .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
202
203 let storage = storage_blobs.get(&tensor_ref.storage.key).ok_or_else(|| {
204 ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key))
205 })?;
206 if end > storage.len() {
207 return Err(ConvertError::InvalidStructure(format!(
208 "tensor {} slice [{}, {}) is out of storage bounds {}",
209 name,
210 start,
211 end,
212 storage.len()
213 )));
214 }
215
216 let raw = storage[start..end].to_vec();
217 let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
218 tensors.insert(name, normalized);
219 }
220
221 Ok(ParsedCheckpoint {
222 source_sha256,
223 warnings: Vec::new(),
224 tensors,
225 metadata,
226 objects,
227 })
228}
229
230fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
231 match dtype {
232 DType::F16 => Ok(TensorData {
233 dtype: DType::F32,
234 shape,
235 bytes: f16_bytes_to_f32_bytes(&bytes)?,
236 }),
237 DType::BF16 => Ok(TensorData {
238 dtype: DType::F32,
239 shape,
240 bytes: bf16_bytes_to_f32_bytes(&bytes)?,
241 }),
242 _ => Ok(TensorData { dtype, shape, bytes }),
243 }
244}
245
246fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
247 if input.len() % 2 != 0 {
248 return Err(ConvertError::InvalidStructure(
249 "f16 tensor bytes must be even-length".to_string(),
250 ));
251 }
252 let mut out = Vec::with_capacity(input.len() * 2);
253 for chunk in input.chunks_exact(2) {
254 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
255 let value = f16_bits_to_f32(bits);
256 out.extend_from_slice(&value.to_le_bytes());
257 }
258 Ok(out)
259}
260
261fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
262 if input.len() % 2 != 0 {
263 return Err(ConvertError::InvalidStructure(
264 "bf16 tensor bytes must be even-length".to_string(),
265 ));
266 }
267 let mut out = Vec::with_capacity(input.len() * 2);
268 for chunk in input.chunks_exact(2) {
269 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
270 let value = f32::from_bits((bits as u32) << 16);
271 out.extend_from_slice(&value.to_le_bytes());
272 }
273 Ok(out)
274}
275
276fn f16_bits_to_f32(bits: u16) -> f32 {
277 let sign = ((bits >> 15) & 0x1) as u32;
278 let exp = ((bits >> 10) & 0x1f) as u32;
279 let frac = (bits & 0x03ff) as u32;
280
281 let f32_bits = if exp == 0 {
282 if frac == 0 {
283 sign << 31
284 } else {
285 let mut mant = frac;
286 let mut e = -14i32;
287 while (mant & 0x0400) == 0 {
288 mant <<= 1;
289 e -= 1;
290 }
291 mant &= 0x03ff;
292 let exp32 = (e + 127) as u32;
293 (sign << 31) | (exp32 << 23) | (mant << 13)
294 }
295 } else if exp == 0x1f {
296 (sign << 31) | (0xff << 23) | (frac << 13)
297 } else {
298 let exp32 = (exp as i32 - 15 + 127) as u32;
299 (sign << 31) | (exp32 << 23) | (frac << 13)
300 };
301
302 f32::from_bits(f32_bits)
303}
304
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::metadata::{collect_constructor_types, project_value_for_metadata};
310 use crate::types::Value;
311 use std::io::Write;
312 use tempfile::tempdir;
313 use zip::write::SimpleFileOptions;
314 use zip::ZipWriter;
315
316 #[test]
317 fn converts_simple_tensor_checkpoint() {
318 let tmp = tempdir().expect("tmp dir");
319 let pt_path = tmp.path().join("weights.pt");
320 write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
321
322 let out_dir = tmp.path().join("export");
323 let result = convert_pt_to_safetensors(&pt_path, &out_dir, ConvertOptions::default())
324 .expect("conversion should work");
325
326 assert!(result.safetensors_path.exists());
327 assert!(result.model_yaml_path.exists());
328 assert_eq!(result.tensor_count, 1);
329
330 let yaml = fs::read_to_string(result.model_yaml_path).expect("yaml readable");
331 assert!(yaml.contains("layer.weight"));
332 assert!(yaml.contains("dtype: 'F32'"));
333 }
334
335 #[test]
336 fn rejects_unsafe_global_reduce() {
337 let tmp = tempdir().expect("tmp dir");
338 let pt_path = tmp.path().join("unsafe.pt");
339 write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
340
341 let err = convert_pt_to_safetensors(&pt_path, tmp.path(), ConvertOptions::default())
342 .expect_err("unsafe pickle should fail");
343 let msg = err.to_string();
344 assert!(msg.contains("unsupported GLOBAL") || msg.contains("unsafe/unsupported pickle"));
345 }
346
347 #[test]
348 fn projects_object_metadata_with_type_args_and_flattened_state() {
349 let value = Value::Object {
350 module: "ultralytics.nn.tasks".to_string(),
351 name: "DetectionModel".to_string(),
352 args: Some(Box::new(Value::Tuple(vec![
353 Value::String("arg0".to_string()),
354 Value::Int(42),
355 ]))),
356 state: Some(Box::new(Value::Dict(vec![(
357 Value::String("training".to_string()),
358 Value::Bool(false),
359 )]))),
360 };
361
362 let projected = project_value_for_metadata(&value);
363 let mapping = match projected {
364 serde_yaml::Value::Mapping(map) => map,
365 other => panic!("expected mapping, got {:?}", other),
366 };
367
368 let type_key = serde_yaml::Value::String("$type".to_string());
369 let class_key = serde_yaml::Value::String("$class".to_string());
370 let args_key = serde_yaml::Value::String("$args".to_string());
371 let training_key = serde_yaml::Value::String("training".to_string());
372
373 assert_eq!(
374 mapping.get(&type_key),
375 Some(&serde_yaml::Value::String("object".to_string()))
376 );
377 assert_eq!(
378 mapping.get(&class_key),
379 Some(&serde_yaml::Value::String(
380 "ultralytics.nn.tasks.DetectionModel".to_string()
381 ))
382 );
383 assert!(mapping.get(&args_key).is_some());
384 assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
385 }
386
387 #[test]
388 fn omits_empty_object_args() {
389 let value = Value::Object {
390 module: "a".to_string(),
391 name: "B".to_string(),
392 args: Some(Box::new(Value::Tuple(Vec::new()))),
393 state: None,
394 };
395 let projected = project_value_for_metadata(&value);
396 let mapping = match projected {
397 serde_yaml::Value::Mapping(map) => map,
398 other => panic!("expected mapping, got {:?}", other),
399 };
400
401 let args_key = serde_yaml::Value::String("$args".to_string());
402 assert!(!mapping.contains_key(&args_key));
403 }
404
405 #[test]
406 fn collects_constructor_types_deduplicated_in_first_seen_order() {
407 let tree = Value::List(vec![
408 Value::Object {
409 module: "a".to_string(),
410 name: "One".to_string(),
411 args: None,
412 state: None,
413 },
414 Value::Dict(vec![(
415 Value::String("nested".to_string()),
416 Value::Object {
417 module: "b".to_string(),
418 name: "Two".to_string(),
419 args: None,
420 state: None,
421 },
422 )]),
423 Value::Object {
424 module: "a".to_string(),
425 name: "One".to_string(),
426 args: None,
427 state: None,
428 },
429 ]);
430
431 let objects = collect_constructor_types(&tree);
432 assert_eq!(
433 objects,
434 vec!["a.One".to_string(), "b.Two".to_string()]
435 );
436 }
437
438 fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
439 let file = File::create(path)?;
440 let mut zip = ZipWriter::new(file);
441 let options = SimpleFileOptions::default();
442
443 let data_pkl = if unsafe_payload {
444 build_unsafe_pickle()
445 } else {
446 build_safe_pickle()
447 };
448
449 zip.start_file("archive/data.pkl", options)?;
450 zip.write_all(&data_pkl)?;
451
452 let floats = [1.0f32, 2.0, 3.0, 4.0];
453 let mut raw = Vec::new();
454 for value in floats {
455 raw.extend_from_slice(&value.to_le_bytes());
456 }
457
458 zip.start_file("archive/data/0", options)?;
459 zip.write_all(&raw)?;
460 zip.finish()?;
461 Ok(())
462 }
463
464 fn build_safe_pickle() -> Vec<u8> {
465 let mut out = Vec::new();
466 out.extend_from_slice(&[0x80, 0x02]);
467
468 out.push(b'}');
469 out.push(b'(');
470
471 push_binunicode(&mut out, "layer.weight");
472 out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
473
474 out.push(b'(');
475
476 out.push(b'(');
477 push_binunicode(&mut out, "storage");
478 out.extend_from_slice(b"ctorch\nFloatStorage\n");
479 push_binunicode(&mut out, "0");
480 push_binunicode(&mut out, "cpu");
481 out.push(b'K');
482 out.push(4);
483 out.push(b't');
484 out.push(b'Q');
485
486 out.push(b'K');
487 out.push(0);
488
489 out.push(b'(');
490 out.push(b'K');
491 out.push(2);
492 out.push(b'K');
493 out.push(2);
494 out.push(b't');
495
496 out.push(b'(');
497 out.push(b'K');
498 out.push(2);
499 out.push(b'K');
500 out.push(1);
501 out.push(b't');
502
503 out.push(0x89);
504 out.push(b'N');
505
506 out.push(b't');
507 out.push(b'R');
508
509 out.push(b'u');
510 out.push(b'.');
511 out
512 }
513
514 fn build_unsafe_pickle() -> Vec<u8> {
515 let mut out = Vec::new();
516 out.extend_from_slice(&[0x80, 0x02]);
517 out.extend_from_slice(b"cos\nsystem\n");
518 out.push(b'(');
519 push_binunicode(&mut out, "echo hacked");
520 out.push(b't');
521 out.push(b'R');
522 out.push(b'.');
523 out
524 }
525
526 fn push_binunicode(out: &mut Vec<u8>, value: &str) {
527 out.push(b'X');
528 out.extend_from_slice(&(value.len() as u32).to_le_bytes());
529 out.extend_from_slice(value.as_bytes());
530 }
531}