burn_store/pytorch/
reader.rs

1//! PyTorch file reader implementation.
2//!
3//! This module provides support for reading PyTorch checkpoint files (.pt/.pth).
4//!
5//! # Supported Formats
6//!
7//! ## 1. Modern ZIP Format (PyTorch 1.6+)
8//! Files are ZIP archives containing:
9//! - `data.pkl` or `archive/data.pkl`: Pickled tensor metadata
10//! - `data/` directory: Binary tensor data files
11//!
12//! ## 2. Legacy Pickle Format (PyTorch 0.1.10 - 1.5)
13//! Sequential pickle streams with the structure:
14//! - Magic number pickle (0x1950a86a20f9469cfc6c)
15//! - Protocol version pickle (e.g., 1001)
16//! - System info pickle (endianness, type sizes)
17//! - Model data pickle (state_dict or full model)
18//!
19//! ## 3. Simple Pickle Format
20//! Direct pickle file with a dictionary at the root, commonly used for
21//! manually saved state_dicts.
22//!
23//! # Compatibility
24//!
25//! The reader handles backward compatibility by detecting the file format
26//! automatically. Files from PyTorch 0.1.10 through current versions are
27//! supported, though full model saves (vs state_dict) may have limitations
28//! as they contain Python code references.
29
30use crate::TensorSnapshot;
31use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33use burn_core::record::serde::{adapter::DefaultAdapter, data::NestedValue, de::Deserializer};
34use serde::de::DeserializeOwned;
35use std::collections::HashMap;
36use std::fs::File;
37use std::io::{BufReader, Read, Seek, SeekFrom};
38use std::path::Path;
39
40use super::lazy_data::LazyDataSource;
41use super::pickle_reader::{Object, PickleError, read_pickle, read_pickle_with_data};
42use std::sync::Arc;
43
44/// Error type for PyTorch file operations
45#[derive(Debug)]
46pub enum PytorchError {
47    /// IO error
48    Io(std::io::Error),
49    /// Pickle parsing error
50    Pickle(PickleError),
51    /// Zip archive error
52    Zip(zip::result::ZipError),
53    /// Invalid file format
54    InvalidFormat(String),
55    /// Key not found
56    KeyNotFound(String),
57    /// Serde deserialization error
58    Serde(burn_core::record::serde::error::Error),
59}
60
61impl From<std::io::Error> for PytorchError {
62    fn from(e: std::io::Error) -> Self {
63        PytorchError::Io(e)
64    }
65}
66
67impl From<PickleError> for PytorchError {
68    fn from(e: PickleError) -> Self {
69        PytorchError::Pickle(e)
70    }
71}
72
73impl From<zip::result::ZipError> for PytorchError {
74    fn from(e: zip::result::ZipError) -> Self {
75        PytorchError::Zip(e)
76    }
77}
78
79impl From<burn_core::record::serde::error::Error> for PytorchError {
80    fn from(e: burn_core::record::serde::error::Error) -> Self {
81        PytorchError::Serde(e)
82    }
83}
84
85impl std::fmt::Display for PytorchError {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            PytorchError::Io(e) => write!(f, "IO error: {}", e),
89            PytorchError::Pickle(e) => write!(
90                f,
91                "Pickle parsing error: {}. This may indicate an unsupported PyTorch file format or corrupted file.",
92                e
93            ),
94            PytorchError::Zip(e) => write!(f, "Zip archive error: {}", e),
95            PytorchError::InvalidFormat(msg) => write!(f, "Invalid PyTorch file format: {}", msg),
96            PytorchError::KeyNotFound(key) => write!(
97                f,
98                "Key '{}' not found in PyTorch file. Available keys may be listed with the keys() method.",
99                key
100            ),
101            PytorchError::Serde(e) => write!(f, "Serde deserialization error: {}", e),
102        }
103    }
104}
105
106impl std::error::Error for PytorchError {}
107
108type Result<T> = std::result::Result<T, PytorchError>;
109
110/// Metadata about a PyTorch file
111///
112/// Contains information about the file format, version, and other properties
113/// that can be useful for debugging or compatibility checking.
114#[derive(Debug, Clone)]
115pub struct PytorchMetadata {
116    /// Format version (e.g., "1.0" for modern ZIP format)
117    pub format_version: Option<String>,
118    /// File format type (ZIP, Legacy, or Pickle)
119    pub format_type: FileFormat,
120    /// Byte order (endianness) - currently only LittleEndian is supported
121    pub byte_order: ByteOrder,
122    /// Whether the file has storage alignment information
123    pub has_storage_alignment: bool,
124    /// PyTorch version that saved the file (if available)
125    pub pytorch_version: Option<String>,
126    /// Number of tensors in the file
127    pub tensor_count: usize,
128    /// Total size of tensor data in bytes (if available)
129    pub total_data_size: Option<usize>,
130}
131
132impl PytorchMetadata {
133    /// Check if this is a modern format file (ZIP-based, PyTorch 1.6+)
134    pub fn is_modern_format(&self) -> bool {
135        matches!(self.format_type, FileFormat::Zip)
136    }
137
138    /// Check if this is a legacy format file (PyTorch 0.1.10 - 1.5)
139    pub fn is_legacy_format(&self) -> bool {
140        matches!(self.format_type, FileFormat::Legacy)
141    }
142}
143
144/// File format type
145#[derive(Debug, Clone, PartialEq)]
146pub enum FileFormat {
147    /// ZIP-based format (PyTorch 1.6+)
148    Zip,
149    /// Legacy format (PyTorch 0.1.10 - 1.5)
150    Legacy,
151    /// Simple pickle file
152    Pickle,
153}
154
155/// Byte order (endianness)
156#[derive(Debug, Clone, PartialEq)]
157pub enum ByteOrder {
158    LittleEndian,
159    BigEndian,
160}
161
162/// PyTorch checkpoint reader
163///
164/// This is the main interface for reading PyTorch checkpoint files (.pt/.pth).
165/// It supports multiple PyTorch formats including modern ZIP-based format (1.6+),
166/// legacy format (0.1.10-1.5), and simple pickle files.
167///
168/// # Example
169/// ```rust,no_run
170/// # use burn_store::pytorch::PytorchReader;
171/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
172/// // Load a checkpoint file
173/// let reader = PytorchReader::new("model.pt")?;
174///
175/// // Get tensor names
176/// let keys = reader.keys();
177///
178/// // Access a specific tensor
179/// if let Some(tensor) = reader.get("conv1.weight") {
180///     let data = tensor.to_data(); // Materializes the tensor
181/// }
182///
183/// // Check file metadata
184/// println!("Format: {:?}", reader.metadata().format_type);
185/// println!("Tensor count: {}", reader.metadata().tensor_count);
186/// # Ok(())
187/// # }
188/// ```
189pub struct PytorchReader {
190    tensors: HashMap<String, TensorSnapshot>,
191    metadata: PytorchMetadata,
192}
193
194impl PytorchReader {
195    /// Load a PyTorch checkpoint file
196    ///
197    /// # Arguments
198    /// * `path` - Path to the PyTorch file (.pt or .pth)
199    ///
200    /// # Returns
201    /// A `PytorchReader` with lazy-loaded tensors and metadata
202    pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
203        let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), None)?;
204        Ok(Self { tensors, metadata })
205    }
206
207    /// Load a PyTorch checkpoint with a specific top-level key
208    ///
209    /// Many PyTorch checkpoints store the model weights under a specific key
210    /// like "state_dict", "model", or "model_state_dict".
211    ///
212    /// # Arguments
213    /// * `path` - Path to the PyTorch file
214    /// * `key` - Top-level key to extract (e.g., "state_dict")
215    ///
216    /// # Example
217    /// ```rust,no_run
218    /// # use burn_store::pytorch::PytorchReader;
219    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
220    /// let reader = PytorchReader::with_top_level_key("checkpoint.pt", "state_dict")?;
221    /// # Ok(())
222    /// # }
223    /// ```
224    pub fn with_top_level_key<P: AsRef<Path>>(path: P, key: &str) -> Result<Self> {
225        let (tensors, metadata) = load_pytorch_file_with_metadata(path.as_ref(), Some(key))?;
226        Ok(Self { tensors, metadata })
227    }
228
229    /// Load from a reader
230    ///
231    /// This method is useful when loading from non-file sources like memory buffers.
232    /// Note: Metadata detection is limited when loading from a reader.
233    ///
234    /// # Arguments
235    /// * `reader` - Any type implementing `Read`
236    /// * `top_level_key` - Optional key to extract
237    pub fn from_reader<R: Read>(reader: R, top_level_key: Option<&str>) -> Result<Self> {
238        // For reader-based loading, we don't have full metadata access
239        let tensors = load_from_reader(reader, top_level_key)?;
240        let metadata = PytorchMetadata {
241            format_version: None,
242            format_type: FileFormat::Pickle, // Default assumption
243            byte_order: ByteOrder::LittleEndian,
244            has_storage_alignment: false,
245            pytorch_version: None,
246            tensor_count: tensors.len(),
247            total_data_size: None,
248        };
249        Ok(Self { tensors, metadata })
250    }
251
252    /// Get all tensor names
253    pub fn keys(&self) -> Vec<String> {
254        self.tensors.keys().cloned().collect()
255    }
256
257    /// Get a tensor by name
258    pub fn get(&self, name: &str) -> Option<&TensorSnapshot> {
259        self.tensors.get(name)
260    }
261
262    /// Get all tensors
263    pub fn tensors(&self) -> &HashMap<String, TensorSnapshot> {
264        &self.tensors
265    }
266
267    /// Take ownership of all tensors
268    pub fn into_tensors(self) -> HashMap<String, TensorSnapshot> {
269        self.tensors
270    }
271
272    /// Get metadata about the loaded file
273    ///
274    /// Provides information about the file format, version, endianness, etc.
275    pub fn metadata(&self) -> &PytorchMetadata {
276        &self.metadata
277    }
278
279    /// Get the number of tensors in the file
280    pub fn len(&self) -> usize {
281        self.tensors.len()
282    }
283
284    /// Check if the file contains no tensors
285    pub fn is_empty(&self) -> bool {
286        self.tensors.is_empty()
287    }
288
289    /// Read raw pickle data from a PyTorch file
290    ///
291    /// This is useful for extracting configuration or metadata that isn't tensor data.
292    /// Returns a simplified JSON-like structure that can be easily converted to other formats.
293    ///
294    /// # Arguments
295    /// * `path` - Path to the PyTorch file
296    /// * `top_level_key` - Optional key to extract from the top-level dictionary
297    ///
298    /// # Returns
299    /// A `PickleValue` representing the pickle data structure
300    pub fn read_pickle_data<P: AsRef<Path>>(
301        path: P,
302        top_level_key: Option<&str>,
303    ) -> Result<PickleValue> {
304        read_pickle_as_value(path.as_ref(), top_level_key)
305    }
306
307    /// Load and deserialize configuration data from a PyTorch file
308    ///
309    /// This method reads configuration or metadata stored in PyTorch checkpoint files
310    /// and deserializes it into the specified type. It's particularly useful for
311    /// extracting model configurations that might be saved alongside model weights.
312    ///
313    /// # Arguments
314    /// * `path` - Path to the PyTorch file (.pt or .pth)
315    /// * `top_level_key` - Optional key to extract specific data within the pickle file.
316    ///   If `None`, the entire content is deserialized.
317    ///
318    /// # Type Parameters
319    /// * `D` - The target type to deserialize into. Must implement `DeserializeOwned`.
320    ///
321    /// # Returns
322    /// A `Result` containing the deserialized configuration data, or an `Error` if
323    /// reading or deserialization fails.
324    ///
325    /// # Example
326    /// ```rust,no_run
327    /// # use burn_store::pytorch::PytorchReader;
328    /// # use serde::Deserialize;
329    /// # fn example() -> Result<(), Box<dyn std::error::Error>> {
330    /// #[derive(Debug, Deserialize)]
331    /// struct ModelConfig {
332    ///     hidden_size: usize,
333    ///     num_layers: usize,
334    /// }
335    ///
336    /// let config: ModelConfig = PytorchReader::load_config("model.pth", Some("config"))?;
337    /// # Ok(())
338    /// # }
339    /// ```
340    pub fn load_config<D, P>(path: P, top_level_key: Option<&str>) -> Result<D>
341    where
342        D: DeserializeOwned,
343        P: AsRef<Path>,
344    {
345        // Read the PyTorch file and extract the pickle data
346        let pickle_value = Self::read_pickle_data(path, top_level_key)?;
347
348        // Convert PickleValue to NestedValue
349        let nested_value = convert_pickle_to_nested_value(pickle_value)?;
350
351        // Create a deserializer with the default adapter
352        let deserializer = Deserializer::<DefaultAdapter>::new(nested_value, false);
353
354        // Deserialize the nested value into the target type
355        let value = D::deserialize(deserializer)?;
356        Ok(value)
357    }
358}
359
360/// Simplified representation of pickle data
361///
362/// This enum provides a JSON-like structure that's easier to work with
363/// than the internal pickle Object type.
364#[derive(Debug, Clone, PartialEq)]
365pub enum PickleValue {
366    /// None/null value
367    None,
368    /// Boolean value
369    Bool(bool),
370    /// Integer value
371    Int(i64),
372    /// Floating point value
373    Float(f64),
374    /// String value
375    String(String),
376    /// List/array of values
377    List(Vec<PickleValue>),
378    /// Dictionary/map of string keys to values
379    Dict(HashMap<String, PickleValue>),
380    /// Binary data
381    Bytes(Vec<u8>),
382}
383
384/// Internal function to load a PyTorch file with metadata
385fn load_pytorch_file_with_metadata(
386    path: &Path,
387    top_level_key: Option<&str>,
388) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
389    // First, try to read as a zip file
390    if let Ok(file) = File::open(path)
391        && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
392    {
393        // PyTorch saves the main data in various locations within the zip
394        let mut pickle_data = Vec::new();
395        let mut pickle_found = false;
396
397        // Try different common pickle file locations
398        let possible_pickle_paths = [
399            "data.pkl",
400            "archive/data.pkl",
401            // Look for any .pkl file in the root or first-level directories
402        ];
403
404        for pickle_path in &possible_pickle_paths {
405            if archive.by_name(pickle_path).is_ok() {
406                let mut pickle_file = archive.by_name(pickle_path)?;
407                pickle_file.read_to_end(&mut pickle_data)?;
408                pickle_found = true;
409                break;
410            }
411        }
412
413        // If not found in standard locations, search for any .pkl file
414        if !pickle_found {
415            for i in 0..archive.len() {
416                let file = archive.by_index(i)?;
417                let name = file.name().to_string();
418                drop(file); // Release the borrow
419
420                if name.ends_with("data.pkl") {
421                    let mut file = archive.by_index(i)?;
422                    file.read_to_end(&mut pickle_data)?;
423                    pickle_found = true;
424                    break;
425                }
426            }
427        }
428
429        if !pickle_found {
430            return Err(PytorchError::InvalidFormat(
431                "No data.pkl file found in ZIP archive. Expected PyTorch 1.6+ format with data.pkl or archive/data.pkl".to_string(),
432            ));
433        }
434
435        // Check for format version (optional)
436        let format_version = if let Ok(mut version_file) = archive.by_name(".format_version") {
437            let mut version_data = Vec::new();
438            version_file.read_to_end(&mut version_data)?;
439            let version_str = String::from_utf8_lossy(&version_data);
440            let version = version_str.trim().to_string();
441            Some(version)
442        } else {
443            None
444        };
445
446        // Check for byteorder file to detect endianness
447        let is_big_endian = if let Ok(mut byteorder_file) = archive.by_name("byteorder") {
448            let mut byteorder_data = Vec::new();
449            byteorder_file.read_to_end(&mut byteorder_data)?;
450            let byteorder_str = String::from_utf8_lossy(&byteorder_data);
451            byteorder_str.trim() == "big"
452        } else {
453            false // Default to little-endian if no byteorder file
454        };
455
456        if is_big_endian {
457            // Big-endian files are not yet supported as they require different byte order conversion
458            // TODO: To support big-endian files, we need to:
459            // 1. Pass endianness info through to pickle_reader
460            // 2. Use from_be_bytes instead of from_le_bytes for tensor data
461            // 3. Handle byte swapping for all numeric types (f32, f64, i32, etc.)
462            return Err(PytorchError::InvalidFormat(
463                "Big-endian PyTorch files are not yet supported. The file was saved on a big-endian system and requires byte order conversion.".to_string()
464            ));
465        }
466
467        // Check for storage alignment file
468        let has_storage_alignment = archive.by_name(".storage_alignment").is_ok();
469
470        // Check for PyTorch version (if saved)
471        let pytorch_version = if let Ok(mut version_file) = archive.by_name("version") {
472            let mut version_data = Vec::new();
473            version_file.read_to_end(&mut version_data)?;
474            Some(String::from_utf8_lossy(&version_data).trim().to_string())
475        } else {
476            None
477        };
478
479        // Create a lazy data source instead of loading all data upfront
480        let data_source = Arc::new(LazyDataSource::from_zip(path)?);
481
482        // Calculate total data size without loading
483        let mut total_data_size = 0usize;
484        for i in 0..archive.len() {
485            let file = archive.by_index(i)?;
486            let name = file.name();
487
488            // Look for data files - they can be in various locations
489            let is_data_file = (name.contains("/data/")
490                || name.starts_with("data/")
491                || name.starts_with("archive/data/"))
492                && !name.ends_with(".pkl")
493                && !name.ends_with("/");
494
495            if is_data_file {
496                total_data_size += file.size() as usize;
497            }
498        }
499
500        // Parse the pickle data with lazy data source
501        let mut pickle_reader = BufReader::new(pickle_data.as_slice());
502        let obj = read_pickle_with_data(&mut pickle_reader, data_source)?;
503
504        // Extract tensors with their data
505        let tensors = extract_tensors_with_data(obj, top_level_key)?;
506
507        // Create metadata
508        let metadata = PytorchMetadata {
509            format_version,
510            format_type: FileFormat::Zip,
511            byte_order: if is_big_endian {
512                ByteOrder::BigEndian
513            } else {
514                ByteOrder::LittleEndian
515            },
516            has_storage_alignment,
517            pytorch_version,
518            tensor_count: tensors.len(),
519            total_data_size: Some(total_data_size),
520        };
521
522        return Ok((tensors, metadata));
523    }
524
525    // If not a zip or zip reading failed, try reading as a plain pickle file
526    let mut file = File::open(path)?;
527
528    // Check for PyTorch legacy format (starts with magic number as pickled integer)
529    let mut header = [0u8; 15];
530    // Use read() instead of read_exact() to handle files smaller than 15 bytes
531    let bytes_read = file.read(&mut header)?;
532    file.seek(std::io::SeekFrom::Start(0))?;
533
534    // Only check for legacy format if we have enough bytes
535    // PyTorch legacy format detection (PyTorch 0.1.10 - 1.3)
536    // Reference: https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L65
537    //
538    // These files use sequential pickle streams with metadata before the actual data.
539    // Format structure:
540    //   1. Magic number (0x1950a86a20f9469cfc6c) stored as LONG1 pickle
541    //   2. Protocol version (e.g., 1001)
542    //   3. System info dict (protocol_version, little_endian, type_sizes)
543    //   4. Actual model data (state_dict or full model)
544    //   5. Storage keys list (pickle)
545    //   6. Raw binary data for each storage
546    //
547    // The pattern is: 0x80 0x02 0x8a 0x0a (PROTO 2, LONG1 with 10 bytes)
548    // followed by 10 bytes of magic number (little-endian), then 0x2e (STOP)
549    let is_legacy_format = bytes_read >= 15
550        && header[0] == 0x80  // PROTO opcode
551        && header[1] == 0x02  // Protocol version 2
552        && header[2] == 0x8a  // LONG1 opcode
553        && header[3] == 0x0a  // 10 bytes follow
554        // Magic number 0x1950a86a20f9469cfc6c in little-endian
555        && header[4] == 0x6c
556        && header[5] == 0xfc
557        && header[6] == 0x9c
558        && header[7] == 0x46
559        && header[8] == 0xf9
560        && header[9] == 0x20
561        && header[10] == 0x6a
562        && header[11] == 0xa8
563        && header[12] == 0x50
564        && header[13] == 0x19
565        && header[14] == 0x2e; // STOP opcode
566
567    if is_legacy_format {
568        return load_legacy_pytorch_file_with_metadata(path, top_level_key);
569    }
570
571    // Standard pickle file
572    // This might be a pickle with tensor references, so we need to handle that case
573    // For plain pickle files without a separate data section, we can't use lazy loading
574    // so we'll just create empty placeholder tensors for the structure
575    let file = File::open(path)?;
576    let mut reader = BufReader::new(file);
577
578    // Try reading without data source first
579    match read_pickle(&mut reader) {
580        Ok(obj) => {
581            let tensors = extract_tensors_with_data(obj, top_level_key)?;
582            let tensor_count = tensors.len();
583            Ok((
584                tensors,
585                PytorchMetadata {
586                    format_version: None,
587                    format_type: FileFormat::Pickle,
588                    byte_order: ByteOrder::LittleEndian,
589                    has_storage_alignment: false,
590                    pytorch_version: None,
591                    tensor_count,
592                    total_data_size: None,
593                },
594            ))
595        }
596        Err(e)
597            if e.to_string()
598                .contains("Cannot load tensor data without a data source") =>
599        {
600            // This pickle file contains tensor data but we're trying to read it without
601            // providing a data source. This shouldn't happen in normal usage as PyTorch
602            // files with actual tensor data should be in ZIP or legacy format.
603            Err(PytorchError::InvalidFormat(
604                "Pickle file contains tensor data but no data source is available. This file should be loaded as ZIP or legacy format.".to_string()
605            ))
606        }
607        Err(e) => Err(PytorchError::Pickle(e)),
608    }
609}
610
611/// Load from a reader
612fn load_from_reader<R: Read>(
613    reader: R,
614    top_level_key: Option<&str>,
615) -> Result<HashMap<String, TensorSnapshot>> {
616    let mut buf_reader = BufReader::new(reader);
617
618    // Try reading without data source
619    match read_pickle(&mut buf_reader) {
620        Ok(obj) => extract_tensors_with_data(obj, top_level_key),
621        Err(e)
622            if e.to_string()
623                .contains("Cannot load tensor data without a data source") =>
624        {
625            // This reader contains tensor data but we can't load it without a file path
626            Err(PytorchError::InvalidFormat(
627                "Reader contains tensor data but no data source is available. Use file-based loading instead.".to_string()
628            ))
629        }
630        Err(e) => Err(PytorchError::Pickle(e)),
631    }
632}
633
634/// Extract tensors from a parsed pickle object
635fn extract_tensors_with_data(
636    obj: Object,
637    top_level_key: Option<&str>,
638) -> Result<HashMap<String, TensorSnapshot>> {
639    let dict = match obj {
640        Object::Dict(dict) => {
641            if let Some(key) = top_level_key {
642                // Extract the nested dictionary if a top-level key is specified
643                match dict.get(key) {
644                    Some(Object::Dict(nested)) => nested.clone(),
645                    _ => {
646                        return Err(PytorchError::KeyNotFound(format!(
647                            "Top-level key '{}' not found or is not a dictionary. Available top-level keys in file: {:?}",
648                            key,
649                            dict.keys().collect::<Vec<_>>()
650                        )));
651                    }
652                }
653            } else {
654                dict
655            }
656        }
657        _ => {
658            return Err(PytorchError::InvalidFormat(
659                "Expected a dictionary at the root of the PyTorch file, but found a different type. The file may be a full model save rather than a state_dict.".to_string(),
660            ));
661        }
662    };
663
664    let mut tensors = HashMap::new();
665    let mut path = Vec::new();
666    extract_tensors_recursive(&Object::Dict(dict), &mut path, &mut tensors);
667    Ok(tensors)
668}
669
670/// Recursively extract tensors from an object
671fn extract_tensors_recursive<'a>(
672    obj: &'a Object,
673    path: &mut Vec<&'a str>,
674    tensors: &mut HashMap<String, TensorSnapshot>,
675) {
676    match obj {
677        Object::Dict(dict) => {
678            for (key, value) in dict {
679                path.push(key);
680                extract_tensors_recursive(value, path, tensors);
681                path.pop();
682            }
683        }
684        Object::TorchParam(snapshot) => {
685            // The TensorSnapshot already contains the data loading closure
686            // Only allocate the string here when we actually insert
687            tensors.insert(path.join("."), snapshot.clone());
688        }
689        _ => {}
690    }
691}
692
693/// Load a legacy PyTorch file with metadata
694fn load_legacy_pytorch_file_with_metadata(
695    path: &Path,
696    top_level_key: Option<&str>,
697) -> Result<(HashMap<String, TensorSnapshot>, PytorchMetadata)> {
698    let file = File::open(path)?;
699    let mut reader = BufReader::new(file);
700
701    // Skip metadata pickles
702    // 1. Magic number
703    let _ = read_pickle(&mut reader).map_err(|e| {
704        PytorchError::InvalidFormat(format!(
705            "Failed to read magic number from legacy format: {}",
706            e
707        ))
708    })?;
709
710    // 2. Protocol version
711    let _ = read_pickle(&mut reader).map_err(|e| {
712        PytorchError::InvalidFormat(format!(
713            "Failed to read protocol version from legacy format: {}",
714            e
715        ))
716    })?;
717
718    // 3. System info
719    let _ = read_pickle(&mut reader).map_err(|e| {
720        PytorchError::InvalidFormat(format!(
721            "Failed to read system info from legacy format: {}",
722            e
723        ))
724    })?;
725
726    // Save position before main pickle
727    let main_pickle_pos = reader.stream_position()?;
728
729    // 4. Skip main object - it might contain tensors so we can't parse it yet
730    // We'll re-read it with a data source later
731    use crate::pytorch::pickle_reader::skip_pickle;
732    skip_pickle(&mut reader).map_err(|e| {
733        PytorchError::InvalidFormat(format!(
734            "Failed to skip main object in legacy format: {}",
735            e
736        ))
737    })?;
738
739    // 5. Storage keys list (sorted keys as written by PyTorch)
740    let storage_keys = match read_pickle(&mut reader) {
741        Ok(Object::List(keys)) => keys
742            .into_iter()
743            .filter_map(|obj| match obj {
744                Object::String(s) => Some(s),
745                _ => None,
746            })
747            .collect::<Vec<_>>(),
748        _ => vec![],
749    };
750
751    // 6. Raw binary data starts here
752    let data_start_pos = reader.stream_position()?;
753    let file_size = reader.seek(SeekFrom::End(0))?;
754    let data_size = file_size - data_start_pos;
755
756    // Create a lazy data source for legacy multi-storage format
757    let data_source = Arc::new(LazyDataSource::from_legacy_multi_storage(
758        path,
759        data_start_pos,
760        data_size,
761    ));
762
763    // Now re-read the main pickle with lazy data source
764    reader.seek(SeekFrom::Start(main_pickle_pos))?;
765    let main_obj = read_pickle_with_data(&mut reader, data_source.clone())?;
766
767    // Set storage keys for lazy boundary detection
768    // The boundaries will be calculated automatically as tensors are accessed
769    if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source
770        && !storage_keys.is_empty()
771    {
772        let source = source
773            .lock()
774            .unwrap_or_else(|poisoned| poisoned.into_inner());
775        source.set_storage_keys(storage_keys.clone());
776    }
777
778    // Extract tensors normally
779    let tensors = extract_tensors_with_data(main_obj, top_level_key)?;
780
781    // Create metadata for legacy format
782    let metadata = PytorchMetadata {
783        format_version: None, // Legacy format doesn't have version files
784        format_type: FileFormat::Legacy,
785        byte_order: ByteOrder::LittleEndian, // Legacy format is little-endian
786        has_storage_alignment: false,
787        pytorch_version: None, // Could parse from protocol version, but not reliable
788        tensor_count: tensors.len(),
789        total_data_size: Some(data_size as usize),
790    };
791
792    Ok((tensors, metadata))
793}
794
795/// Read pickle data from a PyTorch file as a simplified value
796fn read_pickle_as_value(path: &Path, top_level_key: Option<&str>) -> Result<PickleValue> {
797    use crate::pytorch::lazy_data::LazyDataSource;
798    use crate::pytorch::pickle_reader::{read_pickle, read_pickle_with_data};
799    use std::sync::Arc;
800
801    // Try to open as ZIP first
802    if let Ok(file) = File::open(path)
803        && let Ok(mut archive) = zip::ZipArchive::new(BufReader::new(file))
804    {
805        // Read pickle data from ZIP
806        let mut pickle_data = Vec::new();
807
808        // Try standard locations
809        for pickle_path in &["data.pkl", "archive/data.pkl"] {
810            if let Ok(mut pickle_file) = archive.by_name(pickle_path) {
811                pickle_file.read_to_end(&mut pickle_data)?;
812                break;
813            }
814        }
815
816        // If not found, search for any .pkl file
817        if pickle_data.is_empty() {
818            for i in 0..archive.len() {
819                let file = archive.by_index(i)?;
820                let name = file.name().to_string();
821                drop(file);
822
823                if name.ends_with("data.pkl") {
824                    let mut file = archive.by_index(i)?;
825                    file.read_to_end(&mut pickle_data)?;
826                    break;
827                }
828            }
829        }
830
831        if !pickle_data.is_empty() {
832            // Create a data source for the ZIP file
833            let data_source = LazyDataSource::from_zip(path)?;
834            let data_source_arc = Arc::new(data_source);
835
836            let mut reader = BufReader::new(pickle_data.as_slice());
837            let obj = read_pickle_with_data(&mut reader, data_source_arc)?;
838            return convert_object_to_value(obj, top_level_key);
839        }
840    }
841
842    // Try as plain pickle file
843    // First attempt without data source (for pure metadata files)
844    let file = File::open(path)?;
845    let mut reader = BufReader::new(file);
846
847    match read_pickle(&mut reader) {
848        Ok(obj) => convert_object_to_value(obj, top_level_key),
849        Err(e)
850            if e.to_string()
851                .contains("Cannot load tensor data without a data source") =>
852        {
853            // File contains tensors, need to use full PytorchReader
854            // Use the regular reader to get proper tensor handling
855            let reader = PytorchReader::new(path)?;
856
857            // Convert tensors to PickleValue structure
858            let mut result = std::collections::HashMap::new();
859            for key in reader.keys() {
860                // For pickle value extraction, we just need the structure, not the actual data
861                result.insert(
862                    key.clone(),
863                    PickleValue::String(format!("<Tensor:{}>", key)),
864                );
865            }
866
867            if let Some(key) = top_level_key {
868                Ok(PickleValue::Dict(
869                    [(key.to_string(), PickleValue::Dict(result))]
870                        .into_iter()
871                        .collect(),
872                ))
873            } else {
874                Ok(PickleValue::Dict(result))
875            }
876        }
877        Err(e) => Err(PytorchError::Pickle(e)),
878    }
879}
880
881/// Convert internal Object to public PickleValue
882fn convert_object_to_value(obj: Object, top_level_key: Option<&str>) -> Result<PickleValue> {
883    use crate::pytorch::pickle_reader::Object;
884
885    // If a top-level key is specified, extract it first
886    if let Some(key) = top_level_key
887        && let Object::Dict(dict) = obj
888    {
889        if let Some(value) = dict.get(key) {
890            return object_to_pickle_value(value.clone());
891        } else {
892            return Err(PytorchError::KeyNotFound(format!(
893                "Key '{}' not found in pickle data",
894                key
895            )));
896        }
897    }
898
899    object_to_pickle_value(obj)
900}
901
902/// Convert Object to PickleValue
903fn object_to_pickle_value(obj: Object) -> Result<PickleValue> {
904    use crate::pytorch::pickle_reader::Object;
905
906    Ok(match obj {
907        Object::None => PickleValue::None,
908        Object::Bool(b) => PickleValue::Bool(b),
909        Object::Int(i) => PickleValue::Int(i),
910        Object::Float(f) => PickleValue::Float(f),
911        Object::String(s) => PickleValue::String(s),
912        Object::Persistent(data) => {
913            // Persistent data is raw bytes
914            PickleValue::Bytes(data)
915        }
916        Object::PersistentTuple(tuple) => {
917            // Convert persistent tuples to lists
918            let mut values = Vec::new();
919            for item in tuple {
920                values.push(object_to_pickle_value(item)?);
921            }
922            PickleValue::List(values)
923        }
924        Object::List(list) => {
925            let mut values = Vec::new();
926            for item in list {
927                values.push(object_to_pickle_value(item)?);
928            }
929            PickleValue::List(values)
930        }
931        Object::Dict(dict) => {
932            let mut map = HashMap::new();
933            for (k, v) in dict {
934                map.insert(k, object_to_pickle_value(v)?);
935            }
936            PickleValue::Dict(map)
937        }
938        Object::Tuple(tuple) => {
939            // Convert tuples to lists in the public API
940            let mut values = Vec::new();
941            for item in tuple {
942                values.push(object_to_pickle_value(item)?);
943            }
944            PickleValue::List(values)
945        }
946        Object::TorchParam(_) => {
947            // Skip tensor parameters in config reading
948            PickleValue::None
949        }
950        Object::Class { .. } | Object::Build { .. } | Object::Reduce { .. } => {
951            // Complex objects are represented as None for simplicity
952            PickleValue::None
953        }
954    })
955}
956
957/// Convert PickleValue to NestedValue for deserialization
958fn convert_pickle_to_nested_value(value: PickleValue) -> Result<NestedValue> {
959    Ok(match value {
960        PickleValue::None => NestedValue::Default(None),
961        PickleValue::Bool(b) => NestedValue::Bool(b),
962        PickleValue::Int(i) => NestedValue::I64(i),
963        PickleValue::Float(f) => NestedValue::F64(f),
964        PickleValue::String(s) => NestedValue::String(s),
965        PickleValue::List(list) => {
966            let mut vec = Vec::new();
967            for item in list {
968                vec.push(convert_pickle_to_nested_value(item)?);
969            }
970            NestedValue::Vec(vec)
971        }
972        PickleValue::Dict(dict) => {
973            let mut map = HashMap::new();
974            for (k, v) in dict {
975                map.insert(k, convert_pickle_to_nested_value(v)?);
976            }
977            NestedValue::Map(map)
978        }
979        PickleValue::Bytes(data) => {
980            // Convert bytes to a list of u8 values
981            let vec: Vec<NestedValue> = data.into_iter().map(NestedValue::U8).collect();
982            NestedValue::Vec(vec)
983        }
984    })
985}