1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8pub mod writer;
9
10pub use types::{
11 CheckpointMetadata, CheckpointSecurity, CheckpointTensorMetadata, ConvertError, DType, ExportFormat, ExportOptions,
12 ExportResult, LoadOptions, NumpyEndian, NumpyScalarData, ReconstructSource, Result, StorageRef, TensorArray,
13 TensorData, TensorManifest, TensorRef, Value,
14};
15
16use ndarray::{ArrayD, IxDyn};
17use serde::Deserialize;
18use std::collections::{BTreeMap, HashMap};
19use std::fs;
20use std::fs::File;
21use std::io::Read;
22use std::path::Path;
23use std::time::{SystemTime, UNIX_EPOCH};
24use zip::read::ZipArchive;
25
26use extract::{contiguous_stride, extract_state_dict_tensors, numel};
27use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
28use metadata::{collect_call_types, collect_constructor_types, project_root_metadata};
29use parser::parse_pickle;
30use types::ParsedCheckpoint;
31use writer::{write_metadata_yaml, write_safetensors};
32
33#[derive(Debug, Clone)]
34pub struct PtCheckpoint {
35 source_sha256: String,
36 warnings: Vec<String>,
37 metadata: CheckpointMetadata,
38 tensors: BTreeMap<String, TensorData>,
39 tensor_groups: BTreeMap<String, BTreeMap<String, TensorData>>,
40}
41
42impl PtCheckpoint {
43 pub fn load(path: impl AsRef<Path>, opts: LoadOptions) -> Result<Self> {
44 let path = path.as_ref();
45 let parsed = parse_checkpoint(path, &opts)?;
46 let metadata = build_checkpoint_metadata(
47 path.display().to_string(),
48 parsed.source_sha256.clone(),
49 &parsed.metadata,
50 &parsed.security,
51 &parsed.tensors,
52 "model.safetensors".to_string(),
53 );
54
55 Ok(Self {
56 source_sha256: parsed.source_sha256,
57 warnings: parsed.warnings,
58 metadata,
59 tensors: parsed.tensors,
60 tensor_groups: parsed.tensor_groups,
61 })
62 }
63
64 pub fn from_metadata(metadata: CheckpointMetadata, source: ReconstructSource) -> Result<Self> {
65 let tensors = match source {
66 ReconstructSource::WeightsFile(path) => read_safetensors_tensors(&path)?,
67 ReconstructSource::StateDict(values) => values,
68 };
69
70 validate_metadata_against_tensors(&metadata, &tensors)?;
71 let mut tensor_groups = BTreeMap::new();
72 tensor_groups.insert("root".to_string(), tensors.clone());
73
74 Ok(Self {
75 source_sha256: metadata.source_sha256.clone(),
76 warnings: Vec::new(),
77 metadata,
78 tensors,
79 tensor_groups,
80 })
81 }
82
83 pub fn metadata(&self) -> &CheckpointMetadata {
84 &self.metadata
85 }
86
87 pub fn source_sha256(&self) -> &str {
88 &self.source_sha256
89 }
90
91 pub fn warnings(&self) -> &[String] {
92 &self.warnings
93 }
94
95 pub fn tensor_count(&self) -> usize {
96 self.tensors.len()
97 }
98
99 #[cfg(feature = "pyo3")]
100 pub(crate) fn raw_tensors(&self) -> &BTreeMap<String, TensorData> {
101 &self.tensors
102 }
103
104 pub fn state_dict(&self) -> Result<BTreeMap<String, TensorArray>> {
105 let mut out = BTreeMap::new();
106 for (name, tensor) in &self.tensors {
107 out.insert(name.clone(), tensor_data_to_array(tensor)?);
108 }
109 Ok(out)
110 }
111
112 pub fn export(&self, out_dir: impl AsRef<Path>, opts: ExportOptions) -> Result<ExportResult> {
113 match opts.format {
114 ExportFormat::Safetensors => {}
115 }
116
117 let out_dir = out_dir.as_ref();
118 fs::create_dir_all(out_dir)?;
119
120 let is_multi_root = self.tensor_groups.len() > 1 || !self.tensor_groups.contains_key("root");
121 let mut weights_path = out_dir.join(&opts.weights_filename);
122 let mut weights_paths = BTreeMap::new();
123 if is_multi_root {
124 for (root_key, tensors) in &self.tensor_groups {
125 let file_name = with_root_key_suffix(&opts.weights_filename, root_key)?;
126 let path = out_dir.join(&file_name);
127 if path.exists() && !opts.overwrite {
128 return Err(ConvertError::InvalidStructure(format!(
129 "output already exists: {}",
130 path.display()
131 )));
132 }
133 write_safetensors(&path, tensors, &self.source_sha256)?;
134 weights_paths.insert(root_key.clone(), path);
135 }
136 if let Some(preferred) = weights_paths
137 .get("model")
138 .or_else(|| weights_paths.get("root"))
139 .or_else(|| weights_paths.values().next())
140 {
141 weights_path = preferred.clone();
142 }
143 } else {
144 if weights_path.exists() && !opts.overwrite {
145 return Err(ConvertError::InvalidStructure(format!(
146 "output already exists: {}",
147 weights_path.display()
148 )));
149 }
150 write_safetensors(&weights_path, &self.tensors, &self.source_sha256)?;
151 weights_paths.insert("root".to_string(), weights_path.clone());
152 }
153
154 let metadata_path = if opts.include_metadata {
155 let metadata_path = out_dir.join(&opts.metadata_filename);
156 if metadata_path.exists() && !opts.overwrite {
157 return Err(ConvertError::InvalidStructure(format!(
158 "output already exists: {}",
159 metadata_path.display()
160 )));
161 }
162
163 let mut metadata = self.metadata.clone();
164 if is_multi_root {
165 metadata.safetensors_file.clear();
166 metadata.safetensors_files = weights_paths
167 .iter()
168 .map(|(key, path)| (key.clone(), file_name_or_path(path)))
169 .collect();
170 metadata.tensors = TensorManifest::ByRoot(
171 self
172 .tensor_groups
173 .iter()
174 .map(|(key, tensors)| (key.clone(), tensor_summaries_for_metadata(tensors)))
175 .collect(),
176 );
177 } else {
178 metadata.safetensors_file = opts.weights_filename.to_string_lossy().into_owned();
179 metadata.safetensors_files.clear();
180 metadata.tensors = TensorManifest::List(tensor_summaries_for_metadata(&self.tensors));
181 }
182 metadata.created_at_unix = now_unix_secs();
183 metadata.tensor_count = self.tensors.len();
184 metadata.total_tensor_bytes = total_tensor_bytes(&self.tensors);
185 write_metadata_yaml(&metadata_path, &metadata)?;
186 Some(metadata_path)
187 } else {
188 None
189 };
190
191 Ok(ExportResult {
192 weights_path,
193 weights_paths,
194 metadata_path,
195 source_sha256: self.source_sha256.clone(),
196 tensor_count: self.tensors.len(),
197 total_tensor_bytes: total_tensor_bytes(&self.tensors),
198 })
199 }
200}
201
202pub(crate) fn parse_checkpoint(path: &Path, opts: &LoadOptions) -> Result<ParsedCheckpoint> {
203 let file = File::open(path)?;
204 let metadata = file.metadata()?;
205 if metadata.len() > opts.max_archive_bytes {
206 return Err(ConvertError::ResourceLimitExceeded(format!(
207 "archive is {} bytes, limit is {}",
208 metadata.len(),
209 opts.max_archive_bytes
210 )));
211 }
212
213 let mut magic = [0u8; 4];
214 let mut fh = File::open(path)?;
215 fh.read_exact(&mut magic)?;
216 if magic != [0x50, 0x4b, 0x03, 0x04] {
217 return Err(ConvertError::UnsupportedFormat(
218 "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
219 ));
220 }
221
222 let source_sha256 = sha256_file(path)?;
223 let mut archive = ZipArchive::new(file)?;
224 let data_pkl_name = find_data_pkl_name(&mut archive)?;
225 let prefix = data_pkl_name
226 .strip_suffix("data.pkl")
227 .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
228 .to_string();
229 let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
230 if pickle_bytes.len() > opts.max_pickle_bytes {
231 return Err(ConvertError::ResourceLimitExceeded(format!(
232 "data.pkl is {} bytes, limit is {}",
233 pickle_bytes.len(),
234 opts.max_pickle_bytes
235 )));
236 }
237
238 let root = parse_pickle(&pickle_bytes, opts)?;
239 let metadata = project_root_metadata(&root);
240 let objects = collect_constructor_types(&root);
241 let calls = collect_call_types(&root);
242 let tensor_ref_groups = extract_state_dict_tensors(&root, opts)?;
243 if tensor_ref_groups.is_empty() {
244 return Err(ConvertError::InvalidStructure(
245 "no tensors found in checkpoint state_dict".to_string(),
246 ));
247 }
248 let tensor_ref_count = tensor_ref_groups.values().map(|group| group.len()).sum::<usize>();
249 if tensor_ref_count > opts.max_tensor_count {
250 return Err(ConvertError::ResourceLimitExceeded(format!(
251 "tensor count {} exceeds limit {}",
252 tensor_ref_count, opts.max_tensor_count
253 )));
254 }
255
256 let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
257 for tensor_refs in tensor_ref_groups.values() {
258 for tensor in tensor_refs.values() {
259 let key = &tensor.storage.key;
260 if storage_blobs.contains_key(key) {
261 continue;
262 }
263 let blob = read_storage_blob(&mut archive, &prefix, key)?;
264 let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
265 if blob.len() < required_bytes {
266 return Err(ConvertError::InvalidStructure(format!(
267 "storage {} has {} bytes, expected at least {}",
268 key,
269 blob.len(),
270 required_bytes
271 )));
272 }
273 storage_blobs.insert(key.clone(), blob);
274 }
275 }
276
277 let mut tensors = BTreeMap::new();
278 let mut tensor_groups = BTreeMap::new();
279 for (root_key, tensor_refs) in tensor_ref_groups {
280 let mut group_tensors = BTreeMap::new();
281 for (name, tensor_ref) in tensor_refs {
282 if opts.strict_contiguous {
283 let expected = contiguous_stride(&tensor_ref.shape);
284 if expected != tensor_ref.stride {
285 return Err(ConvertError::InvalidStructure(format!(
286 "tensor {} has non-contiguous stride {:?}, expected {:?}",
287 name, tensor_ref.stride, expected
288 )));
289 }
290 }
291
292 let elem_size = tensor_ref.storage.dtype.elem_size();
293 let numel = numel(&tensor_ref.shape)?;
294 let start = tensor_ref
295 .offset_elems
296 .checked_mul(elem_size)
297 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
298 let byte_len = numel
299 .checked_mul(elem_size)
300 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
301 if byte_len > opts.max_tensor_bytes {
302 return Err(ConvertError::ResourceLimitExceeded(format!(
303 "tensor {} is {} bytes, limit is {}",
304 name, byte_len, opts.max_tensor_bytes
305 )));
306 }
307 let end = start
308 .checked_add(byte_len)
309 .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
310
311 let storage = storage_blobs
312 .get(&tensor_ref.storage.key)
313 .ok_or_else(|| ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key)))?;
314 if end > storage.len() {
315 return Err(ConvertError::InvalidStructure(format!(
316 "tensor {} slice [{}, {}) is out of storage bounds {}",
317 name,
318 start,
319 end,
320 storage.len()
321 )));
322 }
323
324 let raw = storage[start..end].to_vec();
325 let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
326 group_tensors.insert(name.clone(), normalized.clone());
327 let merged_name = merge_root_tensor_name(&root_key, &name);
328 tensors.insert(merged_name, normalized);
329 }
330 tensor_groups.insert(root_key, group_tensors);
331 }
332
333 Ok(ParsedCheckpoint {
334 source_sha256,
335 warnings: Vec::new(),
336 tensors,
337 tensor_groups,
338 metadata,
339 security: CheckpointSecurity { objects, calls },
340 })
341}
342
343fn build_checkpoint_metadata(
344 source_file: String,
345 source_sha256: String,
346 metadata: &serde_yaml::Value,
347 security: &CheckpointSecurity,
348 tensors: &BTreeMap<String, TensorData>,
349 safetensors_file: String,
350) -> CheckpointMetadata {
351 CheckpointMetadata {
352 format_version: 1,
353 source_file,
354 source_sha256,
355 safetensors_file,
356 safetensors_files: BTreeMap::new(),
357 created_at_unix: now_unix_secs(),
358 tensor_count: tensors.len(),
359 total_tensor_bytes: total_tensor_bytes(tensors),
360 metadata: metadata.clone(),
361 security: security.clone(),
362 tensors: TensorManifest::List(tensor_summaries_for_metadata(tensors)),
363 }
364}
365
366fn tensor_summaries_for_metadata(tensors: &BTreeMap<String, TensorData>) -> Vec<CheckpointTensorMetadata> {
367 tensors
368 .iter()
369 .map(|(name, tensor)| CheckpointTensorMetadata {
370 name: name.clone(),
371 dtype: tensor.dtype.as_safetensors().to_string(),
372 shape: tensor.shape.clone(),
373 sha256: sha256_hex(&tensor.bytes),
374 })
375 .collect()
376}
377
378fn total_tensor_bytes(tensors: &BTreeMap<String, TensorData>) -> usize {
379 tensors.values().map(|tensor| tensor.bytes.len()).sum()
380}
381
382fn file_name_or_path(path: &Path) -> String {
383 path
384 .file_name()
385 .map(|name| name.to_string_lossy().into_owned())
386 .unwrap_or_else(|| path.display().to_string())
387}
388
389fn merge_root_tensor_name(root: &str, name: &str) -> String {
390 if root == "root" || name == root || name.starts_with(&format!("{root}.")) {
391 name.to_string()
392 } else {
393 format!("{root}.{name}")
394 }
395}
396
397fn with_root_key_suffix(base: &Path, root_key: &str) -> Result<std::path::PathBuf> {
398 let ext = base
399 .extension()
400 .map(|value| value.to_string_lossy().into_owned())
401 .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no extension".to_string()))?;
402 let stem = base
403 .file_stem()
404 .map(|value| value.to_string_lossy().into_owned())
405 .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no stem".to_string()))?;
406 Ok(std::path::PathBuf::from(format!("{stem}.{root_key}.{ext}")))
407}
408
409fn now_unix_secs() -> u64 {
410 SystemTime::now()
411 .duration_since(UNIX_EPOCH)
412 .map(|value| value.as_secs())
413 .unwrap_or(0)
414}
415
416fn validate_metadata_against_tensors(
417 metadata: &CheckpointMetadata,
418 tensors: &BTreeMap<String, TensorData>,
419) -> Result<()> {
420 if metadata.tensor_count != tensors.len() {
421 return Err(ConvertError::InvalidStructure(format!(
422 "metadata tensor_count={} does not match loaded tensor count={}",
423 metadata.tensor_count,
424 tensors.len()
425 )));
426 }
427
428 let tensor_bytes = total_tensor_bytes(tensors);
429 if metadata.total_tensor_bytes != tensor_bytes {
430 return Err(ConvertError::InvalidStructure(format!(
431 "metadata total_tensor_bytes={} does not match loaded tensor bytes={}",
432 metadata.total_tensor_bytes, tensor_bytes
433 )));
434 }
435
436 let flat_manifest = match &metadata.tensors {
437 TensorManifest::List(items) => items.iter().map(|item| (item.name.clone(), item)).collect::<Vec<_>>(),
438 TensorManifest::ByRoot(groups) => groups
439 .iter()
440 .flat_map(|(root, items)| {
441 items
442 .iter()
443 .map(move |item| (merge_root_tensor_name(root, &item.name), item))
444 })
445 .collect::<Vec<_>>(),
446 };
447 for (name, item) in flat_manifest {
448 let Some(tensor) = tensors.get(&name) else {
449 return Err(ConvertError::InvalidStructure(format!(
450 "metadata references missing tensor {}",
451 name
452 )));
453 };
454 if item.dtype != tensor.dtype.as_safetensors() {
455 return Err(ConvertError::InvalidStructure(format!(
456 "metadata dtype mismatch for {}: {} != {}",
457 name,
458 item.dtype,
459 tensor.dtype.as_safetensors()
460 )));
461 }
462 if item.shape != tensor.shape {
463 return Err(ConvertError::InvalidStructure(format!(
464 "metadata shape mismatch for {}",
465 name
466 )));
467 }
468 if item.sha256 != sha256_hex(&tensor.bytes) {
469 return Err(ConvertError::InvalidStructure(format!(
470 "metadata sha256 mismatch for {}",
471 name
472 )));
473 }
474 }
475
476 Ok(())
477}
478
479#[derive(Debug, Deserialize)]
480struct SafetensorHeaderEntry {
481 dtype: String,
482 shape: Vec<usize>,
483 data_offsets: [usize; 2],
484}
485
486fn read_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, TensorData>> {
487 let file_bytes = fs::read(path)?;
488 if file_bytes.len() < 8 {
489 return Err(ConvertError::InvalidStructure(
490 "safetensors file is too short".to_string(),
491 ));
492 }
493
494 let header_len = u64::from_le_bytes(file_bytes[0..8].try_into().expect("8-byte header"));
495 let header_len = header_len as usize;
496 if file_bytes.len() < 8 + header_len {
497 return Err(ConvertError::InvalidStructure(
498 "safetensors header is truncated".to_string(),
499 ));
500 }
501
502 let header_bytes = &file_bytes[8..8 + header_len];
503 let data = &file_bytes[8 + header_len..];
504 let header: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(header_bytes)?;
505
506 let mut tensors = BTreeMap::new();
507 for (name, value) in header {
508 if name == "__metadata__" {
509 continue;
510 }
511 let entry: SafetensorHeaderEntry = serde_json::from_value(value)?;
512 let Some(dtype) = DType::from_safetensors(&entry.dtype) else {
513 return Err(ConvertError::InvalidStructure(format!(
514 "unsupported safetensors dtype {}",
515 entry.dtype
516 )));
517 };
518
519 let start = entry.data_offsets[0];
520 let end = entry.data_offsets[1];
521 if end < start || end > data.len() {
522 return Err(ConvertError::InvalidStructure(format!(
523 "invalid data_offsets for tensor {}",
524 name
525 )));
526 }
527
528 let expected_size = numel(&entry.shape)?
529 .checked_mul(dtype.elem_size())
530 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
531 if end - start != expected_size {
532 return Err(ConvertError::InvalidStructure(format!(
533 "tensor {} bytes mismatch: {} != {}",
534 name,
535 end - start,
536 expected_size
537 )));
538 }
539
540 tensors.insert(
541 name,
542 TensorData {
543 dtype,
544 shape: entry.shape,
545 bytes: data[start..end].to_vec(),
546 },
547 );
548 }
549
550 if tensors.is_empty() {
551 return Err(ConvertError::InvalidStructure(
552 "no tensors found in safetensors file".to_string(),
553 ));
554 }
555
556 Ok(tensors)
557}
558
559fn tensor_data_to_array(tensor: &TensorData) -> Result<TensorArray> {
560 let shape = IxDyn(&tensor.shape);
561 match tensor.dtype {
562 DType::F16 | DType::BF16 => Err(ConvertError::InvalidStructure(
563 "f16/bf16 should be normalized to f32 before state_dict()".to_string(),
564 )),
565 DType::F32 => {
566 let values = bytes_to_vec::<4, f32>(&tensor.bytes, f32::from_le_bytes)?;
567 Ok(TensorArray::F32(ArrayD::from_shape_vec(shape, values)?))
568 }
569 DType::F64 => {
570 let values = bytes_to_vec::<8, f64>(&tensor.bytes, f64::from_le_bytes)?;
571 Ok(TensorArray::F64(ArrayD::from_shape_vec(shape, values)?))
572 }
573 DType::I8 => {
574 let values = tensor.bytes.iter().map(|v| *v as i8).collect::<Vec<_>>();
575 Ok(TensorArray::I8(ArrayD::from_shape_vec(shape, values)?))
576 }
577 DType::I16 => {
578 let values = bytes_to_vec::<2, i16>(&tensor.bytes, i16::from_le_bytes)?;
579 Ok(TensorArray::I16(ArrayD::from_shape_vec(shape, values)?))
580 }
581 DType::I32 => {
582 let values = bytes_to_vec::<4, i32>(&tensor.bytes, i32::from_le_bytes)?;
583 Ok(TensorArray::I32(ArrayD::from_shape_vec(shape, values)?))
584 }
585 DType::I64 => {
586 let values = bytes_to_vec::<8, i64>(&tensor.bytes, i64::from_le_bytes)?;
587 Ok(TensorArray::I64(ArrayD::from_shape_vec(shape, values)?))
588 }
589 DType::U8 => Ok(TensorArray::U8(ArrayD::from_shape_vec(shape, tensor.bytes.clone())?)),
590 DType::Bool => {
591 let values = tensor.bytes.iter().map(|v| *v != 0).collect::<Vec<_>>();
592 Ok(TensorArray::Bool(ArrayD::from_shape_vec(shape, values)?))
593 }
594 }
595}
596
597fn bytes_to_vec<const N: usize, T>(bytes: &[u8], f: impl Fn([u8; N]) -> T) -> Result<Vec<T>> {
598 if bytes.len() % N != 0 {
599 return Err(ConvertError::InvalidStructure(format!(
600 "tensor bytes are not divisible by {}",
601 N
602 )));
603 }
604
605 Ok(
606 bytes
607 .chunks_exact(N)
608 .map(|chunk| {
609 let mut arr = [0u8; N];
610 arr.copy_from_slice(chunk);
611 f(arr)
612 })
613 .collect(),
614 )
615}
616
617fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
618 match dtype {
619 DType::F16 => Ok(TensorData {
620 dtype: DType::F32,
621 shape,
622 bytes: f16_bytes_to_f32_bytes(&bytes)?,
623 }),
624 DType::BF16 => Ok(TensorData {
625 dtype: DType::F32,
626 shape,
627 bytes: bf16_bytes_to_f32_bytes(&bytes)?,
628 }),
629 _ => Ok(TensorData { dtype, shape, bytes }),
630 }
631}
632
633fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
634 if input.len() % 2 != 0 {
635 return Err(ConvertError::InvalidStructure(
636 "f16 tensor bytes must be even-length".to_string(),
637 ));
638 }
639 let mut out = Vec::with_capacity(input.len() * 2);
640 for chunk in input.chunks_exact(2) {
641 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
642 let value = f16_bits_to_f32(bits);
643 out.extend_from_slice(&value.to_le_bytes());
644 }
645 Ok(out)
646}
647
648fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
649 if input.len() % 2 != 0 {
650 return Err(ConvertError::InvalidStructure(
651 "bf16 tensor bytes must be even-length".to_string(),
652 ));
653 }
654 let mut out = Vec::with_capacity(input.len() * 2);
655 for chunk in input.chunks_exact(2) {
656 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
657 let value = f32::from_bits((bits as u32) << 16);
658 out.extend_from_slice(&value.to_le_bytes());
659 }
660 Ok(out)
661}
662
663fn f16_bits_to_f32(bits: u16) -> f32 {
664 let sign = ((bits >> 15) & 0x1) as u32;
665 let exp = ((bits >> 10) & 0x1f) as u32;
666 let frac = (bits & 0x03ff) as u32;
667
668 let f32_bits = if exp == 0 {
669 if frac == 0 {
670 sign << 31
671 } else {
672 let mut mant = frac;
673 let mut e = -14i32;
674 while (mant & 0x0400) == 0 {
675 mant <<= 1;
676 e -= 1;
677 }
678 mant &= 0x03ff;
679 let exp32 = (e + 127) as u32;
680 (sign << 31) | (exp32 << 23) | (mant << 13)
681 }
682 } else if exp == 0x1f {
683 (sign << 31) | (0xff << 23) | (frac << 13)
684 } else {
685 let exp32 = (exp as i32 - 15 + 127) as u32;
686 (sign << 31) | (exp32 << 23) | (frac << 13)
687 };
688
689 f32::from_bits(f32_bits)
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use crate::metadata::{collect_call_types, collect_constructor_types, project_value_for_metadata};
696 use crate::types::Value;
697 use std::io::Write;
698 use tempfile::tempdir;
699 use zip::write::SimpleFileOptions;
700 use zip::ZipWriter;
701
702 #[test]
703 fn converts_simple_tensor_checkpoint() {
704 let tmp = tempdir().expect("tmp dir");
705 let pt_path = tmp.path().join("weights.pt");
706 write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
707
708 let out_dir = tmp.path().join("export");
709 let checkpoint = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect("checkpoint load should work");
710 let result = checkpoint
711 .export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&pt_path)))
712 .expect("export should work");
713
714 assert!(result.weights_path.exists());
715 assert!(result.metadata_path.as_ref().expect("metadata path").exists());
716 assert_eq!(result.tensor_count, 1);
717
718 let yaml = fs::read_to_string(result.metadata_path.expect("metadata path")).expect("yaml readable");
719 assert!(yaml.contains("layer.weight"));
720 assert!(yaml.contains("dtype: F32") || yaml.contains("dtype: 'F32'"));
721 assert!(yaml.contains("security:"));
722 assert!(yaml.contains("objects: []"));
723 assert!(yaml.contains("calls: []"));
724 }
725
726 #[test]
727 fn rejects_unsafe_global_reduce() {
728 let tmp = tempdir().expect("tmp dir");
729 let pt_path = tmp.path().join("unsafe.pt");
730 write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
731
732 let err = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect_err("unsafe pickle should fail");
733 let msg = err.to_string();
734 assert!(msg.contains("could not find a tensor state_dict"));
735 }
736
737 #[test]
738 fn projects_object_metadata_with_type_args_and_flattened_state() {
739 let value = Value::Object {
740 module: "ultralytics.nn.tasks".to_string(),
741 name: "DetectionModel".to_string(),
742 args: vec![Value::String("arg0".to_string()), Value::Int(42)],
743 state: Some(Box::new(Value::Dict(vec![(
744 Value::String("training".to_string()),
745 Value::Bool(false),
746 )]))),
747 };
748
749 let projected = project_value_for_metadata(&value);
750 let mapping = match projected {
751 serde_yaml::Value::Mapping(map) => map,
752 other => panic!("expected mapping, got {:?}", other),
753 };
754
755 let type_key = serde_yaml::Value::String("$type".to_string());
756 let class_key = serde_yaml::Value::String("$class".to_string());
757 let args_key = serde_yaml::Value::String("$args".to_string());
758 let training_key = serde_yaml::Value::String("training".to_string());
759
760 assert_eq!(
761 mapping.get(&type_key),
762 Some(&serde_yaml::Value::String("object".to_string()))
763 );
764 assert_eq!(
765 mapping.get(&class_key),
766 Some(&serde_yaml::Value::String(
767 "ultralytics.nn.tasks.DetectionModel".to_string()
768 ))
769 );
770 assert!(mapping.get(&args_key).is_some());
771 assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
772 }
773
774 #[test]
775 fn omits_empty_object_args() {
776 let value = Value::Object {
777 module: "a".to_string(),
778 name: "B".to_string(),
779 args: Vec::new(),
780 state: None,
781 };
782 let projected = project_value_for_metadata(&value);
783 let mapping = match projected {
784 serde_yaml::Value::Mapping(map) => map,
785 other => panic!("expected mapping, got {:?}", other),
786 };
787
788 let args_key = serde_yaml::Value::String("$args".to_string());
789 assert!(!mapping.contains_key(&args_key));
790 }
791
792 #[test]
793 fn collects_constructor_types_deduplicated_in_first_seen_order() {
794 let tree = Value::List(vec![
795 Value::Object {
796 module: "a".to_string(),
797 name: "One".to_string(),
798 args: Vec::new(),
799 state: None,
800 },
801 Value::Dict(vec![(
802 Value::String("nested".to_string()),
803 Value::Object {
804 module: "b".to_string(),
805 name: "Two".to_string(),
806 args: Vec::new(),
807 state: None,
808 },
809 )]),
810 Value::Object {
811 module: "a".to_string(),
812 name: "One".to_string(),
813 args: Vec::new(),
814 state: None,
815 },
816 ]);
817
818 let objects = collect_constructor_types(&tree);
819 assert_eq!(objects, vec!["a.One".to_string(), "b.Two".to_string()]);
820 }
821
822 #[test]
823 fn collects_call_types_deduplicated_in_first_seen_order() {
824 let tree = Value::List(vec![
825 Value::Call {
826 func: "a.fn".to_string(),
827 args: vec![Value::String("x".to_string())],
828 state: None,
829 },
830 Value::Object {
831 module: "m".to_string(),
832 name: "N".to_string(),
833 args: vec![Value::Call {
834 func: "b.fn".to_string(),
835 args: Vec::new(),
836 state: None,
837 }],
838 state: Some(Box::new(Value::Call {
839 func: "a.fn".to_string(),
840 args: Vec::new(),
841 state: None,
842 })),
843 },
844 ]);
845
846 let calls = collect_call_types(&tree);
847 assert_eq!(calls, vec!["a.fn".to_string(), "b.fn".to_string()]);
848 }
849
850 #[test]
851 fn projects_call_metadata() {
852 let value = Value::Call {
853 func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
854 args: vec![Value::String("x".to_string()), Value::Int(1)],
855 state: None,
856 };
857
858 let projected = project_value_for_metadata(&value);
859 let mapping = match projected {
860 serde_yaml::Value::Mapping(map) => map,
861 other => panic!("expected mapping, got {:?}", other),
862 };
863
864 let type_key = serde_yaml::Value::String("$type".to_string());
865 let func_key = serde_yaml::Value::String("$func".to_string());
866 let args_key = serde_yaml::Value::String("$args".to_string());
867 assert_eq!(
868 mapping.get(&type_key),
869 Some(&serde_yaml::Value::String("call".to_string()))
870 );
871 assert_eq!(
872 mapping.get(&func_key),
873 Some(&serde_yaml::Value::String(
874 "ultralytics.utils.IterableSimpleNamespace".to_string()
875 ))
876 );
877 assert!(matches!(
878 mapping.get(&args_key),
879 Some(serde_yaml::Value::Sequence(items)) if items.len() == 2
880 ));
881 }
882
883 #[test]
884 fn projects_call_metadata_with_state() {
885 let value = Value::Call {
886 func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
887 args: vec![Value::String("x".to_string())],
888 state: Some(Box::new(Value::Dict(vec![(
889 Value::String("k".to_string()),
890 Value::String("v".to_string()),
891 )]))),
892 };
893
894 let projected = project_value_for_metadata(&value);
895 let mapping = match projected {
896 serde_yaml::Value::Mapping(map) => map,
897 other => panic!("expected mapping, got {:?}", other),
898 };
899
900 let state_key = serde_yaml::Value::String("$state".to_string());
901 assert!(matches!(mapping.get(&state_key), Some(serde_yaml::Value::Mapping(_))));
902 }
903
904 fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
905 let file = File::create(path)?;
906 let mut zip = ZipWriter::new(file);
907 let options = SimpleFileOptions::default();
908
909 let data_pkl = if unsafe_payload {
910 build_unsafe_pickle()
911 } else {
912 build_safe_pickle()
913 };
914
915 zip.start_file("archive/data.pkl", options)?;
916 zip.write_all(&data_pkl)?;
917
918 let floats = [1.0f32, 2.0, 3.0, 4.0];
919 let mut raw = Vec::new();
920 for value in floats {
921 raw.extend_from_slice(&value.to_le_bytes());
922 }
923
924 zip.start_file("archive/data/0", options)?;
925 zip.write_all(&raw)?;
926 zip.finish()?;
927 Ok(())
928 }
929
930 fn build_safe_pickle() -> Vec<u8> {
931 let mut out = Vec::new();
932 out.extend_from_slice(&[0x80, 0x02]);
933
934 out.push(b'}');
935 out.push(b'(');
936
937 push_binunicode(&mut out, "layer.weight");
938 out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
939
940 out.push(b'(');
941
942 out.push(b'(');
943 push_binunicode(&mut out, "storage");
944 out.extend_from_slice(b"ctorch\nFloatStorage\n");
945 push_binunicode(&mut out, "0");
946 push_binunicode(&mut out, "cpu");
947 out.push(b'K');
948 out.push(4);
949 out.push(b't');
950 out.push(b'Q');
951
952 out.push(b'K');
953 out.push(0);
954
955 out.push(b'(');
956 out.push(b'K');
957 out.push(2);
958 out.push(b'K');
959 out.push(2);
960 out.push(b't');
961
962 out.push(b'(');
963 out.push(b'K');
964 out.push(2);
965 out.push(b'K');
966 out.push(1);
967 out.push(b't');
968
969 out.push(0x89);
970 out.push(b'N');
971
972 out.push(b't');
973 out.push(b'R');
974
975 out.push(b'u');
976 out.push(b'.');
977 out
978 }
979
980 fn build_unsafe_pickle() -> Vec<u8> {
981 let mut out = Vec::new();
982 out.extend_from_slice(&[0x80, 0x02]);
983 out.extend_from_slice(b"cos\nsystem\n");
984 out.push(b'(');
985 push_binunicode(&mut out, "echo hacked");
986 out.push(b't');
987 out.push(b'R');
988 out.push(b'.');
989 out
990 }
991
992 fn push_binunicode(out: &mut Vec<u8>, value: &str) {
993 out.push(b'X');
994 out.extend_from_slice(&(value.len() as u32).to_le_bytes());
995 out.extend_from_slice(value.as_bytes());
996 }
997}