Skip to main content

pt_loader/
lib.rs

1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8mod writer;
9
10pub use types::{
11  ConvertError, ConvertOptions, ConvertResult, DType, InspectionReport, Result, StorageRef,
12  TensorRef, TensorSummary, Value,
13};
14
15use std::collections::{BTreeMap, HashMap};
16use std::fs;
17use std::fs::File;
18use std::io::Read;
19use std::path::Path;
20use zip::read::ZipArchive;
21
22use extract::{contiguous_stride, extract_state_dict_tensors, numel};
23use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
24use metadata::{collect_constructor_types, project_root_metadata};
25use parser::parse_pickle;
26use types::{ParsedCheckpoint, TensorData};
27use writer::{write_model_yaml, write_safetensors};
28
29pub fn inspect_pt(input_pt: &Path) -> Result<InspectionReport> {
30  let parsed = parse_checkpoint(input_pt, &ConvertOptions::default())?;
31  let mut tensors = Vec::with_capacity(parsed.tensors.len());
32  let mut total_tensor_bytes = 0usize;
33  for (name, tensor) in &parsed.tensors {
34    let nbytes = tensor.bytes.len();
35    total_tensor_bytes += nbytes;
36    tensors.push(TensorSummary {
37      name: name.clone(),
38      dtype: tensor.dtype.as_safetensors().to_string(),
39      shape: tensor.shape.clone(),
40      nbytes,
41    });
42  }
43
44  Ok(InspectionReport {
45    detected_format: "torch_zip_pickle".to_string(),
46    source_file: input_pt.display().to_string(),
47    source_sha256: parsed.source_sha256,
48    tensor_count: tensors.len(),
49    total_tensor_bytes,
50    tensors,
51    warnings: parsed.warnings,
52  })
53}
54
55pub fn convert_pt_to_safetensors(
56  input_pt: &Path,
57  out_dir: &Path,
58  opts: ConvertOptions,
59) -> Result<ConvertResult> {
60  let parsed = parse_checkpoint(input_pt, &opts)?;
61  fs::create_dir_all(out_dir)?;
62
63  let safetensors_path = out_dir.join("model.safetensors");
64  write_safetensors(&safetensors_path, &parsed.tensors, &parsed.source_sha256)?;
65
66  let mut total_tensor_bytes = 0usize;
67  let mut tensor_summaries = Vec::new();
68  for (name, tensor) in &parsed.tensors {
69    total_tensor_bytes += tensor.bytes.len();
70    tensor_summaries.push((
71      name.clone(),
72      tensor.dtype.as_safetensors().to_string(),
73      tensor.shape.clone(),
74      sha256_hex(&tensor.bytes),
75    ));
76  }
77
78  let model_yaml_path = out_dir.join("model.yaml");
79  write_model_yaml(
80    &model_yaml_path,
81    input_pt,
82    &parsed.source_sha256,
83    parsed.tensors.len(),
84    total_tensor_bytes,
85    &parsed.metadata,
86    &parsed.objects,
87    &tensor_summaries,
88  )?;
89
90  Ok(ConvertResult {
91    safetensors_path,
92    model_yaml_path,
93    source_file: input_pt.to_path_buf(),
94    source_sha256: parsed.source_sha256,
95    tensor_count: parsed.tensors.len(),
96    total_tensor_bytes,
97  })
98}
99
100pub fn parse_checkpoint(path: &Path, opts: &ConvertOptions) -> Result<ParsedCheckpoint> {
101  let file = File::open(path)?;
102  let metadata = file.metadata()?;
103  if metadata.len() > opts.max_archive_bytes {
104    return Err(ConvertError::ResourceLimitExceeded(format!(
105      "archive is {} bytes, limit is {}",
106      metadata.len(),
107      opts.max_archive_bytes
108    )));
109  }
110
111  let mut magic = [0u8; 4];
112  let mut fh = File::open(path)?;
113  fh.read_exact(&mut magic)?;
114  if magic != [0x50, 0x4b, 0x03, 0x04] {
115    return Err(ConvertError::UnsupportedFormat(
116      "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
117    ));
118  }
119
120  let source_sha256 = sha256_file(path)?;
121  let mut archive = ZipArchive::new(file)?;
122  let data_pkl_name = find_data_pkl_name(&mut archive)?;
123  let prefix = data_pkl_name
124    .strip_suffix("data.pkl")
125    .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
126    .to_string();
127  let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
128  if pickle_bytes.len() > opts.max_pickle_bytes {
129    return Err(ConvertError::ResourceLimitExceeded(format!(
130      "data.pkl is {} bytes, limit is {}",
131      pickle_bytes.len(),
132      opts.max_pickle_bytes
133    )));
134  }
135
136  let root = parse_pickle(&pickle_bytes, opts)?;
137  let metadata = project_root_metadata(&root);
138  let objects = collect_constructor_types(&root);
139  let tensor_refs = extract_state_dict_tensors(&root)?;
140  if tensor_refs.is_empty() {
141    return Err(ConvertError::InvalidStructure(
142      "no tensors found in checkpoint state_dict".to_string(),
143    ));
144  }
145  if tensor_refs.len() > opts.max_tensor_count {
146    return Err(ConvertError::ResourceLimitExceeded(format!(
147      "tensor count {} exceeds limit {}",
148      tensor_refs.len(),
149      opts.max_tensor_count
150    )));
151  }
152
153  let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
154  for tensor in tensor_refs.values() {
155    let key = &tensor.storage.key;
156    if storage_blobs.contains_key(key) {
157      continue;
158    }
159    let blob = read_storage_blob(&mut archive, &prefix, key)?;
160    let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
161    if blob.len() < required_bytes {
162      return Err(ConvertError::InvalidStructure(format!(
163        "storage {} has {} bytes, expected at least {}",
164        key,
165        blob.len(),
166        required_bytes
167      )));
168    }
169    storage_blobs.insert(key.clone(), blob);
170  }
171
172  let mut tensors = BTreeMap::new();
173  for (name, tensor_ref) in tensor_refs {
174    if opts.strict_contiguous {
175      let expected = contiguous_stride(&tensor_ref.shape);
176      if expected != tensor_ref.stride {
177        return Err(ConvertError::InvalidStructure(format!(
178          "tensor {} has non-contiguous stride {:?}, expected {:?}",
179          name, tensor_ref.stride, expected
180        )));
181      }
182    }
183
184    let elem_size = tensor_ref.storage.dtype.elem_size();
185    let numel = numel(&tensor_ref.shape)?;
186    let start = tensor_ref
187      .offset_elems
188      .checked_mul(elem_size)
189      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
190    let byte_len = numel
191      .checked_mul(elem_size)
192      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
193    if byte_len > opts.max_tensor_bytes {
194      return Err(ConvertError::ResourceLimitExceeded(format!(
195        "tensor {} is {} bytes, limit is {}",
196        name, byte_len, opts.max_tensor_bytes
197      )));
198    }
199    let end = start
200      .checked_add(byte_len)
201      .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
202
203    let storage = storage_blobs.get(&tensor_ref.storage.key).ok_or_else(|| {
204      ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key))
205    })?;
206    if end > storage.len() {
207      return Err(ConvertError::InvalidStructure(format!(
208        "tensor {} slice [{}, {}) is out of storage bounds {}",
209        name,
210        start,
211        end,
212        storage.len()
213      )));
214    }
215
216    let raw = storage[start..end].to_vec();
217    let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
218    tensors.insert(name, normalized);
219  }
220
221  Ok(ParsedCheckpoint {
222    source_sha256,
223    warnings: Vec::new(),
224    tensors,
225    metadata,
226    objects,
227  })
228}
229
230fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
231  match dtype {
232    DType::F16 => Ok(TensorData {
233      dtype: DType::F32,
234      shape,
235      bytes: f16_bytes_to_f32_bytes(&bytes)?,
236    }),
237    DType::BF16 => Ok(TensorData {
238      dtype: DType::F32,
239      shape,
240      bytes: bf16_bytes_to_f32_bytes(&bytes)?,
241    }),
242    _ => Ok(TensorData { dtype, shape, bytes }),
243  }
244}
245
246fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
247  if input.len() % 2 != 0 {
248    return Err(ConvertError::InvalidStructure(
249      "f16 tensor bytes must be even-length".to_string(),
250    ));
251  }
252  let mut out = Vec::with_capacity(input.len() * 2);
253  for chunk in input.chunks_exact(2) {
254    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
255    let value = f16_bits_to_f32(bits);
256    out.extend_from_slice(&value.to_le_bytes());
257  }
258  Ok(out)
259}
260
261fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
262  if input.len() % 2 != 0 {
263    return Err(ConvertError::InvalidStructure(
264      "bf16 tensor bytes must be even-length".to_string(),
265    ));
266  }
267  let mut out = Vec::with_capacity(input.len() * 2);
268  for chunk in input.chunks_exact(2) {
269    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
270    let value = f32::from_bits((bits as u32) << 16);
271    out.extend_from_slice(&value.to_le_bytes());
272  }
273  Ok(out)
274}
275
276fn f16_bits_to_f32(bits: u16) -> f32 {
277  let sign = ((bits >> 15) & 0x1) as u32;
278  let exp = ((bits >> 10) & 0x1f) as u32;
279  let frac = (bits & 0x03ff) as u32;
280
281  let f32_bits = if exp == 0 {
282    if frac == 0 {
283      sign << 31
284    } else {
285      let mut mant = frac;
286      let mut e = -14i32;
287      while (mant & 0x0400) == 0 {
288        mant <<= 1;
289        e -= 1;
290      }
291      mant &= 0x03ff;
292      let exp32 = (e + 127) as u32;
293      (sign << 31) | (exp32 << 23) | (mant << 13)
294    }
295  } else if exp == 0x1f {
296    (sign << 31) | (0xff << 23) | (frac << 13)
297  } else {
298    let exp32 = (exp as i32 - 15 + 127) as u32;
299    (sign << 31) | (exp32 << 23) | (frac << 13)
300  };
301
302  f32::from_bits(f32_bits)
303}
304
305
306#[cfg(test)]
307mod tests {
308  use super::*;
309  use crate::metadata::{collect_constructor_types, project_value_for_metadata};
310  use crate::types::Value;
311  use std::io::Write;
312  use tempfile::tempdir;
313  use zip::write::SimpleFileOptions;
314  use zip::ZipWriter;
315
316  #[test]
317  fn converts_simple_tensor_checkpoint() {
318    let tmp = tempdir().expect("tmp dir");
319    let pt_path = tmp.path().join("weights.pt");
320    write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
321
322    let out_dir = tmp.path().join("export");
323    let result = convert_pt_to_safetensors(&pt_path, &out_dir, ConvertOptions::default())
324      .expect("conversion should work");
325
326    assert!(result.safetensors_path.exists());
327    assert!(result.model_yaml_path.exists());
328    assert_eq!(result.tensor_count, 1);
329
330    let yaml = fs::read_to_string(result.model_yaml_path).expect("yaml readable");
331    assert!(yaml.contains("layer.weight"));
332    assert!(yaml.contains("dtype: 'F32'"));
333  }
334
335  #[test]
336  fn rejects_unsafe_global_reduce() {
337    let tmp = tempdir().expect("tmp dir");
338    let pt_path = tmp.path().join("unsafe.pt");
339    write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
340
341    let err = convert_pt_to_safetensors(&pt_path, tmp.path(), ConvertOptions::default())
342      .expect_err("unsafe pickle should fail");
343    let msg = err.to_string();
344    assert!(msg.contains("unsupported GLOBAL") || msg.contains("unsafe/unsupported pickle"));
345  }
346
347  #[test]
348  fn projects_object_metadata_with_type_args_and_flattened_state() {
349    let value = Value::Object {
350      module: "ultralytics.nn.tasks".to_string(),
351      name: "DetectionModel".to_string(),
352      args: Some(Box::new(Value::Tuple(vec![
353        Value::String("arg0".to_string()),
354        Value::Int(42),
355      ]))),
356      state: Some(Box::new(Value::Dict(vec![(
357        Value::String("training".to_string()),
358        Value::Bool(false),
359      )]))),
360    };
361
362    let projected = project_value_for_metadata(&value);
363    let mapping = match projected {
364      serde_yaml::Value::Mapping(map) => map,
365      other => panic!("expected mapping, got {:?}", other),
366    };
367
368    let type_key = serde_yaml::Value::String("$type".to_string());
369    let class_key = serde_yaml::Value::String("$class".to_string());
370    let args_key = serde_yaml::Value::String("$args".to_string());
371    let training_key = serde_yaml::Value::String("training".to_string());
372
373    assert_eq!(
374      mapping.get(&type_key),
375      Some(&serde_yaml::Value::String("object".to_string()))
376    );
377    assert_eq!(
378      mapping.get(&class_key),
379      Some(&serde_yaml::Value::String(
380        "ultralytics.nn.tasks.DetectionModel".to_string()
381      ))
382    );
383    assert!(mapping.get(&args_key).is_some());
384    assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
385  }
386
387  #[test]
388  fn omits_empty_object_args() {
389    let value = Value::Object {
390      module: "a".to_string(),
391      name: "B".to_string(),
392      args: Some(Box::new(Value::Tuple(Vec::new()))),
393      state: None,
394    };
395    let projected = project_value_for_metadata(&value);
396    let mapping = match projected {
397      serde_yaml::Value::Mapping(map) => map,
398      other => panic!("expected mapping, got {:?}", other),
399    };
400
401    let args_key = serde_yaml::Value::String("$args".to_string());
402    assert!(!mapping.contains_key(&args_key));
403  }
404
405  #[test]
406  fn collects_constructor_types_deduplicated_in_first_seen_order() {
407    let tree = Value::List(vec![
408      Value::Object {
409        module: "a".to_string(),
410        name: "One".to_string(),
411        args: None,
412        state: None,
413      },
414      Value::Dict(vec![(
415        Value::String("nested".to_string()),
416        Value::Object {
417          module: "b".to_string(),
418          name: "Two".to_string(),
419          args: None,
420          state: None,
421        },
422      )]),
423      Value::Object {
424        module: "a".to_string(),
425        name: "One".to_string(),
426        args: None,
427        state: None,
428      },
429    ]);
430
431    let objects = collect_constructor_types(&tree);
432    assert_eq!(
433      objects,
434      vec!["a.One".to_string(), "b.Two".to_string()]
435    );
436  }
437
438  fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
439    let file = File::create(path)?;
440    let mut zip = ZipWriter::new(file);
441    let options = SimpleFileOptions::default();
442
443    let data_pkl = if unsafe_payload {
444      build_unsafe_pickle()
445    } else {
446      build_safe_pickle()
447    };
448
449    zip.start_file("archive/data.pkl", options)?;
450    zip.write_all(&data_pkl)?;
451
452    let floats = [1.0f32, 2.0, 3.0, 4.0];
453    let mut raw = Vec::new();
454    for value in floats {
455      raw.extend_from_slice(&value.to_le_bytes());
456    }
457
458    zip.start_file("archive/data/0", options)?;
459    zip.write_all(&raw)?;
460    zip.finish()?;
461    Ok(())
462  }
463
464  fn build_safe_pickle() -> Vec<u8> {
465    let mut out = Vec::new();
466    out.extend_from_slice(&[0x80, 0x02]);
467
468    out.push(b'}');
469    out.push(b'(');
470
471    push_binunicode(&mut out, "layer.weight");
472    out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
473
474    out.push(b'(');
475
476    out.push(b'(');
477    push_binunicode(&mut out, "storage");
478    out.extend_from_slice(b"ctorch\nFloatStorage\n");
479    push_binunicode(&mut out, "0");
480    push_binunicode(&mut out, "cpu");
481    out.push(b'K');
482    out.push(4);
483    out.push(b't');
484    out.push(b'Q');
485
486    out.push(b'K');
487    out.push(0);
488
489    out.push(b'(');
490    out.push(b'K');
491    out.push(2);
492    out.push(b'K');
493    out.push(2);
494    out.push(b't');
495
496    out.push(b'(');
497    out.push(b'K');
498    out.push(2);
499    out.push(b'K');
500    out.push(1);
501    out.push(b't');
502
503    out.push(0x89);
504    out.push(b'N');
505
506    out.push(b't');
507    out.push(b'R');
508
509    out.push(b'u');
510    out.push(b'.');
511    out
512  }
513
514  fn build_unsafe_pickle() -> Vec<u8> {
515    let mut out = Vec::new();
516    out.extend_from_slice(&[0x80, 0x02]);
517    out.extend_from_slice(b"cos\nsystem\n");
518    out.push(b'(');
519    push_binunicode(&mut out, "echo hacked");
520    out.push(b't');
521    out.push(b'R');
522    out.push(b'.');
523    out
524  }
525
526  fn push_binunicode(out: &mut Vec<u8>, value: &str) {
527    out.push(b'X');
528    out.extend_from_slice(&(value.len() as u32).to_le_bytes());
529    out.extend_from_slice(value.as_bytes());
530  }
531}