Skip to main content

pt_loader/
lib.rs

1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8pub mod writer;
9
10pub use types::{
11  CheckpointMetadata, CheckpointSecurity, CheckpointTensorMetadata, ConvertError, DType, ExportFormat, ExportOptions,
12  ExportResult, LoadOptions, NumpyEndian, NumpyScalarData, ReconstructSource, Result, StorageRef, TensorArray,
13  TensorData, TensorManifest, TensorRef, Value,
14};
15
16use ndarray::{ArrayD, IxDyn};
17use serde::Deserialize;
18use std::collections::{BTreeMap, HashMap};
19use std::fs;
20use std::fs::File;
21use std::io::Read;
22use std::path::Path;
23use std::time::{SystemTime, UNIX_EPOCH};
24use zip::read::ZipArchive;
25
26use extract::{contiguous_stride, extract_state_dict_tensors, numel};
27use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
28use metadata::{collect_call_types, collect_constructor_types, project_root_metadata};
29use parser::parse_pickle;
30use types::ParsedCheckpoint;
31use writer::{write_metadata_yaml, write_safetensors};
32
33#[derive(Debug, Clone)]
34pub struct PtCheckpoint {
35  source_sha256: String,
36  warnings: Vec<String>,
37  metadata: CheckpointMetadata,
38  tensors: BTreeMap<String, TensorData>,
39  tensor_groups: BTreeMap<String, BTreeMap<String, TensorData>>,
40}
41
42impl PtCheckpoint {
43  pub fn load(path: impl AsRef<Path>, opts: LoadOptions) -> Result<Self> {
44    let path = path.as_ref();
45    let parsed = parse_checkpoint(path, &opts)?;
46    let metadata = build_checkpoint_metadata(
47      path.display().to_string(),
48      parsed.source_sha256.clone(),
49      &parsed.metadata,
50      &parsed.security,
51      &parsed.tensors,
52      "model.safetensors".to_string(),
53    );
54
55    Ok(Self {
56      source_sha256: parsed.source_sha256,
57      warnings: parsed.warnings,
58      metadata,
59      tensors: parsed.tensors,
60      tensor_groups: parsed.tensor_groups,
61    })
62  }
63
64  pub fn from_metadata(metadata: CheckpointMetadata, source: ReconstructSource) -> Result<Self> {
65    let tensors = match source {
66      ReconstructSource::WeightsFile(path) => read_safetensors_tensors(&path)?,
67      ReconstructSource::StateDict(values) => values,
68    };
69
70    validate_metadata_against_tensors(&metadata, &tensors)?;
71    let mut tensor_groups = BTreeMap::new();
72    tensor_groups.insert("root".to_string(), tensors.clone());
73
74    Ok(Self {
75      source_sha256: metadata.source_sha256.clone(),
76      warnings: Vec::new(),
77      metadata,
78      tensors,
79      tensor_groups,
80    })
81  }
82
83  pub fn metadata(&self) -> &CheckpointMetadata {
84    &self.metadata
85  }
86
87  pub fn source_sha256(&self) -> &str {
88    &self.source_sha256
89  }
90
91  pub fn warnings(&self) -> &[String] {
92    &self.warnings
93  }
94
95  pub fn tensor_count(&self) -> usize {
96    self.tensors.len()
97  }
98
99  #[cfg(feature = "pyo3")]
100  pub(crate) fn raw_tensors(&self) -> &BTreeMap<String, TensorData> {
101    &self.tensors
102  }
103
104  pub fn state_dict(&self) -> Result<BTreeMap<String, TensorArray>> {
105    let mut out = BTreeMap::new();
106    for (name, tensor) in &self.tensors {
107      out.insert(name.clone(), tensor_data_to_array(tensor)?);
108    }
109    Ok(out)
110  }
111
112  pub fn export(&self, out_dir: impl AsRef<Path>, opts: ExportOptions) -> Result<ExportResult> {
113    match opts.format {
114      ExportFormat::Safetensors => {}
115    }
116
117    let out_dir = out_dir.as_ref();
118    fs::create_dir_all(out_dir)?;
119
120    let is_multi_root = self.tensor_groups.len() > 1 || !self.tensor_groups.contains_key("root");
121    let mut weights_path = out_dir.join(&opts.weights_filename);
122    let mut weights_paths = BTreeMap::new();
123    if is_multi_root {
124      for (root_key, tensors) in &self.tensor_groups {
125        let file_name = with_root_key_suffix(&opts.weights_filename, root_key)?;
126        let path = out_dir.join(&file_name);
127        if path.exists() && !opts.overwrite {
128          return Err(ConvertError::InvalidStructure(format!(
129            "output already exists: {}",
130            path.display()
131          )));
132        }
133        write_safetensors(&path, tensors, &self.source_sha256)?;
134        weights_paths.insert(root_key.clone(), path);
135      }
136      if let Some(preferred) = weights_paths
137        .get("model")
138        .or_else(|| weights_paths.get("root"))
139        .or_else(|| weights_paths.values().next())
140      {
141        weights_path = preferred.clone();
142      }
143    } else {
144      if weights_path.exists() && !opts.overwrite {
145        return Err(ConvertError::InvalidStructure(format!(
146          "output already exists: {}",
147          weights_path.display()
148        )));
149      }
150      write_safetensors(&weights_path, &self.tensors, &self.source_sha256)?;
151      weights_paths.insert("root".to_string(), weights_path.clone());
152    }
153
154    let metadata_path = if opts.include_metadata {
155      let metadata_path = out_dir.join(&opts.metadata_filename);
156      if metadata_path.exists() && !opts.overwrite {
157        return Err(ConvertError::InvalidStructure(format!(
158          "output already exists: {}",
159          metadata_path.display()
160        )));
161      }
162
163      let mut metadata = self.metadata.clone();
164      if is_multi_root {
165        metadata.safetensors_file.clear();
166        metadata.safetensors_files = weights_paths
167          .iter()
168          .map(|(key, path)| (key.clone(), file_name_or_path(path)))
169          .collect();
170        metadata.tensors = TensorManifest::ByRoot(
171          self
172            .tensor_groups
173            .iter()
174            .map(|(key, tensors)| (key.clone(), tensor_summaries_for_metadata(tensors)))
175            .collect(),
176        );
177      } else {
178        metadata.safetensors_file = opts.weights_filename.to_string_lossy().into_owned();
179        metadata.safetensors_files.clear();
180        metadata.tensors = TensorManifest::List(tensor_summaries_for_metadata(&self.tensors));
181      }
182      metadata.created_at_unix = now_unix_secs();
183      metadata.tensor_count = self.tensors.len();
184      metadata.total_tensor_bytes = total_tensor_bytes(&self.tensors);
185      write_metadata_yaml(&metadata_path, &metadata)?;
186      Some(metadata_path)
187    } else {
188      None
189    };
190
191    Ok(ExportResult {
192      weights_path,
193      weights_paths,
194      metadata_path,
195      source_sha256: self.source_sha256.clone(),
196      tensor_count: self.tensors.len(),
197      total_tensor_bytes: total_tensor_bytes(&self.tensors),
198    })
199  }
200}
201
202pub(crate) fn parse_checkpoint(path: &Path, opts: &LoadOptions) -> Result<ParsedCheckpoint> {
203  let file = File::open(path)?;
204  let metadata = file.metadata()?;
205  if metadata.len() > opts.max_archive_bytes {
206    return Err(ConvertError::ResourceLimitExceeded(format!(
207      "archive is {} bytes, limit is {}",
208      metadata.len(),
209      opts.max_archive_bytes
210    )));
211  }
212
213  let mut magic = [0u8; 4];
214  let mut fh = File::open(path)?;
215  fh.read_exact(&mut magic)?;
216  if magic != [0x50, 0x4b, 0x03, 0x04] {
217    return Err(ConvertError::UnsupportedFormat(
218      "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
219    ));
220  }
221
222  let source_sha256 = sha256_file(path)?;
223  let mut archive = ZipArchive::new(file)?;
224  let data_pkl_name = find_data_pkl_name(&mut archive)?;
225  let prefix = data_pkl_name
226    .strip_suffix("data.pkl")
227    .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
228    .to_string();
229  let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
230  if pickle_bytes.len() > opts.max_pickle_bytes {
231    return Err(ConvertError::ResourceLimitExceeded(format!(
232      "data.pkl is {} bytes, limit is {}",
233      pickle_bytes.len(),
234      opts.max_pickle_bytes
235    )));
236  }
237
238  let root = parse_pickle(&pickle_bytes, opts)?;
239  let metadata = project_root_metadata(&root);
240  let objects = collect_constructor_types(&root);
241  let calls = collect_call_types(&root);
242  let tensor_ref_groups = extract_state_dict_tensors(&root, opts)?;
243  if tensor_ref_groups.is_empty() {
244    return Err(ConvertError::InvalidStructure(
245      "no tensors found in checkpoint state_dict".to_string(),
246    ));
247  }
248  let tensor_ref_count = tensor_ref_groups.values().map(|group| group.len()).sum::<usize>();
249  if tensor_ref_count > opts.max_tensor_count {
250    return Err(ConvertError::ResourceLimitExceeded(format!(
251      "tensor count {} exceeds limit {}",
252      tensor_ref_count, opts.max_tensor_count
253    )));
254  }
255
256  let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
257  for tensor_refs in tensor_ref_groups.values() {
258    for tensor in tensor_refs.values() {
259      let key = &tensor.storage.key;
260      if storage_blobs.contains_key(key) {
261        continue;
262      }
263      let blob = read_storage_blob(&mut archive, &prefix, key)?;
264      let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
265      if blob.len() < required_bytes {
266        return Err(ConvertError::InvalidStructure(format!(
267          "storage {} has {} bytes, expected at least {}",
268          key,
269          blob.len(),
270          required_bytes
271        )));
272      }
273      storage_blobs.insert(key.clone(), blob);
274    }
275  }
276
277  let mut tensors = BTreeMap::new();
278  let mut tensor_groups = BTreeMap::new();
279  for (root_key, tensor_refs) in tensor_ref_groups {
280    let mut group_tensors = BTreeMap::new();
281    for (name, tensor_ref) in tensor_refs {
282      if opts.strict_contiguous {
283        let expected = contiguous_stride(&tensor_ref.shape);
284        if expected != tensor_ref.stride {
285          return Err(ConvertError::InvalidStructure(format!(
286            "tensor {} has non-contiguous stride {:?}, expected {:?}",
287            name, tensor_ref.stride, expected
288          )));
289        }
290      }
291
292      let elem_size = tensor_ref.storage.dtype.elem_size();
293      let numel = numel(&tensor_ref.shape)?;
294      let start = tensor_ref
295        .offset_elems
296        .checked_mul(elem_size)
297        .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
298      let byte_len = numel
299        .checked_mul(elem_size)
300        .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
301      if byte_len > opts.max_tensor_bytes {
302        return Err(ConvertError::ResourceLimitExceeded(format!(
303          "tensor {} is {} bytes, limit is {}",
304          name, byte_len, opts.max_tensor_bytes
305        )));
306      }
307      let end = start
308        .checked_add(byte_len)
309        .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
310
311      let storage = storage_blobs
312        .get(&tensor_ref.storage.key)
313        .ok_or_else(|| ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key)))?;
314      if end > storage.len() {
315        return Err(ConvertError::InvalidStructure(format!(
316          "tensor {} slice [{}, {}) is out of storage bounds {}",
317          name,
318          start,
319          end,
320          storage.len()
321        )));
322      }
323
324      let raw = storage[start..end].to_vec();
325      let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
326      group_tensors.insert(name.clone(), normalized.clone());
327      let merged_name = merge_root_tensor_name(&root_key, &name);
328      tensors.insert(merged_name, normalized);
329    }
330    tensor_groups.insert(root_key, group_tensors);
331  }
332
333  Ok(ParsedCheckpoint {
334    source_sha256,
335    warnings: Vec::new(),
336    tensors,
337    tensor_groups,
338    metadata,
339    security: CheckpointSecurity { objects, calls },
340  })
341}
342
343fn build_checkpoint_metadata(
344  source_file: String,
345  source_sha256: String,
346  metadata: &serde_yaml::Value,
347  security: &CheckpointSecurity,
348  tensors: &BTreeMap<String, TensorData>,
349  safetensors_file: String,
350) -> CheckpointMetadata {
351  CheckpointMetadata {
352    format_version: 1,
353    source_file,
354    source_sha256,
355    safetensors_file,
356    safetensors_files: BTreeMap::new(),
357    created_at_unix: now_unix_secs(),
358    tensor_count: tensors.len(),
359    total_tensor_bytes: total_tensor_bytes(tensors),
360    metadata: metadata.clone(),
361    security: security.clone(),
362    tensors: TensorManifest::List(tensor_summaries_for_metadata(tensors)),
363  }
364}
365
366fn tensor_summaries_for_metadata(tensors: &BTreeMap<String, TensorData>) -> Vec<CheckpointTensorMetadata> {
367  tensors
368    .iter()
369    .map(|(name, tensor)| CheckpointTensorMetadata {
370      name: name.clone(),
371      dtype: tensor.dtype.as_safetensors().to_string(),
372      shape: tensor.shape.clone(),
373      sha256: sha256_hex(&tensor.bytes),
374    })
375    .collect()
376}
377
378fn total_tensor_bytes(tensors: &BTreeMap<String, TensorData>) -> usize {
379  tensors.values().map(|tensor| tensor.bytes.len()).sum()
380}
381
382fn file_name_or_path(path: &Path) -> String {
383  path
384    .file_name()
385    .map(|name| name.to_string_lossy().into_owned())
386    .unwrap_or_else(|| path.display().to_string())
387}
388
389fn merge_root_tensor_name(root: &str, name: &str) -> String {
390  if root == "root" || name == root || name.starts_with(&format!("{root}.")) {
391    name.to_string()
392  } else {
393    format!("{root}.{name}")
394  }
395}
396
397fn with_root_key_suffix(base: &Path, root_key: &str) -> Result<std::path::PathBuf> {
398  let ext = base
399    .extension()
400    .map(|value| value.to_string_lossy().into_owned())
401    .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no extension".to_string()))?;
402  let stem = base
403    .file_stem()
404    .map(|value| value.to_string_lossy().into_owned())
405    .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no stem".to_string()))?;
406  Ok(std::path::PathBuf::from(format!("{stem}.{root_key}.{ext}")))
407}
408
409fn now_unix_secs() -> u64 {
410  SystemTime::now()
411    .duration_since(UNIX_EPOCH)
412    .map(|value| value.as_secs())
413    .unwrap_or(0)
414}
415
416fn validate_metadata_against_tensors(
417  metadata: &CheckpointMetadata,
418  tensors: &BTreeMap<String, TensorData>,
419) -> Result<()> {
420  if metadata.tensor_count != tensors.len() {
421    return Err(ConvertError::InvalidStructure(format!(
422      "metadata tensor_count={} does not match loaded tensor count={}",
423      metadata.tensor_count,
424      tensors.len()
425    )));
426  }
427
428  let tensor_bytes = total_tensor_bytes(tensors);
429  if metadata.total_tensor_bytes != tensor_bytes {
430    return Err(ConvertError::InvalidStructure(format!(
431      "metadata total_tensor_bytes={} does not match loaded tensor bytes={}",
432      metadata.total_tensor_bytes, tensor_bytes
433    )));
434  }
435
436  let flat_manifest = match &metadata.tensors {
437    TensorManifest::List(items) => items.iter().map(|item| (item.name.clone(), item)).collect::<Vec<_>>(),
438    TensorManifest::ByRoot(groups) => groups
439      .iter()
440      .flat_map(|(root, items)| {
441        items
442          .iter()
443          .map(move |item| (merge_root_tensor_name(root, &item.name), item))
444      })
445      .collect::<Vec<_>>(),
446  };
447  for (name, item) in flat_manifest {
448    let Some(tensor) = tensors.get(&name) else {
449      return Err(ConvertError::InvalidStructure(format!(
450        "metadata references missing tensor {}",
451        name
452      )));
453    };
454    if item.dtype != tensor.dtype.as_safetensors() {
455      return Err(ConvertError::InvalidStructure(format!(
456        "metadata dtype mismatch for {}: {} != {}",
457        name,
458        item.dtype,
459        tensor.dtype.as_safetensors()
460      )));
461    }
462    if item.shape != tensor.shape {
463      return Err(ConvertError::InvalidStructure(format!(
464        "metadata shape mismatch for {}",
465        name
466      )));
467    }
468    if item.sha256 != sha256_hex(&tensor.bytes) {
469      return Err(ConvertError::InvalidStructure(format!(
470        "metadata sha256 mismatch for {}",
471        name
472      )));
473    }
474  }
475
476  Ok(())
477}
478
479#[derive(Debug, Deserialize)]
480struct SafetensorHeaderEntry {
481  dtype: String,
482  shape: Vec<usize>,
483  data_offsets: [usize; 2],
484}
485
486fn read_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, TensorData>> {
487  let file_bytes = fs::read(path)?;
488  if file_bytes.len() < 8 {
489    return Err(ConvertError::InvalidStructure(
490      "safetensors file is too short".to_string(),
491    ));
492  }
493
494  let header_len = u64::from_le_bytes(file_bytes[0..8].try_into().expect("8-byte header"));
495  let header_len = header_len as usize;
496  if file_bytes.len() < 8 + header_len {
497    return Err(ConvertError::InvalidStructure(
498      "safetensors header is truncated".to_string(),
499    ));
500  }
501
502  let header_bytes = &file_bytes[8..8 + header_len];
503  let data = &file_bytes[8 + header_len..];
504  let header: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(header_bytes)?;
505
506  let mut tensors = BTreeMap::new();
507  for (name, value) in header {
508    if name == "__metadata__" {
509      continue;
510    }
511    let entry: SafetensorHeaderEntry = serde_json::from_value(value)?;
512    let Some(dtype) = DType::from_safetensors(&entry.dtype) else {
513      return Err(ConvertError::InvalidStructure(format!(
514        "unsupported safetensors dtype {}",
515        entry.dtype
516      )));
517    };
518
519    let start = entry.data_offsets[0];
520    let end = entry.data_offsets[1];
521    if end < start || end > data.len() {
522      return Err(ConvertError::InvalidStructure(format!(
523        "invalid data_offsets for tensor {}",
524        name
525      )));
526    }
527
528    let expected_size = numel(&entry.shape)?
529      .checked_mul(dtype.elem_size())
530      .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
531    if end - start != expected_size {
532      return Err(ConvertError::InvalidStructure(format!(
533        "tensor {} bytes mismatch: {} != {}",
534        name,
535        end - start,
536        expected_size
537      )));
538    }
539
540    tensors.insert(
541      name,
542      TensorData {
543        dtype,
544        shape: entry.shape,
545        bytes: data[start..end].to_vec(),
546      },
547    );
548  }
549
550  if tensors.is_empty() {
551    return Err(ConvertError::InvalidStructure(
552      "no tensors found in safetensors file".to_string(),
553    ));
554  }
555
556  Ok(tensors)
557}
558
559fn tensor_data_to_array(tensor: &TensorData) -> Result<TensorArray> {
560  let shape = IxDyn(&tensor.shape);
561  match tensor.dtype {
562    DType::F16 | DType::BF16 => Err(ConvertError::InvalidStructure(
563      "f16/bf16 should be normalized to f32 before state_dict()".to_string(),
564    )),
565    DType::F32 => {
566      let values = bytes_to_vec::<4, f32>(&tensor.bytes, f32::from_le_bytes)?;
567      Ok(TensorArray::F32(ArrayD::from_shape_vec(shape, values)?))
568    }
569    DType::F64 => {
570      let values = bytes_to_vec::<8, f64>(&tensor.bytes, f64::from_le_bytes)?;
571      Ok(TensorArray::F64(ArrayD::from_shape_vec(shape, values)?))
572    }
573    DType::I8 => {
574      let values = tensor.bytes.iter().map(|v| *v as i8).collect::<Vec<_>>();
575      Ok(TensorArray::I8(ArrayD::from_shape_vec(shape, values)?))
576    }
577    DType::I16 => {
578      let values = bytes_to_vec::<2, i16>(&tensor.bytes, i16::from_le_bytes)?;
579      Ok(TensorArray::I16(ArrayD::from_shape_vec(shape, values)?))
580    }
581    DType::I32 => {
582      let values = bytes_to_vec::<4, i32>(&tensor.bytes, i32::from_le_bytes)?;
583      Ok(TensorArray::I32(ArrayD::from_shape_vec(shape, values)?))
584    }
585    DType::I64 => {
586      let values = bytes_to_vec::<8, i64>(&tensor.bytes, i64::from_le_bytes)?;
587      Ok(TensorArray::I64(ArrayD::from_shape_vec(shape, values)?))
588    }
589    DType::U8 => Ok(TensorArray::U8(ArrayD::from_shape_vec(shape, tensor.bytes.clone())?)),
590    DType::Bool => {
591      let values = tensor.bytes.iter().map(|v| *v != 0).collect::<Vec<_>>();
592      Ok(TensorArray::Bool(ArrayD::from_shape_vec(shape, values)?))
593    }
594  }
595}
596
597fn bytes_to_vec<const N: usize, T>(bytes: &[u8], f: impl Fn([u8; N]) -> T) -> Result<Vec<T>> {
598  if bytes.len() % N != 0 {
599    return Err(ConvertError::InvalidStructure(format!(
600      "tensor bytes are not divisible by {}",
601      N
602    )));
603  }
604
605  Ok(
606    bytes
607      .chunks_exact(N)
608      .map(|chunk| {
609        let mut arr = [0u8; N];
610        arr.copy_from_slice(chunk);
611        f(arr)
612      })
613      .collect(),
614  )
615}
616
617fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
618  match dtype {
619    DType::F16 => Ok(TensorData {
620      dtype: DType::F32,
621      shape,
622      bytes: f16_bytes_to_f32_bytes(&bytes)?,
623    }),
624    DType::BF16 => Ok(TensorData {
625      dtype: DType::F32,
626      shape,
627      bytes: bf16_bytes_to_f32_bytes(&bytes)?,
628    }),
629    _ => Ok(TensorData { dtype, shape, bytes }),
630  }
631}
632
633fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
634  if input.len() % 2 != 0 {
635    return Err(ConvertError::InvalidStructure(
636      "f16 tensor bytes must be even-length".to_string(),
637    ));
638  }
639  let mut out = Vec::with_capacity(input.len() * 2);
640  for chunk in input.chunks_exact(2) {
641    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
642    let value = f16_bits_to_f32(bits);
643    out.extend_from_slice(&value.to_le_bytes());
644  }
645  Ok(out)
646}
647
648fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
649  if input.len() % 2 != 0 {
650    return Err(ConvertError::InvalidStructure(
651      "bf16 tensor bytes must be even-length".to_string(),
652    ));
653  }
654  let mut out = Vec::with_capacity(input.len() * 2);
655  for chunk in input.chunks_exact(2) {
656    let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
657    let value = f32::from_bits((bits as u32) << 16);
658    out.extend_from_slice(&value.to_le_bytes());
659  }
660  Ok(out)
661}
662
663fn f16_bits_to_f32(bits: u16) -> f32 {
664  let sign = ((bits >> 15) & 0x1) as u32;
665  let exp = ((bits >> 10) & 0x1f) as u32;
666  let frac = (bits & 0x03ff) as u32;
667
668  let f32_bits = if exp == 0 {
669    if frac == 0 {
670      sign << 31
671    } else {
672      let mut mant = frac;
673      let mut e = -14i32;
674      while (mant & 0x0400) == 0 {
675        mant <<= 1;
676        e -= 1;
677      }
678      mant &= 0x03ff;
679      let exp32 = (e + 127) as u32;
680      (sign << 31) | (exp32 << 23) | (mant << 13)
681    }
682  } else if exp == 0x1f {
683    (sign << 31) | (0xff << 23) | (frac << 13)
684  } else {
685    let exp32 = (exp as i32 - 15 + 127) as u32;
686    (sign << 31) | (exp32 << 23) | (frac << 13)
687  };
688
689  f32::from_bits(f32_bits)
690}
691
692#[cfg(test)]
693mod tests {
694  use super::*;
695  use crate::metadata::{collect_call_types, collect_constructor_types, project_value_for_metadata};
696  use crate::types::Value;
697  use std::io::Write;
698  use tempfile::tempdir;
699  use zip::write::SimpleFileOptions;
700  use zip::ZipWriter;
701
702  #[test]
703  fn converts_simple_tensor_checkpoint() {
704    let tmp = tempdir().expect("tmp dir");
705    let pt_path = tmp.path().join("weights.pt");
706    write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
707
708    let out_dir = tmp.path().join("export");
709    let checkpoint = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect("checkpoint load should work");
710    let result = checkpoint
711      .export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&pt_path)))
712      .expect("export should work");
713
714    assert!(result.weights_path.exists());
715    assert!(result.metadata_path.as_ref().expect("metadata path").exists());
716    assert_eq!(result.tensor_count, 1);
717
718    let yaml = fs::read_to_string(result.metadata_path.expect("metadata path")).expect("yaml readable");
719    assert!(yaml.contains("layer.weight"));
720    assert!(yaml.contains("dtype: F32") || yaml.contains("dtype: 'F32'"));
721    assert!(yaml.contains("security:"));
722    assert!(yaml.contains("objects: []"));
723    assert!(yaml.contains("calls: []"));
724  }
725
726  #[test]
727  fn rejects_unsafe_global_reduce() {
728    let tmp = tempdir().expect("tmp dir");
729    let pt_path = tmp.path().join("unsafe.pt");
730    write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
731
732    let err = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect_err("unsafe pickle should fail");
733    let msg = err.to_string();
734    assert!(msg.contains("could not find a tensor state_dict"));
735  }
736
737  #[test]
738  fn projects_object_metadata_with_type_args_and_flattened_state() {
739    let value = Value::Object {
740      module: "ultralytics.nn.tasks".to_string(),
741      name: "DetectionModel".to_string(),
742      args: vec![Value::String("arg0".to_string()), Value::Int(42)],
743      state: Some(Box::new(Value::Dict(vec![(
744        Value::String("training".to_string()),
745        Value::Bool(false),
746      )]))),
747    };
748
749    let projected = project_value_for_metadata(&value);
750    let mapping = match projected {
751      serde_yaml::Value::Mapping(map) => map,
752      other => panic!("expected mapping, got {:?}", other),
753    };
754
755    let type_key = serde_yaml::Value::String("$type".to_string());
756    let class_key = serde_yaml::Value::String("$class".to_string());
757    let args_key = serde_yaml::Value::String("$args".to_string());
758    let training_key = serde_yaml::Value::String("training".to_string());
759
760    assert_eq!(
761      mapping.get(&type_key),
762      Some(&serde_yaml::Value::String("object".to_string()))
763    );
764    assert_eq!(
765      mapping.get(&class_key),
766      Some(&serde_yaml::Value::String(
767        "ultralytics.nn.tasks.DetectionModel".to_string()
768      ))
769    );
770    assert!(mapping.get(&args_key).is_some());
771    assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
772  }
773
774  #[test]
775  fn omits_empty_object_args() {
776    let value = Value::Object {
777      module: "a".to_string(),
778      name: "B".to_string(),
779      args: Vec::new(),
780      state: None,
781    };
782    let projected = project_value_for_metadata(&value);
783    let mapping = match projected {
784      serde_yaml::Value::Mapping(map) => map,
785      other => panic!("expected mapping, got {:?}", other),
786    };
787
788    let args_key = serde_yaml::Value::String("$args".to_string());
789    assert!(!mapping.contains_key(&args_key));
790  }
791
792  #[test]
793  fn collects_constructor_types_deduplicated_in_first_seen_order() {
794    let tree = Value::List(vec![
795      Value::Object {
796        module: "a".to_string(),
797        name: "One".to_string(),
798        args: Vec::new(),
799        state: None,
800      },
801      Value::Dict(vec![(
802        Value::String("nested".to_string()),
803        Value::Object {
804          module: "b".to_string(),
805          name: "Two".to_string(),
806          args: Vec::new(),
807          state: None,
808        },
809      )]),
810      Value::Object {
811        module: "a".to_string(),
812        name: "One".to_string(),
813        args: Vec::new(),
814        state: None,
815      },
816    ]);
817
818    let objects = collect_constructor_types(&tree);
819    assert_eq!(objects, vec!["a.One".to_string(), "b.Two".to_string()]);
820  }
821
822  #[test]
823  fn collects_call_types_deduplicated_in_first_seen_order() {
824    let tree = Value::List(vec![
825      Value::Call {
826        func: "a.fn".to_string(),
827        args: vec![Value::String("x".to_string())],
828        state: None,
829      },
830      Value::Object {
831        module: "m".to_string(),
832        name: "N".to_string(),
833        args: vec![Value::Call {
834          func: "b.fn".to_string(),
835          args: Vec::new(),
836          state: None,
837        }],
838        state: Some(Box::new(Value::Call {
839          func: "a.fn".to_string(),
840          args: Vec::new(),
841          state: None,
842        })),
843      },
844    ]);
845
846    let calls = collect_call_types(&tree);
847    assert_eq!(calls, vec!["a.fn".to_string(), "b.fn".to_string()]);
848  }
849
850  #[test]
851  fn projects_call_metadata() {
852    let value = Value::Call {
853      func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
854      args: vec![Value::String("x".to_string()), Value::Int(1)],
855      state: None,
856    };
857
858    let projected = project_value_for_metadata(&value);
859    let mapping = match projected {
860      serde_yaml::Value::Mapping(map) => map,
861      other => panic!("expected mapping, got {:?}", other),
862    };
863
864    let type_key = serde_yaml::Value::String("$type".to_string());
865    let func_key = serde_yaml::Value::String("$func".to_string());
866    let args_key = serde_yaml::Value::String("$args".to_string());
867    assert_eq!(
868      mapping.get(&type_key),
869      Some(&serde_yaml::Value::String("call".to_string()))
870    );
871    assert_eq!(
872      mapping.get(&func_key),
873      Some(&serde_yaml::Value::String(
874        "ultralytics.utils.IterableSimpleNamespace".to_string()
875      ))
876    );
877    assert!(matches!(
878      mapping.get(&args_key),
879      Some(serde_yaml::Value::Sequence(items)) if items.len() == 2
880    ));
881  }
882
883  #[test]
884  fn projects_call_metadata_with_state() {
885    let value = Value::Call {
886      func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
887      args: vec![Value::String("x".to_string())],
888      state: Some(Box::new(Value::Dict(vec![(
889        Value::String("k".to_string()),
890        Value::String("v".to_string()),
891      )]))),
892    };
893
894    let projected = project_value_for_metadata(&value);
895    let mapping = match projected {
896      serde_yaml::Value::Mapping(map) => map,
897      other => panic!("expected mapping, got {:?}", other),
898    };
899
900    let state_key = serde_yaml::Value::String("$state".to_string());
901    assert!(matches!(mapping.get(&state_key), Some(serde_yaml::Value::Mapping(_))));
902  }
903
904  fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
905    let file = File::create(path)?;
906    let mut zip = ZipWriter::new(file);
907    let options = SimpleFileOptions::default();
908
909    let data_pkl = if unsafe_payload {
910      build_unsafe_pickle()
911    } else {
912      build_safe_pickle()
913    };
914
915    zip.start_file("archive/data.pkl", options)?;
916    zip.write_all(&data_pkl)?;
917
918    let floats = [1.0f32, 2.0, 3.0, 4.0];
919    let mut raw = Vec::new();
920    for value in floats {
921      raw.extend_from_slice(&value.to_le_bytes());
922    }
923
924    zip.start_file("archive/data/0", options)?;
925    zip.write_all(&raw)?;
926    zip.finish()?;
927    Ok(())
928  }
929
930  fn build_safe_pickle() -> Vec<u8> {
931    let mut out = Vec::new();
932    out.extend_from_slice(&[0x80, 0x02]);
933
934    out.push(b'}');
935    out.push(b'(');
936
937    push_binunicode(&mut out, "layer.weight");
938    out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
939
940    out.push(b'(');
941
942    out.push(b'(');
943    push_binunicode(&mut out, "storage");
944    out.extend_from_slice(b"ctorch\nFloatStorage\n");
945    push_binunicode(&mut out, "0");
946    push_binunicode(&mut out, "cpu");
947    out.push(b'K');
948    out.push(4);
949    out.push(b't');
950    out.push(b'Q');
951
952    out.push(b'K');
953    out.push(0);
954
955    out.push(b'(');
956    out.push(b'K');
957    out.push(2);
958    out.push(b'K');
959    out.push(2);
960    out.push(b't');
961
962    out.push(b'(');
963    out.push(b'K');
964    out.push(2);
965    out.push(b'K');
966    out.push(1);
967    out.push(b't');
968
969    out.push(0x89);
970    out.push(b'N');
971
972    out.push(b't');
973    out.push(b'R');
974
975    out.push(b'u');
976    out.push(b'.');
977    out
978  }
979
980  fn build_unsafe_pickle() -> Vec<u8> {
981    let mut out = Vec::new();
982    out.extend_from_slice(&[0x80, 0x02]);
983    out.extend_from_slice(b"cos\nsystem\n");
984    out.push(b'(');
985    push_binunicode(&mut out, "echo hacked");
986    out.push(b't');
987    out.push(b'R');
988    out.push(b'.');
989    out
990  }
991
992  fn push_binunicode(out: &mut Vec<u8>, value: &str) {
993    out.push(b'X');
994    out.extend_from_slice(&(value.len() as u32).to_le_bytes());
995    out.extend_from_slice(value.as_bytes());
996  }
997}