Skip to main content

pt_loader/
lib.rs

1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8pub mod writer;
9
10pub use types::{
11  CheckpointMetadata, CheckpointSecurity, CheckpointTensorMetadata, ConvertError, DType, ExportFormat,
12  ExportOptions, ExportResult, TensorManifest,
13  LoadOptions, ReconstructSource, Result, StorageRef, TensorArray, TensorData, TensorRef, Value,
14};
15
16use ndarray::{ArrayD, IxDyn};
17use serde::Deserialize;
18use std::collections::{BTreeMap, HashMap};
19use std::fs;
20use std::fs::File;
21use std::io::Read;
22use std::path::Path;
23use std::time::{SystemTime, UNIX_EPOCH};
24use zip::read::ZipArchive;
25
26use extract::{contiguous_stride, extract_state_dict_tensors, numel};
27use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
28use metadata::{collect_call_types, collect_constructor_types, project_root_metadata};
29use parser::parse_pickle;
30use types::ParsedCheckpoint;
31use writer::{write_metadata_yaml, write_safetensors};
32
33#[derive(Debug, Clone)]
34pub struct PtCheckpoint {
35  source_sha256: String,
36  warnings: Vec<String>,
37  metadata: CheckpointMetadata,
38  tensors: BTreeMap<String, TensorData>,
39  tensor_groups: BTreeMap<String, BTreeMap<String, TensorData>>,
40}
41
42impl PtCheckpoint {
43  pub fn load(path: impl AsRef<Path>, opts: LoadOptions) -> Result<Self> {
44    let path = path.as_ref();
45    let parsed = parse_checkpoint(path, &opts)?;
46    let metadata = build_checkpoint_metadata(
47      path.display().to_string(),
48      parsed.source_sha256.clone(),
49      &parsed.metadata,
50      &parsed.security,
51      &parsed.tensors,
52      "model.safetensors".to_string(),
53    );
54
55    Ok(Self {
56      source_sha256: parsed.source_sha256,
57      warnings: parsed.warnings,
58      metadata,
59      tensors: parsed.tensors,
60      tensor_groups: parsed.tensor_groups,
61    })
62  }
63
64  pub fn from_metadata(metadata: CheckpointMetadata, source: ReconstructSource) -> Result<Self> {
65    let tensors = match source {
66      ReconstructSource::WeightsFile(path) => read_safetensors_tensors(&path)?,
67      ReconstructSource::StateDict(values) => values,
68    };
69
70    validate_metadata_against_tensors(&metadata, &tensors)?;
71    let mut tensor_groups = BTreeMap::new();
72    tensor_groups.insert("root".to_string(), tensors.clone());
73
74    Ok(Self {
75      source_sha256: metadata.source_sha256.clone(),
76      warnings: Vec::new(),
77      metadata,
78      tensors,
79      tensor_groups,
80    })
81  }
82
83  pub fn metadata(&self) -> &CheckpointMetadata {
84    &self.metadata
85  }
86
87  pub fn source_sha256(&self) -> &str {
88    &self.source_sha256
89  }
90
91  pub fn warnings(&self) -> &[String] {
92    &self.warnings
93  }
94
95  pub fn tensor_count(&self) -> usize {
96    self.tensors.len()
97  }
98
99  #[cfg(feature = "pyo3")]
100  pub(crate) fn raw_tensors(&self) -> &BTreeMap<String, TensorData> {
101    &self.tensors
102  }
103
104  pub fn state_dict(&self) -> Result<BTreeMap<String, TensorArray>> {
105    let mut out = BTreeMap::new();
106    for (name, tensor) in &self.tensors {
107      out.insert(name.clone(), tensor_data_to_array(tensor)?);
108    }
109    Ok(out)
110  }
111
112  pub fn export(&self, out_dir: impl AsRef<Path>, opts: ExportOptions) -> Result<ExportResult> {
113    match opts.format {
114      ExportFormat::Safetensors => {}
115    }
116
117    let out_dir = out_dir.as_ref();
118    fs::create_dir_all(out_dir)?;
119
120    let is_multi_root = self.tensor_groups.len() > 1 || !self.tensor_groups.contains_key("root");
121    let weights_path = out_dir.join(&opts.weights_filename);
122    let mut weights_paths = BTreeMap::new();
123    if is_multi_root {
124      for (root_key, tensors) in &self.tensor_groups {
125        let file_name = with_root_key_suffix(&opts.weights_filename, root_key)?;
126        let path = out_dir.join(&file_name);
127        if path.exists() && !opts.overwrite {
128          return Err(ConvertError::InvalidStructure(format!(
129            "output already exists: {}",
130            path.display()
131          )));
132        }
133        write_safetensors(&path, tensors, &self.source_sha256)?;
134        weights_paths.insert(root_key.clone(), path);
135      }
136    } else {
137      if weights_path.exists() && !opts.overwrite {
138        return Err(ConvertError::InvalidStructure(format!(
139          "output already exists: {}",
140          weights_path.display()
141        )));
142      }
143      write_safetensors(&weights_path, &self.tensors, &self.source_sha256)?;
144      weights_paths.insert("root".to_string(), weights_path.clone());
145    }
146
147    let metadata_path = if opts.include_metadata {
148      let metadata_path = out_dir.join(&opts.metadata_filename);
149      if metadata_path.exists() && !opts.overwrite {
150        return Err(ConvertError::InvalidStructure(format!(
151          "output already exists: {}",
152          metadata_path.display()
153        )));
154      }
155
156      let mut metadata = self.metadata.clone();
157      if is_multi_root {
158        metadata.safetensors_file.clear();
159        metadata.safetensors_files = weights_paths
160          .iter()
161          .map(|(key, path)| (key.clone(), file_name_or_path(path)))
162          .collect();
163        metadata.tensors = TensorManifest::ByRoot(
164          self
165            .tensor_groups
166            .iter()
167            .map(|(key, tensors)| (key.clone(), tensor_summaries_for_metadata(tensors)))
168            .collect(),
169        );
170      } else {
171        metadata.safetensors_file = opts.weights_filename.to_string_lossy().into_owned();
172        metadata.safetensors_files.clear();
173        metadata.tensors = TensorManifest::List(tensor_summaries_for_metadata(&self.tensors));
174      }
175      metadata.created_at_unix = now_unix_secs();
176      metadata.tensor_count = self.tensors.len();
177      metadata.total_tensor_bytes = total_tensor_bytes(&self.tensors);
178      write_metadata_yaml(&metadata_path, &metadata)?;
179      Some(metadata_path)
180    } else {
181      None
182    };
183
184    Ok(ExportResult {
185      weights_path,
186      weights_paths,
187      metadata_path,
188      source_sha256: self.source_sha256.clone(),
189      tensor_count: self.tensors.len(),
190      total_tensor_bytes: total_tensor_bytes(&self.tensors),
191    })
192  }
193}
194
195pub(crate) fn parse_checkpoint(path: &Path, opts: &LoadOptions) -> Result<ParsedCheckpoint> {
196  let file = File::open(path)?;
197  let metadata = file.metadata()?;
198  if metadata.len() > opts.max_archive_bytes {
199    return Err(ConvertError::ResourceLimitExceeded(format!(
200      "archive is {} bytes, limit is {}",
201      metadata.len(),
202      opts.max_archive_bytes
203    )));
204  }
205
206  let mut magic = [0u8; 4];
207  let mut fh = File::open(path)?;
208  fh.read_exact(&mut magic)?;
209  if magic != [0x50, 0x4b, 0x03, 0x04] {
210    return Err(ConvertError::UnsupportedFormat(
211      "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
212    ));
213  }
214
215  let source_sha256 = sha256_file(path)?;
216  let mut archive = ZipArchive::new(file)?;
217  let data_pkl_name = find_data_pkl_name(&mut archive)?;
218  let prefix = data_pkl_name
219    .strip_suffix("data.pkl")
220    .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
221    .to_string();
222  let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
223  if pickle_bytes.len() > opts.max_pickle_bytes {
224    return Err(ConvertError::ResourceLimitExceeded(format!(
225      "data.pkl is {} bytes, limit is {}",
226      pickle_bytes.len(),
227      opts.max_pickle_bytes
228    )));
229  }
230
231  let root = parse_pickle(&pickle_bytes, opts)?;
232  let metadata = project_root_metadata(&root);
233  let objects = collect_constructor_types(&root);
234  let calls = collect_call_types(&root);
235  let tensor_ref_groups = extract_state_dict_tensors(&root, opts)?;
236  if tensor_ref_groups.is_empty() {
237    return Err(ConvertError::InvalidStructure(
238      "no tensors found in checkpoint state_dict".to_string(),
239    ));
240  }
241  let tensor_ref_count = tensor_ref_groups.values().map(|group| group.len()).sum::<usize>();
242  if tensor_ref_count > opts.max_tensor_count {
243    return Err(ConvertError::ResourceLimitExceeded(format!(
244      "tensor count {} exceeds limit {}",
245      tensor_ref_count,
246      opts.max_tensor_count
247    )));
248  }
249
250  let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
251  for tensor_refs in tensor_ref_groups.values() {
252    for tensor in tensor_refs.values() {
253      let key = &tensor.storage.key;
254      if storage_blobs.contains_key(key) {
255        continue;
256      }
257      let blob = read_storage_blob(&mut archive, &prefix, key)?;
258      let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
259      if blob.len() < required_bytes {
260        return Err(ConvertError::InvalidStructure(format!(
261          "storage {} has {} bytes, expected at least {}",
262          key,
263          blob.len(),
264          required_bytes
265        )));
266      }
267      storage_blobs.insert(key.clone(), blob);
268    }
269  }
270
271  let mut tensors = BTreeMap::new();
272  let mut tensor_groups = BTreeMap::new();
273  for (root_key, tensor_refs) in tensor_ref_groups {
274    let mut group_tensors = BTreeMap::new();
275    for (name, tensor_ref) in tensor_refs {
276      if opts.strict_contiguous {
277        let expected = contiguous_stride(&tensor_ref.shape);
278        if expected != tensor_ref.stride {
279          return Err(ConvertError::InvalidStructure(format!(
280            "tensor {} has non-contiguous stride {:?}, expected {:?}",
281            name, tensor_ref.stride, expected
282          )));
283        }
284      }
285
286      let elem_size = tensor_ref.storage.dtype.elem_size();
287      let numel = numel(&tensor_ref.shape)?;
288      let start = tensor_ref
289        .offset_elems
290        .checked_mul(elem_size)
291        .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
292      let byte_len = numel
293        .checked_mul(elem_size)
294        .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
295      if byte_len > opts.max_tensor_bytes {
296        return Err(ConvertError::ResourceLimitExceeded(format!(
297          "tensor {} is {} bytes, limit is {}",
298          name, byte_len, opts.max_tensor_bytes
299        )));
300      }
301      let end = start
302        .checked_add(byte_len)
303        .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
304
305      let storage = storage_blobs
306        .get(&tensor_ref.storage.key)
307        .ok_or_else(|| ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key)))?;
308      if end > storage.len() {
309        return Err(ConvertError::InvalidStructure(format!(
310          "tensor {} slice [{}, {}) is out of storage bounds {}",
311          name,
312          start,
313          end,
314          storage.len()
315        )));
316      }
317
318      let raw = storage[start..end].to_vec();
319      let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
320      group_tensors.insert(name.clone(), normalized.clone());
321      let merged_name = merge_root_tensor_name(&root_key, &name);
322      tensors.insert(merged_name, normalized);
323    }
324    tensor_groups.insert(root_key, group_tensors);
325  }
326
327  Ok(ParsedCheckpoint {
328    source_sha256,
329    warnings: Vec::new(),
330    tensors,
331    tensor_groups,
332    metadata,
333    security: CheckpointSecurity { objects, calls },
334  })
335}
336
337fn build_checkpoint_metadata(
338  source_file: String,
339  source_sha256: String,
340  metadata: &serde_yaml::Value,
341  security: &CheckpointSecurity,
342  tensors: &BTreeMap<String, TensorData>,
343  safetensors_file: String,
344) -> CheckpointMetadata {
345  CheckpointMetadata {
346    format_version: 1,
347    source_file,
348    source_sha256,
349    safetensors_file,
350    safetensors_files: BTreeMap::new(),
351    created_at_unix: now_unix_secs(),
352    tensor_count: tensors.len(),
353    total_tensor_bytes: total_tensor_bytes(tensors),
354    metadata: metadata.clone(),
355    security: security.clone(),
356    tensors: TensorManifest::List(tensor_summaries_for_metadata(tensors)),
357  }
358}
359
360fn tensor_summaries_for_metadata(tensors: &BTreeMap<String, TensorData>) -> Vec<CheckpointTensorMetadata> {
361  tensors
362    .iter()
363    .map(|(name, tensor)| CheckpointTensorMetadata {
364      name: name.clone(),
365      dtype: tensor.dtype.as_safetensors().to_string(),
366      shape: tensor.shape.clone(),
367      sha256: sha256_hex(&tensor.bytes),
368    })
369    .collect()
370}
371
372fn total_tensor_bytes(tensors: &BTreeMap<String, TensorData>) -> usize {
373  tensors.values().map(|tensor| tensor.bytes.len()).sum()
374}
375
376fn file_name_or_path(path: &Path) -> String {
377  path
378    .file_name()
379    .map(|name| name.to_string_lossy().into_owned())
380    .unwrap_or_else(|| path.display().to_string())
381}
382
383fn merge_root_tensor_name(root: &str, name: &str) -> String {
384  if root == "root" || name == root || name.starts_with(&format!("{root}.")) {
385    name.to_string()
386  } else {
387    format!("{root}.{name}")
388  }
389}
390
391fn with_root_key_suffix(base: &Path, root_key: &str) -> Result<std::path::PathBuf> {
392  let ext = base
393    .extension()
394    .map(|value| value.to_string_lossy().into_owned())
395    .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no extension".to_string()))?;
396  let stem = base
397    .file_stem()
398    .map(|value| value.to_string_lossy().into_owned())
399    .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no stem".to_string()))?;
400  Ok(std::path::PathBuf::from(format!("{stem}.{root_key}.{ext}")))
401}
402
403fn now_unix_secs() -> u64 {
404  SystemTime::now()
405    .duration_since(UNIX_EPOCH)
406    .map(|value| value.as_secs())
407    .unwrap_or(0)
408}
409
410fn validate_metadata_against_tensors(
411  metadata: &CheckpointMetadata,
412  tensors: &BTreeMap<String, TensorData>,
413) -> Result<()> {
414  if metadata.tensor_count != tensors.len() {
415    return Err(ConvertError::InvalidStructure(format!(
416      "metadata tensor_count={} does not match loaded tensor count={}",
417      metadata.tensor_count,
418      tensors.len()
419    )));
420  }
421
422  let tensor_bytes = total_tensor_bytes(tensors);
423  if metadata.total_tensor_bytes != tensor_bytes {
424    return Err(ConvertError::InvalidStructure(format!(
425      "metadata total_tensor_bytes={} does not match loaded tensor bytes={}",
426      metadata.total_tensor_bytes, tensor_bytes
427    )));
428  }
429
430  let flat_manifest = match &metadata.tensors {
431    TensorManifest::List(items) => items.iter().map(|item| (item.name.clone(), item)).collect::<Vec<_>>(),
432    TensorManifest::ByRoot(groups) => groups
433      .iter()
434      .flat_map(|(root, items)| {
435        items
436          .iter()
437          .map(move |item| (merge_root_tensor_name(root, &item.name), item))
438      })
439      .collect::<Vec<_>>(),
440  };
441  for (name, item) in flat_manifest {
442    let Some(tensor) = tensors.get(&name) else {
443      return Err(ConvertError::InvalidStructure(format!(
444        "metadata references missing tensor {}",
445        name
446      )));
447    };
448    if item.dtype != tensor.dtype.as_safetensors() {
449      return Err(ConvertError::InvalidStructure(format!(
450        "metadata dtype mismatch for {}: {} != {}",
451        name,
452        item.dtype,
453        tensor.dtype.as_safetensors()
454      )));
455    }
456    if item.shape != tensor.shape {
457      return Err(ConvertError::InvalidStructure(format!(
458        "metadata shape mismatch for {}",
459        name
460      )));
461    }
462    if item.sha256 != sha256_hex(&tensor.bytes) {
463      return Err(ConvertError::InvalidStructure(format!(
464        "metadata sha256 mismatch for {}",
465        name
466      )));
467    }
468  }
469
470  Ok(())
471}
472
473#[derive(Debug, Deserialize)]
474struct SafetensorHeaderEntry {
475  dtype: String,
476  shape: Vec<usize>,
477  data_offsets: [usize; 2],
478}
479
480fn read_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, TensorData>> {
481  let file_bytes = fs::read(path)?;
482  if file_bytes.len() < 8 {
483    return Err(ConvertError::InvalidStructure(
484      "safetensors file is too short".to_string(),
485    ));
486  }
487
488  let header_len = u64::from_le_bytes(file_bytes[0..8].try_into().expect("8-byte header"));
489  let header_len = header_len as usize;
490  if file_bytes.len() < 8 + header_len {
491    return Err(ConvertError::InvalidStructure(
492      "safetensors header is truncated".to_string(),
493    ));
494  }
495
496  let header_bytes = &file_bytes[8..8 + header_len];
497  let data = &file_bytes[8 + header_len..];
498  let header: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(header_bytes)?;
499
500  let mut tensors = BTreeMap::new();
501  for (name, value) in header {
502    if name == "__metadata__" {
503      continue;
504    }
505    let entry: SafetensorHeaderEntry = serde_json::from_value(value)?;
506    let Some(dtype) = DType::from_safetensors(&entry.dtype) else {
507      return Err(ConvertError::InvalidStructure(format!(
508        "unsupported safetensors dtype {}",
509        entry.dtype
510      )));
511    };
512
513    let start = entry.data_offsets[0];
514    let end = entry.data_offsets[1];
515    if end < start || end > data.len() {
516      return Err(ConvertError::InvalidStructure(format!(
517        "invalid data_offsets for tensor {}",
518        name
519      )));
520    }
521
522    let expected_size = numel(&entry.shape)?
523      .checked_mul(dtype.elem_size())
524      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
525    if end - start != expected_size {
526      return Err(ConvertError::InvalidStructure(format!(
527        "tensor {} bytes mismatch: {} != {}",
528        name,
529        end - start,
530        expected_size
531      )));
532    }
533
534    tensors.insert(
535      name,
536      TensorData {
537        dtype,
538        shape: entry.shape,
539        bytes: data[start..end].to_vec(),
540      },
541    );
542  }
543
544  if tensors.is_empty() {
545    return Err(ConvertError::InvalidStructure(
546      "no tensors found in safetensors file".to_string(),
547    ));
548  }
549
550  Ok(tensors)
551}
552
553fn tensor_data_to_array(tensor: &TensorData) -> Result<TensorArray> {
554  let shape = IxDyn(&tensor.shape);
555  match tensor.dtype {
556    DType::F16 | DType::BF16 => Err(ConvertError::InvalidStructure(
557      "f16/bf16 should be normalized to f32 before state_dict()".to_string(),
558    )),
559    DType::F32 => {
560      let values = bytes_to_vec::<4, f32>(&tensor.bytes, f32::from_le_bytes)?;
561      Ok(TensorArray::F32(ArrayD::from_shape_vec(shape, values)?))
562    }
563    DType::F64 => {
564      let values = bytes_to_vec::<8, f64>(&tensor.bytes, f64::from_le_bytes)?;
565      Ok(TensorArray::F64(ArrayD::from_shape_vec(shape, values)?))
566    }
567    DType::I8 => {
568      let values = tensor.bytes.iter().map(|v| *v as i8).collect::<Vec<_>>();
569      Ok(TensorArray::I8(ArrayD::from_shape_vec(shape, values)?))
570    }
571    DType::I16 => {
572      let values = bytes_to_vec::<2, i16>(&tensor.bytes, i16::from_le_bytes)?;
573      Ok(TensorArray::I16(ArrayD::from_shape_vec(shape, values)?))
574    }
575    DType::I32 => {
576      let values = bytes_to_vec::<4, i32>(&tensor.bytes, i32::from_le_bytes)?;
577      Ok(TensorArray::I32(ArrayD::from_shape_vec(shape, values)?))
578    }
579    DType::I64 => {
580      let values = bytes_to_vec::<8, i64>(&tensor.bytes, i64::from_le_bytes)?;
581      Ok(TensorArray::I64(ArrayD::from_shape_vec(shape, values)?))
582    }
583    DType::U8 => Ok(TensorArray::U8(ArrayD::from_shape_vec(shape, tensor.bytes.clone())?)),
584    DType::Bool => {
585      let values = tensor.bytes.iter().map(|v| *v != 0).collect::<Vec<_>>();
586      Ok(TensorArray::Bool(ArrayD::from_shape_vec(shape, values)?))
587    }
588  }
589}
590
591fn bytes_to_vec<const N: usize, T>(bytes: &[u8], f: impl Fn([u8; N]) -> T) -> Result<Vec<T>> {
592  if bytes.len() % N != 0 {
593    return Err(ConvertError::InvalidStructure(format!(
594      "tensor bytes are not divisible by {}",
595      N
596    )));
597  }
598
599  Ok(
600    bytes
601      .chunks_exact(N)
602      .map(|chunk| {
603        let mut arr = [0u8; N];
604        arr.copy_from_slice(chunk);
605        f(arr)
606      })
607      .collect(),
608  )
609}
610
611fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
612  match dtype {
613    DType::F16 => Ok(TensorData {
614      dtype: DType::F32,
615      shape,
616      bytes: f16_bytes_to_f32_bytes(&bytes)?,
617    }),
618    DType::BF16 => Ok(TensorData {
619      dtype: DType::F32,
620      shape,
621      bytes: bf16_bytes_to_f32_bytes(&bytes)?,
622    }),
623    _ => Ok(TensorData { dtype, shape, bytes }),
624  }
625}
626
627fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
628  if input.len() % 2 != 0 {
629    return Err(ConvertError::InvalidStructure(
630      "f16 tensor bytes must be even-length".to_string(),
631    ));
632  }
633  let mut out = Vec::with_capacity(input.len() * 2);
634  for chunk in input.chunks_exact(2) {
635    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
636    let value = f16_bits_to_f32(bits);
637    out.extend_from_slice(&value.to_le_bytes());
638  }
639  Ok(out)
640}
641
642fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
643  if input.len() % 2 != 0 {
644    return Err(ConvertError::InvalidStructure(
645      "bf16 tensor bytes must be even-length".to_string(),
646    ));
647  }
648  let mut out = Vec::with_capacity(input.len() * 2);
649  for chunk in input.chunks_exact(2) {
650    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
651    let value = f32::from_bits((bits as u32) << 16);
652    out.extend_from_slice(&value.to_le_bytes());
653  }
654  Ok(out)
655}
656
657fn f16_bits_to_f32(bits: u16) -> f32 {
658  let sign = ((bits >> 15) & 0x1) as u32;
659  let exp = ((bits >> 10) & 0x1f) as u32;
660  let frac = (bits & 0x03ff) as u32;
661
662  let f32_bits = if exp == 0 {
663    if frac == 0 {
664      sign << 31
665    } else {
666      let mut mant = frac;
667      let mut e = -14i32;
668      while (mant & 0x0400) == 0 {
669        mant <<= 1;
670        e -= 1;
671      }
672      mant &= 0x03ff;
673      let exp32 = (e + 127) as u32;
674      (sign << 31) | (exp32 << 23) | (mant << 13)
675    }
676  } else if exp == 0x1f {
677    (sign << 31) | (0xff << 23) | (frac << 13)
678  } else {
679    let exp32 = (exp as i32 - 15 + 127) as u32;
680    (sign << 31) | (exp32 << 23) | (frac << 13)
681  };
682
683  f32::from_bits(f32_bits)
684}
685
686#[cfg(test)]
687mod tests {
688  use super::*;
689  use crate::metadata::{collect_call_types, collect_constructor_types, project_value_for_metadata};
690  use crate::types::Value;
691  use std::io::Write;
692  use tempfile::tempdir;
693  use zip::write::SimpleFileOptions;
694  use zip::ZipWriter;
695
696  #[test]
697  fn converts_simple_tensor_checkpoint() {
698    let tmp = tempdir().expect("tmp dir");
699    let pt_path = tmp.path().join("weights.pt");
700    write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
701
702    let out_dir = tmp.path().join("export");
703    let checkpoint = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect("checkpoint load should work");
704    let result = checkpoint
705      .export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&pt_path)))
706      .expect("export should work");
707
708    assert!(result.weights_path.exists());
709    assert!(result.metadata_path.as_ref().expect("metadata path").exists());
710    assert_eq!(result.tensor_count, 1);
711
712    let yaml = fs::read_to_string(result.metadata_path.expect("metadata path")).expect("yaml readable");
713    assert!(yaml.contains("layer.weight"));
714    assert!(yaml.contains("dtype: F32") || yaml.contains("dtype: 'F32'"));
715    assert!(yaml.contains("security:"));
716    assert!(yaml.contains("objects: []"));
717    assert!(yaml.contains("calls: []"));
718  }
719
720  #[test]
721  fn rejects_unsafe_global_reduce() {
722    let tmp = tempdir().expect("tmp dir");
723    let pt_path = tmp.path().join("unsafe.pt");
724    write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
725
726    let err = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect_err("unsafe pickle should fail");
727    let msg = err.to_string();
728    assert!(msg.contains("could not find a tensor state_dict"));
729  }
730
731  #[test]
732  fn projects_object_metadata_with_type_args_and_flattened_state() {
733    let value = Value::Object {
734      module: "ultralytics.nn.tasks".to_string(),
735      name: "DetectionModel".to_string(),
736      args: vec![
737        Value::String("arg0".to_string()),
738        Value::Int(42),
739      ],
740      state: Some(Box::new(Value::Dict(vec![(
741        Value::String("training".to_string()),
742        Value::Bool(false),
743      )]))),
744    };
745
746    let projected = project_value_for_metadata(&value);
747    let mapping = match projected {
748      serde_yaml::Value::Mapping(map) => map,
749      other => panic!("expected mapping, got {:?}", other),
750    };
751
752    let type_key = serde_yaml::Value::String("$type".to_string());
753    let class_key = serde_yaml::Value::String("$class".to_string());
754    let args_key = serde_yaml::Value::String("$args".to_string());
755    let training_key = serde_yaml::Value::String("training".to_string());
756
757    assert_eq!(
758      mapping.get(&type_key),
759      Some(&serde_yaml::Value::String("object".to_string()))
760    );
761    assert_eq!(
762      mapping.get(&class_key),
763      Some(&serde_yaml::Value::String(
764        "ultralytics.nn.tasks.DetectionModel".to_string()
765      ))
766    );
767    assert!(mapping.get(&args_key).is_some());
768    assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
769  }
770
771  #[test]
772  fn omits_empty_object_args() {
773    let value = Value::Object {
774      module: "a".to_string(),
775      name: "B".to_string(),
776      args: Vec::new(),
777      state: None,
778    };
779    let projected = project_value_for_metadata(&value);
780    let mapping = match projected {
781      serde_yaml::Value::Mapping(map) => map,
782      other => panic!("expected mapping, got {:?}", other),
783    };
784
785    let args_key = serde_yaml::Value::String("$args".to_string());
786    assert!(!mapping.contains_key(&args_key));
787  }
788
789  #[test]
790  fn collects_constructor_types_deduplicated_in_first_seen_order() {
791    let tree = Value::List(vec![
792      Value::Object {
793        module: "a".to_string(),
794        name: "One".to_string(),
795        args: Vec::new(),
796        state: None,
797      },
798      Value::Dict(vec![(
799        Value::String("nested".to_string()),
800        Value::Object {
801          module: "b".to_string(),
802          name: "Two".to_string(),
803          args: Vec::new(),
804          state: None,
805        },
806      )]),
807      Value::Object {
808        module: "a".to_string(),
809        name: "One".to_string(),
810        args: Vec::new(),
811        state: None,
812      },
813    ]);
814
815    let objects = collect_constructor_types(&tree);
816    assert_eq!(objects, vec!["a.One".to_string(), "b.Two".to_string()]);
817  }
818
819  #[test]
820  fn collects_call_types_deduplicated_in_first_seen_order() {
821    let tree = Value::List(vec![
822      Value::Call {
823        func: "a.fn".to_string(),
824        args: vec![Value::String("x".to_string())],
825        state: None,
826      },
827      Value::Object {
828        module: "m".to_string(),
829        name: "N".to_string(),
830        args: vec![Value::Call {
831          func: "b.fn".to_string(),
832          args: Vec::new(),
833          state: None,
834        }],
835        state: Some(Box::new(Value::Call {
836          func: "a.fn".to_string(),
837          args: Vec::new(),
838          state: None,
839        })),
840      },
841    ]);
842
843    let calls = collect_call_types(&tree);
844    assert_eq!(calls, vec!["a.fn".to_string(), "b.fn".to_string()]);
845  }
846
847  #[test]
848  fn projects_call_metadata() {
849    let value = Value::Call {
850      func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
851      args: vec![Value::String("x".to_string()), Value::Int(1)],
852      state: None,
853    };
854
855    let projected = project_value_for_metadata(&value);
856    let mapping = match projected {
857      serde_yaml::Value::Mapping(map) => map,
858      other => panic!("expected mapping, got {:?}", other),
859    };
860
861    let type_key = serde_yaml::Value::String("$type".to_string());
862    let func_key = serde_yaml::Value::String("$func".to_string());
863    let args_key = serde_yaml::Value::String("$args".to_string());
864    assert_eq!(
865      mapping.get(&type_key),
866      Some(&serde_yaml::Value::String("call".to_string()))
867    );
868    assert_eq!(
869      mapping.get(&func_key),
870      Some(&serde_yaml::Value::String(
871        "ultralytics.utils.IterableSimpleNamespace".to_string()
872      ))
873    );
874    assert!(matches!(
875      mapping.get(&args_key),
876      Some(serde_yaml::Value::Sequence(items)) if items.len() == 2
877    ));
878  }
879
880  #[test]
881  fn projects_call_metadata_with_state() {
882    let value = Value::Call {
883      func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
884      args: vec![Value::String("x".to_string())],
885      state: Some(Box::new(Value::Dict(vec![(
886        Value::String("k".to_string()),
887        Value::String("v".to_string()),
888      )]))),
889    };
890
891    let projected = project_value_for_metadata(&value);
892    let mapping = match projected {
893      serde_yaml::Value::Mapping(map) => map,
894      other => panic!("expected mapping, got {:?}", other),
895    };
896
897    let state_key = serde_yaml::Value::String("$state".to_string());
898    assert!(matches!(mapping.get(&state_key), Some(serde_yaml::Value::Mapping(_))));
899  }
900
901  fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
902    let file = File::create(path)?;
903    let mut zip = ZipWriter::new(file);
904    let options = SimpleFileOptions::default();
905
906    let data_pkl = if unsafe_payload {
907      build_unsafe_pickle()
908    } else {
909      build_safe_pickle()
910    };
911
912    zip.start_file("archive/data.pkl", options)?;
913    zip.write_all(&data_pkl)?;
914
915    let floats = [1.0f32, 2.0, 3.0, 4.0];
916    let mut raw = Vec::new();
917    for value in floats {
918      raw.extend_from_slice(&value.to_le_bytes());
919    }
920
921    zip.start_file("archive/data/0", options)?;
922    zip.write_all(&raw)?;
923    zip.finish()?;
924    Ok(())
925  }
926
927  fn build_safe_pickle() -> Vec<u8> {
928    let mut out = Vec::new();
929    out.extend_from_slice(&[0x80, 0x02]);
930
931    out.push(b'}');
932    out.push(b'(');
933
934    push_binunicode(&mut out, "layer.weight");
935    out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
936
937    out.push(b'(');
938
939    out.push(b'(');
940    push_binunicode(&mut out, "storage");
941    out.extend_from_slice(b"ctorch\nFloatStorage\n");
942    push_binunicode(&mut out, "0");
943    push_binunicode(&mut out, "cpu");
944    out.push(b'K');
945    out.push(4);
946    out.push(b't');
947    out.push(b'Q');
948
949    out.push(b'K');
950    out.push(0);
951
952    out.push(b'(');
953    out.push(b'K');
954    out.push(2);
955    out.push(b'K');
956    out.push(2);
957    out.push(b't');
958
959    out.push(b'(');
960    out.push(b'K');
961    out.push(2);
962    out.push(b'K');
963    out.push(1);
964    out.push(b't');
965
966    out.push(0x89);
967    out.push(b'N');
968
969    out.push(b't');
970    out.push(b'R');
971
972    out.push(b'u');
973    out.push(b'.');
974    out
975  }
976
977  fn build_unsafe_pickle() -> Vec<u8> {
978    let mut out = Vec::new();
979    out.extend_from_slice(&[0x80, 0x02]);
980    out.extend_from_slice(b"cos\nsystem\n");
981    out.push(b'(');
982    push_binunicode(&mut out, "echo hacked");
983    out.push(b't');
984    out.push(b'R');
985    out.push(b'.');
986    out
987  }
988
989  fn push_binunicode(out: &mut Vec<u8>, value: &str) {
990    out.push(b'X');
991    out.extend_from_slice(&(value.len() as u32).to_le_bytes());
992    out.extend_from_slice(value.as_bytes());
993  }
994}