Skip to main content

burn_store/pytorch/
reader.rs

1//! PyTorch file reader implementation.
2//!
3//! This module provides support for reading PyTorch checkpoint files (.pt/.pth).
4//!
5//! # Supported Formats
6//!
7//! ## 1. Modern ZIP Format (PyTorch 1.6+)
8//! Files are ZIP archives containing:
9//! - `data.pkl` or `archive/data.pkl`: Pickled tensor metadata
10//! - `data/` directory: Binary tensor data files
11//!
12//! ## 2. TAR Format (older torchvision models like AlexNet, SqueezeNet)
13//! TAR archives containing:
14//! - `sys_info`: System info pickle (endianness, type sizes)
15//! - `pickle`: OrderedDict mapping tensor names to storage keys
16//! - `tensors`: Tensor metadata (unused, metadata is in pickle)
17//! - `storages`: Count pickle + sequential (metadata, num_elements, raw data)
18//!
19//! ## 3. Legacy Pickle Format (PyTorch 0.1.10 - 1.5)
20//! Sequential pickle streams with the structure:
21//! - Magic number pickle (0x1950a86a20f9469cfc6c)
22//! - Protocol version pickle (e.g., 1001)
23//! - System info pickle (endianness, type sizes)
24//! - Model data pickle (state_dict or full model)
25//!
26//! ## 4. Simple Pickle Format
27//! Direct pickle file with a dictionary at the root, commonly used for
28//! manually saved state_dicts.
29//!
30//! # Compatibility
31//!
32//! The reader handles backward compatibility by detecting the file format
33//! automatically. Files from PyTorch 0.1.10 through current versions are
34//! supported, though full model saves (vs state_dict) may have limitations
35//! as they contain Python code references.
36
37use crate::TensorSnapshot;
38use alloc::string::{String, ToString};
39use alloc::vec::Vec;
40use burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer};
41use serde::de::DeserializeOwned;
42use std::collections::HashMap;
43use std::fs::File;
44use std::io::{BufReader, Read, Seek, SeekFrom};
45use std::path::Path;
46
47use super::lazy_data::LazyDataSource;
48use super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data};
49use std::sync::Arc;
50
51/// Error type for PyTorch file operations
52#[derive(Debug)]
53pub enum PytorchError {
54    /// IO error
55    Io(std::io::Error),
56    /// Pickle parsing error
57    Pickle(PickleError),
58    /// Zip archive error
59    Zip(zip::result::ZipError),
60    /// TAR archive error
61    Tar(std::io::Error),
62    /// Invalid file format
63    InvalidFormat(String),
64    /// Key not found
65    KeyNotFound(String),
66    /// Serde deserialization error
67    Serde(burn_core::record::serde::error::Error),
68}
69
70impl From<std::io::Error> for PytorchError {
71    fn from(e: std::io::Error) -> Self {
72        PytorchError::Io(e)
73    }
74}
75
76impl From<PickleError> for PytorchError {
77    fn from(e: PickleError) -> Self {
78        PytorchError::Pickle(e)
79    }
80}
81
82impl From<zip::result::ZipError> for PytorchError {
83    fn from(e: zip::result::ZipError) -> Self {
84        PytorchError::Zip(e)
85    }
86}
87
88impl From<burn_core::record::serde::error::Error> for PytorchError {
89    fn from(e: burn_core::record::serde::error::Error) -> Self {
90        PytorchError::Serde(e)
91    }
92}
93
94impl std::fmt::Display for PytorchError {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        match self {
97            PytorchError::Io(e) => write!(f, "IO error: {}", e),
98            PytorchError::Pickle(e) => write!(
99                f,
100                "Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.",
101                e
102            ),
103            PytorchError::Zip(e) => write!(f, "Zip archive error: {}", e),
104            PytorchError::Tar(e) => write!(f, "TAR archive error: {}", e),
105            PytorchError::InvalidFormat(msg) => write!(f, "Invalid PyTorch file format: {}", msg),
106            PytorchError::KeyNotFound(key) => write!(
107                f,
108                "Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.",
109                key
110            ),
111            PytorchError::Serde(e) => write!(f, "Serde deserialization error: {}", e),
112        }
113    }
114}
115
116impl std::error::Error for PytorchError {}
117
118type Result<T> = std::result::Result<T, PytorchError>;
119
120/// Metadata about a PyTorch file
121///
122/// Contains information about the file format, version, and other properties
123/// that can be useful for debugging or compatibility checking.
124#[derive(Debug, Clone)]
125pub struct PytorchMetadata {
126    /// Format version (e.g., "1.0" for modern ZIP format)
127    pub format_version: Option<String>,
128    /// File format type (ZIP, Legacy, or Pickle)
129    pub format_type: FileFormat,
130    /// Byte order (endianness) - currently only LittleEndian is supported
131    pub byte_order: ByteOrder,
132    /// Whether the file has storage alignment information
133    pub has_storage_alignment: bool,
134    /// PyTorch version that saved the file (if available)
135    pub pytorch_version: Option<String>,
136    /// Number of tensors in the file
137    pub tensor_count: usize,
138    /// Total size of tensor data in bytes (if available)
139    pub total_data_size: Option<usize>,
140}
141
142impl PytorchMetadata {
143    /// Check if this is a modern format file (ZIP-based, PyTorch 1.6+)
144    pub fn is_modern_format(&self) -> bool {
145        matches!(self.format_type, FileFormat::Zip)
146    }
147
148    /// Check if this is a legacy format file (PyTorch 0.1.10 - 1.5)
149    pub fn is_legacy_format(&self) -> bool {
150        matches!(self.format_type, FileFormat::Legacy)
151    }
152}
153
154/// File format type
155#[derive(Debug, Clone, PartialEq)]
156pub enum FileFormat {
157    /// ZIP-based format (PyTorch 1.6+)
158    Zip,
159    /// TAR-based format (older torchvision models)
160    Tar,
161    /// Legacy format (PyTorch 0.1.10 - 1.5)
162    Legacy,
163    /// Simple pickle file
164    Pickle,
165}
166
167/// Byte order (endianness)
168#[derive(Debug, Clone, PartialEq)]
169pub enum ByteOrder {
170    LittleEndian,
171    BigEndian,
172}
173
174/// PyTorch checkpoint reader
175///
176/// This is the main interface for reading PyTorch checkpoint files (.pt/.pth).
177/// It supports multiple PyTorch formats including modern ZIP-based format (1.6+),
178/// legacy format (0.1.10-1.5), and simple pickle files.
179///
180/// # Example
181/// ```rust,no_run
182/// # use burn_store::pytorch::PytorchReader;
183/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
184/// // Load a checkpoint file
185/// let reader = PytorchReader::new("model.pt")?;
186///
187/// // Get tensor names
188/// let keys = reader.keys();
189///
190/// // Access a specific tensor
191/// if let Some(tensor) = reader.get("conv1.weight") {
192///     let data = tensor.to_data(); // Materializes the tensor
193/// }
194///
195/// // Check file metadata
196/// println!("Format: {:?}", reader.metadata().format_type);
197/// println!("Tensor count: {}", reader.metadata().tensor_count);
198/// # Ok(())
199/// # }
200/// ```
201pub struct PytorchReader {
202    tensors: HashMap<String, TensorSnapshot>,
203    metadata: PytorchMetadata,
204}
205
206impl PytorchReader {
207    /// Load a PyTorch checkpoint file
208    ///
209    /// # Arguments
210    /// * `path` - Path to the PyTorch file (.pt or .pth)
211    ///
212    /// # Returns
213    /// A `PytorchReader` with lazy-loaded tensors and metadata
214    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
215        let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?;
216        Ok(Self { tensors, metadata })
217    }
218
219    /// Load a PyTorch checkpoint with a specific top-level key
220    ///
221    /// Many PyTorch checkpoints store the model weights under a specific key
222    /// like "state_dict", "model", or "model_state_dict".
223    ///
224    /// # Arguments
225    /// * `path` - Path to the PyTorch file
226    /// * `key` - Top-level key to extract (e.g., "state_dict")
227    ///
228    /// # Example
229    /// ```rust,no_run
230    /// # use burn_store::pytorch::PytorchReader;
231    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
232    /// let reader = PytorchReader::with_top_level_key("checkpoint.pt", "state_dict")?;
233    /// # Ok(())
234    /// # }
235    /// ```
236    pub fn with_top_level_key<P: AsRef<Path>>(path: P, key: &str) -> Result<Self> {
237        let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?;
238        Ok(Self { tensors, metadata })
239    }
240
241    /// Load from a reader
242    ///
243    /// This method is useful when loading from non-file sources like memory buffers.
244    /// Note: Metadata detection is limited when loading from a reader.
245    ///
246    /// # Arguments
247    /// * `reader` - Any type implementing `Read`
248    /// * `top_level_key` - Optional key to extract
249    pub fn from_reader<R: Read>(reader: R, top_level_key: Option<&str>) -> Result<Self> {
250        // For reader-based loading, we don't have full metadata access
251        let tensors = load_from_reader(reader, top_level_key)?;
252        let metadata = PytorchMetadata {
253            format_version: None,
254            format_type: FileFormat::Pickle, // Default assumption
255            byte_order: ByteOrder::LittleEndian,
256            has_storage_alignment: false,
257            pytorch_version: None,
258            tensor_count: tensors.len(),
259            total_data_size: None,
260        };
261        Ok(Self { tensors, metadata })
262    }
263
264    /// Get all tensor names
265    pub fn keys(&self) -> Vec<String> {
266        self.tensors.keys().cloned().collect()
267    }
268
269    /// Get a tensor by name
270    pub fn get(&self, name: &str) -> Option<&TensorSnapshot> {
271        self.tensors.get(name)
272    }
273
274    /// Get all tensors
275    pub fn tensors(&self) -> &HashMap<String, TensorSnapshot> {
276        &self.tensors
277    }
278
279    /// Take ownership of all tensors
280    pub fn into_tensors(self) -> HashMap<String, TensorSnapshot> {
281        self.tensors
282    }
283
284    /// Get metadata about the loaded file
285    ///
286    /// Provides information about the file format, version, endianness, etc.
287    pub fn metadata(&self) -> &PytorchMetadata {
288        &self.metadata
289    }
290
291    /// Get the number of tensors in the file
292    pub fn len(&self) -> usize {
293        self.tensors.len()
294    }
295
296    /// Check if the file contains no tensors
297    pub fn is_empty(&self) -> bool {
298        self.tensors.is_empty()
299    }
300
301    /// Read raw pickle data from a PyTorch file
302    ///
303    /// This is useful for extracting configuration or metadata that isn't tensor data.
304    /// Returns a simplified JSON-like structure that can be easily converted to other formats.
305    ///
306    /// # Arguments
307    /// * `path` - Path to the PyTorch file
308    /// * `top_level_key` - Optional key to extract from the top-level dictionary
309    ///
310    /// # Returns
311    /// A `PickleValue` representing the pickle data structure
312    pub fn read_pickle_data<P: AsRef<Path>>(
313        path: P,
314        top_level_key: Option<&str>,
315    ) -> Result<PickleValue> {
316        read_pickle_as_value(path.as_ref(), top_level_key)
317    }
318
319    /// Load and deserialize configuration data from a PyTorch file
320    ///
321    /// This method reads configuration or metadata stored in PyTorch checkpoint files
322    /// and deserializes it into the specified type. It's particularly useful for
323    /// extracting model configurations that might be saved alongside model weights.
324    ///
325    /// # Arguments
326    /// * `path` - Path to the PyTorch file (.pt or .pth)
327    /// * `top_level_key` - Optional key to extract specific data within the pickle file.
328    ///   If `None`, the entire content is deserialized.
329    ///
330    /// # Type Parameters
331    /// * `D` - The target type to deserialize into. Must implement `DeserializeOwned`.
332    ///
333    /// # Returns
334    /// A `Result` containing the deserialized configuration data, or an `Error` if
335    /// reading or deserialization fails.
336    ///
337    /// # Example
338    /// ```rust,no_run
339    /// # use burn_store::pytorch::PytorchReader;
340    /// # use serde::Deserialize;
341    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
342    /// #[derive(Debug, Deserialize)]
343    /// struct ModelConfig {
344    ///     hidden_size: usize,
345    ///     num_layers: usize,
346    /// }
347    ///
348    /// let config: ModelConfig = PytorchReader::load_config("model.pth", Some("config"))?;
349    /// # Ok(())
350    /// # }
351    /// ```
352    pub fn load_config<D, P>(path: P, top_level_key: Option<&str>) -> Result<D>
353    where
354        D: DeserializeOwned,
355        P: AsRef<Path>,
356    {
357        // Read the PyTorch file and extract the pickle data
358        let pickle_value = Self::read_pickle_data(path, top_level_key)?;
359
360        // Convert PickleValue to NestedValue
361        let nested_value = convert_pickle_to_nested_value(pickle_value)?;
362
363        // Create a deserializer with the default adapter
364        let deserializer = Deserializer::<DefaultAdapter>::new(nested_value, false);
365
366        // Deserialize the nested value into the target type
367        let value = D::deserialize(deserializer)?;
368        Ok(value)
369    }
370}
371
372/// Simplified representation of pickle data
373///
374/// This enum provides a JSON-like structure that's easier to work with
375/// than the internal pickle Object type.
376#[derive(Debug, Clone, PartialEq)]
377pub enum PickleValue {
378    /// None/null value
379    None,
380    /// Boolean value
381    Bool(bool),
382    /// Integer value
383    Int(i64),
384    /// Floating point value
385    Float(f64),
386    /// String value
387    String(String),
388    /// List/array of values
389    List(Vec<PickleValue>),
390    /// Dictionary/map of string keys to values
391    Dict(HashMap<String, PickleValue>),
392    /// Binary data
393    Bytes(Vec<u8>),
394}
395
396/// Internal function to load a PyTorch file with metadata
397fn load_pytorch_file_with_metadata(
398    path: &Path,
399    top_level_key: Option<&str>,
400) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
401    // First, try to read as a zip file
402    if let Ok(file) = File::open(path)
403        && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
404    {
405        // PyTorch saves the main data in various locations within the zip
406        let mut pickle_data = Vec::new();
407        let mut pickle_found = false;
408
409        // Try different common pickle file locations
410        let possible_pickle_paths = [
411            "data.pkl",
412            "archive/data.pkl",
413            // Look for any .pkl file in the root or first-level directories
414        ];
415
416        for pickle_path in &possible_pickle_paths {
417            if archive.by_name(pickle_path).is_ok() {
418                let mut pickle_file = archive.by_name(pickle_path)?;
419                pickle_file.read_to_end(&mut pickle_data)?;
420                pickle_found = true;
421                break;
422            }
423        }
424
425        // If not found in standard locations, search for any .pkl file
426        if !pickle_found {
427            for i in 0..archive.len() {
428                let file = archive.by_index(i)?;
429                let name = file.name().to_string();
430                drop(file); // Release the borrow
431
432                if name.ends_with("data.pkl") {
433                    let mut file = archive.by_index(i)?;
434                    file.read_to_end(&mut pickle_data)?;
435                    pickle_found = true;
436                    break;
437                }
438            }
439        }
440
441        if !pickle_found {
442            return Err(PytorchError::InvalidFormat(
443                "No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl".to_string(),
444            ));
445        }
446
447        // Check for format version (optional)
448        let format_version = if let Ok(mut version_file) = archive.by_name(".format_version") {
449            let mut version_data = Vec::new();
450            version_file.read_to_end(&mut version_data)?;
451            let version_str = String::from_utf8_lossy(&version_data);
452            let version = version_str.trim().to_string();
453            Some(version)
454        } else {
455            None
456        };
457
458        // Check for byteorder file to detect endianness
459        let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name("byteorder") {
460            let mut byteorder_data = Vec::new();
461            byteorder_file.read_to_end(&mut byteorder_data)?;
462            let byteorder_str = String::from_utf8_lossy(&byteorder_data);
463            byteorder_str.trim() == "big"
464        } else {
465            false // Default to little-endian if no byteorder file
466        };
467
468        if is_big_endian {
469            // Big-endian files are not yet supported as they require different byte order conversion
470            // TODO: To support big-endian files, we need to:
471            // 1. Pass endianness info through to pickle_reader
472            // 2. Use from_be_bytes instead of from_le_bytes for tensor data
473            // 3. Handle byte swapping for all numeric types (f32, f64, i32, etc.)
474            return Err(PytorchError::InvalidFormat(
475                "Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.".to_string()
476            ));
477        }
478
479        // Check for storage alignment file
480        let has_storage_alignment = archive.by_name(".storage_alignment").is_ok();
481
482        // Check for PyTorch version (if saved)
483        let pytorch_version = if let Ok(mut version_file) = archive.by_name("version") {
484            let mut version_data = Vec::new();
485            version_file.read_to_end(&mut version_data)?;
486            Some(String::from_utf8_lossy(&version_data).trim().to_string())
487        } else {
488            None
489        };
490
491        // Create a lazy data source instead of loading all data upfront
492        let data_source = Arc::new(LazyDataSource::from_zip(path)?);
493
494        // Calculate total data size without loading
495        let mut total_data_size = 0usize;
496        for i in 0..archive.len() {
497            let file = archive.by_index(i)?;
498            let name = file.name();
499
500            // Look for data files - they can be in various locations
501            let is_data_file = (name.contains("/data/")
502                || name.starts_with("data/")
503                || name.starts_with("archive/data/"))
504                && !name.ends_with(".pkl")
505                && !name.ends_with("/");
506
507            if is_data_file {
508                total_data_size += file.size() as usize;
509            }
510        }
511
512        // Parse the pickle data with lazy data source
513        let mut pickle_reader = BufReader::new(pickle_data.as_slice());
514        let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
515
516        // Extract tensors with their data
517        let tensors = extract_tensors_with_data(obj, top_level_key)?;
518
519        // Create metadata
520        let metadata = PytorchMetadata {
521            format_version,
522            format_type: FileFormat::Zip,
523            byte_order: if is_big_endian {
524                ByteOrder::BigEndian
525            } else {
526                ByteOrder::LittleEndian
527            },
528            has_storage_alignment,
529            pytorch_version,
530            tensor_count: tensors.len(),
531            total_data_size: Some(total_data_size),
532        };
533
534        return Ok((tensors, metadata));
535    }
536
537    // If not a zip or zip reading failed, try TAR format
538    if is_tar_file(path) {
539        return load_tar_pytorch_file_with_metadata(path, top_level_key);
540    }
541
542    // Try reading as a plain pickle file
543    let mut file = File::open(path)?;
544
545    // Check for PyTorch legacy format (starts with magic number as pickled integer)
546    let mut header = [0u8; 15];
547    // Use read() instead of read_exact() to handle files smaller than 15 bytes
548    let bytes_read = file.read(&mut header)?;
549    file.seek(std::io::SeekFrom::Start(0))?;
550
551    // Only check for legacy format if we have enough bytes
552    // PyTorch legacy format detection (PyTorch 0.1.10 - 1.3)
553    // Reference: https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L65
554    //
555    // These files use sequential pickle streams with metadata before the actual data.
556    // Format structure:
557    //   1. Magic number (0x1950a86a20f9469cfc6c) stored as LONG1 pickle
558    //   2. Protocol version (e.g., 1001)
559    //   3. System info dict (protocol_version, little_endian, type_sizes)
560    //   4. Actual model data (state_dict or full model)
561    //   5. Storage keys list (pickle)
562    //   6. Raw binary data for each storage
563    //
564    // The pattern is: 0x80 0x02 0x8a 0x0a (PROTO 2, LONG1 with 10 bytes)
565    // followed by 10 bytes of magic number (little-endian), then 0x2e (STOP)
566    let is_legacy_format = bytes_read >= 15
567        && header[0] == 0x80  // PROTO opcode
568        && header[1] == 0x02  // Protocol version 2
569        && header[2] == 0x8a  // LONG1 opcode
570        && header[3] == 0x0a  // 10 bytes follow
571        // Magic number 0x1950a86a20f9469cfc6c in little-endian
572        && header[4] == 0x6c
573        && header[5] == 0xfc
574        && header[6] == 0x9c
575        && header[7] == 0x46
576        && header[8] == 0xf9
577        && header[9] == 0x20
578        && header[10] == 0x6a
579        && header[11] == 0xa8
580        && header[12] == 0x50
581        && header[13] == 0x19
582        && header[14] == 0x2e; // STOP opcode
583
584    if is_legacy_format {
585        return load_legacy_pytorch_file_with_metadata(path, top_level_key);
586    }
587
588    // Standard pickle file
589    // This might be a pickle with tensor references, so we need to handle that case
590    // For plain pickle files without a separate data section, we can't use lazy loading
591    // so we'll just create empty placeholder tensors for the structure
592    let file = File::open(path)?;
593    let mut reader = BufReader::new(file);
594
595    // Try reading without data source first
596    match read_pickle(&mut reader) {
597        Ok(obj) => {
598            let tensors = extract_tensors_with_data(obj, top_level_key)?;
599            let tensor_count = tensors.len();
600            Ok((
601                tensors,
602                PytorchMetadata {
603                    format_version: None,
604                    format_type: FileFormat::Pickle,
605                    byte_order: ByteOrder::LittleEndian,
606                    has_storage_alignment: false,
607                    pytorch_version: None,
608                    tensor_count,
609                    total_data_size: None,
610                },
611            ))
612        }
613        Err(e)
614            if e.to_string()
615                .contains("Cannot load tensor data without a data source") =>
616        {
617            // This pickle file contains tensor data but we're trying to read it without
618            // providing a data source. This shouldn't happen in normal usage as PyTorch
619            // files with actual tensor data should be in ZIP or legacy format.
620            Err(PytorchError::InvalidFormat(
621                "Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.".to_string()
622            ))
623        }
624        Err(e) => Err(PytorchError::Pickle(e)),
625    }
626}
627
628/// Load from a reader
629fn load_from_reader<R: Read>(
630    reader: R,
631    top_level_key: Option<&str>,
632) -> Result<HashMap<String, TensorSnapshot>> {
633    let mut buf_reader = BufReader::new(reader);
634
635    // Try reading without data source
636    match read_pickle(&mut buf_reader) {
637        Ok(obj) => extract_tensors_with_data(obj, top_level_key),
638        Err(e)
639            if e.to_string()
640                .contains("Cannot load tensor data without a data source") =>
641        {
642            // This reader contains tensor data but we can't load it without a file path
643            Err(PytorchError::InvalidFormat(
644                "Reader contains tensor data but no data source is available. Use file-based loading instead.".to_string()
645            ))
646        }
647        Err(e) => Err(PytorchError::Pickle(e)),
648    }
649}
650
651/// Extract tensors from a parsed pickle object
652fn extract_tensors_with_data(
653    obj: Object,
654    top_level_key: Option<&str>,
655) -> Result<HashMap<String, TensorSnapshot>> {
656    let dict = match obj {
657        Object::Dict(dict) => {
658            if let Some(key) = top_level_key {
659                // Extract the nested dictionary if a top-level key is specified
660                match dict.get(key) {
661                    Some(Object::Dict(nested)) => nested.clone(),
662                    _ => {
663                        return Err(PytorchError::KeyNotFound(format!(
664                            "Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}",
665                            key,
666                            dict.keys().collect::<Vec<_>>()
667                        )));
668                    }
669                }
670            } else {
671                dict
672            }
673        }
674        _ => {
675            return Err(PytorchError::InvalidFormat(
676                "Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.".to_string(),
677            ));
678        }
679    };
680
681    let mut tensors = HashMap::new();
682    let mut path = Vec::new();
683    extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors);
684    Ok(tensors)
685}
686
687/// Recursively extract tensors from an object
688fn extract_tensors_recursive<'a>(
689    obj: &'a Object,
690    path: &mut Vec<&'a str>,
691    tensors: &mut HashMap<String, TensorSnapshot>,
692) {
693    match obj {
694        Object::Dict(dict) => {
695            for (key, value) in dict {
696                path.push(key);
697                extract_tensors_recursive(value, path, tensors);
698                path.pop();
699            }
700        }
701        Object::TorchParam(snapshot) => {
702            // The TensorSnapshot already contains the data loading closure
703            // Only allocate the string here when we actually insert
704            tensors.insert(path.join("."), snapshot.clone());
705        }
706        _ => {}
707    }
708}
709
710/// Load a legacy PyTorch file with metadata
711fn load_legacy_pytorch_file_with_metadata(
712    path: &Path,
713    top_level_key: Option<&str>,
714) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
715    let file = File::open(path)?;
716    let mut reader = BufReader::new(file);
717
718    // Skip metadata pickles
719    // 1. Magic number
720    let _ = read_pickle(&mut reader).map_err(|e| {
721        PytorchError::InvalidFormat(format!(
722            "Failed to read magic number from legacy format: {}",
723            e
724        ))
725    })?;
726
727    // 2. Protocol version
728    let _ = read_pickle(&mut reader).map_err(|e| {
729        PytorchError::InvalidFormat(format!(
730            "Failed to read protocol version from legacy format: {}",
731            e
732        ))
733    })?;
734
735    // 3. System info
736    let _ = read_pickle(&mut reader).map_err(|e| {
737        PytorchError::InvalidFormat(format!(
738            "Failed to read system info from legacy format: {}",
739            e
740        ))
741    })?;
742
743    // Save position before main pickle
744    let main_pickle_pos = reader.stream_position()?;
745
746    // 4. Skip main object - it might contain tensors so we can't parse it yet
747    // We'll re-read it with a data source later
748    use crate::pytorch::pickle_reader::skip_pickle;
749    skip_pickle(&mut reader).map_err(|e| {
750        PytorchError::InvalidFormat(format!(
751            "Failed to skip main object in legacy format: {}",
752            e
753        ))
754    })?;
755
756    // 5. Storage keys list (sorted keys as written by PyTorch)
757    let storage_keys = match read_pickle(&mut reader) {
758        Ok(Object::List(keys)) => keys
759            .into_iter()
760            .filter_map(|obj| match obj {
761                Object::String(s) => Some(s),
762                _ => None,
763            })
764            .collect::<Vec<_>>(),
765        _ => vec![],
766    };
767
768    // 6. Raw binary data starts here
769    let data_start_pos = reader.stream_position()?;
770    let file_size = reader.seek(SeekFrom::End(0))?;
771    let data_size = file_size - data_start_pos;
772
773    // Create a lazy data source for legacy multi-storage format
774    let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage(
775        path,
776        data_start_pos,
777        data_size,
778    ));
779
780    // Set storage keys BEFORE parsing the main pickle
781    // This is critical because track_storage_usage() is called during parsing
782    // and it needs storage_keys to build the storage map
783    if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source
784        && !storage_keys.is_empty()
785    {
786        let source = source
787            .lock()
788            .unwrap_or_else(|poisoned| poisoned.into_inner());
789        source.set_storage_keys(storage_keys.clone());
790    }
791
792    // Now re-read the main pickle with lazy data source
793    reader.seek(SeekFrom::Start(main_pickle_pos))?;
794    let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?;
795
796    // Extract tensors normally
797    let tensors = extract_tensors_with_data(main_obj, top_level_key)?;
798
799    // Create metadata for legacy format
800    let metadata = PytorchMetadata {
801        format_version: None, // Legacy format doesn't have version files
802        format_type: FileFormat::Legacy,
803        byte_order: ByteOrder::LittleEndian, // Legacy format is little-endian
804        has_storage_alignment: false,
805        pytorch_version: None, // Could parse from protocol version, but not reliable
806        tensor_count: tensors.len(),
807        total_data_size: Some(data_size as usize),
808    };
809
810    Ok((tensors, metadata))
811}
812
813/// Check if a file is a TAR archive
814fn is_tar_file(path: &Path) -> bool {
815    if let Ok(mut file) = File::open(path) {
816        // TAR files have "ustar" magic at offset 257
817        let mut header = [0u8; 263];
818        if file.read_exact(&mut header).is_ok() {
819            // Check for "ustar" magic at offset 257
820            return &header[257..262] == b"ustar";
821        }
822    }
823    false
824}
825
826/// Load a TAR format PyTorch file with metadata
827fn load_tar_pytorch_file_with_metadata(
828    path: &Path,
829    top_level_key: Option<&str>,
830) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
831    use tar::Archive;
832
833    let file = File::open(path)?;
834    let mut archive = Archive::new(BufReader::new(file));
835
836    // Extract the main entries from the TAR archive
837    let mut sys_info_data: Option<Vec<u8>> = None;
838    let mut pickle_data: Option<Vec<u8>> = None;
839    let mut storages_data: Option<Vec<u8>> = None;
840
841    for entry in archive.entries().map_err(PytorchError::Tar)? {
842        let mut entry = entry.map_err(PytorchError::Tar)?;
843        let entry_path = entry
844            .path()
845            .map_err(PytorchError::Tar)?
846            .to_string_lossy()
847            .to_string();
848
849        // Skip PAX headers
850        if entry_path.contains("@PaxHeader") {
851            continue;
852        }
853
854        // Normalize path (remove ./ prefix if present)
855        let normalized = entry_path.trim_start_matches("./");
856
857        match normalized {
858            "sys_info" => {
859                let mut data = Vec::new();
860                entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
861                sys_info_data = Some(data);
862            }
863            "pickle" => {
864                let mut data = Vec::new();
865                entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
866                pickle_data = Some(data);
867            }
868            "storages" => {
869                let mut data = Vec::new();
870                entry.read_to_end(&mut data).map_err(PytorchError::Tar)?;
871                storages_data = Some(data);
872            }
873            _ => {}
874        }
875    }
876
877    // Validate required entries
878    let pickle_data = pickle_data.ok_or_else(|| {
879        PytorchError::InvalidFormat("TAR file missing 'pickle' entry".to_string())
880    })?;
881    let storages_data = storages_data.ok_or_else(|| {
882        PytorchError::InvalidFormat("TAR file missing 'storages' entry".to_string())
883    })?;
884
885    // Parse sys_info to check endianness
886    let is_little_endian = if let Some(ref data) = sys_info_data {
887        parse_tar_sys_info(data)?
888    } else {
889        true // Default to little-endian
890    };
891
892    if !is_little_endian {
893        return Err(PytorchError::InvalidFormat(
894            "Big-endian TAR PyTorch files are not supported".to_string(),
895        ));
896    }
897
898    // Create TarSource for lazy loading
899    let data_source = Arc::new(LazyDataSource::from_tar(&storages_data)?);
900
901    // Parse the pickle (OrderedDict of name -> storage_key)
902    let mut pickle_reader = BufReader::new(pickle_data.as_slice());
903    let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
904
905    // Extract tensors
906    let tensors = extract_tensors_with_data(obj, top_level_key)?;
907
908    let metadata = PytorchMetadata {
909        format_version: None,
910        format_type: FileFormat::Tar,
911        byte_order: ByteOrder::LittleEndian,
912        has_storage_alignment: false,
913        pytorch_version: None,
914        tensor_count: tensors.len(),
915        total_data_size: Some(storages_data.len()),
916    };
917
918    Ok((tensors, metadata))
919}
920
921/// Parse sys_info pickle from TAR format to extract endianness
922fn parse_tar_sys_info(data: &[u8]) -> Result<bool> {
923    let mut reader = BufReader::new(data);
924    let obj = read_pickle(&mut reader)?;
925
926    if let Object::Dict(dict) = obj
927        && let Some(Object::Bool(little_endian)) = dict.get("little_endian")
928    {
929        return Ok(*little_endian);
930    }
931
932    Ok(true) // Default assumption
933}
934
935/// Read pickle data from a PyTorch file as a simplified value
936fn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result<PickleValue> {
937    use crate::pytorch::lazy_data::LazyDataSource;
938    use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data};
939    use std::sync::Arc;
940
941    // Try to open as ZIP first
942    if let Ok(file) = File::open(path)
943        && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
944    {
945        // Read pickle data from ZIP
946        let mut pickle_data = Vec::new();
947
948        // Try standard locations
949        for pickle_path in &["data.pkl", "archive/data.pkl"] {
950            if let Ok(mut pickle_file) = archive.by_name(pickle_path) {
951                pickle_file.read_to_end(&mut pickle_data)?;
952                break;
953            }
954        }
955
956        // If not found, search for any .pkl file
957        if pickle_data.is_empty() {
958            for i in 0..archive.len() {
959                let file = archive.by_index(i)?;
960                let name = file.name().to_string();
961                drop(file);
962
963                if name.ends_with("data.pkl") {
964                    let mut file = archive.by_index(i)?;
965                    file.read_to_end(&mut pickle_data)?;
966                    break;
967                }
968            }
969        }
970
971        if !pickle_data.is_empty() {
972            // Create a data source for the ZIP file
973            let data_source = LazyDataSource::from_zip(path)?;
974            let data_source_arc = Arc::new(data_source);
975
976            let mut reader = BufReader::new(pickle_data.as_slice());
977            let obj = read_pickle_with_data(&mut reader, data_source_arc)?;
978            return convert_object_to_value(obj, top_level_key);
979        }
980    }
981
982    // Try as plain pickle file
983    // First attempt without data source (for pure metadata files)
984    let file = File::open(path)?;
985    let mut reader = BufReader::new(file);
986
987    match read_pickle(&mut reader) {
988        Ok(obj) => convert_object_to_value(obj, top_level_key),
989        Err(e)
990            if e.to_string()
991                .contains("Cannot load tensor data without a data source") =>
992        {
993            // File contains tensors, need to use full PytorchReader
994            // Use the regular reader to get proper tensor handling
995            let reader = PytorchReader::new(path)?;
996
997            // Convert tensors to PickleValue structure
998            let mut result = std::collections::HashMap::new();
999            for key in reader.keys() {
1000                // For pickle value extraction, we just need the structure, not the actual data
1001                result.insert(
1002                    key.clone(),
1003                    PickleValue::String(format!("<Tensor:{}>", key)),
1004                );
1005            }
1006
1007            if let Some(key) = top_level_key {
1008                Ok(PickleValue::Dict(
1009                    [(key.to_string(), PickleValue::Dict(result))]
1010                        .into_iter()
1011                        .collect(),
1012                ))
1013            } else {
1014                Ok(PickleValue::Dict(result))
1015            }
1016        }
1017        Err(e) => Err(PytorchError::Pickle(e)),
1018    }
1019}
1020
1021/// Convert internal Object to public PickleValue
1022fn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result<PickleValue> {
1023    use crate::pytorch::pickle_reader::Object;
1024
1025    // If a top-level key is specified, extract it first
1026    if let Some(key) = top_level_key
1027        && let Object::Dict(dict) = obj
1028    {
1029        if let Some(value) = dict.get(key) {
1030            return object_to_pickle_value(value.clone());
1031        } else {
1032            return Err(PytorchError::KeyNotFound(format!(
1033                "Key '{}' not found in pickle data",
1034                key
1035            )));
1036        }
1037    }
1038
1039    object_to_pickle_value(obj)
1040}
1041
1042/// Convert Object to PickleValue
1043fn object_to_pickle_value(obj: Object) -> Result<PickleValue> {
1044    use crate::pytorch::pickle_reader::Object;
1045
1046    Ok(match obj {
1047        Object::None => PickleValue::None,
1048        Object::Bool(b) => PickleValue::Bool(b),
1049        Object::Int(i) => PickleValue::Int(i),
1050        Object::Float(f) => PickleValue::Float(f),
1051        Object::String(s) => PickleValue::String(s),
1052        Object::Persistent(data) => {
1053            // Persistent data is raw bytes
1054            PickleValue::Bytes(data)
1055        }
1056        Object::PersistentTuple(tuple) => {
1057            // Convert persistent tuples to lists
1058            let mut values = Vec::new();
1059            for item in tuple {
1060                values.push(object_to_pickle_value(item)?);
1061            }
1062            PickleValue::List(values)
1063        }
1064        Object::List(list) => {
1065            let mut values = Vec::new();
1066            for item in list {
1067                values.push(object_to_pickle_value(item)?);
1068            }
1069            PickleValue::List(values)
1070        }
1071        Object::Dict(dict) => {
1072            let mut map = HashMap::new();
1073            for (k, v) in dict {
1074                map.insert(k, object_to_pickle_value(v)?);
1075            }
1076            PickleValue::Dict(map)
1077        }
1078        Object::Tuple(tuple) => {
1079            // Convert tuples to lists in the public API
1080            let mut values = Vec::new();
1081            for item in tuple {
1082                values.push(object_to_pickle_value(item)?);
1083            }
1084            PickleValue::List(values)
1085        }
1086        Object::TorchParam(_) => {
1087            // Skip tensor parameters in config reading
1088            PickleValue::None
1089        }
1090        Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => {
1091            // Complex objects are represented as None for simplicity
1092            PickleValue::None
1093        }
1094    })
1095}
1096
1097/// Convert PickleValue to NestedValue for deserialization
1098fn convert_pickle_to_nested_value(value: PickleValue) -> Result<NestedValue> {
1099    Ok(match value {
1100        PickleValue::None => NestedValue::Default(None),
1101        PickleValue::Bool(b) => NestedValue::Bool(b),
1102        PickleValue::Int(i) => NestedValue::I64(i),
1103        PickleValue::Float(f) => NestedValue::F64(f),
1104        PickleValue::String(s) => NestedValue::String(s),
1105        PickleValue::List(list) => {
1106            let mut vec = Vec::new();
1107            for item in list {
1108                vec.push(convert_pickle_to_nested_value(item)?);
1109            }
1110            NestedValue::Vec(vec)
1111        }
1112        PickleValue::Dict(dict) => {
1113            let mut map = HashMap::new();
1114            for (k, v) in dict {
1115                map.insert(k, convert_pickle_to_nested_value(v)?);
1116            }
1117            NestedValue::Map(map)
1118        }
1119        PickleValue::Bytes(data) => {
1120            // Convert bytes to a list of u8 values
1121            let vec: Vec<NestedValue> = data.into_iter().map(NestedValue::U8).collect();
1122            NestedValue::Vec(vec)
1123        }
1124    })
1125}