1use crate::TensorSnapshot;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33use burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer};
34use serde::de::DeserializeOwned;
35use std::collections::HashMap;
36use std::fs::File;
37use std::io::{BufReader, Read, Seek, SeekFrom};
38use std::path::Path;
39
40use super::lazy_data::LazyDataSource;
41use super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data};
42use std::sync::Arc;
43
44#[derive(Debug)]
46pub enum PytorchError {
47 Io(std::io::Error),
49 Pickle(PickleError),
51 Zip(zip::result::ZipError),
53 InvalidFormat(String),
55 KeyNotFound(String),
57 Serde(burn_core::record::serde::error::Error),
59}
60
61impl From<std::io::Error> for PytorchError {
62 fn from(e: std::io::Error) -> Self {
63 PytorchError::Io(e)
64 }
65}
66
67impl From<PickleError> for PytorchError {
68 fn from(e: PickleError) -> Self {
69 PytorchError::Pickle(e)
70 }
71}
72
73impl From<zip::result::ZipError> for PytorchError {
74 fn from(e: zip::result::ZipError) -> Self {
75 PytorchError::Zip(e)
76 }
77}
78
79impl From<burn_core::record::serde::error::Error> for PytorchError {
80 fn from(e: burn_core::record::serde::error::Error) -> Self {
81 PytorchError::Serde(e)
82 }
83}
84
85impl std::fmt::Display for PytorchError {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 PytorchError::Io(e) => write!(f, "IO error: {}", e),
89 PytorchError::Pickle(e) => write!(
90 f,
91 "Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.",
92 e
93 ),
94 PytorchError::Zip(e) => write!(f, "Zip archive error: {}", e),
95 PytorchError::InvalidFormat(msg) => write!(f, "Invalid PyTorch file format: {}", msg),
96 PytorchError::KeyNotFound(key) => write!(
97 f,
98 "Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.",
99 key
100 ),
101 PytorchError::Serde(e) => write!(f, "Serde deserialization error: {}", e),
102 }
103 }
104}
105
106impl std::error::Error for PytorchError {}
107
108type Result<T> = std::result::Result<T, PytorchError>;
109
110#[derive(Debug, Clone)]
115pub struct PytorchMetadata {
116 pub format_version: Option<String>,
118 pub format_type: FileFormat,
120 pub byte_order: ByteOrder,
122 pub has_storage_alignment: bool,
124 pub pytorch_version: Option<String>,
126 pub tensor_count: usize,
128 pub total_data_size: Option<usize>,
130}
131
132impl PytorchMetadata {
133 pub fn is_modern_format(&self) -> bool {
135 matches!(self.format_type, FileFormat::Zip)
136 }
137
138 pub fn is_legacy_format(&self) -> bool {
140 matches!(self.format_type, FileFormat::Legacy)
141 }
142}
143
144#[derive(Debug, Clone, PartialEq)]
146pub enum FileFormat {
147 Zip,
149 Legacy,
151 Pickle,
153}
154
155#[derive(Debug, Clone, PartialEq)]
157pub enum ByteOrder {
158 LittleEndian,
159 BigEndian,
160}
161
162pub struct PytorchReader {
190 tensors: HashMap<String, TensorSnapshot>,
191 metadata: PytorchMetadata,
192}
193
194impl PytorchReader {
195 pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
203 let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?;
204 Ok(Self { tensors, metadata })
205 }
206
207 pub fn with_top_level_key<P: AsRef<Path>>(path: P, key: &str) -> Result<Self> {
225 let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?;
226 Ok(Self { tensors, metadata })
227 }
228
229 pub fn from_reader<R: Read>(reader: R, top_level_key: Option<&str>) -> Result<Self> {
238 let tensors = load_from_reader(reader, top_level_key)?;
240 let metadata = PytorchMetadata {
241 format_version: None,
242 format_type: FileFormat::Pickle, byte_order: ByteOrder::LittleEndian,
244 has_storage_alignment: false,
245 pytorch_version: None,
246 tensor_count: tensors.len(),
247 total_data_size: None,
248 };
249 Ok(Self { tensors, metadata })
250 }
251
252 pub fn keys(&self) -> Vec<String> {
254 self.tensors.keys().cloned().collect()
255 }
256
257 pub fn get(&self, name: &str) -> Option<&TensorSnapshot> {
259 self.tensors.get(name)
260 }
261
262 pub fn tensors(&self) -> &HashMap<String, TensorSnapshot> {
264 &self.tensors
265 }
266
267 pub fn into_tensors(self) -> HashMap<String, TensorSnapshot> {
269 self.tensors
270 }
271
272 pub fn metadata(&self) -> &PytorchMetadata {
276 &self.metadata
277 }
278
279 pub fn len(&self) -> usize {
281 self.tensors.len()
282 }
283
284 pub fn is_empty(&self) -> bool {
286 self.tensors.is_empty()
287 }
288
289 pub fn read_pickle_data<P: AsRef<Path>>(
301 path: P,
302 top_level_key: Option<&str>,
303 ) -> Result<PickleValue> {
304 read_pickle_as_value(path.as_ref(), top_level_key)
305 }
306
307 pub fn load_config<D, P>(path: P, top_level_key: Option<&str>) -> Result<D>
341 where
342 D: DeserializeOwned,
343 P: AsRef<Path>,
344 {
345 let pickle_value = Self::read_pickle_data(path, top_level_key)?;
347
348 let nested_value = convert_pickle_to_nested_value(pickle_value)?;
350
351 let deserializer = Deserializer::<DefaultAdapter>::new(nested_value, false);
353
354 let value = D::deserialize(deserializer)?;
356 Ok(value)
357 }
358}
359
360#[derive(Debug, Clone, PartialEq)]
365pub enum PickleValue {
366 None,
368 Bool(bool),
370 Int(i64),
372 Float(f64),
374 String(String),
376 List(Vec<PickleValue>),
378 Dict(HashMap<String, PickleValue>),
380 Bytes(Vec<u8>),
382}
383
384fn load_pytorch_file_with_metadata(
386 path: &Path,
387 top_level_key: Option<&str>,
388) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
389 if let Ok(file) = File::open(path)
391 && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
392 {
393 let mut pickle_data = Vec::new();
395 let mut pickle_found = false;
396
397 let possible_pickle_paths = [
399 "data.pkl",
400 "archive/data.pkl",
401 ];
403
404 for pickle_path in &possible_pickle_paths {
405 if archive.by_name(pickle_path).is_ok() {
406 let mut pickle_file = archive.by_name(pickle_path)?;
407 pickle_file.read_to_end(&mut pickle_data)?;
408 pickle_found = true;
409 break;
410 }
411 }
412
413 if !pickle_found {
415 for i in 0..archive.len() {
416 let file = archive.by_index(i)?;
417 let name = file.name().to_string();
418 drop(file); if name.ends_with("data.pkl") {
421 let mut file = archive.by_index(i)?;
422 file.read_to_end(&mut pickle_data)?;
423 pickle_found = true;
424 break;
425 }
426 }
427 }
428
429 if !pickle_found {
430 return Err(PytorchError::InvalidFormat(
431 "No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl".to_string(),
432 ));
433 }
434
435 let format_version = if let Ok(mut version_file) = archive.by_name(".format_version") {
437 let mut version_data = Vec::new();
438 version_file.read_to_end(&mut version_data)?;
439 let version_str = String::from_utf8_lossy(&version_data);
440 let version = version_str.trim().to_string();
441 Some(version)
442 } else {
443 None
444 };
445
446 let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name("byteorder") {
448 let mut byteorder_data = Vec::new();
449 byteorder_file.read_to_end(&mut byteorder_data)?;
450 let byteorder_str = String::from_utf8_lossy(&byteorder_data);
451 byteorder_str.trim() == "big"
452 } else {
453 false };
455
456 if is_big_endian {
457 return Err(PytorchError::InvalidFormat(
463 "Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.".to_string()
464 ));
465 }
466
467 let has_storage_alignment = archive.by_name(".storage_alignment").is_ok();
469
470 let pytorch_version = if let Ok(mut version_file) = archive.by_name("version") {
472 let mut version_data = Vec::new();
473 version_file.read_to_end(&mut version_data)?;
474 Some(String::from_utf8_lossy(&version_data).trim().to_string())
475 } else {
476 None
477 };
478
479 let data_source = Arc::new(LazyDataSource::from_zip(path)?);
481
482 let mut total_data_size = 0usize;
484 for i in 0..archive.len() {
485 let file = archive.by_index(i)?;
486 let name = file.name();
487
488 let is_data_file = (name.contains("/data/")
490 || name.starts_with("data/")
491 || name.starts_with("archive/data/"))
492 && !name.ends_with(".pkl")
493 && !name.ends_with("/");
494
495 if is_data_file {
496 total_data_size += file.size() as usize;
497 }
498 }
499
500 let mut pickle_reader = BufReader::new(pickle_data.as_slice());
502 let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
503
504 let tensors = extract_tensors_with_data(obj, top_level_key)?;
506
507 let metadata = PytorchMetadata {
509 format_version,
510 format_type: FileFormat::Zip,
511 byte_order: if is_big_endian {
512 ByteOrder::BigEndian
513 } else {
514 ByteOrder::LittleEndian
515 },
516 has_storage_alignment,
517 pytorch_version,
518 tensor_count: tensors.len(),
519 total_data_size: Some(total_data_size),
520 };
521
522 return Ok((tensors, metadata));
523 }
524
525 let mut file = File::open(path)?;
527
528 let mut header = [0u8; 15];
530 let bytes_read = file.read(&mut header)?;
532 file.seek(std::io::SeekFrom::Start(0))?;
533
534 let is_legacy_format = bytes_read >= 15
550 && header[0] == 0x80 && header[1] == 0x02 && header[2] == 0x8a && header[3] == 0x0a && header[4] == 0x6c
556 && header[5] == 0xfc
557 && header[6] == 0x9c
558 && header[7] == 0x46
559 && header[8] == 0xf9
560 && header[9] == 0x20
561 && header[10] == 0x6a
562 && header[11] == 0xa8
563 && header[12] == 0x50
564 && header[13] == 0x19
565 && header[14] == 0x2e; if is_legacy_format {
568 return load_legacy_pytorch_file_with_metadata(path, top_level_key);
569 }
570
571 let file = File::open(path)?;
576 let mut reader = BufReader::new(file);
577
578 match read_pickle(&mut reader) {
580 Ok(obj) => {
581 let tensors = extract_tensors_with_data(obj, top_level_key)?;
582 let tensor_count = tensors.len();
583 Ok((
584 tensors,
585 PytorchMetadata {
586 format_version: None,
587 format_type: FileFormat::Pickle,
588 byte_order: ByteOrder::LittleEndian,
589 has_storage_alignment: false,
590 pytorch_version: None,
591 tensor_count,
592 total_data_size: None,
593 },
594 ))
595 }
596 Err(e)
597 if e.to_string()
598 .contains("Cannot load tensor data without a data source") =>
599 {
600 Err(PytorchError::InvalidFormat(
604 "Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.".to_string()
605 ))
606 }
607 Err(e) => Err(PytorchError::Pickle(e)),
608 }
609}
610
611fn load_from_reader<R: Read>(
613 reader: R,
614 top_level_key: Option<&str>,
615) -> Result<HashMap<String, TensorSnapshot>> {
616 let mut buf_reader = BufReader::new(reader);
617
618 match read_pickle(&mut buf_reader) {
620 Ok(obj) => extract_tensors_with_data(obj, top_level_key),
621 Err(e)
622 if e.to_string()
623 .contains("Cannot load tensor data without a data source") =>
624 {
625 Err(PytorchError::InvalidFormat(
627 "Reader contains tensor data but no data source is available. Use file-based loading instead.".to_string()
628 ))
629 }
630 Err(e) => Err(PytorchError::Pickle(e)),
631 }
632}
633
634fn extract_tensors_with_data(
636 obj: Object,
637 top_level_key: Option<&str>,
638) -> Result<HashMap<String, TensorSnapshot>> {
639 let dict = match obj {
640 Object::Dict(dict) => {
641 if let Some(key) = top_level_key {
642 match dict.get(key) {
644 Some(Object::Dict(nested)) => nested.clone(),
645 _ => {
646 return Err(PytorchError::KeyNotFound(format!(
647 "Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}",
648 key,
649 dict.keys().collect::<Vec<_>>()
650 )));
651 }
652 }
653 } else {
654 dict
655 }
656 }
657 _ => {
658 return Err(PytorchError::InvalidFormat(
659 "Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.".to_string(),
660 ));
661 }
662 };
663
664 let mut tensors = HashMap::new();
665 let mut path = Vec::new();
666 extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors);
667 Ok(tensors)
668}
669
670fn extract_tensors_recursive<'a>(
672 obj: &'a Object,
673 path: &mut Vec<&'a str>,
674 tensors: &mut HashMap<String, TensorSnapshot>,
675) {
676 match obj {
677 Object::Dict(dict) => {
678 for (key, value) in dict {
679 path.push(key);
680 extract_tensors_recursive(value, path, tensors);
681 path.pop();
682 }
683 }
684 Object::TorchParam(snapshot) => {
685 tensors.insert(path.join("."), snapshot.clone());
688 }
689 _ => {}
690 }
691}
692
693fn load_legacy_pytorch_file_with_metadata(
695 path: &Path,
696 top_level_key: Option<&str>,
697) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
698 let file = File::open(path)?;
699 let mut reader = BufReader::new(file);
700
701 let _ = read_pickle(&mut reader).map_err(|e| {
704 PytorchError::InvalidFormat(format!(
705 "Failed to read magic number from legacy format: {}",
706 e
707 ))
708 })?;
709
710 let _ = read_pickle(&mut reader).map_err(|e| {
712 PytorchError::InvalidFormat(format!(
713 "Failed to read protocol version from legacy format: {}",
714 e
715 ))
716 })?;
717
718 let _ = read_pickle(&mut reader).map_err(|e| {
720 PytorchError::InvalidFormat(format!(
721 "Failed to read system info from legacy format: {}",
722 e
723 ))
724 })?;
725
726 let main_pickle_pos = reader.stream_position()?;
728
729 use crate::pytorch::pickle_reader::skip_pickle;
732 skip_pickle(&mut reader).map_err(|e| {
733 PytorchError::InvalidFormat(format!(
734 "Failed to skip main object in legacy format: {}",
735 e
736 ))
737 })?;
738
739 let storage_keys = match read_pickle(&mut reader) {
741 Ok(Object::List(keys)) => keys
742 .into_iter()
743 .filter_map(|obj| match obj {
744 Object::String(s) => Some(s),
745 _ => None,
746 })
747 .collect::<Vec<_>>(),
748 _ => vec![],
749 };
750
751 let data_start_pos = reader.stream_position()?;
753 let file_size = reader.seek(SeekFrom::End(0))?;
754 let data_size = file_size - data_start_pos;
755
756 let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage(
758 path,
759 data_start_pos,
760 data_size,
761 ));
762
763 reader.seek(SeekFrom::Start(main_pickle_pos))?;
765 let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?;
766
767 if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source
770 && !storage_keys.is_empty()
771 {
772 let source = source
773 .lock()
774 .unwrap_or_else(|poisoned| poisoned.into_inner());
775 source.set_storage_keys(storage_keys.clone());
776 }
777
778 let tensors = extract_tensors_with_data(main_obj, top_level_key)?;
780
781 let metadata = PytorchMetadata {
783 format_version: None, format_type: FileFormat::Legacy,
785 byte_order: ByteOrder::LittleEndian, has_storage_alignment: false,
787 pytorch_version: None, tensor_count: tensors.len(),
789 total_data_size: Some(data_size as usize),
790 };
791
792 Ok((tensors, metadata))
793}
794
795fn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result<PickleValue> {
797 use crate::pytorch::lazy_data::LazyDataSource;
798 use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data};
799 use std::sync::Arc;
800
801 if let Ok(file) = File::open(path)
803 && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
804 {
805 let mut pickle_data = Vec::new();
807
808 for pickle_path in &["data.pkl", "archive/data.pkl"] {
810 if let Ok(mut pickle_file) = archive.by_name(pickle_path) {
811 pickle_file.read_to_end(&mut pickle_data)?;
812 break;
813 }
814 }
815
816 if pickle_data.is_empty() {
818 for i in 0..archive.len() {
819 let file = archive.by_index(i)?;
820 let name = file.name().to_string();
821 drop(file);
822
823 if name.ends_with("data.pkl") {
824 let mut file = archive.by_index(i)?;
825 file.read_to_end(&mut pickle_data)?;
826 break;
827 }
828 }
829 }
830
831 if !pickle_data.is_empty() {
832 let data_source = LazyDataSource::from_zip(path)?;
834 let data_source_arc = Arc::new(data_source);
835
836 let mut reader = BufReader::new(pickle_data.as_slice());
837 let obj = read_pickle_with_data(&mut reader, data_source_arc)?;
838 return convert_object_to_value(obj, top_level_key);
839 }
840 }
841
842 let file = File::open(path)?;
845 let mut reader = BufReader::new(file);
846
847 match read_pickle(&mut reader) {
848 Ok(obj) => convert_object_to_value(obj, top_level_key),
849 Err(e)
850 if e.to_string()
851 .contains("Cannot load tensor data without a data source") =>
852 {
853 let reader = PytorchReader::new(path)?;
856
857 let mut result = std::collections::HashMap::new();
859 for key in reader.keys() {
860 result.insert(
862 key.clone(),
863 PickleValue::String(format!("<Tensor:{}>", key)),
864 );
865 }
866
867 if let Some(key) = top_level_key {
868 Ok(PickleValue::Dict(
869 [(key.to_string(), PickleValue::Dict(result))]
870 .into_iter()
871 .collect(),
872 ))
873 } else {
874 Ok(PickleValue::Dict(result))
875 }
876 }
877 Err(e) => Err(PytorchError::Pickle(e)),
878 }
879}
880
881fn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result<PickleValue> {
883 use crate::pytorch::pickle_reader::Object;
884
885 if let Some(key) = top_level_key
887 && let Object::Dict(dict) = obj
888 {
889 if let Some(value) = dict.get(key) {
890 return object_to_pickle_value(value.clone());
891 } else {
892 return Err(PytorchError::KeyNotFound(format!(
893 "Key '{}' not found in pickle data",
894 key
895 )));
896 }
897 }
898
899 object_to_pickle_value(obj)
900}
901
902fn object_to_pickle_value(obj: Object) -> Result<PickleValue> {
904 use crate::pytorch::pickle_reader::Object;
905
906 Ok(match obj {
907 Object::None => PickleValue::None,
908 Object::Bool(b) => PickleValue::Bool(b),
909 Object::Int(i) => PickleValue::Int(i),
910 Object::Float(f) => PickleValue::Float(f),
911 Object::String(s) => PickleValue::String(s),
912 Object::Persistent(data) => {
913 PickleValue::Bytes(data)
915 }
916 Object::PersistentTuple(tuple) => {
917 let mut values = Vec::new();
919 for item in tuple {
920 values.push(object_to_pickle_value(item)?);
921 }
922 PickleValue::List(values)
923 }
924 Object::List(list) => {
925 let mut values = Vec::new();
926 for item in list {
927 values.push(object_to_pickle_value(item)?);
928 }
929 PickleValue::List(values)
930 }
931 Object::Dict(dict) => {
932 let mut map = HashMap::new();
933 for (k, v) in dict {
934 map.insert(k, object_to_pickle_value(v)?);
935 }
936 PickleValue::Dict(map)
937 }
938 Object::Tuple(tuple) => {
939 let mut values = Vec::new();
941 for item in tuple {
942 values.push(object_to_pickle_value(item)?);
943 }
944 PickleValue::List(values)
945 }
946 Object::TorchParam(_) => {
947 PickleValue::None
949 }
950 Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => {
951 PickleValue::None
953 }
954 })
955}
956
957fn convert_pickle_to_nested_value(value: PickleValue) -> Result<NestedValue> {
959 Ok(match value {
960 PickleValue::None => NestedValue::Default(None),
961 PickleValue::Bool(b) => NestedValue::Bool(b),
962 PickleValue::Int(i) => NestedValue::I64(i),
963 PickleValue::Float(f) => NestedValue::F64(f),
964 PickleValue::String(s) => NestedValue::String(s),
965 PickleValue::List(list) => {
966 let mut vec = Vec::new();
967 for item in list {
968 vec.push(convert_pickle_to_nested_value(item)?);
969 }
970 NestedValue::Vec(vec)
971 }
972 PickleValue::Dict(dict) => {
973 let mut map = HashMap::new();
974 for (k, v) in dict {
975 map.insert(k, convert_pickle_to_nested_value(v)?);
976 }
977 NestedValue::Map(map)
978 }
979 PickleValue::Bytes(data) => {
980 let vec: Vec<NestedValue> = data.into_iter().map(NestedValue::U8).collect();
982 NestedValue::Vec(vec)
983 }
984 })
985}