1mod extract;
2mod iohash;
3mod metadata;
4mod parser;
5#[cfg(feature = "pyo3")]
6mod python;
7mod types;
8pub mod writer;
9
10pub use types::{
11 CheckpointMetadata, CheckpointTensorMetadata, ConvertError, DType, ExportFormat, ExportOptions, ExportResult,
12 LoadOptions, ReconstructSource, Result, StorageRef, TensorArray, TensorData, TensorRef, Value,
13};
14
15use ndarray::{ArrayD, IxDyn};
16use serde::Deserialize;
17use std::collections::{BTreeMap, HashMap};
18use std::fs;
19use std::fs::File;
20use std::io::Read;
21use std::path::Path;
22use std::time::{SystemTime, UNIX_EPOCH};
23use zip::read::ZipArchive;
24
25use extract::{contiguous_stride, extract_state_dict_tensors, numel};
26use iohash::{find_data_pkl_name, read_storage_blob, read_zip_entry, sha256_file, sha256_hex};
27use metadata::{collect_constructor_types, project_root_metadata};
28use parser::parse_pickle;
29use types::ParsedCheckpoint;
30use writer::{write_metadata_yaml, write_safetensors};
31
32#[derive(Debug, Clone)]
33pub struct PtCheckpoint {
34 source_sha256: String,
35 warnings: Vec<String>,
36 metadata: CheckpointMetadata,
37 tensors: BTreeMap<String, TensorData>,
38}
39
40impl PtCheckpoint {
41 pub fn load(path: impl AsRef<Path>, opts: LoadOptions) -> Result<Self> {
42 let path = path.as_ref();
43 let parsed = parse_checkpoint(path, &opts)?;
44 let metadata = build_checkpoint_metadata(
45 path.display().to_string(),
46 parsed.source_sha256.clone(),
47 &parsed.metadata,
48 &parsed.objects,
49 &parsed.tensors,
50 "model.safetensors".to_string(),
51 );
52
53 Ok(Self {
54 source_sha256: parsed.source_sha256,
55 warnings: parsed.warnings,
56 metadata,
57 tensors: parsed.tensors,
58 })
59 }
60
61 pub fn from_metadata(metadata: CheckpointMetadata, source: ReconstructSource) -> Result<Self> {
62 let tensors = match source {
63 ReconstructSource::WeightsFile(path) => read_safetensors_tensors(&path)?,
64 ReconstructSource::StateDict(values) => values,
65 };
66
67 validate_metadata_against_tensors(&metadata, &tensors)?;
68
69 Ok(Self {
70 source_sha256: metadata.source_sha256.clone(),
71 warnings: Vec::new(),
72 metadata,
73 tensors,
74 })
75 }
76
77 pub fn metadata(&self) -> &CheckpointMetadata {
78 &self.metadata
79 }
80
81 pub fn source_sha256(&self) -> &str {
82 &self.source_sha256
83 }
84
85 pub fn warnings(&self) -> &[String] {
86 &self.warnings
87 }
88
89 pub fn tensor_count(&self) -> usize {
90 self.tensors.len()
91 }
92
93 #[cfg(feature = "pyo3")]
94 pub(crate) fn raw_tensors(&self) -> &BTreeMap<String, TensorData> {
95 &self.tensors
96 }
97
98 pub fn state_dict(&self) -> Result<BTreeMap<String, TensorArray>> {
99 let mut out = BTreeMap::new();
100 for (name, tensor) in &self.tensors {
101 out.insert(name.clone(), tensor_data_to_array(tensor)?);
102 }
103 Ok(out)
104 }
105
106 pub fn export(&self, out_dir: impl AsRef<Path>, opts: ExportOptions) -> Result<ExportResult> {
107 match opts.format {
108 ExportFormat::Safetensors => {}
109 }
110
111 let out_dir = out_dir.as_ref();
112 fs::create_dir_all(out_dir)?;
113
114 let weights_path = out_dir.join(&opts.weights_filename);
115 if weights_path.exists() && !opts.overwrite {
116 return Err(ConvertError::InvalidStructure(format!(
117 "output already exists: {}",
118 weights_path.display()
119 )));
120 }
121
122 write_safetensors(&weights_path, &self.tensors, &self.source_sha256)?;
123
124 let metadata_path = if opts.include_metadata {
125 let metadata_path = out_dir.join(&opts.metadata_filename);
126 if metadata_path.exists() && !opts.overwrite {
127 return Err(ConvertError::InvalidStructure(format!(
128 "output already exists: {}",
129 metadata_path.display()
130 )));
131 }
132
133 let mut metadata = self.metadata.clone();
134 metadata.safetensors_file = opts.weights_filename.to_string_lossy().into_owned();
135 metadata.created_at_unix = now_unix_secs();
136 metadata.tensor_count = self.tensors.len();
137 metadata.total_tensor_bytes = total_tensor_bytes(&self.tensors);
138 metadata.tensors = tensor_summaries_for_metadata(&self.tensors);
139 write_metadata_yaml(&metadata_path, &metadata)?;
140 Some(metadata_path)
141 } else {
142 None
143 };
144
145 Ok(ExportResult {
146 weights_path,
147 metadata_path,
148 source_sha256: self.source_sha256.clone(),
149 tensor_count: self.tensors.len(),
150 total_tensor_bytes: total_tensor_bytes(&self.tensors),
151 })
152 }
153}
154
155pub(crate) fn parse_checkpoint(path: &Path, opts: &LoadOptions) -> Result<ParsedCheckpoint> {
156 let file = File::open(path)?;
157 let metadata = file.metadata()?;
158 if metadata.len() > opts.max_archive_bytes {
159 return Err(ConvertError::ResourceLimitExceeded(format!(
160 "archive is {} bytes, limit is {}",
161 metadata.len(),
162 opts.max_archive_bytes
163 )));
164 }
165
166 let mut magic = [0u8; 4];
167 let mut fh = File::open(path)?;
168 fh.read_exact(&mut magic)?;
169 if magic != [0x50, 0x4b, 0x03, 0x04] {
170 return Err(ConvertError::UnsupportedFormat(
171 "only torch zip checkpoints are supported (legacy raw-pickle .pt is rejected)".to_string(),
172 ));
173 }
174
175 let source_sha256 = sha256_file(path)?;
176 let mut archive = ZipArchive::new(file)?;
177 let data_pkl_name = find_data_pkl_name(&mut archive)?;
178 let prefix = data_pkl_name
179 .strip_suffix("data.pkl")
180 .ok_or_else(|| ConvertError::InvalidStructure("invalid data.pkl entry name".to_string()))?
181 .to_string();
182 let pickle_bytes = read_zip_entry(&mut archive, &data_pkl_name)?;
183 if pickle_bytes.len() > opts.max_pickle_bytes {
184 return Err(ConvertError::ResourceLimitExceeded(format!(
185 "data.pkl is {} bytes, limit is {}",
186 pickle_bytes.len(),
187 opts.max_pickle_bytes
188 )));
189 }
190
191 let root = parse_pickle(&pickle_bytes, opts)?;
192 let metadata = project_root_metadata(&root);
193 let objects = collect_constructor_types(&root);
194 let tensor_refs = extract_state_dict_tensors(&root)?;
195 if tensor_refs.is_empty() {
196 return Err(ConvertError::InvalidStructure(
197 "no tensors found in checkpoint state_dict".to_string(),
198 ));
199 }
200 if tensor_refs.len() > opts.max_tensor_count {
201 return Err(ConvertError::ResourceLimitExceeded(format!(
202 "tensor count {} exceeds limit {}",
203 tensor_refs.len(),
204 opts.max_tensor_count
205 )));
206 }
207
208 let mut storage_blobs: HashMap<String, Vec<u8>> = HashMap::new();
209 for tensor in tensor_refs.values() {
210 let key = &tensor.storage.key;
211 if storage_blobs.contains_key(key) {
212 continue;
213 }
214 let blob = read_storage_blob(&mut archive, &prefix, key)?;
215 let required_bytes = tensor.storage.size_elems * tensor.storage.dtype.elem_size();
216 if blob.len() < required_bytes {
217 return Err(ConvertError::InvalidStructure(format!(
218 "storage {} has {} bytes, expected at least {}",
219 key,
220 blob.len(),
221 required_bytes
222 )));
223 }
224 storage_blobs.insert(key.clone(), blob);
225 }
226
227 let mut tensors = BTreeMap::new();
228 for (name, tensor_ref) in tensor_refs {
229 if opts.strict_contiguous {
230 let expected = contiguous_stride(&tensor_ref.shape);
231 if expected != tensor_ref.stride {
232 return Err(ConvertError::InvalidStructure(format!(
233 "tensor {} has non-contiguous stride {:?}, expected {:?}",
234 name, tensor_ref.stride, expected
235 )));
236 }
237 }
238
239 let elem_size = tensor_ref.storage.dtype.elem_size();
240 let numel = numel(&tensor_ref.shape)?;
241 let start = tensor_ref
242 .offset_elems
243 .checked_mul(elem_size)
244 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte offset overflow".to_string()))?;
245 let byte_len = numel
246 .checked_mul(elem_size)
247 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
248 if byte_len > opts.max_tensor_bytes {
249 return Err(ConvertError::ResourceLimitExceeded(format!(
250 "tensor {} is {} bytes, limit is {}",
251 name, byte_len, opts.max_tensor_bytes
252 )));
253 }
254 let end = start
255 .checked_add(byte_len)
256 .ok_or_else(|| ConvertError::InvalidStructure("tensor slice overflow".to_string()))?;
257
258 let storage = storage_blobs
259 .get(&tensor_ref.storage.key)
260 .ok_or_else(|| ConvertError::InvalidStructure(format!("missing storage blob {}", tensor_ref.storage.key)))?;
261 if end > storage.len() {
262 return Err(ConvertError::InvalidStructure(format!(
263 "tensor {} slice [{}, {}) is out of storage bounds {}",
264 name,
265 start,
266 end,
267 storage.len()
268 )));
269 }
270
271 let raw = storage[start..end].to_vec();
272 let normalized = normalize_tensor_dtype(tensor_ref.storage.dtype, tensor_ref.shape, raw)?;
273 tensors.insert(name, normalized);
274 }
275
276 Ok(ParsedCheckpoint {
277 source_sha256,
278 warnings: Vec::new(),
279 tensors,
280 metadata,
281 objects,
282 })
283}
284
285fn build_checkpoint_metadata(
286 source_file: String,
287 source_sha256: String,
288 metadata: &serde_yaml::Value,
289 objects: &[String],
290 tensors: &BTreeMap<String, TensorData>,
291 safetensors_file: String,
292) -> CheckpointMetadata {
293 CheckpointMetadata {
294 format_version: 1,
295 source_file,
296 source_sha256,
297 safetensors_file,
298 created_at_unix: now_unix_secs(),
299 tensor_count: tensors.len(),
300 total_tensor_bytes: total_tensor_bytes(tensors),
301 metadata: metadata.clone(),
302 objects: objects.to_vec(),
303 tensors: tensor_summaries_for_metadata(tensors),
304 }
305}
306
307fn tensor_summaries_for_metadata(tensors: &BTreeMap<String, TensorData>) -> Vec<CheckpointTensorMetadata> {
308 tensors
309 .iter()
310 .map(|(name, tensor)| CheckpointTensorMetadata {
311 name: name.clone(),
312 dtype: tensor.dtype.as_safetensors().to_string(),
313 shape: tensor.shape.clone(),
314 sha256: sha256_hex(&tensor.bytes),
315 })
316 .collect()
317}
318
319fn total_tensor_bytes(tensors: &BTreeMap<String, TensorData>) -> usize {
320 tensors.values().map(|tensor| tensor.bytes.len()).sum()
321}
322
323fn now_unix_secs() -> u64 {
324 SystemTime::now()
325 .duration_since(UNIX_EPOCH)
326 .map(|value| value.as_secs())
327 .unwrap_or(0)
328}
329
330fn validate_metadata_against_tensors(
331 metadata: &CheckpointMetadata,
332 tensors: &BTreeMap<String, TensorData>,
333) -> Result<()> {
334 if metadata.tensor_count != tensors.len() {
335 return Err(ConvertError::InvalidStructure(format!(
336 "metadata tensor_count={} does not match loaded tensor count={}",
337 metadata.tensor_count,
338 tensors.len()
339 )));
340 }
341
342 let tensor_bytes = total_tensor_bytes(tensors);
343 if metadata.total_tensor_bytes != tensor_bytes {
344 return Err(ConvertError::InvalidStructure(format!(
345 "metadata total_tensor_bytes={} does not match loaded tensor bytes={}",
346 metadata.total_tensor_bytes, tensor_bytes
347 )));
348 }
349
350 for item in &metadata.tensors {
351 let Some(tensor) = tensors.get(&item.name) else {
352 return Err(ConvertError::InvalidStructure(format!(
353 "metadata references missing tensor {}",
354 item.name
355 )));
356 };
357 if item.dtype != tensor.dtype.as_safetensors() {
358 return Err(ConvertError::InvalidStructure(format!(
359 "metadata dtype mismatch for {}: {} != {}",
360 item.name,
361 item.dtype,
362 tensor.dtype.as_safetensors()
363 )));
364 }
365 if item.shape != tensor.shape {
366 return Err(ConvertError::InvalidStructure(format!(
367 "metadata shape mismatch for {}",
368 item.name
369 )));
370 }
371 if item.sha256 != sha256_hex(&tensor.bytes) {
372 return Err(ConvertError::InvalidStructure(format!(
373 "metadata sha256 mismatch for {}",
374 item.name
375 )));
376 }
377 }
378
379 Ok(())
380}
381
382#[derive(Debug, Deserialize)]
383struct SafetensorHeaderEntry {
384 dtype: String,
385 shape: Vec<usize>,
386 data_offsets: [usize; 2],
387}
388
389fn read_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, TensorData>> {
390 let file_bytes = fs::read(path)?;
391 if file_bytes.len() < 8 {
392 return Err(ConvertError::InvalidStructure(
393 "safetensors file is too short".to_string(),
394 ));
395 }
396
397 let header_len = u64::from_le_bytes(file_bytes[0..8].try_into().expect("8-byte header"));
398 let header_len = header_len as usize;
399 if file_bytes.len() < 8 + header_len {
400 return Err(ConvertError::InvalidStructure(
401 "safetensors header is truncated".to_string(),
402 ));
403 }
404
405 let header_bytes = &file_bytes[8..8 + header_len];
406 let data = &file_bytes[8 + header_len..];
407 let header: serde_json::Map<String, serde_json::Value> = serde_json::from_slice(header_bytes)?;
408
409 let mut tensors = BTreeMap::new();
410 for (name, value) in header {
411 if name == "__metadata__" {
412 continue;
413 }
414 let entry: SafetensorHeaderEntry = serde_json::from_value(value)?;
415 let Some(dtype) = DType::from_safetensors(&entry.dtype) else {
416 return Err(ConvertError::InvalidStructure(format!(
417 "unsupported safetensors dtype {}",
418 entry.dtype
419 )));
420 };
421
422 let start = entry.data_offsets[0];
423 let end = entry.data_offsets[1];
424 if end < start || end > data.len() {
425 return Err(ConvertError::InvalidStructure(format!(
426 "invalid data_offsets for tensor {}",
427 name
428 )));
429 }
430
431 let expected_size = numel(&entry.shape)?
432 .checked_mul(dtype.elem_size())
433 .ok_or_else(|| ConvertError::InvalidStructure("tensor byte length overflow".to_string()))?;
434 if end - start != expected_size {
435 return Err(ConvertError::InvalidStructure(format!(
436 "tensor {} bytes mismatch: {} != {}",
437 name,
438 end - start,
439 expected_size
440 )));
441 }
442
443 tensors.insert(
444 name,
445 TensorData {
446 dtype,
447 shape: entry.shape,
448 bytes: data[start..end].to_vec(),
449 },
450 );
451 }
452
453 if tensors.is_empty() {
454 return Err(ConvertError::InvalidStructure(
455 "no tensors found in safetensors file".to_string(),
456 ));
457 }
458
459 Ok(tensors)
460}
461
462fn tensor_data_to_array(tensor: &TensorData) -> Result<TensorArray> {
463 let shape = IxDyn(&tensor.shape);
464 match tensor.dtype {
465 DType::F16 | DType::BF16 => Err(ConvertError::InvalidStructure(
466 "f16/bf16 should be normalized to f32 before state_dict()".to_string(),
467 )),
468 DType::F32 => {
469 let values = bytes_to_vec::<4, f32>(&tensor.bytes, f32::from_le_bytes)?;
470 Ok(TensorArray::F32(ArrayD::from_shape_vec(shape, values)?))
471 }
472 DType::F64 => {
473 let values = bytes_to_vec::<8, f64>(&tensor.bytes, f64::from_le_bytes)?;
474 Ok(TensorArray::F64(ArrayD::from_shape_vec(shape, values)?))
475 }
476 DType::I8 => {
477 let values = tensor.bytes.iter().map(|v| *v as i8).collect::<Vec<_>>();
478 Ok(TensorArray::I8(ArrayD::from_shape_vec(shape, values)?))
479 }
480 DType::I16 => {
481 let values = bytes_to_vec::<2, i16>(&tensor.bytes, i16::from_le_bytes)?;
482 Ok(TensorArray::I16(ArrayD::from_shape_vec(shape, values)?))
483 }
484 DType::I32 => {
485 let values = bytes_to_vec::<4, i32>(&tensor.bytes, i32::from_le_bytes)?;
486 Ok(TensorArray::I32(ArrayD::from_shape_vec(shape, values)?))
487 }
488 DType::I64 => {
489 let values = bytes_to_vec::<8, i64>(&tensor.bytes, i64::from_le_bytes)?;
490 Ok(TensorArray::I64(ArrayD::from_shape_vec(shape, values)?))
491 }
492 DType::U8 => Ok(TensorArray::U8(ArrayD::from_shape_vec(shape, tensor.bytes.clone())?)),
493 DType::Bool => {
494 let values = tensor.bytes.iter().map(|v| *v != 0).collect::<Vec<_>>();
495 Ok(TensorArray::Bool(ArrayD::from_shape_vec(shape, values)?))
496 }
497 }
498}
499
500fn bytes_to_vec<const N: usize, T>(bytes: &[u8], f: impl Fn([u8; N]) -> T) -> Result<Vec<T>> {
501 if bytes.len() % N != 0 {
502 return Err(ConvertError::InvalidStructure(format!(
503 "tensor bytes are not divisible by {}",
504 N
505 )));
506 }
507
508 Ok(
509 bytes
510 .chunks_exact(N)
511 .map(|chunk| {
512 let mut arr = [0u8; N];
513 arr.copy_from_slice(chunk);
514 f(arr)
515 })
516 .collect(),
517 )
518}
519
520fn normalize_tensor_dtype(dtype: DType, shape: Vec<usize>, bytes: Vec<u8>) -> Result<TensorData> {
521 match dtype {
522 DType::F16 => Ok(TensorData {
523 dtype: DType::F32,
524 shape,
525 bytes: f16_bytes_to_f32_bytes(&bytes)?,
526 }),
527 DType::BF16 => Ok(TensorData {
528 dtype: DType::F32,
529 shape,
530 bytes: bf16_bytes_to_f32_bytes(&bytes)?,
531 }),
532 _ => Ok(TensorData { dtype, shape, bytes }),
533 }
534}
535
536fn f16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
537 if input.len() % 2 != 0 {
538 return Err(ConvertError::InvalidStructure(
539 "f16 tensor bytes must be even-length".to_string(),
540 ));
541 }
542 let mut out = Vec::with_capacity(input.len() * 2);
543 for chunk in input.chunks_exact(2) {
544 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
545 let value = f16_bits_to_f32(bits);
546 out.extend_from_slice(&value.to_le_bytes());
547 }
548 Ok(out)
549}
550
551fn bf16_bytes_to_f32_bytes(input: &[u8]) -> Result<Vec<u8>> {
552 if input.len() % 2 != 0 {
553 return Err(ConvertError::InvalidStructure(
554 "bf16 tensor bytes must be even-length".to_string(),
555 ));
556 }
557 let mut out = Vec::with_capacity(input.len() * 2);
558 for chunk in input.chunks_exact(2) {
559 let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
560 let value = f32::from_bits((bits as u32) << 16);
561 out.extend_from_slice(&value.to_le_bytes());
562 }
563 Ok(out)
564}
565
566fn f16_bits_to_f32(bits: u16) -> f32 {
567 let sign = ((bits >> 15) & 0x1) as u32;
568 let exp = ((bits >> 10) & 0x1f) as u32;
569 let frac = (bits & 0x03ff) as u32;
570
571 let f32_bits = if exp == 0 {
572 if frac == 0 {
573 sign << 31
574 } else {
575 let mut mant = frac;
576 let mut e = -14i32;
577 while (mant & 0x0400) == 0 {
578 mant <<= 1;
579 e -= 1;
580 }
581 mant &= 0x03ff;
582 let exp32 = (e + 127) as u32;
583 (sign << 31) | (exp32 << 23) | (mant << 13)
584 }
585 } else if exp == 0x1f {
586 (sign << 31) | (0xff << 23) | (frac << 13)
587 } else {
588 let exp32 = (exp as i32 - 15 + 127) as u32;
589 (sign << 31) | (exp32 << 23) | (frac << 13)
590 };
591
592 f32::from_bits(f32_bits)
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use crate::metadata::{collect_constructor_types, project_value_for_metadata};
599 use crate::types::Value;
600 use std::io::Write;
601 use tempfile::tempdir;
602 use zip::write::SimpleFileOptions;
603 use zip::ZipWriter;
604
605 #[test]
606 fn converts_simple_tensor_checkpoint() {
607 let tmp = tempdir().expect("tmp dir");
608 let pt_path = tmp.path().join("weights.pt");
609 write_fixture_checkpoint(&pt_path, false).expect("fixture checkpoint");
610
611 let out_dir = tmp.path().join("export");
612 let checkpoint = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect("checkpoint load should work");
613 let result = checkpoint
614 .export(&out_dir, ExportOptions::new(ExportFormat::Safetensors, Some(&pt_path)))
615 .expect("export should work");
616
617 assert!(result.weights_path.exists());
618 assert!(result.metadata_path.as_ref().expect("metadata path").exists());
619 assert_eq!(result.tensor_count, 1);
620
621 let yaml = fs::read_to_string(result.metadata_path.expect("metadata path")).expect("yaml readable");
622 assert!(yaml.contains("layer.weight"));
623 assert!(yaml.contains("dtype: F32") || yaml.contains("dtype: 'F32'"));
624 }
625
626 #[test]
627 fn rejects_unsafe_global_reduce() {
628 let tmp = tempdir().expect("tmp dir");
629 let pt_path = tmp.path().join("unsafe.pt");
630 write_fixture_checkpoint(&pt_path, true).expect("fixture checkpoint");
631
632 let err = PtCheckpoint::load(&pt_path, LoadOptions::default()).expect_err("unsafe pickle should fail");
633 let msg = err.to_string();
634 assert!(msg.contains("unsupported GLOBAL") || msg.contains("unsafe/unsupported pickle"));
635 }
636
637 #[test]
638 fn projects_object_metadata_with_type_args_and_flattened_state() {
639 let value = Value::Object {
640 module: "ultralytics.nn.tasks".to_string(),
641 name: "DetectionModel".to_string(),
642 args: Some(Box::new(Value::Tuple(vec![
643 Value::String("arg0".to_string()),
644 Value::Int(42),
645 ]))),
646 state: Some(Box::new(Value::Dict(vec![(
647 Value::String("training".to_string()),
648 Value::Bool(false),
649 )]))),
650 };
651
652 let projected = project_value_for_metadata(&value);
653 let mapping = match projected {
654 serde_yaml::Value::Mapping(map) => map,
655 other => panic!("expected mapping, got {:?}", other),
656 };
657
658 let type_key = serde_yaml::Value::String("$type".to_string());
659 let class_key = serde_yaml::Value::String("$class".to_string());
660 let args_key = serde_yaml::Value::String("$args".to_string());
661 let training_key = serde_yaml::Value::String("training".to_string());
662
663 assert_eq!(
664 mapping.get(&type_key),
665 Some(&serde_yaml::Value::String("object".to_string()))
666 );
667 assert_eq!(
668 mapping.get(&class_key),
669 Some(&serde_yaml::Value::String(
670 "ultralytics.nn.tasks.DetectionModel".to_string()
671 ))
672 );
673 assert!(mapping.get(&args_key).is_some());
674 assert_eq!(mapping.get(&training_key), Some(&serde_yaml::Value::Bool(false)));
675 }
676
677 #[test]
678 fn omits_empty_object_args() {
679 let value = Value::Object {
680 module: "a".to_string(),
681 name: "B".to_string(),
682 args: Some(Box::new(Value::Tuple(Vec::new()))),
683 state: None,
684 };
685 let projected = project_value_for_metadata(&value);
686 let mapping = match projected {
687 serde_yaml::Value::Mapping(map) => map,
688 other => panic!("expected mapping, got {:?}", other),
689 };
690
691 let args_key = serde_yaml::Value::String("$args".to_string());
692 assert!(!mapping.contains_key(&args_key));
693 }
694
695 #[test]
696 fn collects_constructor_types_deduplicated_in_first_seen_order() {
697 let tree = Value::List(vec![
698 Value::Object {
699 module: "a".to_string(),
700 name: "One".to_string(),
701 args: None,
702 state: None,
703 },
704 Value::Dict(vec![(
705 Value::String("nested".to_string()),
706 Value::Object {
707 module: "b".to_string(),
708 name: "Two".to_string(),
709 args: None,
710 state: None,
711 },
712 )]),
713 Value::Object {
714 module: "a".to_string(),
715 name: "One".to_string(),
716 args: None,
717 state: None,
718 },
719 ]);
720
721 let objects = collect_constructor_types(&tree);
722 assert_eq!(objects, vec!["a.One".to_string(), "b.Two".to_string()]);
723 }
724
725 fn write_fixture_checkpoint(path: &Path, unsafe_payload: bool) -> Result<()> {
726 let file = File::create(path)?;
727 let mut zip = ZipWriter::new(file);
728 let options = SimpleFileOptions::default();
729
730 let data_pkl = if unsafe_payload {
731 build_unsafe_pickle()
732 } else {
733 build_safe_pickle()
734 };
735
736 zip.start_file("archive/data.pkl", options)?;
737 zip.write_all(&data_pkl)?;
738
739 let floats = [1.0f32, 2.0, 3.0, 4.0];
740 let mut raw = Vec::new();
741 for value in floats {
742 raw.extend_from_slice(&value.to_le_bytes());
743 }
744
745 zip.start_file("archive/data/0", options)?;
746 zip.write_all(&raw)?;
747 zip.finish()?;
748 Ok(())
749 }
750
751 fn build_safe_pickle() -> Vec<u8> {
752 let mut out = Vec::new();
753 out.extend_from_slice(&[0x80, 0x02]);
754
755 out.push(b'}');
756 out.push(b'(');
757
758 push_binunicode(&mut out, "layer.weight");
759 out.extend_from_slice(b"ctorch._utils\n_rebuild_tensor_v2\n");
760
761 out.push(b'(');
762
763 out.push(b'(');
764 push_binunicode(&mut out, "storage");
765 out.extend_from_slice(b"ctorch\nFloatStorage\n");
766 push_binunicode(&mut out, "0");
767 push_binunicode(&mut out, "cpu");
768 out.push(b'K');
769 out.push(4);
770 out.push(b't');
771 out.push(b'Q');
772
773 out.push(b'K');
774 out.push(0);
775
776 out.push(b'(');
777 out.push(b'K');
778 out.push(2);
779 out.push(b'K');
780 out.push(2);
781 out.push(b't');
782
783 out.push(b'(');
784 out.push(b'K');
785 out.push(2);
786 out.push(b'K');
787 out.push(1);
788 out.push(b't');
789
790 out.push(0x89);
791 out.push(b'N');
792
793 out.push(b't');
794 out.push(b'R');
795
796 out.push(b'u');
797 out.push(b'.');
798 out
799 }
800
801 fn build_unsafe_pickle() -> Vec<u8> {
802 let mut out = Vec::new();
803 out.extend_from_slice(&[0x80, 0x02]);
804 out.extend_from_slice(b"cos\nsystem\n");
805 out.push(b'(');
806 push_binunicode(&mut out, "echo hacked");
807 out.push(b't');
808 out.push(b'R');
809 out.push(b'.');
810 out
811 }
812
813 fn push_binunicode(out: &mut Vec<u8>, value: &str) {
814 out.push(b'X');
815 out.extend_from_slice(&(value.len() as u32).to_le_bytes());
816 out.extend_from_slice(value.as_bytes());
817 }
818}