1use crate::TensorSnapshot;
38use alloc::string::{String, ToString};
39use alloc::vec::Vec;
40use burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer};
41use serde::de::DeserializeOwned;
42use std::collections::HashMap;
43use std::fs::File;
44use std::io::{BufReader, Read, Seek, SeekFrom};
45use std::path::Path;
46
47use super::lazy_data::LazyDataSource;
48use super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data};
49use std::sync::Arc;
50
51#[derive(Debug)]
53pub enum PytorchError {
54 Io(std::io::Error),
56 Pickle(PickleError),
58 Zip(zip::result::ZipError),
60 Tar(std::io::Error),
62 InvalidFormat(String),
64 KeyNotFound(String),
66 Serde(burn_core::record::serde::error::Error),
68}
69
70impl From<std::io::Error> for PytorchError {
71 fn from(e: std::io::Error) -> Self {
72 PytorchError::Io(e)
73 }
74}
75
76impl From<PickleError> for PytorchError {
77 fn from(e: PickleError) -> Self {
78 PytorchError::Pickle(e)
79 }
80}
81
82impl From<zip::result::ZipError> for PytorchError {
83 fn from(e: zip::result::ZipError) -> Self {
84 PytorchError::Zip(e)
85 }
86}
87
88impl From<burn_core::record::serde::error::Error> for PytorchError {
89 fn from(e: burn_core::record::serde::error::Error) -> Self {
90 PytorchError::Serde(e)
91 }
92}
93
94impl std::fmt::Display for PytorchError {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 PytorchError::Io(e) => write!(f, "IO error: {}", e),
98 PytorchError::Pickle(e) => write!(
99 f,
100 "Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.",
101 e
102 ),
103 PytorchError::Zip(e) => write!(f, "Zip archive error: {}", e),
104 PytorchError::Tar(e) => write!(f, "TAR archive error: {}", e),
105 PytorchError::InvalidFormat(msg) => write!(f, "Invalid PyTorch file format: {}", msg),
106 PytorchError::KeyNotFound(key) => write!(
107 f,
108 "Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.",
109 key
110 ),
111 PytorchError::Serde(e) => write!(f, "Serde deserialization error: {}", e),
112 }
113 }
114}
115
116impl std::error::Error for PytorchError {}
117
118type Result<T> = std::result::Result<T, PytorchError>;
119
120#[derive(Debug, Clone)]
125pub struct PytorchMetadata {
126 pub format_version: Option<String>,
128 pub format_type: FileFormat,
130 pub byte_order: ByteOrder,
132 pub has_storage_alignment: bool,
134 pub pytorch_version: Option<String>,
136 pub tensor_count: usize,
138 pub total_data_size: Option<usize>,
140}
141
142impl PytorchMetadata {
143 pub fn is_modern_format(&self) -> bool {
145 matches!(self.format_type, FileFormat::Zip)
146 }
147
148 pub fn is_legacy_format(&self) -> bool {
150 matches!(self.format_type, FileFormat::Legacy)
151 }
152}
153
154#[derive(Debug, Clone, PartialEq)]
156pub enum FileFormat {
157 Zip,
159 Tar,
161 Legacy,
163 Pickle,
165}
166
167#[derive(Debug, Clone, PartialEq)]
169pub enum ByteOrder {
170 LittleEndian,
171 BigEndian,
172}
173
174pub struct PytorchReader {
202 tensors: HashMap<String, TensorSnapshot>,
203 metadata: PytorchMetadata,
204}
205
206impl PytorchReader {
207 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
215 let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?;
216 Ok(Self { tensors, metadata })
217 }
218
219 pub fn with_top_level_key<P: AsRef<Path>>(path: P, key: &str) -> Result<Self> {
237 let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?;
238 Ok(Self { tensors, metadata })
239 }
240
241 pub fn from_reader<R: Read>(reader: R, top_level_key: Option<&str>) -> Result<Self> {
250 let tensors = load_from_reader(reader, top_level_key)?;
252 let metadata = PytorchMetadata {
253 format_version: None,
254 format_type: FileFormat::Pickle, byte_order: ByteOrder::LittleEndian,
256 has_storage_alignment: false,
257 pytorch_version: None,
258 tensor_count: tensors.len(),
259 total_data_size: None,
260 };
261 Ok(Self { tensors, metadata })
262 }
263
264 pub fn keys(&self) -> Vec<String> {
266 self.tensors.keys().cloned().collect()
267 }
268
269 pub fn get(&self, name: &str) -> Option<&TensorSnapshot> {
271 self.tensors.get(name)
272 }
273
274 pub fn tensors(&self) -> &HashMap<String, TensorSnapshot> {
276 &self.tensors
277 }
278
279 pub fn into_tensors(self) -> HashMap<String, TensorSnapshot> {
281 self.tensors
282 }
283
284 pub fn metadata(&self) -> &PytorchMetadata {
288 &self.metadata
289 }
290
291 pub fn len(&self) -> usize {
293 self.tensors.len()
294 }
295
296 pub fn is_empty(&self) -> bool {
298 self.tensors.is_empty()
299 }
300
301 pub fn read_pickle_data<P: AsRef<Path>>(
313 path: P,
314 top_level_key: Option<&str>,
315 ) -> Result<PickleValue> {
316 read_pickle_as_value(path.as_ref(), top_level_key)
317 }
318
319 pub fn load_config<D, P>(path: P, top_level_key: Option<&str>) -> Result<D>
353 where
354 D: DeserializeOwned,
355 P: AsRef<Path>,
356 {
357 let pickle_value = Self::read_pickle_data(path, top_level_key)?;
359
360 let nested_value = convert_pickle_to_nested_value(pickle_value)?;
362
363 let deserializer = Deserializer::<DefaultAdapter>::new(nested_value, false);
365
366 let value = D::deserialize(deserializer)?;
368 Ok(value)
369 }
370}
371
372#[derive(Debug, Clone, PartialEq)]
377pub enum PickleValue {
378 None,
380 Bool(bool),
382 Int(i64),
384 Float(f64),
386 String(String),
388 List(Vec<PickleValue>),
390 Dict(HashMap<String, PickleValue>),
392 Bytes(Vec<u8>),
394}
395
396fn load_pytorch_file_with_metadata(
398 path: &Path,
399 top_level_key: Option<&str>,
400) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
401 if let Ok(file) = File::open(path)
403 && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
404 {
405 let mut pickle_data = Vec::new();
407 let mut pickle_found = false;
408
409 let possible_pickle_paths = [
411 "data.pkl",
412 "archive/data.pkl",
413 ];
415
416 for pickle_path in &possible_pickle_paths {
417 if archive.by_name(pickle_path).is_ok() {
418 let mut pickle_file = archive.by_name(pickle_path)?;
419 pickle_file.read_to_end(&mut pickle_data)?;
420 pickle_found = true;
421 break;
422 }
423 }
424
425 if !pickle_found {
427 for i in 0..archive.len() {
428 let file = archive.by_index(i)?;
429 let name = file.name().to_string();
430 drop(file); if name.ends_with("data.pkl") {
433 let mut file = archive.by_index(i)?;
434 file.read_to_end(&mut pickle_data)?;
435 pickle_found = true;
436 break;
437 }
438 }
439 }
440
441 if !pickle_found {
442 return Err(PytorchError::InvalidFormat(
443 "No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl".to_string(),
444 ));
445 }
446
447 let format_version = if let Ok(mut version_file) = archive.by_name(".format_version") {
449 let mut version_data = Vec::new();
450 version_file.read_to_end(&mut version_data)?;
451 let version_str = String::from_utf8_lossy(&version_data);
452 let version = version_str.trim().to_string();
453 Some(version)
454 } else {
455 None
456 };
457
458 let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name("byteorder") {
460 let mut byteorder_data = Vec::new();
461 byteorder_file.read_to_end(&mut byteorder_data)?;
462 let byteorder_str = String::from_utf8_lossy(&byteorder_data);
463 byteorder_str.trim() == "big"
464 } else {
465 false };
467
468 if is_big_endian {
469 return Err(PytorchError::InvalidFormat(
475 "Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.".to_string()
476 ));
477 }
478
479 let has_storage_alignment = archive.by_name(".storage_alignment").is_ok();
481
482 let pytorch_version = if let Ok(mut version_file) = archive.by_name("version") {
484 let mut version_data = Vec::new();
485 version_file.read_to_end(&mut version_data)?;
486 Some(String::from_utf8_lossy(&version_data).trim().to_string())
487 } else {
488 None
489 };
490
491 let data_source = Arc::new(LazyDataSource::from_zip(path)?);
493
494 let mut total_data_size = 0usize;
496 for i in 0..archive.len() {
497 let file = archive.by_index(i)?;
498 let name = file.name();
499
500 let is_data_file = (name.contains("/data/")
502 || name.starts_with("data/")
503 || name.starts_with("archive/data/"))
504 && !name.ends_with(".pkl")
505 && !name.ends_with("/");
506
507 if is_data_file {
508 total_data_size += file.size() as usize;
509 }
510 }
511
512 let mut pickle_reader = BufReader::new(pickle_data.as_slice());
514 let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
515
516 let tensors = extract_tensors_with_data(obj, top_level_key)?;
518
519 let metadata = PytorchMetadata {
521 format_version,
522 format_type: FileFormat::Zip,
523 byte_order: if is_big_endian {
524 ByteOrder::BigEndian
525 } else {
526 ByteOrder::LittleEndian
527 },
528 has_storage_alignment,
529 pytorch_version,
530 tensor_count: tensors.len(),
531 total_data_size: Some(total_data_size),
532 };
533
534 return Ok((tensors, metadata));
535 }
536
537 if is_tar_file(path) {
539 return load_tar_pytorch_file_with_metadata(path, top_level_key);
540 }
541
542 let mut file = File::open(path)?;
544
545 let mut header = [0u8; 15];
547 let bytes_read = file.read(&mut header)?;
549 file.seek(std::io::SeekFrom::Start(0))?;
550
551 let is_legacy_format = bytes_read >= 15
567 && header[0] == 0x80 && header[1] == 0x02 && header[2] == 0x8a && header[3] == 0x0a && header[4] == 0x6c
573 && header[5] == 0xfc
574 && header[6] == 0x9c
575 && header[7] == 0x46
576 && header[8] == 0xf9
577 && header[9] == 0x20
578 && header[10] == 0x6a
579 && header[11] == 0xa8
580 && header[12] == 0x50
581 && header[13] == 0x19
582 && header[14] == 0x2e; if is_legacy_format {
585 return load_legacy_pytorch_file_with_metadata(path, top_level_key);
586 }
587
588 let file = File::open(path)?;
593 let mut reader = BufReader::new(file);
594
595 match read_pickle(&mut reader) {
597 Ok(obj) => {
598 let tensors = extract_tensors_with_data(obj, top_level_key)?;
599 let tensor_count = tensors.len();
600 Ok((
601 tensors,
602 PytorchMetadata {
603 format_version: None,
604 format_type: FileFormat::Pickle,
605 byte_order: ByteOrder::LittleEndian,
606 has_storage_alignment: false,
607 pytorch_version: None,
608 tensor_count,
609 total_data_size: None,
610 },
611 ))
612 }
613 Err(e)
614 if e.to_string()
615 .contains("Cannot load tensor data without a data source") =>
616 {
617 Err(PytorchError::InvalidFormat(
621 "Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.".to_string()
622 ))
623 }
624 Err(e) => Err(PytorchError::Pickle(e)),
625 }
626}
627
628fn load_from_reader<R: Read>(
630 reader: R,
631 top_level_key: Option<&str>,
632) -> Result<HashMap<String, TensorSnapshot>> {
633 let mut buf_reader = BufReader::new(reader);
634
635 match read_pickle(&mut buf_reader) {
637 Ok(obj) => extract_tensors_with_data(obj, top_level_key),
638 Err(e)
639 if e.to_string()
640 .contains("Cannot load tensor data without a data source") =>
641 {
642 Err(PytorchError::InvalidFormat(
644 "Reader contains tensor data but no data source is available. Use file-based loading instead.".to_string()
645 ))
646 }
647 Err(e) => Err(PytorchError::Pickle(e)),
648 }
649}
650
651fn extract_tensors_with_data(
653 obj: Object,
654 top_level_key: Option<&str>,
655) -> Result<HashMap<String, TensorSnapshot>> {
656 let dict = match obj {
657 Object::Dict(dict) => {
658 if let Some(key) = top_level_key {
659 match dict.get(key) {
661 Some(Object::Dict(nested)) => nested.clone(),
662 _ => {
663 return Err(PytorchError::KeyNotFound(format!(
664 "Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}",
665 key,
666 dict.keys().collect::<Vec<_>>()
667 )));
668 }
669 }
670 } else {
671 dict
672 }
673 }
674 _ => {
675 return Err(PytorchError::InvalidFormat(
676 "Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.".to_string(),
677 ));
678 }
679 };
680
681 let mut tensors = HashMap::new();
682 let mut path = Vec::new();
683 extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors);
684 Ok(tensors)
685}
686
687fn extract_tensors_recursive<'a>(
689 obj: &'a Object,
690 path: &mut Vec<&'a str>,
691 tensors: &mut HashMap<String, TensorSnapshot>,
692) {
693 match obj {
694 Object::Dict(dict) => {
695 for (key, value) in dict {
696 path.push(key);
697 extract_tensors_recursive(value, path, tensors);
698 path.pop();
699 }
700 }
701 Object::TorchParam(snapshot) => {
702 tensors.insert(path.join("."), snapshot.clone());
705 }
706 _ => {}
707 }
708}
709
710fn load_legacy_pytorch_file_with_metadata(
712 path: &Path,
713 top_level_key: Option<&str>,
714) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
715 let file = File::open(path)?;
716 let mut reader = BufReader::new(file);
717
718 let _ = read_pickle(&mut reader).map_err(|e| {
721 PytorchError::InvalidFormat(format!(
722 "Failed to read magic number from legacy format: {}",
723 e
724 ))
725 })?;
726
727 let _ = read_pickle(&mut reader).map_err(|e| {
729 PytorchError::InvalidFormat(format!(
730 "Failed to read protocol version from legacy format: {}",
731 e
732 ))
733 })?;
734
735 let _ = read_pickle(&mut reader).map_err(|e| {
737 PytorchError::InvalidFormat(format!(
738 "Failed to read system info from legacy format: {}",
739 e
740 ))
741 })?;
742
743 let main_pickle_pos = reader.stream_position()?;
745
746 use crate::pytorch::pickle_reader::skip_pickle;
749 skip_pickle(&mut reader).map_err(|e| {
750 PytorchError::InvalidFormat(format!(
751 "Failed to skip main object in legacy format: {}",
752 e
753 ))
754 })?;
755
756 let storage_keys = match read_pickle(&mut reader) {
758 Ok(Object::List(keys)) => keys
759 .into_iter()
760 .filter_map(|obj| match obj {
761 Object::String(s) => Some(s),
762 _ => None,
763 })
764 .collect::<Vec<_>>(),
765 _ => vec![],
766 };
767
768 let data_start_pos = reader.stream_position()?;
770 let file_size = reader.seek(SeekFrom::End(0))?;
771 let data_size = file_size - data_start_pos;
772
773 let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage(
775 path,
776 data_start_pos,
777 data_size,
778 ));
779
780 if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source
784 && !storage_keys.is_empty()
785 {
786 let source = source
787 .lock()
788 .unwrap_or_else(|poisoned| poisoned.into_inner());
789 source.set_storage_keys(storage_keys.clone());
790 }
791
792 reader.seek(SeekFrom::Start(main_pickle_pos))?;
794 let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?;
795
796 let tensors = extract_tensors_with_data(main_obj, top_level_key)?;
798
799 let metadata = PytorchMetadata {
801 format_version: None, format_type: FileFormat::Legacy,
803 byte_order: ByteOrder::LittleEndian, has_storage_alignment: false,
805 pytorch_version: None, tensor_count: tensors.len(),
807 total_data_size: Some(data_size as usize),
808 };
809
810 Ok((tensors, metadata))
811}
812
813fn is_tar_file(path: &Path) -> bool {
815 if let Ok(mut file) = File::open(path) {
816 let mut header = [0u8; 263];
818 if file.read_exact(&mut header).is_ok() {
819 return &header[257..262] == b"ustar";
821 }
822 }
823 false
824}
825
826fn load_tar_pytorch_file_with_metadata(
828 path: &Path,
829 top_level_key: Option<&str>,
830) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
831 use tar::Archive;
832
833 let file = File::open(path)?;
834 let mut archive = Archive::new(BufReader::new(file));
835
836 let mut sys_info_data: Option<Vec<u8>> = None;
838 let mut pickle_data: Option<Vec<u8>> = None;
839 let mut storages_data: Option<Vec<u8>> = None;
840
841 for entry in archive.entries().map_err(PytorchError::Tar)? {
842 let mut entry = entry.map_err(PytorchError::Tar)?;
843 let entry_path = entry
844 .path()
845 .map_err(PytorchError::Tar)?
846 .to_string_lossy()
847 .to_string();
848
849 if entry_path.contains("@PaxHeader") {
851 continue;
852 }
853
854 let normalized = entry_path.trim_start_matches("./");
856
857 match normalized {
858 "sys_info" => {
859 let mut data = Vec::new();
860 entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
861 sys_info_data = Some(data);
862 }
863 "pickle" => {
864 let mut data = Vec::new();
865 entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
866 pickle_data = Some(data);
867 }
868 "storages" => {
869 let mut data = Vec::new();
870 entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
871 storages_data = Some(data);
872 }
873 _ => {}
874 }
875 }
876
877 let pickle_data = pickle_data.ok_or_else(|| {
879 PytorchError::InvalidFormat("TAR file missing 'pickle' entry".to_string())
880 })?;
881 let storages_data = storages_data.ok_or_else(|| {
882 PytorchError::InvalidFormat("TAR file missing 'storages' entry".to_string())
883 })?;
884
885 let is_little_endian = if let Some(ref data) = sys_info_data {
887 parse_tar_sys_info(data)?
888 } else {
889 true };
891
892 if !is_little_endian {
893 return Err(PytorchError::InvalidFormat(
894 "Big-endian TAR PyTorch files are not supported".to_string(),
895 ));
896 }
897
898 let data_source = Arc::new(LazyDataSource::from_tar(&storages_data)?);
900
901 let mut pickle_reader = BufReader::new(pickle_data.as_slice());
903 let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
904
905 let tensors = extract_tensors_with_data(obj, top_level_key)?;
907
908 let metadata = PytorchMetadata {
909 format_version: None,
910 format_type: FileFormat::Tar,
911 byte_order: ByteOrder::LittleEndian,
912 has_storage_alignment: false,
913 pytorch_version: None,
914 tensor_count: tensors.len(),
915 total_data_size: Some(storages_data.len()),
916 };
917
918 Ok((tensors, metadata))
919}
920
921fn parse_tar_sys_info(data: &[u8]) -> Result<bool> {
923 let mut reader = BufReader::new(data);
924 let obj = read_pickle(&mut reader)?;
925
926 if let Object::Dict(dict) = obj
927 && let Some(Object::Bool(little_endian)) = dict.get("little_endian")
928 {
929 return Ok(*little_endian);
930 }
931
932 Ok(true) }
934
935fn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result<PickleValue> {
937 use crate::pytorch::lazy_data::LazyDataSource;
938 use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data};
939 use std::sync::Arc;
940
941 if let Ok(file) = File::open(path)
943 && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
944 {
945 let mut pickle_data = Vec::new();
947
948 for pickle_path in &["data.pkl", "archive/data.pkl"] {
950 if let Ok(mut pickle_file) = archive.by_name(pickle_path) {
951 pickle_file.read_to_end(&mut pickle_data)?;
952 break;
953 }
954 }
955
956 if pickle_data.is_empty() {
958 for i in 0..archive.len() {
959 let file = archive.by_index(i)?;
960 let name = file.name().to_string();
961 drop(file);
962
963 if name.ends_with("data.pkl") {
964 let mut file = archive.by_index(i)?;
965 file.read_to_end(&mut pickle_data)?;
966 break;
967 }
968 }
969 }
970
971 if !pickle_data.is_empty() {
972 let data_source = LazyDataSource::from_zip(path)?;
974 let data_source_arc = Arc::new(data_source);
975
976 let mut reader = BufReader::new(pickle_data.as_slice());
977 let obj = read_pickle_with_data(&mut reader, data_source_arc)?;
978 return convert_object_to_value(obj, top_level_key);
979 }
980 }
981
982 let file = File::open(path)?;
985 let mut reader = BufReader::new(file);
986
987 match read_pickle(&mut reader) {
988 Ok(obj) => convert_object_to_value(obj, top_level_key),
989 Err(e)
990 if e.to_string()
991 .contains("Cannot load tensor data without a data source") =>
992 {
993 let reader = PytorchReader::new(path)?;
996
997 let mut result = std::collections::HashMap::new();
999 for key in reader.keys() {
1000 result.insert(
1002 key.clone(),
1003 PickleValue::String(format!("<Tensor:{}>", key)),
1004 );
1005 }
1006
1007 if let Some(key) = top_level_key {
1008 Ok(PickleValue::Dict(
1009 [(key.to_string(), PickleValue::Dict(result))]
1010 .into_iter()
1011 .collect(),
1012 ))
1013 } else {
1014 Ok(PickleValue::Dict(result))
1015 }
1016 }
1017 Err(e) => Err(PytorchError::Pickle(e)),
1018 }
1019}
1020
1021fn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result<PickleValue> {
1023 use crate::pytorch::pickle_reader::Object;
1024
1025 if let Some(key) = top_level_key
1027 && let Object::Dict(dict) = obj
1028 {
1029 if let Some(value) = dict.get(key) {
1030 return object_to_pickle_value(value.clone());
1031 } else {
1032 return Err(PytorchError::KeyNotFound(format!(
1033 "Key '{}' not found in pickle data",
1034 key
1035 )));
1036 }
1037 }
1038
1039 object_to_pickle_value(obj)
1040}
1041
1042fn object_to_pickle_value(obj: Object) -> Result<PickleValue> {
1044 use crate::pytorch::pickle_reader::Object;
1045
1046 Ok(match obj {
1047 Object::None => PickleValue::None,
1048 Object::Bool(b) => PickleValue::Bool(b),
1049 Object::Int(i) => PickleValue::Int(i),
1050 Object::Float(f) => PickleValue::Float(f),
1051 Object::String(s) => PickleValue::String(s),
1052 Object::Persistent(data) => {
1053 PickleValue::Bytes(data)
1055 }
1056 Object::PersistentTuple(tuple) => {
1057 let mut values = Vec::new();
1059 for item in tuple {
1060 values.push(object_to_pickle_value(item)?);
1061 }
1062 PickleValue::List(values)
1063 }
1064 Object::List(list) => {
1065 let mut values = Vec::new();
1066 for item in list {
1067 values.push(object_to_pickle_value(item)?);
1068 }
1069 PickleValue::List(values)
1070 }
1071 Object::Dict(dict) => {
1072 let mut map = HashMap::new();
1073 for (k, v) in dict {
1074 map.insert(k, object_to_pickle_value(v)?);
1075 }
1076 PickleValue::Dict(map)
1077 }
1078 Object::Tuple(tuple) => {
1079 let mut values = Vec::new();
1081 for item in tuple {
1082 values.push(object_to_pickle_value(item)?);
1083 }
1084 PickleValue::List(values)
1085 }
1086 Object::TorchParam(_) => {
1087 PickleValue::None
1089 }
1090 Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => {
1091 PickleValue::None
1093 }
1094 })
1095}
1096
1097fn convert_pickle_to_nested_value(value: PickleValue) -> Result<NestedValue> {
1099 Ok(match value {
1100 PickleValue::None => NestedValue::Default(None),
1101 PickleValue::Bool(b) => NestedValue::Bool(b),
1102 PickleValue::Int(i) => NestedValue::I64(i),
1103 PickleValue::Float(f) => NestedValue::F64(f),
1104 PickleValue::String(s) => NestedValue::String(s),
1105 PickleValue::List(list) => {
1106 let mut vec = Vec::new();
1107 for item in list {
1108 vec.push(convert_pickle_to_nested_value(item)?);
1109 }
1110 NestedValue::Vec(vec)
1111 }
1112 PickleValue::Dict(dict) => {
1113 let mut map = HashMap::new();
1114 for (k, v) in dict {
1115 map.insert(k, convert_pickle_to_nested_value(v)?);
1116 }
1117 NestedValue::Map(map)
1118 }
1119 PickleValue::Bytes(data) => {
1120 let vec: Vec<NestedValue> = data.into_iter().map(NestedValue::U8).collect();
1122 NestedValue::Vec(vec)
1123 }
1124 })
1125}