Skip to main content

pt_loader/
lib.rs

1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8pub mod writer;
9
10pub use types::{
11  CheckpointMetadata, CheckpointTensorMetadata, ConvertError, DType, ExportFormat, ExportOptions, ExportResult,
12  LoadOptions, ReconstructSource, Result, StorageRef, TensorArray, TensorData, TensorRef, Value,
13};
14
15use ndarray::{ArrayD, IxDyn};
16use serde::Deserialize;
17use std::collections::{BTreeMap, HashMap};
18use std::fs;
19use std::fs::File;
20use std::io::Read;
21use std::path::Path;
22use std::time::{SystemTime, UNIX_EPOCH};
23use zip::read::ZipArchive;
24
25use extract::{contiguous_stride, extract_state_dict_tensors, numel};
26use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
27use metadata::{collect_constructor_types, project_root_metadata};
28use parser::parse_pickle;
29use types::ParsedCheckpoint;
30use writer::{write_metadata_yaml, write_safetensors};
31
32#[derive(Debug, Clone)]
33pub struct PtCheckpoint {
34  source_sha256: String,
35  warnings: Vec<String>,
36  metadata: CheckpointMetadata,
37  tensors: BTreeMap<String, TensorData>,
38}
39
40impl PtCheckpoint {
41  pub fn load(path: impl AsRef<Path>, opts: LoadOptions) -> Result<Self> {
42    let path = path.as_ref();
43    let parsed = parse_checkpoint(path, &opts)?;
44    let metadata = build_checkpoint_metadata(
45      path.display().to_string(),
46      parsed.source_sha256.clone(),
47      &parsed.metadata,
48      &parsed.objects,
49      &parsed.tensors,
50      "model.safetensors".to_string(),
51    );
52
53    Ok(Self {
54      source_sha256: parsed.source_sha256,
55      warnings: parsed.warnings,
56      metadata,
57      tensors: parsed.tensors,
58    })
59  }
60
61  pub fn from_metadata(metadata: CheckpointMetadata, source: ReconstructSource) -> Result<Self> {
62    let tensors = match source {
63      ReconstructSource::WeightsFile(path) => read_safetensors_tensors(&path)?,
64      ReconstructSource::StateDict(values) => values,
65    };
66
67    validate_metadata_against_tensors(&metadata, &tensors)?;
68
69    Ok(Self {
70      source_sha256: metadata.source_sha256.clone(),
71      warnings: Vec::new(),
72      metadata,
73      tensors,
74    })
75  }
76
77  pub fn metadata(&self) -> &CheckpointMetadata {
78    &self.metadata
79  }
80
81  pub fn source_sha256(&self) -> &str {
82    &self.source_sha256
83  }
84
85  pub fn warnings(&self) -> &[String] {
86    &self.warnings
87  }
88
89  pub fn tensor_count(&self) -> usize {
90    self.tensors.len()
91  }
92
93  #[cfg(feature = "pyo3")]
94  pub(crate) fn raw_tensors(&self) -> &BTreeMap<String, TensorData> {
95    &self.tensors
96  }
97
98  pub fn state_dict(&self) -> Result<BTreeMap<String, TensorArray>> {
99    let mut out = BTreeMap::new();
100    for (name, tensor) in &self.tensors {
101      out.insert(name.clone(), tensor_data_to_array(tensor)?);
102    }
103    Ok(out)
104  }
105
106  pub fn export(&self, out_dir: impl AsRef<Path>, opts: ExportOptions) -> Result<ExportResult> {
107    match opts.format {
108      ExportFormat::Safetensors => {}
109    }
110
111    let out_dir = out_dir.as_ref();
112    fs::create_dir_all(out_dir)?;
113
114    let weights_path = out_dir.join(&opts.weights_filename);
115    if weights_path.exists() && !opts.overwrite {
116      return Err(ConvertError::InvalidStructure(format!(
117        "output already exists: {}",
118        weights_path.display()
119      )));
120    }
121
122    write_safetensors(&weights_path, &self.tensors, &self.source_sha256)?;
123
124    let metadata_path = if opts.include_metadata {
125      let metadata_path = out_dir.join(&opts.metadata_filename);
126      if metadata_path.exists() && !opts.overwrite {
127        return Err(ConvertError::InvalidStructure(format!(
128          "output already exists: {}",
129          metadata_path.display()
130        )));
131      }
132
133      let mut metadata = self.metadata.clone();
134      metadata.safetensors_file = opts.weights_filename.to_string_lossy().into_owned();
135      metadata.created_at_unix = now_unix_secs();
136      metadata.tensor_count = self.tensors.len();
137      metadata.total_tensor_bytes = total_tensor_bytes(&self.tensors);
138      metadata.tensors = tensor_summaries_for_metadata(&self.tensors);
139      write_metadata_yaml(&metadata_path, &metadata)?;
140      Some(metadata_path)
141    } else {
142      None
143    };
144
145    Ok(ExportResult {
146      weights_path,
147      metadata_path,
148      source_sha256: self.source_sha256.clone(),
149      tensor_count: self.tensors.len(),
150      total_tensor_bytes: total_tensor_bytes(&self.tensors),
151    })
152  }
153}
154
155pub(crate) fn parse_checkpoint(path: &Path, opts: &LoadOptions) -> Result<ParsedCheckpoint> {
156  let file = File::open(path)?;
157  let metadata = file.metadata()?;
158  if metadata.len() > opts.max_archive_bytes {
159    return Err(ConvertError::ResourceLimitExceeded(format!(
160      "archive is {} bytes, limit is {}",
161      metadata.len(),
162      opts.max_archive_bytes
163    )));
164  }
165
166  let mut magic = [0u8; 4];
167  let mut fh = File::open(path)?;
168  fh.read_exact(&mut magic)?;
169  if magic != [0x50, 0x4b, 0x03, 0x04] {
170    return Err(ConvertError::UnsupportedFormat(
171      "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
172    ));
173  }
174
175  let source_sha256 = sha256_file(path)?;
176  let mut archive = ZipArchive::new(file)?;
177  let data_pkl_name = find_data_pkl_name(&mut archive)?;
178  let prefix = data_pkl_name
179    .strip_suffix("data.pkl")
180    .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
181    .to_string();
182  let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
183  if pickle_bytes.len() > opts.max_pickle_bytes {
184    return Err(ConvertError::ResourceLimitExceeded(format!(
185      "data.pkl is {} bytes, limit is {}",
186      pickle_bytes.len(),
187      opts.max_pickle_bytes
188    )));
189  }
190
191  let root = parse_pickle(&pickle_bytes, opts)?;
192  let metadata = project_root_metadata(&root);
193  let objects = collect_constructor_types(&root);
194  let tensor_refs = extract_state_dict_tensors(&root)?;
195  if tensor_refs.is_empty() {
196    return Err(ConvertError::InvalidStructure(
197      "no tensors found in checkpoint state_dict".to_string(),
198    ));
199  }
200  if tensor_refs.len() > opts.max_tensor_count {
201    return Err(ConvertError::ResourceLimitExceeded(format!(
202      "tensor count {} exceeds limit {}",
203      tensor_refs.len(),
204      opts.max_tensor_count
205    )));
206  }
207
208  let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
209  for tensor in tensor_refs.values() {
210    let key = &tensor.storage.key;
211    if storage_blobs.contains_key(key) {
212      continue;
213    }
214    let blob = read_storage_blob(&mut archive, &prefix, key)?;
215    let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
216    if blob.len() < required_bytes {
217      return Err(ConvertError::InvalidStructure(format!(
218        "storage {} has {} bytes, expected at least {}",
219        key,
220        blob.len(),
221        required_bytes
222      )));
223    }
224    storage_blobs.insert(key.clone(), blob);
225  }
226
227  let mut tensors = BTreeMap::new();
228  for (name, tensor_ref) in tensor_refs {
229    if opts.strict_contiguous {
230      let expected = contiguous_stride(&tensor_ref.shape);
231      if expected != tensor_ref.stride {
232        return Err(ConvertError::InvalidStructure(format!(
233          "tensor {} has non-contiguous stride {:?}, expected {:?}",
234          name, tensor_ref.stride, expected
235        )));
236      }
237    }
238
239    let elem_size = tensor_ref.storage.dtype.elem_size();
240    let numel = numel(&tensor_ref.shape)?;
241    let start = tensor_ref
242      .offset_elems
243      .checked_mul(elem_size)
244      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
245    let byte_len = numel
246      .checked_mul(elem_size)
247      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
248    if byte_len > opts.max_tensor_bytes {
249      return Err(ConvertError::ResourceLimitExceeded(format!(
250        "tensor {} is {} bytes, limit is {}",
251        name, byte_len, opts.max_tensor_bytes
252      )));
253    }
254    let end = start
255      .checked_add(byte_len)
256      .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
257
258    let storage = storage_blobs
259      .get(&tensor_ref.storage.key)
260      .ok_or_else(|| ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key)))?;
261    if end > storage.len() {
262      return Err(ConvertError::InvalidStructure(format!(
263        "tensor {} slice [{}, {}) is out of storage bounds {}",
264        name,
265        start,
266        end,
267        storage.len()
268      )));
269    }
270
271    let raw = storage[start..end].to_vec();
272    let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
273    tensors.insert(name, normalized);
274  }
275
276  Ok(ParsedCheckpoint {
277    source_sha256,
278    warnings: Vec::new(),
279    tensors,
280    metadata,
281    objects,
282  })
283}
284
285fn build_checkpoint_metadata(
286  source_file: String,
287  source_sha256: String,
288  metadata: &serde_yaml::Value,
289  objects: &[String],
290  tensors: &BTreeMap<String, TensorData>,
291  safetensors_file: String,
292) -> CheckpointMetadata {
293  CheckpointMetadata {
294    format_version: 1,
295    source_file,
296    source_sha256,
297    safetensors_file,
298    created_at_unix: now_unix_secs(),
299    tensor_count: tensors.len(),
300    total_tensor_bytes: total_tensor_bytes(tensors),
301    metadata: metadata.clone(),
302    objects: objects.to_vec(),
303    tensors: tensor_summaries_for_metadata(tensors),
304  }
305}
306
307fn tensor_summaries_for_metadata(tensors: &BTreeMap<String, TensorData>) -> Vec<CheckpointTensorMetadata> {
308  tensors
309    .iter()
310    .map(|(name, tensor)| CheckpointTensorMetadata {
311      name: name.clone(),
312      dtype: tensor.dtype.as_safetensors().to_string(),
313      shape: tensor.shape.clone(),
314      sha256: sha256_hex(&tensor.bytes),
315    })
316    .collect()
317}
318
319fn total_tensor_bytes(tensors: &BTreeMap<String, TensorData>) -> usize {
320  tensors.values().map(|tensor| tensor.bytes.len()).sum()
321}
322
323fn now_unix_secs() -> u64 {
324  SystemTime::now()
325    .duration_since(UNIX_EPOCH)
326    .map(|value| value.as_secs())
327    .unwrap_or(0)
328}
329
330fn validate_metadata_against_tensors(
331  metadata: &CheckpointMetadata,
332  tensors: &BTreeMap<String, TensorData>,
333) -> Result<()> {
334  if metadata.tensor_count != tensors.len() {
335    return Err(ConvertError::InvalidStructure(format!(
336      "metadata tensor_count={} does not match loaded tensor count={}",
337      metadata.tensor_count,
338      tensors.len()
339    )));
340  }
341
342  let tensor_bytes = total_tensor_bytes(tensors);
343  if metadata.total_tensor_bytes != tensor_bytes {
344    return Err(ConvertError::InvalidStructure(format!(
345      "metadata total_tensor_bytes={} does not match loaded tensor bytes={}",
346      metadata.total_tensor_bytes, tensor_bytes
347    )));
348  }
349
350  for item in &metadata.tensors {
351    let Some(tensor) = tensors.get(&item.name) else {
352      return Err(ConvertError::InvalidStructure(format!(
353        "metadata references missing tensor {}",
354        item.name
355      )));
356    };
357    if item.dtype != tensor.dtype.as_safetensors() {
358      return Err(ConvertError::InvalidStructure(format!(
359        "metadata dtype mismatch for {}: {} != {}",
360        item.name,
361        item.dtype,
362        tensor.dtype.as_safetensors()
363      )));
364    }
365    if item.shape != tensor.shape {
366      return Err(ConvertError::InvalidStructure(format!(
367        "metadata shape mismatch for {}",
368        item.name
369      )));
370    }
371    if item.sha256 != sha256_hex(&tensor.bytes) {
372      return Err(ConvertError::InvalidStructure(format!(
373        "metadata sha256 mismatch for {}",
374        item.name
375      )));
376    }
377  }
378
379  Ok(())
380}
381
382#[derive(Debug, Deserialize)]
383struct SafetensorHeaderEntry {
384  dtype: String,
385  shape: Vec<usize>,
386  data_offsets: [usize; 2],
387}
388
389fn read_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, TensorData>> {
390  let file_bytes = fs::read(path)?;
391  if file_bytes.len() < 8 {
392    return Err(ConvertError::InvalidStructure(
393      "safetensors file is too short".to_string(),
394    ));
395  }
396
397  let header_len = u64::from_le_bytes(file_bytes[0..8].try_into().expect("8-byte header"));
398  let header_len = header_len as usize;
399  if file_bytes.len() < 8 + header_len {
400    return Err(ConvertError::InvalidStructure(
401      "safetensors header is truncated".to_string(),
402    ));
403  }
404
405  let header_bytes = &file_bytes[8..8 + header_len];
406  let data = &file_bytes[8 + header_len..];
407  let header: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(header_bytes)?;
408
409  let mut tensors = BTreeMap::new();
410  for (name, value) in header {
411    if name == "__metadata__" {
412      continue;
413    }
414    let entry: SafetensorHeaderEntry = serde_json::from_value(value)?;
415    let Some(dtype) = DType::from_safetensors(&entry.dtype) else {
416      return Err(ConvertError::InvalidStructure(format!(
417        "unsupported safetensors dtype {}",
418        entry.dtype
419      )));
420    };
421
422    let start = entry.data_offsets[0];
423    let end = entry.data_offsets[1];
424    if end < start || end > data.len() {
425      return Err(ConvertError::InvalidStructure(format!(
426        "invalid data_offsets for tensor {}",
427        name
428      )));
429    }
430
431    let expected_size = numel(&entry.shape)?
432      .checked_mul(dtype.elem_size())
433      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
434    if end - start != expected_size {
435      return Err(ConvertError::InvalidStructure(format!(
436        "tensor {} bytes mismatch: {} != {}",
437        name,
438        end - start,
439        expected_size
440      )));
441    }
442
443    tensors.insert(
444      name,
445      TensorData {
446        dtype,
447        shape: entry.shape,
448        bytes: data[start..end].to_vec(),
449      },
450    );
451  }
452
453  if tensors.is_empty() {
454    return Err(ConvertError::InvalidStructure(
455      "no tensors found in safetensors file".to_string(),
456    ));
457  }
458
459  Ok(tensors)
460}
461
462fn tensor_data_to_array(tensor: &TensorData) -> Result<TensorArray> {
463  let shape = IxDyn(&tensor.shape);
464  match tensor.dtype {
465    DType::F16 | DType::BF16 => Err(ConvertError::InvalidStructure(
466      "f16/bf16 should be normalized to f32 before state_dict()".to_string(),
467    )),
468    DType::F32 => {
469      let values = bytes_to_vec::<4, f32>(&tensor.bytes, f32::from_le_bytes)?;
470      Ok(TensorArray::F32(ArrayD::from_shape_vec(shape, values)?))
471    }
472    DType::F64 => {
473      let values = bytes_to_vec::<8, f64>(&tensor.bytes, f64::from_le_bytes)?;
474      Ok(TensorArray::F64(ArrayD::from_shape_vec(shape, values)?))
475    }
476    DType::I8 => {
477      let values = tensor.bytes.iter().map(|v| *v as i8).collect::<Vec<_>>();
478      Ok(TensorArray::I8(ArrayD::from_shape_vec(shape, values)?))
479    }
480    DType::I16 => {
481      let values = bytes_to_vec::<2, i16>(&tensor.bytes, i16::from_le_bytes)?;
482      Ok(TensorArray::I16(ArrayD::from_shape_vec(shape, values)?))
483    }
484    DType::I32 => {
485      let values = bytes_to_vec::<4, i32>(&tensor.bytes, i32::from_le_bytes)?;
486      Ok(TensorArray::I32(ArrayD::from_shape_vec(shape, values)?))
487    }
488    DType::I64 => {
489      let values = bytes_to_vec::<8, i64>(&tensor.bytes, i64::from_le_bytes)?;
490      Ok(TensorArray::I64(ArrayD::from_shape_vec(shape, values)?))
491    }
492    DType::U8 => Ok(TensorArray::U8(ArrayD::from_shape_vec(shape, tensor.bytes.clone())?)),
493    DType::Bool => {
494      let values = tensor.bytes.iter().map(|v| *v != 0).collect::<Vec<_>>();
495      Ok(TensorArray::Bool(ArrayD::from_shape_vec(shape, values)?))
496    }
497  }
498}
499
500fn bytes_to_vec<const N: usize, T>(bytes: &[u8], f: impl Fn([u8; N]) -> T) -> Result<Vec<T>> {
501  if bytes.len() % N != 0 {
502    return Err(ConvertError::InvalidStructure(format!(
503      "tensor bytes are not divisible by {}",
504      N
505    )));
506  }
507
508  Ok(
509    bytes
510      .chunks_exact(N)
511      .map(|chunk| {
512        let mut arr = [0u8; N];
513        arr.copy_from_slice(chunk);
514        f(arr)
515      })
516      .collect(),
517  )
518}
519
520fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
521  match dtype {
522    DType::F16 => Ok(TensorData {
523      dtype: DType::F32,
524      shape,
525      bytes: f16_bytes_to_f32_bytes(&bytes)?,
526    }),
527    DType::BF16 => Ok(TensorData {
528      dtype: DType::F32,
529      shape,
530      bytes: bf16_bytes_to_f32_bytes(&bytes)?,
531    }),
532    _ => Ok(TensorData { dtype, shape, bytes }),
533  }
534}
535
536fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
537  if input.len() % 2 != 0 {
538    return Err(ConvertError::InvalidStructure(
539      "f16 tensor bytes must be even-length".to_string(),
540    ));
541  }
542  let mut out = Vec::with_capacity(input.len() * 2);
543  for chunk in input.chunks_exact(2) {
544    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
545    let value = f16_bits_to_f32(bits);
546    out.extend_from_slice(&value.to_le_bytes());
547  }
548  Ok(out)
549}
550
551fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
552  if input.len() % 2 != 0 {
553    return Err(ConvertError::InvalidStructure(
554      "bf16 tensor bytes must be even-length".to_string(),
555    ));
556  }
557  let mut out = Vec::with_capacity(input.len() * 2);
558  for chunk in input.chunks_exact(2) {
559    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
560    let value = f32::from_bits((bits as u32) << 16);
561    out.extend_from_slice(&value.to_le_bytes());
562  }
563  Ok(out)
564}
565
566fn f16_bits_to_f32(bits: u16) -> f32 {
567  let sign = ((bits >> 15) & 0x1) as u32;
568  let exp = ((bits >> 10) & 0x1f) as u32;
569  let frac = (bits & 0x03ff) as u32;
570
571  let f32_bits = if exp == 0 {
572    if frac == 0 {
573      sign << 31
574    } else {
575      let mut mant = frac;
576      let mut e = -14i32;
577      while (mant & 0x0400) == 0 {
578        mant <<= 1;
579        e -= 1;
580      }
581      mant &= 0x03ff;
582      let exp32 = (e + 127) as u32;
583      (sign << 31) | (exp32 << 23) | (mant << 13)
584    }
585  } else if exp == 0x1f {
586    (sign << 31) | (0xff << 23) | (frac << 13)
587  } else {
588    let exp32 = (exp as i32 - 15 + 127) as u32;
589    (sign << 31) | (exp32 << 23) | (frac << 13)
590  };
591
592  f32::from_bits(f32_bits)
593}
594
595#[cfg(test)]
596mod tests {
597  use super::*;
598  use crate::metadata::{collect_constructor_types, project_value_for_metadata};
599  use crate::types::Value;
600  use std::io::Write;
601  use tempfile::tempdir;
602  use zip::write::SimpleFileOptions;
603  use zip::ZipWriter;
604
605  #[test]
606  fn converts_simple_tensor_checkpoint() {
607    let tmp = tempdir().expect("tmp dir");
608    let pt_path = tmp.path().join("weights.pt");
609    write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
610
611    let out_dir = tmp.path().join("export");
612    let checkpoint = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect("checkpoint load should work");
613    let result = checkpoint
614      .export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&pt_path)))
615      .expect("export should work");
616
617    assert!(result.weights_path.exists());
618    assert!(result.metadata_path.as_ref().expect("metadata path").exists());
619    assert_eq!(result.tensor_count, 1);
620
621    let yaml = fs::read_to_string(result.metadata_path.expect("metadata path")).expect("yaml readable");
622    assert!(yaml.contains("layer.weight"));
623    assert!(yaml.contains("dtype: F32") || yaml.contains("dtype: 'F32'"));
624  }
625
626  #[test]
627  fn rejects_unsafe_global_reduce() {
628    let tmp = tempdir().expect("tmp dir");
629    let pt_path = tmp.path().join("unsafe.pt");
630    write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
631
632    let err = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect_err("unsafe pickle should fail");
633    let msg = err.to_string();
634    assert!(msg.contains("unsupported GLOBAL") || msg.contains("unsafe/unsupported pickle"));
635  }
636
637  #[test]
638  fn projects_object_metadata_with_type_args_and_flattened_state() {
639    let value = Value::Object {
640      module: "ultralytics.nn.tasks".to_string(),
641      name: "DetectionModel".to_string(),
642      args: Some(Box::new(Value::Tuple(vec![
643        Value::String("arg0".to_string()),
644        Value::Int(42),
645      ]))),
646      state: Some(Box::new(Value::Dict(vec![(
647        Value::String("training".to_string()),
648        Value::Bool(false),
649      )]))),
650    };
651
652    let projected = project_value_for_metadata(&value);
653    let mapping = match projected {
654      serde_yaml::Value::Mapping(map) => map,
655      other => panic!("expected mapping, got {:?}", other),
656    };
657
658    let type_key = serde_yaml::Value::String("$type".to_string());
659    let class_key = serde_yaml::Value::String("$class".to_string());
660    let args_key = serde_yaml::Value::String("$args".to_string());
661    let training_key = serde_yaml::Value::String("training".to_string());
662
663    assert_eq!(
664      mapping.get(&type_key),
665      Some(&serde_yaml::Value::String("object".to_string()))
666    );
667    assert_eq!(
668      mapping.get(&class_key),
669      Some(&serde_yaml::Value::String(
670        "ultralytics.nn.tasks.DetectionModel".to_string()
671      ))
672    );
673    assert!(mapping.get(&args_key).is_some());
674    assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
675  }
676
677  #[test]
678  fn omits_empty_object_args() {
679    let value = Value::Object {
680      module: "a".to_string(),
681      name: "B".to_string(),
682      args: Some(Box::new(Value::Tuple(Vec::new()))),
683      state: None,
684    };
685    let projected = project_value_for_metadata(&value);
686    let mapping = match projected {
687      serde_yaml::Value::Mapping(map) => map,
688      other => panic!("expected mapping, got {:?}", other),
689    };
690
691    let args_key = serde_yaml::Value::String("$args".to_string());
692    assert!(!mapping.contains_key(&args_key));
693  }
694
695  #[test]
696  fn collects_constructor_types_deduplicated_in_first_seen_order() {
697    let tree = Value::List(vec![
698      Value::Object {
699        module: "a".to_string(),
700        name: "One".to_string(),
701        args: None,
702        state: None,
703      },
704      Value::Dict(vec![(
705        Value::String("nested".to_string()),
706        Value::Object {
707          module: "b".to_string(),
708          name: "Two".to_string(),
709          args: None,
710          state: None,
711        },
712      )]),
713      Value::Object {
714        module: "a".to_string(),
715        name: "One".to_string(),
716        args: None,
717        state: None,
718      },
719    ]);
720
721    let objects = collect_constructor_types(&tree);
722    assert_eq!(objects, vec!["a.One".to_string(), "b.Two".to_string()]);
723  }
724
725  fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
726    let file = File::create(path)?;
727    let mut zip = ZipWriter::new(file);
728    let options = SimpleFileOptions::default();
729
730    let data_pkl = if unsafe_payload {
731      build_unsafe_pickle()
732    } else {
733      build_safe_pickle()
734    };
735
736    zip.start_file("archive/data.pkl", options)?;
737    zip.write_all(&data_pkl)?;
738
739    let floats = [1.0f32, 2.0, 3.0, 4.0];
740    let mut raw = Vec::new();
741    for value in floats {
742      raw.extend_from_slice(&value.to_le_bytes());
743    }
744
745    zip.start_file("archive/data/0", options)?;
746    zip.write_all(&raw)?;
747    zip.finish()?;
748    Ok(())
749  }
750
751  fn build_safe_pickle() -> Vec<u8> {
752    let mut out = Vec::new();
753    out.extend_from_slice(&[0x80, 0x02]);
754
755    out.push(b'}');
756    out.push(b'(');
757
758    push_binunicode(&mut out, "layer.weight");
759    out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
760
761    out.push(b'(');
762
763    out.push(b'(');
764    push_binunicode(&mut out, "storage");
765    out.extend_from_slice(b"ctorch\nFloatStorage\n");
766    push_binunicode(&mut out, "0");
767    push_binunicode(&mut out, "cpu");
768    out.push(b'K');
769    out.push(4);
770    out.push(b't');
771    out.push(b'Q');
772
773    out.push(b'K');
774    out.push(0);
775
776    out.push(b'(');
777    out.push(b'K');
778    out.push(2);
779    out.push(b'K');
780    out.push(2);
781    out.push(b't');
782
783    out.push(b'(');
784    out.push(b'K');
785    out.push(2);
786    out.push(b'K');
787    out.push(1);
788    out.push(b't');
789
790    out.push(0x89);
791    out.push(b'N');
792
793    out.push(b't');
794    out.push(b'R');
795
796    out.push(b'u');
797    out.push(b'.');
798    out
799  }
800
801  fn build_unsafe_pickle() -> Vec<u8> {
802    let mut out = Vec::new();
803    out.extend_from_slice(&[0x80, 0x02]);
804    out.extend_from_slice(b"cos\nsystem\n");
805    out.push(b'(');
806    push_binunicode(&mut out, "echo hacked");
807    out.push(b't');
808    out.push(b'R');
809    out.push(b'.');
810    out
811  }
812
813  fn push_binunicode(out: &mut Vec<u8>, value: &str) {
814    out.push(b'X');
815    out.extend_from_slice(&(value.len() as u32).to_le_bytes());
816    out.extend_from_slice(value.as_bytes());
817  }
818}