1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8pub mod writer;
9
10pub use types::{
11 CheckpointMetadata, CheckpointSecurity, CheckpointTensorMetadata, ConvertError, DType, ExportFormat,
12 ExportOptions, ExportResult, TensorManifest,
13 LoadOptions, ReconstructSource, Result, StorageRef, TensorArray, TensorData, TensorRef, Value,
14};
15
16use ndarray::{ArrayD, IxDyn};
17use serde::Deserialize;
18use std::collections::{BTreeMap, HashMap};
19use std::fs;
20use std::fs::File;
21use std::io::Read;
22use std::path::Path;
23use std::time::{SystemTime, UNIX_EPOCH};
24use zip::read::ZipArchive;
25
26use extract::{contiguous_stride, extract_state_dict_tensors, numel};
27use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
28use metadata::{collect_call_types, collect_constructor_types, project_root_metadata};
29use parser::parse_pickle;
30use types::ParsedCheckpoint;
31use writer::{write_metadata_yaml, write_safetensors};
32
33#[derive(Debug, Clone)]
34pub struct PtCheckpoint {
35 source_sha256: String,
36 warnings: Vec<String>,
37 metadata: CheckpointMetadata,
38 tensors: BTreeMap<String, TensorData>,
39 tensor_groups: BTreeMap<String, BTreeMap<String, TensorData>>,
40}
41
42impl PtCheckpoint {
43 pub fn load(path: impl AsRef<Path>, opts: LoadOptions) -> Result<Self> {
44 let path = path.as_ref();
45 let parsed = parse_checkpoint(path, &opts)?;
46 let metadata = build_checkpoint_metadata(
47 path.display().to_string(),
48 parsed.source_sha256.clone(),
49 &parsed.metadata,
50 &parsed.security,
51 &parsed.tensors,
52 "model.safetensors".to_string(),
53 );
54
55 Ok(Self {
56 source_sha256: parsed.source_sha256,
57 warnings: parsed.warnings,
58 metadata,
59 tensors: parsed.tensors,
60 tensor_groups: parsed.tensor_groups,
61 })
62 }
63
64 pub fn from_metadata(metadata: CheckpointMetadata, source: ReconstructSource) -> Result<Self> {
65 let tensors = match source {
66 ReconstructSource::WeightsFile(path) => read_safetensors_tensors(&path)?,
67 ReconstructSource::StateDict(values) => values,
68 };
69
70 validate_metadata_against_tensors(&metadata, &tensors)?;
71 let mut tensor_groups = BTreeMap::new();
72 tensor_groups.insert("root".to_string(), tensors.clone());
73
74 Ok(Self {
75 source_sha256: metadata.source_sha256.clone(),
76 warnings: Vec::new(),
77 metadata,
78 tensors,
79 tensor_groups,
80 })
81 }
82
83 pub fn metadata(&self) -> &CheckpointMetadata {
84 &self.metadata
85 }
86
87 pub fn source_sha256(&self) -> &str {
88 &self.source_sha256
89 }
90
91 pub fn warnings(&self) -> &[String] {
92 &self.warnings
93 }
94
95 pub fn tensor_count(&self) -> usize {
96 self.tensors.len()
97 }
98
99 #[cfg(feature = "pyo3")]
100 pub(crate) fn raw_tensors(&self) -> &BTreeMap<String, TensorData> {
101 &self.tensors
102 }
103
104 pub fn state_dict(&self) -> Result<BTreeMap<String, TensorArray>> {
105 let mut out = BTreeMap::new();
106 for (name, tensor) in &self.tensors {
107 out.insert(name.clone(), tensor_data_to_array(tensor)?);
108 }
109 Ok(out)
110 }
111
112 pub fn export(&self, out_dir: impl AsRef<Path>, opts: ExportOptions) -> Result<ExportResult> {
113 match opts.format {
114 ExportFormat::Safetensors => {}
115 }
116
117 let out_dir = out_dir.as_ref();
118 fs::create_dir_all(out_dir)?;
119
120 let is_multi_root = self.tensor_groups.len() > 1 || !self.tensor_groups.contains_key("root");
121 let weights_path = out_dir.join(&opts.weights_filename);
122 let mut weights_paths = BTreeMap::new();
123 if is_multi_root {
124 for (root_key, tensors) in &self.tensor_groups {
125 let file_name = with_root_key_suffix(&opts.weights_filename, root_key)?;
126 let path = out_dir.join(&file_name);
127 if path.exists() && !opts.overwrite {
128 return Err(ConvertError::InvalidStructure(format!(
129 "output already exists: {}",
130 path.display()
131 )));
132 }
133 write_safetensors(&path, tensors, &self.source_sha256)?;
134 weights_paths.insert(root_key.clone(), path);
135 }
136 } else {
137 if weights_path.exists() && !opts.overwrite {
138 return Err(ConvertError::InvalidStructure(format!(
139 "output already exists: {}",
140 weights_path.display()
141 )));
142 }
143 write_safetensors(&weights_path, &self.tensors, &self.source_sha256)?;
144 weights_paths.insert("root".to_string(), weights_path.clone());
145 }
146
147 let metadata_path = if opts.include_metadata {
148 let metadata_path = out_dir.join(&opts.metadata_filename);
149 if metadata_path.exists() && !opts.overwrite {
150 return Err(ConvertError::InvalidStructure(format!(
151 "output already exists: {}",
152 metadata_path.display()
153 )));
154 }
155
156 let mut metadata = self.metadata.clone();
157 if is_multi_root {
158 metadata.safetensors_file.clear();
159 metadata.safetensors_files = weights_paths
160 .iter()
161 .map(|(key, path)| (key.clone(), file_name_or_path(path)))
162 .collect();
163 metadata.tensors = TensorManifest::ByRoot(
164 self
165 .tensor_groups
166 .iter()
167 .map(|(key, tensors)| (key.clone(), tensor_summaries_for_metadata(tensors)))
168 .collect(),
169 );
170 } else {
171 metadata.safetensors_file = opts.weights_filename.to_string_lossy().into_owned();
172 metadata.safetensors_files.clear();
173 metadata.tensors = TensorManifest::List(tensor_summaries_for_metadata(&self.tensors));
174 }
175 metadata.created_at_unix = now_unix_secs();
176 metadata.tensor_count = self.tensors.len();
177 metadata.total_tensor_bytes = total_tensor_bytes(&self.tensors);
178 write_metadata_yaml(&metadata_path, &metadata)?;
179 Some(metadata_path)
180 } else {
181 None
182 };
183
184 Ok(ExportResult {
185 weights_path,
186 weights_paths,
187 metadata_path,
188 source_sha256: self.source_sha256.clone(),
189 tensor_count: self.tensors.len(),
190 total_tensor_bytes: total_tensor_bytes(&self.tensors),
191 })
192 }
193}
194
195pub(crate) fn parse_checkpoint(path: &Path, opts: &LoadOptions) -> Result<ParsedCheckpoint> {
196 let file = File::open(path)?;
197 let metadata = file.metadata()?;
198 if metadata.len() > opts.max_archive_bytes {
199 return Err(ConvertError::ResourceLimitExceeded(format!(
200 "archive is {} bytes, limit is {}",
201 metadata.len(),
202 opts.max_archive_bytes
203 )));
204 }
205
206 let mut magic = [0u8; 4];
207 let mut fh = File::open(path)?;
208 fh.read_exact(&mut magic)?;
209 if magic != [0x50, 0x4b, 0x03, 0x04] {
210 return Err(ConvertError::UnsupportedFormat(
211 "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
212 ));
213 }
214
215 let source_sha256 = sha256_file(path)?;
216 let mut archive = ZipArchive::new(file)?;
217 let data_pkl_name = find_data_pkl_name(&mut archive)?;
218 let prefix = data_pkl_name
219 .strip_suffix("data.pkl")
220 .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
221 .to_string();
222 let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
223 if pickle_bytes.len() > opts.max_pickle_bytes {
224 return Err(ConvertError::ResourceLimitExceeded(format!(
225 "data.pkl is {} bytes, limit is {}",
226 pickle_bytes.len(),
227 opts.max_pickle_bytes
228 )));
229 }
230
231 let root = parse_pickle(&pickle_bytes, opts)?;
232 let metadata = project_root_metadata(&root);
233 let objects = collect_constructor_types(&root);
234 let calls = collect_call_types(&root);
235 let tensor_ref_groups = extract_state_dict_tensors(&root, opts)?;
236 if tensor_ref_groups.is_empty() {
237 return Err(ConvertError::InvalidStructure(
238 "no tensors found in checkpoint state_dict".to_string(),
239 ));
240 }
241 let tensor_ref_count = tensor_ref_groups.values().map(|group| group.len()).sum::<usize>();
242 if tensor_ref_count > opts.max_tensor_count {
243 return Err(ConvertError::ResourceLimitExceeded(format!(
244 "tensor count {} exceeds limit {}",
245 tensor_ref_count,
246 opts.max_tensor_count
247 )));
248 }
249
250 let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
251 for tensor_refs in tensor_ref_groups.values() {
252 for tensor in tensor_refs.values() {
253 let key = &tensor.storage.key;
254 if storage_blobs.contains_key(key) {
255 continue;
256 }
257 let blob = read_storage_blob(&mut archive, &prefix, key)?;
258 let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
259 if blob.len() < required_bytes {
260 return Err(ConvertError::InvalidStructure(format!(
261 "storage {} has {} bytes, expected at least {}",
262 key,
263 blob.len(),
264 required_bytes
265 )));
266 }
267 storage_blobs.insert(key.clone(), blob);
268 }
269 }
270
271 let mut tensors = BTreeMap::new();
272 let mut tensor_groups = BTreeMap::new();
273 for (root_key, tensor_refs) in tensor_ref_groups {
274 let mut group_tensors = BTreeMap::new();
275 for (name, tensor_ref) in tensor_refs {
276 if opts.strict_contiguous {
277 let expected = contiguous_stride(&tensor_ref.shape);
278 if expected != tensor_ref.stride {
279 return Err(ConvertError::InvalidStructure(format!(
280 "tensor {} has non-contiguous stride {:?}, expected {:?}",
281 name, tensor_ref.stride, expected
282 )));
283 }
284 }
285
286 let elem_size = tensor_ref.storage.dtype.elem_size();
287 let numel = numel(&tensor_ref.shape)?;
288 let start = tensor_ref
289 .offset_elems
290 .checked_mul(elem_size)
291 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
292 let byte_len = numel
293 .checked_mul(elem_size)
294 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
295 if byte_len > opts.max_tensor_bytes {
296 return Err(ConvertError::ResourceLimitExceeded(format!(
297 "tensor {} is {} bytes, limit is {}",
298 name, byte_len, opts.max_tensor_bytes
299 )));
300 }
301 let end = start
302 .checked_add(byte_len)
303 .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
304
305 let storage = storage_blobs
306 .get(&tensor_ref.storage.key)
307 .ok_or_else(|| ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key)))?;
308 if end > storage.len() {
309 return Err(ConvertError::InvalidStructure(format!(
310 "tensor {} slice [{}, {}) is out of storage bounds {}",
311 name,
312 start,
313 end,
314 storage.len()
315 )));
316 }
317
318 let raw = storage[start..end].to_vec();
319 let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
320 group_tensors.insert(name.clone(), normalized.clone());
321 let merged_name = merge_root_tensor_name(&root_key, &name);
322 tensors.insert(merged_name, normalized);
323 }
324 tensor_groups.insert(root_key, group_tensors);
325 }
326
327 Ok(ParsedCheckpoint {
328 source_sha256,
329 warnings: Vec::new(),
330 tensors,
331 tensor_groups,
332 metadata,
333 security: CheckpointSecurity { objects, calls },
334 })
335}
336
337fn build_checkpoint_metadata(
338 source_file: String,
339 source_sha256: String,
340 metadata: &serde_yaml::Value,
341 security: &CheckpointSecurity,
342 tensors: &BTreeMap<String, TensorData>,
343 safetensors_file: String,
344) -> CheckpointMetadata {
345 CheckpointMetadata {
346 format_version: 1,
347 source_file,
348 source_sha256,
349 safetensors_file,
350 safetensors_files: BTreeMap::new(),
351 created_at_unix: now_unix_secs(),
352 tensor_count: tensors.len(),
353 total_tensor_bytes: total_tensor_bytes(tensors),
354 metadata: metadata.clone(),
355 security: security.clone(),
356 tensors: TensorManifest::List(tensor_summaries_for_metadata(tensors)),
357 }
358}
359
360fn tensor_summaries_for_metadata(tensors: &BTreeMap<String, TensorData>) -> Vec<CheckpointTensorMetadata> {
361 tensors
362 .iter()
363 .map(|(name, tensor)| CheckpointTensorMetadata {
364 name: name.clone(),
365 dtype: tensor.dtype.as_safetensors().to_string(),
366 shape: tensor.shape.clone(),
367 sha256: sha256_hex(&tensor.bytes),
368 })
369 .collect()
370}
371
372fn total_tensor_bytes(tensors: &BTreeMap<String, TensorData>) -> usize {
373 tensors.values().map(|tensor| tensor.bytes.len()).sum()
374}
375
376fn file_name_or_path(path: &Path) -> String {
377 path
378 .file_name()
379 .map(|name| name.to_string_lossy().into_owned())
380 .unwrap_or_else(|| path.display().to_string())
381}
382
383fn merge_root_tensor_name(root: &str, name: &str) -> String {
384 if root == "root" || name == root || name.starts_with(&format!("{root}.")) {
385 name.to_string()
386 } else {
387 format!("{root}.{name}")
388 }
389}
390
391fn with_root_key_suffix(base: &Path, root_key: &str) -> Result<std::path::PathBuf> {
392 let ext = base
393 .extension()
394 .map(|value| value.to_string_lossy().into_owned())
395 .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no extension".to_string()))?;
396 let stem = base
397 .file_stem()
398 .map(|value| value.to_string_lossy().into_owned())
399 .ok_or_else(|| ConvertError::InvalidStructure("weights filename has no stem".to_string()))?;
400 Ok(std::path::PathBuf::from(format!("{stem}.{root_key}.{ext}")))
401}
402
403fn now_unix_secs() -> u64 {
404 SystemTime::now()
405 .duration_since(UNIX_EPOCH)
406 .map(|value| value.as_secs())
407 .unwrap_or(0)
408}
409
410fn validate_metadata_against_tensors(
411 metadata: &CheckpointMetadata,
412 tensors: &BTreeMap<String, TensorData>,
413) -> Result<()> {
414 if metadata.tensor_count != tensors.len() {
415 return Err(ConvertError::InvalidStructure(format!(
416 "metadata tensor_count={} does not match loaded tensor count={}",
417 metadata.tensor_count,
418 tensors.len()
419 )));
420 }
421
422 let tensor_bytes = total_tensor_bytes(tensors);
423 if metadata.total_tensor_bytes != tensor_bytes {
424 return Err(ConvertError::InvalidStructure(format!(
425 "metadata total_tensor_bytes={} does not match loaded tensor bytes={}",
426 metadata.total_tensor_bytes, tensor_bytes
427 )));
428 }
429
430 let flat_manifest = match &metadata.tensors {
431 TensorManifest::List(items) => items.iter().map(|item| (item.name.clone(), item)).collect::<Vec<_>>(),
432 TensorManifest::ByRoot(groups) => groups
433 .iter()
434 .flat_map(|(root, items)| {
435 items
436 .iter()
437 .map(move |item| (merge_root_tensor_name(root, &item.name), item))
438 })
439 .collect::<Vec<_>>(),
440 };
441 for (name, item) in flat_manifest {
442 let Some(tensor) = tensors.get(&name) else {
443 return Err(ConvertError::InvalidStructure(format!(
444 "metadata references missing tensor {}",
445 name
446 )));
447 };
448 if item.dtype != tensor.dtype.as_safetensors() {
449 return Err(ConvertError::InvalidStructure(format!(
450 "metadata dtype mismatch for {}: {} != {}",
451 name,
452 item.dtype,
453 tensor.dtype.as_safetensors()
454 )));
455 }
456 if item.shape != tensor.shape {
457 return Err(ConvertError::InvalidStructure(format!(
458 "metadata shape mismatch for {}",
459 name
460 )));
461 }
462 if item.sha256 != sha256_hex(&tensor.bytes) {
463 return Err(ConvertError::InvalidStructure(format!(
464 "metadata sha256 mismatch for {}",
465 name
466 )));
467 }
468 }
469
470 Ok(())
471}
472
473#[derive(Debug, Deserialize)]
474struct SafetensorHeaderEntry {
475 dtype: String,
476 shape: Vec<usize>,
477 data_offsets: [usize; 2],
478}
479
480fn read_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, TensorData>> {
481 let file_bytes = fs::read(path)?;
482 if file_bytes.len() < 8 {
483 return Err(ConvertError::InvalidStructure(
484 "safetensors file is too short".to_string(),
485 ));
486 }
487
488 let header_len = u64::from_le_bytes(file_bytes[0..8].try_into().expect("8-byte header"));
489 let header_len = header_len as usize;
490 if file_bytes.len() < 8 + header_len {
491 return Err(ConvertError::InvalidStructure(
492 "safetensors header is truncated".to_string(),
493 ));
494 }
495
496 let header_bytes = &file_bytes[8..8 + header_len];
497 let data = &file_bytes[8 + header_len..];
498 let header: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(header_bytes)?;
499
500 let mut tensors = BTreeMap::new();
501 for (name, value) in header {
502 if name == "__metadata__" {
503 continue;
504 }
505 let entry: SafetensorHeaderEntry = serde_json::from_value(value)?;
506 let Some(dtype) = DType::from_safetensors(&entry.dtype) else {
507 return Err(ConvertError::InvalidStructure(format!(
508 "unsupported safetensors dtype {}",
509 entry.dtype
510 )));
511 };
512
513 let start = entry.data_offsets[0];
514 let end = entry.data_offsets[1];
515 if end < start || end > data.len() {
516 return Err(ConvertError::InvalidStructure(format!(
517 "invalid data_offsets for tensor {}",
518 name
519 )));
520 }
521
522 let expected_size = numel(&entry.shape)?
523 .checked_mul(dtype.elem_size())
524 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
525 if end - start != expected_size {
526 return Err(ConvertError::InvalidStructure(format!(
527 "tensor {} bytes mismatch: {} != {}",
528 name,
529 end - start,
530 expected_size
531 )));
532 }
533
534 tensors.insert(
535 name,
536 TensorData {
537 dtype,
538 shape: entry.shape,
539 bytes: data[start..end].to_vec(),
540 },
541 );
542 }
543
544 if tensors.is_empty() {
545 return Err(ConvertError::InvalidStructure(
546 "no tensors found in safetensors file".to_string(),
547 ));
548 }
549
550 Ok(tensors)
551}
552
553fn tensor_data_to_array(tensor: &TensorData) -> Result<TensorArray> {
554 let shape = IxDyn(&tensor.shape);
555 match tensor.dtype {
556 DType::F16 | DType::BF16 => Err(ConvertError::InvalidStructure(
557 "f16/bf16 should be normalized to f32 before state_dict()".to_string(),
558 )),
559 DType::F32 => {
560 let values = bytes_to_vec::<4, f32>(&tensor.bytes, f32::from_le_bytes)?;
561 Ok(TensorArray::F32(ArrayD::from_shape_vec(shape, values)?))
562 }
563 DType::F64 => {
564 let values = bytes_to_vec::<8, f64>(&tensor.bytes, f64::from_le_bytes)?;
565 Ok(TensorArray::F64(ArrayD::from_shape_vec(shape, values)?))
566 }
567 DType::I8 => {
568 let values = tensor.bytes.iter().map(|v| *v as i8).collect::<Vec<_>>();
569 Ok(TensorArray::I8(ArrayD::from_shape_vec(shape, values)?))
570 }
571 DType::I16 => {
572 let values = bytes_to_vec::<2, i16>(&tensor.bytes, i16::from_le_bytes)?;
573 Ok(TensorArray::I16(ArrayD::from_shape_vec(shape, values)?))
574 }
575 DType::I32 => {
576 let values = bytes_to_vec::<4, i32>(&tensor.bytes, i32::from_le_bytes)?;
577 Ok(TensorArray::I32(ArrayD::from_shape_vec(shape, values)?))
578 }
579 DType::I64 => {
580 let values = bytes_to_vec::<8, i64>(&tensor.bytes, i64::from_le_bytes)?;
581 Ok(TensorArray::I64(ArrayD::from_shape_vec(shape, values)?))
582 }
583 DType::U8 => Ok(TensorArray::U8(ArrayD::from_shape_vec(shape, tensor.bytes.clone())?)),
584 DType::Bool => {
585 let values = tensor.bytes.iter().map(|v| *v != 0).collect::<Vec<_>>();
586 Ok(TensorArray::Bool(ArrayD::from_shape_vec(shape, values)?))
587 }
588 }
589}
590
591fn bytes_to_vec<const N: usize, T>(bytes: &[u8], f: impl Fn([u8; N]) -> T) -> Result<Vec<T>> {
592 if bytes.len() % N != 0 {
593 return Err(ConvertError::InvalidStructure(format!(
594 "tensor bytes are not divisible by {}",
595 N
596 )));
597 }
598
599 Ok(
600 bytes
601 .chunks_exact(N)
602 .map(|chunk| {
603 let mut arr = [0u8; N];
604 arr.copy_from_slice(chunk);
605 f(arr)
606 })
607 .collect(),
608 )
609}
610
611fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
612 match dtype {
613 DType::F16 => Ok(TensorData {
614 dtype: DType::F32,
615 shape,
616 bytes: f16_bytes_to_f32_bytes(&bytes)?,
617 }),
618 DType::BF16 => Ok(TensorData {
619 dtype: DType::F32,
620 shape,
621 bytes: bf16_bytes_to_f32_bytes(&bytes)?,
622 }),
623 _ => Ok(TensorData { dtype, shape, bytes }),
624 }
625}
626
627fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
628 if input.len() % 2 != 0 {
629 return Err(ConvertError::InvalidStructure(
630 "f16 tensor bytes must be even-length".to_string(),
631 ));
632 }
633 let mut out = Vec::with_capacity(input.len() * 2);
634 for chunk in input.chunks_exact(2) {
635 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
636 let value = f16_bits_to_f32(bits);
637 out.extend_from_slice(&value.to_le_bytes());
638 }
639 Ok(out)
640}
641
642fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
643 if input.len() % 2 != 0 {
644 return Err(ConvertError::InvalidStructure(
645 "bf16 tensor bytes must be even-length".to_string(),
646 ));
647 }
648 let mut out = Vec::with_capacity(input.len() * 2);
649 for chunk in input.chunks_exact(2) {
650 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
651 let value = f32::from_bits((bits as u32) << 16);
652 out.extend_from_slice(&value.to_le_bytes());
653 }
654 Ok(out)
655}
656
657fn f16_bits_to_f32(bits: u16) -> f32 {
658 let sign = ((bits >> 15) & 0x1) as u32;
659 let exp = ((bits >> 10) & 0x1f) as u32;
660 let frac = (bits & 0x03ff) as u32;
661
662 let f32_bits = if exp == 0 {
663 if frac == 0 {
664 sign << 31
665 } else {
666 let mut mant = frac;
667 let mut e = -14i32;
668 while (mant & 0x0400) == 0 {
669 mant <<= 1;
670 e -= 1;
671 }
672 mant &= 0x03ff;
673 let exp32 = (e + 127) as u32;
674 (sign << 31) | (exp32 << 23) | (mant << 13)
675 }
676 } else if exp == 0x1f {
677 (sign << 31) | (0xff << 23) | (frac << 13)
678 } else {
679 let exp32 = (exp as i32 - 15 + 127) as u32;
680 (sign << 31) | (exp32 << 23) | (frac << 13)
681 };
682
683 f32::from_bits(f32_bits)
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::metadata::{collect_call_types, collect_constructor_types, project_value_for_metadata};
690 use crate::types::Value;
691 use std::io::Write;
692 use tempfile::tempdir;
693 use zip::write::SimpleFileOptions;
694 use zip::ZipWriter;
695
696 #[test]
697 fn converts_simple_tensor_checkpoint() {
698 let tmp = tempdir().expect("tmp dir");
699 let pt_path = tmp.path().join("weights.pt");
700 write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
701
702 let out_dir = tmp.path().join("export");
703 let checkpoint = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect("checkpoint load should work");
704 let result = checkpoint
705 .export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&pt_path)))
706 .expect("export should work");
707
708 assert!(result.weights_path.exists());
709 assert!(result.metadata_path.as_ref().expect("metadata path").exists());
710 assert_eq!(result.tensor_count, 1);
711
712 let yaml = fs::read_to_string(result.metadata_path.expect("metadata path")).expect("yaml readable");
713 assert!(yaml.contains("layer.weight"));
714 assert!(yaml.contains("dtype: F32") || yaml.contains("dtype: 'F32'"));
715 assert!(yaml.contains("security:"));
716 assert!(yaml.contains("objects: []"));
717 assert!(yaml.contains("calls: []"));
718 }
719
720 #[test]
721 fn rejects_unsafe_global_reduce() {
722 let tmp = tempdir().expect("tmp dir");
723 let pt_path = tmp.path().join("unsafe.pt");
724 write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
725
726 let err = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect_err("unsafe pickle should fail");
727 let msg = err.to_string();
728 assert!(msg.contains("could not find a tensor state_dict"));
729 }
730
731 #[test]
732 fn projects_object_metadata_with_type_args_and_flattened_state() {
733 let value = Value::Object {
734 module: "ultralytics.nn.tasks".to_string(),
735 name: "DetectionModel".to_string(),
736 args: vec![
737 Value::String("arg0".to_string()),
738 Value::Int(42),
739 ],
740 state: Some(Box::new(Value::Dict(vec![(
741 Value::String("training".to_string()),
742 Value::Bool(false),
743 )]))),
744 };
745
746 let projected = project_value_for_metadata(&value);
747 let mapping = match projected {
748 serde_yaml::Value::Mapping(map) => map,
749 other => panic!("expected mapping, got {:?}", other),
750 };
751
752 let type_key = serde_yaml::Value::String("$type".to_string());
753 let class_key = serde_yaml::Value::String("$class".to_string());
754 let args_key = serde_yaml::Value::String("$args".to_string());
755 let training_key = serde_yaml::Value::String("training".to_string());
756
757 assert_eq!(
758 mapping.get(&type_key),
759 Some(&serde_yaml::Value::String("object".to_string()))
760 );
761 assert_eq!(
762 mapping.get(&class_key),
763 Some(&serde_yaml::Value::String(
764 "ultralytics.nn.tasks.DetectionModel".to_string()
765 ))
766 );
767 assert!(mapping.get(&args_key).is_some());
768 assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
769 }
770
771 #[test]
772 fn omits_empty_object_args() {
773 let value = Value::Object {
774 module: "a".to_string(),
775 name: "B".to_string(),
776 args: Vec::new(),
777 state: None,
778 };
779 let projected = project_value_for_metadata(&value);
780 let mapping = match projected {
781 serde_yaml::Value::Mapping(map) => map,
782 other => panic!("expected mapping, got {:?}", other),
783 };
784
785 let args_key = serde_yaml::Value::String("$args".to_string());
786 assert!(!mapping.contains_key(&args_key));
787 }
788
789 #[test]
790 fn collects_constructor_types_deduplicated_in_first_seen_order() {
791 let tree = Value::List(vec![
792 Value::Object {
793 module: "a".to_string(),
794 name: "One".to_string(),
795 args: Vec::new(),
796 state: None,
797 },
798 Value::Dict(vec![(
799 Value::String("nested".to_string()),
800 Value::Object {
801 module: "b".to_string(),
802 name: "Two".to_string(),
803 args: Vec::new(),
804 state: None,
805 },
806 )]),
807 Value::Object {
808 module: "a".to_string(),
809 name: "One".to_string(),
810 args: Vec::new(),
811 state: None,
812 },
813 ]);
814
815 let objects = collect_constructor_types(&tree);
816 assert_eq!(objects, vec!["a.One".to_string(), "b.Two".to_string()]);
817 }
818
819 #[test]
820 fn collects_call_types_deduplicated_in_first_seen_order() {
821 let tree = Value::List(vec![
822 Value::Call {
823 func: "a.fn".to_string(),
824 args: vec![Value::String("x".to_string())],
825 state: None,
826 },
827 Value::Object {
828 module: "m".to_string(),
829 name: "N".to_string(),
830 args: vec![Value::Call {
831 func: "b.fn".to_string(),
832 args: Vec::new(),
833 state: None,
834 }],
835 state: Some(Box::new(Value::Call {
836 func: "a.fn".to_string(),
837 args: Vec::new(),
838 state: None,
839 })),
840 },
841 ]);
842
843 let calls = collect_call_types(&tree);
844 assert_eq!(calls, vec!["a.fn".to_string(), "b.fn".to_string()]);
845 }
846
847 #[test]
848 fn projects_call_metadata() {
849 let value = Value::Call {
850 func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
851 args: vec![Value::String("x".to_string()), Value::Int(1)],
852 state: None,
853 };
854
855 let projected = project_value_for_metadata(&value);
856 let mapping = match projected {
857 serde_yaml::Value::Mapping(map) => map,
858 other => panic!("expected mapping, got {:?}", other),
859 };
860
861 let type_key = serde_yaml::Value::String("$type".to_string());
862 let func_key = serde_yaml::Value::String("$func".to_string());
863 let args_key = serde_yaml::Value::String("$args".to_string());
864 assert_eq!(
865 mapping.get(&type_key),
866 Some(&serde_yaml::Value::String("call".to_string()))
867 );
868 assert_eq!(
869 mapping.get(&func_key),
870 Some(&serde_yaml::Value::String(
871 "ultralytics.utils.IterableSimpleNamespace".to_string()
872 ))
873 );
874 assert!(matches!(
875 mapping.get(&args_key),
876 Some(serde_yaml::Value::Sequence(items)) if items.len() == 2
877 ));
878 }
879
880 #[test]
881 fn projects_call_metadata_with_state() {
882 let value = Value::Call {
883 func: "ultralytics.utils.IterableSimpleNamespace".to_string(),
884 args: vec![Value::String("x".to_string())],
885 state: Some(Box::new(Value::Dict(vec![(
886 Value::String("k".to_string()),
887 Value::String("v".to_string()),
888 )]))),
889 };
890
891 let projected = project_value_for_metadata(&value);
892 let mapping = match projected {
893 serde_yaml::Value::Mapping(map) => map,
894 other => panic!("expected mapping, got {:?}", other),
895 };
896
897 let state_key = serde_yaml::Value::String("$state".to_string());
898 assert!(matches!(mapping.get(&state_key), Some(serde_yaml::Value::Mapping(_))));
899 }
900
901 fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
902 let file = File::create(path)?;
903 let mut zip = ZipWriter::new(file);
904 let options = SimpleFileOptions::default();
905
906 let data_pkl = if unsafe_payload {
907 build_unsafe_pickle()
908 } else {
909 build_safe_pickle()
910 };
911
912 zip.start_file("archive/data.pkl", options)?;
913 zip.write_all(&data_pkl)?;
914
915 let floats = [1.0f32, 2.0, 3.0, 4.0];
916 let mut raw = Vec::new();
917 for value in floats {
918 raw.extend_from_slice(&value.to_le_bytes());
919 }
920
921 zip.start_file("archive/data/0", options)?;
922 zip.write_all(&raw)?;
923 zip.finish()?;
924 Ok(())
925 }
926
927 fn build_safe_pickle() -> Vec<u8> {
928 let mut out = Vec::new();
929 out.extend_from_slice(&[0x80, 0x02]);
930
931 out.push(b'}');
932 out.push(b'(');
933
934 push_binunicode(&mut out, "layer.weight");
935 out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
936
937 out.push(b'(');
938
939 out.push(b'(');
940 push_binunicode(&mut out, "storage");
941 out.extend_from_slice(b"ctorch\nFloatStorage\n");
942 push_binunicode(&mut out, "0");
943 push_binunicode(&mut out, "cpu");
944 out.push(b'K');
945 out.push(4);
946 out.push(b't');
947 out.push(b'Q');
948
949 out.push(b'K');
950 out.push(0);
951
952 out.push(b'(');
953 out.push(b'K');
954 out.push(2);
955 out.push(b'K');
956 out.push(2);
957 out.push(b't');
958
959 out.push(b'(');
960 out.push(b'K');
961 out.push(2);
962 out.push(b'K');
963 out.push(1);
964 out.push(b't');
965
966 out.push(0x89);
967 out.push(b'N');
968
969 out.push(b't');
970 out.push(b'R');
971
972 out.push(b'u');
973 out.push(b'.');
974 out
975 }
976
977 fn build_unsafe_pickle() -> Vec<u8> {
978 let mut out = Vec::new();
979 out.extend_from_slice(&[0x80, 0x02]);
980 out.extend_from_slice(b"cos\nsystem\n");
981 out.push(b'(');
982 push_binunicode(&mut out, "echo hacked");
983 out.push(b't');
984 out.push(b'R');
985 out.push(b'.');
986 out
987 }
988
989 fn push_binunicode(out: &mut Vec<u8>, value: &str) {
990 out.push(b'X');
991 out.extend_from_slice(&(value.len() as u32).to_le_bytes());
992 out.extend_from_slice(value.as_bytes());
993 }
994}