burn_store/pytorch/
pickle_reader.rs

1//! Just enough pickle support to be able to read PyTorch checkpoints.
2//!
3//! This implementation is based on the candle project's pickle loader with significant
4//! modifications for improved separation of concerns and extended PyTorch compatibility.
5//!
6//! Original source: https://github.com/huggingface/candle/blob/main/candle-core/src/pickle.rs
7//!
8//! Modifications include:
9//! - Lazy tensor data loading for memory efficiency
10//! - Extended PyTorch version compatibility (0.1.10 - 2.x)
11//! - Better separation of pickle parsing and tensor extraction
12//! - Support for both legacy and modern PyTorch formats
13use crate::TensorSnapshot;
14use crate::pytorch::lazy_data::LazyDataSource;
15use alloc::rc::Rc;
16use alloc::string::{String, ToString};
17use alloc::vec::Vec;
18use burn_core::module::ParamId;
19use burn_tensor::{DType, TensorData};
20use byteorder::{LittleEndian, ReadBytesExt};
21use half::{bf16, f16};
22use std::collections::HashMap;
23use std::io::{self, BufRead};
24use std::sync::Arc;
25
26/// Error type for pickle operations
27#[derive(Debug)]
28pub enum PickleError {
29    Io(io::Error),
30    InvalidOpCode(u8),
31    InvalidProtocol(u8),
32    UnexpectedOpCode(OpCode),
33    UnsupportedType(String),
34    InvalidData(String),
35    StackUnderflow,
36    MemoNotFound(u32),
37    InvalidShapeOrType,
38}
39
40impl From<io::Error> for PickleError {
41    fn from(e: io::Error) -> Self {
42        PickleError::Io(e)
43    }
44}
45
46impl std::fmt::Display for PickleError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            PickleError::Io(e) => write!(f, "IO error: {}", e),
50            PickleError::InvalidOpCode(code) => write!(
51                f,
52                "Invalid pickle opcode: 0x{:02x}. The file may be corrupted or use an unsupported pickle protocol.",
53                code
54            ),
55            PickleError::InvalidProtocol(proto) => write!(
56                f,
57                "Invalid or unsupported pickle protocol version: {}. Supported versions are 2-5.",
58                proto
59            ),
60            PickleError::UnexpectedOpCode(op) => {
61                write!(f, "Unexpected pickle opcode {:?} in current context", op)
62            }
63            PickleError::UnsupportedType(ty) => write!(
64                f,
65                "Unsupported Python type '{}'. This may indicate a full model save rather than a state_dict.",
66                ty
67            ),
68            PickleError::InvalidData(msg) => write!(f, "Invalid data in pickle file: {}", msg),
69            PickleError::StackUnderflow => {
70                write!(f, "Pickle stack underflow - the file may be corrupted")
71            }
72            PickleError::MemoNotFound(idx) => write!(
73                f,
74                "Pickle memo reference {} not found - the file may be corrupted",
75                idx
76            ),
77            PickleError::InvalidShapeOrType => {
78                write!(f, "Invalid tensor shape or data type in PyTorch file")
79            }
80        }
81    }
82}
83
84impl std::error::Error for PickleError {}
85
86type Result<T> = std::result::Result<T, PickleError>;
87
88// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
89#[repr(u8)]
90#[derive(Debug, Eq, PartialEq, Clone)]
91pub enum OpCode {
92    // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
93    Proto = 0x80,
94    Global = b'c',
95    BinPut = b'q',
96    LongBinPut = b'r',
97    EmptyTuple = b')',
98    Reduce = b'R',
99    Mark = b'(',
100    BinUnicode = b'X',
101    ShortBinString = b'U',
102    BinInt = b'J',
103    Int = b'I',
104    Tuple = b't',
105    BinPersId = b'Q',
106    BinInt1 = b'K',
107    BinInt2 = b'M',
108    Tuple1 = 0x85,
109    Tuple2 = 0x86,
110    Tuple3 = 0x87,
111    NewTrue = 0x88,
112    NewFalse = 0x89,
113    None = b'N',
114    BinGet = b'h',
115    LongBinGet = b'j',
116    SetItem = b's',
117    SetItems = b'u',
118    EmptyDict = b'}',
119    Dict = b'd',
120    Build = b'b',
121    Stop = b'.',
122    NewObj = 0x81,
123    EmptyList = b']',
124    List = b'l',
125    BinFloat = b'G',
126    Append = b'a',
127    Appends = b'e',
128    Long1 = 0x8a,
129    Memoize = 0x94,
130}
131
132// Avoid using FromPrimitive so as not to drag another dependency.
133impl TryFrom<u8> for OpCode {
134    type Error = u8;
135    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
136        match value {
137            0x80 => Ok(Self::Proto),
138            b'c' => Ok(Self::Global),
139            b'q' => Ok(Self::BinPut),
140            b'r' => Ok(Self::LongBinPut),
141            b')' => Ok(Self::EmptyTuple),
142            b'R' => Ok(Self::Reduce),
143            b'(' => Ok(Self::Mark),
144            b'X' => Ok(Self::BinUnicode),
145            b'U' => Ok(Self::ShortBinString),
146            b'J' => Ok(Self::BinInt),
147            b'I' => Ok(Self::Int),
148            b't' => Ok(Self::Tuple),
149            b'Q' => Ok(Self::BinPersId),
150            b'K' => Ok(Self::BinInt1),
151            b'M' => Ok(Self::BinInt2),
152            b'N' => Ok(Self::None),
153            0x85 => Ok(Self::Tuple1),
154            0x86 => Ok(Self::Tuple2),
155            0x87 => Ok(Self::Tuple3),
156            0x88 => Ok(Self::NewTrue),
157            0x89 => Ok(Self::NewFalse),
158            b'h' => Ok(Self::BinGet),
159            b'j' => Ok(Self::LongBinGet),
160            b's' => Ok(Self::SetItem),
161            b'u' => Ok(Self::SetItems),
162            b'}' => Ok(Self::EmptyDict),
163            b'd' => Ok(Self::Dict),
164            b'b' => Ok(Self::Build),
165            b'.' => Ok(Self::Stop),
166            0x81 => Ok(Self::NewObj),
167            b']' => Ok(Self::EmptyList),
168            b'l' => Ok(Self::List),
169            b'G' => Ok(Self::BinFloat),
170            b'a' => Ok(Self::Append),
171            b'e' => Ok(Self::Appends),
172            0x8a => Ok(Self::Long1),
173            0x94 => Ok(Self::Memoize),
174            value => Err(value),
175        }
176    }
177}
178
179fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
180    let mut data: Vec<u8> = Vec::with_capacity(32);
181    r.read_until(b'\n', &mut data)?;
182    data.pop();
183    if data.last() == Some(&b'\r') {
184        data.pop();
185    }
186    Ok(data)
187}
188
189fn buf_to_str(buf: &[u8]) -> Result<String> {
190    String::from_utf8(buf.to_vec())
191        .map_err(|e| PickleError::InvalidData(format!("Invalid UTF-8: {}", e)))
192}
193
194#[derive(Debug, Clone)]
195pub enum Object {
196    Class {
197        module_name: String,
198        name: String,
199    },
200    String(String),
201    Int(i64),
202    Float(f64),
203    Bool(bool),
204    None,
205    Tuple(Vec<Object>),
206    List(Vec<Object>),
207    Dict(HashMap<String, Object>),
208    Persistent(Vec<u8>),
209    PersistentTuple(Vec<Object>),
210    Reduce {
211        callable: Box<Object>,
212        args: Box<Object>,
213    },
214    Build {
215        callable: Box<Object>,
216        args: Box<Object>,
217    },
218    TorchParam(TensorSnapshot),
219}
220
221fn rebuild_from_type_v2(
222    o: Object,
223    memo: &mut HashMap<u32, Object>,
224    data_source: &Option<Arc<LazyDataSource>>,
225) -> Result<Object> {
226    let args = if let Object::Tuple(args) = o {
227        if args.is_empty() {
228            return Err(PickleError::InvalidData(
229                "rebuild_from_type_v2: empty args".to_string(),
230            ));
231        }
232        args
233    } else {
234        return Err(PickleError::InvalidData(format!(
235            "rebuild_from_type_v2: expected tuple got {:?}",
236            o
237        )));
238    };
239    let func = &args[0];
240    match func {
241        Object::Class { module_name, name } => {
242            let module_name = module_name.as_str();
243            let name = name.as_str();
244            // For rebuild_tensor_v2, the args might already be in a tuple
245            let actual_args = if args.len() == 2 && matches!(&args[1], Object::Tuple(_)) {
246                // If there's only one arg and it's a tuple, use it directly
247                args[1].clone()
248            } else {
249                // Otherwise, wrap the remaining args in a tuple
250                Object::Tuple(args[1..].to_vec())
251            };
252            if module_name == "torch._utils" && name == "_rebuild_tensor_v2" {
253                rebuild_tensor_v2(actual_args, memo, data_source)
254            } else if module_name == "torch._tensor" && name == "_rebuild_from_type_v2" {
255                rebuild_from_type_v2(actual_args, memo, data_source)
256            } else if module_name == "torch._utils" && name == "_rebuild_parameter" {
257                rebuild_parameter(actual_args, memo, data_source)
258            } else if module_name == "collections" && name == "OrderedDict" {
259                // OrderedDict is treated as a regular Dict in our implementation
260                Ok(Object::Dict(HashMap::new()))
261            } else {
262                Err(PickleError::UnsupportedType(format!(
263                    "{}.{}",
264                    module_name, name
265                )))
266            }
267        }
268        _ => Err(PickleError::InvalidData(format!(
269            "rebuild_from_type_v2: expected class got {:?}",
270            func
271        ))),
272    }
273}
274
275fn rebuild_parameter(
276    args: Object,
277    memo: &mut HashMap<u32, Object>,
278    data_source: &Option<Arc<LazyDataSource>>,
279) -> Result<Object> {
280    let args = if let Object::Tuple(args) = args {
281        if args.is_empty() {
282            return Err(PickleError::InvalidData(
283                "rebuild_parameter: empty args".to_string(),
284            ));
285        }
286        args
287    } else {
288        return Err(PickleError::InvalidData(format!(
289            "rebuild_parameter: expected tuple got {:?}",
290            args
291        )));
292    };
293    let data = &args[0];
294    let tensor = match data {
295        Object::Reduce {
296            callable: _,
297            args: _,
298        } => rebuild_from_type_v2(data.clone(), memo, data_source)?,
299        _ => data.clone(),
300    };
301    Ok(tensor)
302}
303
304fn rebuild_tensor_v2(
305    args: Object,
306    _memo: &mut HashMap<u32, Object>,
307    data_source: &Option<Arc<LazyDataSource>>,
308) -> Result<Object> {
309    // args is (storage, storage_offset, shape, stride, requires_grad, backward_hooks)
310    let args = if let Object::Tuple(args) = args {
311        args
312    } else {
313        return Err(PickleError::InvalidData(format!(
314            "rebuild_tensor_v2: expected tuple got {:?}",
315            args
316        )));
317    };
318
319    if args.len() < 5 {
320        return Err(PickleError::InvalidData(format!(
321            "rebuild_tensor_v2: expected at least 5 args, got {}",
322            args.len()
323        )));
324    }
325
326    // First argument is the storage (persistent ID)
327    let (storage_info, storage_tuple) = match &args[0] {
328        Object::Persistent(data) => (data.clone(), None),
329        Object::PersistentTuple(tuple) => (vec![], Some(tuple.clone())),
330        _ => {
331            return Err(PickleError::InvalidData(format!(
332                "rebuild_tensor_v2: expected persistent id got {:?}",
333                args[0]
334            )));
335        }
336    };
337
338    // Second argument is storage offset
339    let storage_offset = match &args[1] {
340        Object::Int(offset) => *offset as usize,
341        _ => 0,
342    };
343
344    // Third argument is shape
345    let shape = match &args[2] {
346        Object::Tuple(shape) => shape
347            .iter()
348            .map(|x| match x {
349                Object::Int(i) => Ok(*i as usize),
350                _ => Err(PickleError::InvalidData(
351                    "shape must contain ints".to_string(),
352                )),
353            })
354            .collect::<Result<Vec<_>>>()?,
355        _ => {
356            return Err(PickleError::InvalidData(format!(
357                "rebuild_tensor_v2: expected shape tuple got {:?}",
358                args[2]
359            )));
360        }
361    };
362
363    // Fourth argument is stride (we don't use it but validate it exists)
364    let _stride = matches!(&args[3], Object::Tuple(_));
365
366    // Parse the storage info to extract dtype and storage key
367    // The persistent ID is typically a tuple like: ('storage', 'FloatStorage', '0', 'cpu', 4)
368    let (dtype, storage_key) = if let Some(tuple) = storage_tuple {
369        // Direct tuple access
370        if tuple.len() >= 3 {
371            let storage_type = match &tuple[1] {
372                Object::String(s) => s.as_str(),
373                Object::Class {
374                    module_name: _,
375                    name,
376                } => name.as_str(),
377                _ => "FloatStorage",
378            };
379            let dtype = match storage_type {
380                "FloatStorage" => DType::F32,
381                "DoubleStorage" => DType::F64,
382                "HalfStorage" => DType::F16,
383                "BFloat16Storage" => DType::BF16,
384                "LongStorage" => DType::I64,
385                "IntStorage" => DType::I32,
386                "ShortStorage" => DType::I16,
387                "CharStorage" => DType::I8,
388                "ByteStorage" => DType::U8,
389                "BoolStorage" => DType::Bool,
390                _ => DType::F32, // Default to F32
391            };
392            let key = match &tuple[2] {
393                Object::String(s) => s.clone(),
394                _ => "0".to_string(),
395            };
396            (dtype, key)
397        } else {
398            (DType::F32, "0".to_string())
399        }
400    } else if !storage_info.is_empty() {
401        // Legacy string-based parsing
402        let storage_str = String::from_utf8_lossy(&storage_info);
403        if storage_str.starts_with("Tuple(") {
404            // Parse from the debug representation we stored
405            let parts: Vec<&str> = storage_str
406                .trim_start_matches("Tuple(")
407                .trim_end_matches(")")
408                .split(", ")
409                .map(|s| {
410                    let trimmed = s.trim_matches('"');
411                    if let Some(inner) = trimmed
412                        .strip_prefix("Object::String(\"")
413                        .and_then(|s| s.strip_suffix("\")"))
414                    {
415                        inner
416                    } else {
417                        trimmed
418                    }
419                })
420                .collect();
421
422            if parts.len() >= 3 {
423                let dtype = match parts[1] {
424                    "FloatStorage" => DType::F32,
425                    "DoubleStorage" => DType::F64,
426                    "HalfStorage" => DType::F16,
427                    "BFloat16Storage" => DType::BF16,
428                    "LongStorage" => DType::I64,
429                    "IntStorage" => DType::I32,
430                    "ShortStorage" => DType::I16,
431                    "CharStorage" => DType::I8,
432                    "ByteStorage" => DType::U8,
433                    _ => DType::F32, // Default to F32
434                };
435                (dtype, parts[2].to_string())
436            } else {
437                (DType::F32, "0".to_string())
438            }
439        } else {
440            (DType::F32, "0".to_string())
441        }
442    } else {
443        (DType::F32, "0".to_string())
444    };
445
446    // If no data source, we can't load tensor data
447    let data_source = match data_source {
448        Some(ds) => ds.clone(),
449        None => {
450            return Err(PickleError::InvalidData(
451                "Cannot load tensor data without a data source".to_string(),
452            ));
453        }
454    };
455
456    // Create clones for the closure
457    let data_source_clone = data_source.clone();
458    let shape_clone = shape.clone();
459
460    // Find the correct data file key
461    let data_file_key = {
462        let exact_key = format!("data/{}", storage_key);
463        if data_source.contains(&exact_key) {
464            exact_key
465        } else {
466            // Try other patterns
467            data_source
468                .keys()
469                .into_iter()
470                .find(|key| {
471                    key.ends_with(&format!("/data/{}", storage_key))
472                        || (key.contains("/data/") && key.rsplit('/').next() == Some(&storage_key))
473                })
474                .unwrap_or_else(|| format!("data/{}", storage_key))
475        }
476    };
477
478    // Track storage usage IMMEDIATELY for lazy boundary detection
479    // This must happen BEFORE creating the closure, not inside it!
480    if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source {
481        let source = source
482            .lock()
483            .unwrap_or_else(|poisoned| poisoned.into_inner());
484        let num_elements: usize = shape.iter().product();
485        let bytes_needed = storage_offset * dtype.size() + num_elements * dtype.size();
486        source.track_storage_usage(&storage_key, 0, bytes_needed);
487    }
488
489    // Create a TensorSnapshot with a closure that loads the actual data on-demand
490    // TODO extract this long code into stand along function (part of a bigger refactor)
491    Ok(Object::TorchParam(TensorSnapshot::from_closure(
492        Rc::new(move || {
493            // Load data only when needed
494            if let Ok(data) = data_source_clone.read(&data_file_key) {
495                // Parse the binary data based on dtype
496                let num_elements = shape_clone.iter().product::<usize>().max(1);
497
498                // Use dtype.size() to get element size in bytes
499                let element_size = dtype.size();
500
501                // Apply storage offset
502                let offset_bytes = storage_offset * element_size;
503                if offset_bytes >= data.len() {
504                    return Ok(TensorData::new(
505                        vec![0.0f32; num_elements],
506                        shape_clone.clone(),
507                    ));
508                }
509
510                let data_slice = &data[offset_bytes..];
511                let available_elements = data_slice.len() / element_size;
512                let elements_to_read = num_elements.min(available_elements);
513
514                // Convert bytes to the appropriate type
515                match dtype {
516                    DType::F32 => {
517                        let mut values = Vec::with_capacity(num_elements);
518                        for i in 0..elements_to_read {
519                            let bytes = [
520                                data_slice[i * element_size],
521                                data_slice[i * element_size + 1],
522                                data_slice[i * element_size + 2],
523                                data_slice[i * element_size + 3],
524                            ];
525                            values.push(f32::from_le_bytes(bytes));
526                        }
527                        // Pad with zeros if needed
528                        values.resize(num_elements, 0.0);
529                        Ok(TensorData::new(values, shape_clone.clone()))
530                    }
531                    DType::F64 => {
532                        let mut values = Vec::with_capacity(num_elements);
533                        for i in 0..elements_to_read {
534                            let mut bytes = [0u8; 8];
535                            bytes.copy_from_slice(
536                                &data_slice[i * element_size..(i + 1) * element_size],
537                            );
538                            values.push(f64::from_le_bytes(bytes));
539                        }
540                        values.resize(num_elements, 0.0);
541                        Ok(TensorData::new(values, shape_clone.clone()))
542                    }
543                    DType::I64 => {
544                        let mut values = Vec::with_capacity(num_elements);
545                        for i in 0..elements_to_read {
546                            let mut bytes = [0u8; 8];
547                            bytes.copy_from_slice(
548                                &data_slice[i * element_size..(i + 1) * element_size],
549                            );
550                            values.push(i64::from_le_bytes(bytes));
551                        }
552                        values.resize(num_elements, 0);
553                        Ok(TensorData::new(values, shape_clone.clone()))
554                    }
555                    DType::I32 => {
556                        let mut values = Vec::with_capacity(num_elements);
557                        for i in 0..elements_to_read {
558                            let mut bytes = [0u8; 4];
559                            bytes.copy_from_slice(
560                                &data_slice[i * element_size..(i + 1) * element_size],
561                            );
562                            values.push(i32::from_le_bytes(bytes));
563                        }
564                        values.resize(num_elements, 0);
565                        Ok(TensorData::new(values, shape_clone.clone()))
566                    }
567                    DType::I16 => {
568                        let mut values = Vec::with_capacity(num_elements);
569                        for i in 0..elements_to_read {
570                            let mut bytes = [0u8; 2];
571                            bytes.copy_from_slice(
572                                &data_slice[i * element_size..(i + 1) * element_size],
573                            );
574                            values.push(i16::from_le_bytes(bytes));
575                        }
576                        values.resize(num_elements, 0);
577                        Ok(TensorData::new(values, shape_clone.clone()))
578                    }
579                    DType::I8 => {
580                        let mut values = Vec::with_capacity(num_elements);
581                        for &byte in data_slice.iter().take(elements_to_read) {
582                            values.push(byte as i8);
583                        }
584                        values.resize(num_elements, 0);
585                        Ok(TensorData::new(values, shape_clone.clone()))
586                    }
587                    DType::Bool => {
588                        let mut values = Vec::with_capacity(num_elements);
589                        for &byte in data_slice.iter().take(elements_to_read) {
590                            values.push(byte != 0);
591                        }
592                        values.resize(num_elements, false);
593                        Ok(TensorData::new(values, shape_clone.clone()))
594                    }
595                    DType::F16 => {
596                        let mut values = Vec::with_capacity(num_elements);
597                        for i in 0..elements_to_read {
598                            let mut bytes = [0u8; 2];
599                            bytes.copy_from_slice(
600                                &data_slice[i * element_size..(i + 1) * element_size],
601                            );
602                            values.push(f16::from_le_bytes(bytes));
603                        }
604                        values.resize(num_elements, f16::ZERO);
605                        Ok(TensorData::new(values, shape_clone.clone()))
606                    }
607                    DType::BF16 => {
608                        let mut values = Vec::with_capacity(num_elements);
609                        for i in 0..elements_to_read {
610                            let mut bytes = [0u8; 2];
611                            bytes.copy_from_slice(
612                                &data_slice[i * element_size..(i + 1) * element_size],
613                            );
614                            values.push(bf16::from_le_bytes(bytes));
615                        }
616                        values.resize(num_elements, bf16::ZERO);
617                        Ok(TensorData::new(values, shape_clone.clone()))
618                    }
619                    DType::U8 => {
620                        let mut values = Vec::with_capacity(num_elements);
621                        for &byte in data_slice.iter().take(elements_to_read) {
622                            values.push(byte);
623                        }
624                        values.resize(num_elements, 0);
625                        Ok(TensorData::new(values, shape_clone.clone()))
626                    }
627                    DType::U16 => {
628                        let mut values = Vec::with_capacity(num_elements);
629                        for i in 0..elements_to_read {
630                            let mut bytes = [0u8; 2];
631                            bytes.copy_from_slice(
632                                &data_slice[i * element_size..(i + 1) * element_size],
633                            );
634                            values.push(u16::from_le_bytes(bytes));
635                        }
636                        values.resize(num_elements, 0);
637                        Ok(TensorData::new(values, shape_clone.clone()))
638                    }
639                    DType::U32 => {
640                        let mut values = Vec::with_capacity(num_elements);
641                        for i in 0..elements_to_read {
642                            let mut bytes = [0u8; 4];
643                            bytes.copy_from_slice(
644                                &data_slice[i * element_size..(i + 1) * element_size],
645                            );
646                            values.push(u32::from_le_bytes(bytes));
647                        }
648                        values.resize(num_elements, 0);
649                        Ok(TensorData::new(values, shape_clone.clone()))
650                    }
651                    DType::U64 => {
652                        let mut values = Vec::with_capacity(num_elements);
653                        for i in 0..elements_to_read {
654                            let mut bytes = [0u8; 8];
655                            bytes.copy_from_slice(
656                                &data_slice[i * element_size..(i + 1) * element_size],
657                            );
658                            values.push(u64::from_le_bytes(bytes));
659                        }
660                        values.resize(num_elements, 0);
661                        Ok(TensorData::new(values, shape_clone.clone()))
662                    }
663                    _ => {
664                        // For any remaining unsupported types, return an error
665                        Err(crate::TensorSnapshotError::DataError(format!(
666                            "Unsupported dtype for tensor data reading: {:?}",
667                            dtype
668                        )))
669                    }
670                }
671            } else {
672                // If no data file found, return zeros of the appropriate type
673                let num_elements = shape_clone.iter().product::<usize>().max(1);
674                match dtype {
675                    DType::F32 => Ok(TensorData::new(
676                        vec![0.0f32; num_elements],
677                        shape_clone.clone(),
678                    )),
679                    DType::F64 => Ok(TensorData::new(
680                        vec![0.0f64; num_elements],
681                        shape_clone.clone(),
682                    )),
683                    DType::F16 => Ok(TensorData::new(
684                        vec![f16::ZERO; num_elements],
685                        shape_clone.clone(),
686                    )),
687                    DType::BF16 => Ok(TensorData::new(
688                        vec![bf16::ZERO; num_elements],
689                        shape_clone.clone(),
690                    )),
691                    DType::I64 => Ok(TensorData::new(
692                        vec![0i64; num_elements],
693                        shape_clone.clone(),
694                    )),
695                    DType::I32 => Ok(TensorData::new(
696                        vec![0i32; num_elements],
697                        shape_clone.clone(),
698                    )),
699                    DType::I16 => Ok(TensorData::new(
700                        vec![0i16; num_elements],
701                        shape_clone.clone(),
702                    )),
703                    DType::I8 => Ok(TensorData::new(
704                        vec![0i8; num_elements],
705                        shape_clone.clone(),
706                    )),
707                    DType::U8 => Ok(TensorData::new(
708                        vec![0u8; num_elements],
709                        shape_clone.clone(),
710                    )),
711                    DType::U16 => Ok(TensorData::new(
712                        vec![0u16; num_elements],
713                        shape_clone.clone(),
714                    )),
715                    DType::U32 => Ok(TensorData::new(
716                        vec![0u32; num_elements],
717                        shape_clone.clone(),
718                    )),
719                    DType::U64 => Ok(TensorData::new(
720                        vec![0u64; num_elements],
721                        shape_clone.clone(),
722                    )),
723                    DType::Bool => Ok(TensorData::new(
724                        vec![false; num_elements],
725                        shape_clone.clone(),
726                    )),
727                    _ => {
728                        // For any remaining unsupported types, return an error
729                        Err(crate::TensorSnapshotError::DataError(format!(
730                            "Unsupported dtype for tensor data reading: {:?}",
731                            dtype
732                        )))
733                    }
734                }
735            }
736        }),
737        dtype,
738        shape,
739        vec![],         // path_stack
740        vec![],         // container_stack
741        ParamId::new(), // tensor_id
742    )))
743}
744
745pub struct Stack {
746    stack: Vec<Object>,
747    memo: HashMap<u32, Object>,
748    data_source: Option<Arc<LazyDataSource>>,
749}
750
751impl Default for Stack {
752    fn default() -> Self {
753        Self::new()
754    }
755}
756
757impl Stack {
758    pub fn new() -> Self {
759        // For cases where no data source is needed (pure pickle without tensor data)
760        Self {
761            stack: Vec::new(),
762            memo: HashMap::new(),
763            data_source: None,
764        }
765    }
766
767    pub fn with_data_source(data_source: Arc<LazyDataSource>) -> Self {
768        Self {
769            stack: Vec::new(),
770            memo: HashMap::new(),
771            data_source: Some(data_source),
772        }
773    }
774
775    fn push(&mut self, o: Object) {
776        self.stack.push(o)
777    }
778
779    fn pop(&mut self) -> Result<Object> {
780        match self.stack.pop() {
781            None => Err(PickleError::StackUnderflow),
782            Some(o) => Ok(o),
783        }
784    }
785
786    fn top(&self) -> Result<Object> {
787        match self.stack.last() {
788            None => Err(PickleError::StackUnderflow),
789            Some(o) => Ok(o.clone()),
790        }
791    }
792
793    fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
794        let marker_pos = self
795            .stack
796            .iter()
797            .rposition(|o| {
798                matches!(o, Object::Class { module_name, name }
799                if module_name == "mark" && name == "mark")
800            })
801            .ok_or(PickleError::InvalidData("marker not found".to_string()))?;
802
803        let result = self.stack.split_off(marker_pos + 1);
804        self.stack.pop(); // Remove the marker
805        Ok(result)
806    }
807
808    fn last_mut(&mut self) -> Result<&mut Object> {
809        match self.stack.last_mut() {
810            None => Err(PickleError::StackUnderflow),
811            Some(o) => Ok(o),
812        }
813    }
814
815    fn push_mark(&mut self) {
816        self.stack.push(Object::Class {
817            module_name: "mark".to_string(),
818            name: "mark".to_string(),
819        });
820    }
821
822    fn memo_get(&self, idx: u32) -> Result<Object> {
823        self.memo
824            .get(&idx)
825            .cloned()
826            .ok_or(PickleError::MemoNotFound(idx))
827    }
828
829    fn memo_put(&mut self, idx: u32, obj: Object) {
830        self.memo.insert(idx, obj);
831    }
832
833    fn memo_len(&self) -> usize {
834        self.memo.len()
835    }
836}
837
838fn read_global<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
839    let module_name = buf_to_str(&read_to_newline(r)?)?;
840    let name = buf_to_str(&read_to_newline(r)?)?;
841    stack.push(Object::Class { module_name, name });
842    Ok(())
843}
844
845fn read_long1<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
846    let len = r.read_u8()? as usize;
847    let mut data = vec![0u8; len];
848    r.read_exact(&mut data)?;
849    // Handle little-endian signed integer
850    let mut value = 0i64;
851    for (i, &byte) in data.iter().enumerate().take(8) {
852        // Only process up to 8 bytes for i64, and use wrapping to avoid overflow
853        value |= (byte as i64).wrapping_shl((i as u32) * 8);
854    }
855    // Handle sign extension for negative numbers
856    if len < 8 && data.last().is_some_and(|&b| b & 0x80 != 0) {
857        // Sign extend
858        for i in len..8 {
859            value |= 0xffi64.wrapping_shl((i as u32) * 8);
860        }
861    }
862    stack.push(Object::Int(value));
863    Ok(())
864}
865
866fn read_string<R: BufRead>(r: &mut R, stack: &mut Stack, len: usize) -> Result<()> {
867    let mut data = vec![0u8; len];
868    r.read_exact(&mut data)?;
869    let s = buf_to_str(&data)?;
870    stack.push(Object::String(s));
871    Ok(())
872}
873
874fn read_bin_int<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
875    let v = r.read_i32::<LittleEndian>()?;
876    stack.push(Object::Int(v as i64));
877    Ok(())
878}
879
880fn read_int<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
881    // INT opcode reads an integer as ASCII string followed by newline
882    let line = read_to_newline(r)?;
883    let s = buf_to_str(&line)?;
884    let v = s
885        .parse::<i64>()
886        .map_err(|e| PickleError::InvalidData(format!("Invalid INT value '{}': {}", s, e)))?;
887    stack.push(Object::Int(v));
888    Ok(())
889}
890
891fn read_bin_int1<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
892    let v = r.read_u8()?;
893    stack.push(Object::Int(v as i64));
894    Ok(())
895}
896
897fn read_bin_int2<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
898    let v = r.read_u16::<LittleEndian>()?;
899    stack.push(Object::Int(v as i64));
900    Ok(())
901}
902
903fn read_bin_float<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
904    // Python's BINFLOAT uses big-endian encoding
905    let v = r.read_f64::<byteorder::BigEndian>()?;
906    stack.push(Object::Float(v));
907    Ok(())
908}
909
910pub fn read_pickle<R: BufRead>(r: &mut R) -> Result<Object> {
911    // For pure pickle without tensor data, no data source is needed
912    read_pickle_with_optional_data(r, None)
913}
914
915/// Skip over a pickle without parsing it fully
916/// This is useful for legacy format where we need to skip the main object
917/// that contains tensors but we don't have a data source yet
918pub fn skip_pickle<R: BufRead>(r: &mut R) -> Result<()> {
919    // Read the protocol marker if present
920    let mut first_byte = [0u8; 1];
921    r.read_exact(&mut first_byte)?;
922
923    if first_byte[0] == 0x80 {
924        // PROTO marker - read protocol version
925        let mut proto_version = [0u8; 1];
926        r.read_exact(&mut proto_version)?;
927    } else {
928        // Not a PROTO marker, we need to handle this byte
929        // Put it back by using a small state machine
930        // For now, we'll track that we've seen a non-proto byte
931    }
932
933    // Scan until we find STOP (0x2e) opcode
934    loop {
935        let mut byte = [0u8; 1];
936        r.read_exact(&mut byte)?;
937
938        match byte[0] {
939            0x2e => {
940                // STOP - end of pickle
941                break;
942            }
943            0x58 | 0x42 | 0x43 | 0x54 | 0x55 | 0x56 | 0x8c | 0x8d | 0x8e => {
944                // String/bytes opcodes with length prefixes
945                let length = match byte[0] {
946                    0x43 | 0x55 | 0x8c => {
947                        // SHORT versions - 1 byte length
948                        let mut len_byte = [0u8; 1];
949                        r.read_exact(&mut len_byte)?;
950                        len_byte[0] as usize
951                    }
952                    0x42 | 0x54 | 0x58 | 0x56 => {
953                        // Regular versions - 4 byte length
954                        let mut len_bytes = [0u8; 4];
955                        r.read_exact(&mut len_bytes)?;
956                        u32::from_le_bytes(len_bytes) as usize
957                    }
958                    0x8d | 0x8e => {
959                        // 8-byte length versions
960                        let mut len_bytes = [0u8; 8];
961                        r.read_exact(&mut len_bytes)?;
962                        u64::from_le_bytes(len_bytes) as usize
963                    }
964                    _ => 0,
965                };
966
967                // Skip the actual data
968                let mut skip_buf = vec![0u8; length.min(8192)];
969                let mut skipped = 0;
970                while skipped < length {
971                    let to_skip = (length - skipped).min(skip_buf.len());
972                    r.read_exact(&mut skip_buf[..to_skip])?;
973                    skipped += to_skip;
974                }
975            }
976            0x4b | 0x4d | 0x4e => {
977                // BININT1, BININT2, BININT4 - skip the integer bytes
978                let skip_count = match byte[0] {
979                    0x4b => 1,
980                    0x4d => 2,
981                    0x4e => 4,
982                    _ => 0,
983                };
984                let mut skip_buf = vec![0u8; skip_count];
985                r.read_exact(&mut skip_buf)?;
986            }
987            0x47 => {
988                // BINFLOAT - skip 8 bytes
989                let mut skip_buf = [0u8; 8];
990                r.read_exact(&mut skip_buf)?;
991            }
992            0x4a => {
993                // BININT - skip 4 bytes (signed)
994                let mut skip_buf = [0u8; 4];
995                r.read_exact(&mut skip_buf)?;
996            }
997            0x8a => {
998                // LONG1 - 1 byte length, then that many bytes
999                let mut len_byte = [0u8; 1];
1000                r.read_exact(&mut len_byte)?;
1001                let length = len_byte[0] as usize;
1002                let mut skip_buf = vec![0u8; length];
1003                r.read_exact(&mut skip_buf)?;
1004            }
1005            0x8b => {
1006                // LONG4 - 4 byte length, then that many bytes
1007                let mut len_bytes = [0u8; 4];
1008                r.read_exact(&mut len_bytes)?;
1009                let length = u32::from_le_bytes(len_bytes) as usize;
1010                let mut skip_buf = vec![0u8; length.min(8192)];
1011                let mut skipped = 0;
1012                while skipped < length {
1013                    let to_skip = (length - skipped).min(skip_buf.len());
1014                    r.read_exact(&mut skip_buf[..to_skip])?;
1015                    skipped += to_skip;
1016                }
1017            }
1018            _ => {
1019                // Other opcodes - most don't have additional data
1020                // or are stack manipulation opcodes we can ignore
1021            }
1022        }
1023    }
1024
1025    Ok(())
1026}
1027
1028pub fn read_pickle_with_data<R: BufRead>(
1029    r: &mut R,
1030    data_source: Arc<LazyDataSource>,
1031) -> Result<Object> {
1032    read_pickle_with_optional_data(r, Some(data_source))
1033}
1034
1035pub fn read_pickle_with_optional_data<R: BufRead>(
1036    r: &mut R,
1037    data_source: Option<Arc<LazyDataSource>>,
1038) -> Result<Object> {
1039    let mut stack = match data_source {
1040        Some(ds) => Stack::with_data_source(ds),
1041        None => Stack::new(),
1042    };
1043    loop {
1044        let op_code = r.read_u8()?;
1045        let op_code = OpCode::try_from(op_code).map_err(PickleError::InvalidOpCode)?;
1046        match op_code {
1047            OpCode::Proto => {
1048                let version = r.read_u8()?;
1049                if version > 5 {
1050                    return Err(PickleError::InvalidProtocol(version));
1051                }
1052            }
1053            OpCode::Global => read_global(r, &mut stack)?,
1054            OpCode::BinInt => read_bin_int(r, &mut stack)?,
1055            OpCode::Int => read_int(r, &mut stack)?,
1056            OpCode::BinInt1 => read_bin_int1(r, &mut stack)?,
1057            OpCode::BinInt2 => read_bin_int2(r, &mut stack)?,
1058            OpCode::BinFloat => read_bin_float(r, &mut stack)?,
1059            OpCode::BinUnicode => {
1060                let len = r.read_u32::<LittleEndian>()? as usize;
1061                read_string(r, &mut stack, len)?
1062            }
1063            OpCode::ShortBinString => {
1064                let len = r.read_u8()? as usize;
1065                read_string(r, &mut stack, len)?
1066            }
1067            OpCode::Long1 => read_long1(r, &mut stack)?,
1068            OpCode::None => stack.push(Object::None),
1069            OpCode::NewTrue => stack.push(Object::Bool(true)),
1070            OpCode::NewFalse => stack.push(Object::Bool(false)),
1071            OpCode::EmptyTuple => stack.push(Object::Tuple(Vec::new())),
1072            OpCode::EmptyList => stack.push(Object::List(Vec::new())),
1073            OpCode::EmptyDict => stack.push(Object::Dict(HashMap::new())),
1074            OpCode::Tuple => {
1075                let objs = stack.pop_to_marker()?;
1076                stack.push(Object::Tuple(objs))
1077            }
1078            OpCode::Tuple1 => {
1079                let obj = stack.pop()?;
1080                stack.push(Object::Tuple(vec![obj]))
1081            }
1082            OpCode::Tuple2 => {
1083                let obj2 = stack.pop()?;
1084                let obj1 = stack.pop()?;
1085                stack.push(Object::Tuple(vec![obj1, obj2]))
1086            }
1087            OpCode::Tuple3 => {
1088                let obj3 = stack.pop()?;
1089                let obj2 = stack.pop()?;
1090                let obj1 = stack.pop()?;
1091                stack.push(Object::Tuple(vec![obj1, obj2, obj3]))
1092            }
1093            OpCode::Append => {
1094                let value = stack.pop()?;
1095                match stack.last_mut()? {
1096                    Object::List(list) => list.push(value),
1097                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1098                }
1099            }
1100            OpCode::Appends => {
1101                let objs = stack.pop_to_marker()?;
1102                match stack.last_mut()? {
1103                    Object::List(list) => list.extend(objs),
1104                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1105                }
1106            }
1107            OpCode::SetItem => {
1108                let value = stack.pop()?;
1109                let key = stack.pop()?;
1110                match stack.last_mut()? {
1111                    Object::Dict(dict) => {
1112                        if let Object::String(key) = key {
1113                            dict.insert(key, value);
1114                        } else {
1115                            return Err(PickleError::InvalidData(
1116                                "dict key must be a string".to_string(),
1117                            ));
1118                        }
1119                    }
1120                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1121                }
1122            }
1123            OpCode::SetItems => {
1124                let mut objs = stack.pop_to_marker()?;
1125                if objs.len() % 2 != 0 {
1126                    return Err(PickleError::InvalidData(
1127                        "setitems requires even number of objects".to_string(),
1128                    ));
1129                }
1130                match stack.last_mut()? {
1131                    Object::Dict(dict) => {
1132                        while !objs.is_empty() {
1133                            let key = objs.remove(0);
1134                            let value = objs.remove(0);
1135                            if let Object::String(key) = key {
1136                                dict.insert(key, value);
1137                            } else {
1138                                return Err(PickleError::InvalidData(
1139                                    "dict key must be a string".to_string(),
1140                                ));
1141                            }
1142                        }
1143                    }
1144                    _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1145                }
1146            }
1147            OpCode::BinPut => {
1148                let idx = r.read_u8()? as u32;
1149                let obj = stack.top()?;
1150                stack.memo_put(idx, obj);
1151            }
1152            OpCode::LongBinPut => {
1153                let idx = r.read_u32::<LittleEndian>()?;
1154                let obj = stack.top()?;
1155                stack.memo_put(idx, obj);
1156            }
1157            OpCode::BinGet => {
1158                let idx = r.read_u8()? as u32;
1159                let obj = stack.memo_get(idx)?;
1160                stack.push(obj);
1161            }
1162            OpCode::LongBinGet => {
1163                let idx = r.read_u32::<LittleEndian>()?;
1164                let obj = stack.memo_get(idx)?;
1165                stack.push(obj);
1166            }
1167            OpCode::Mark => stack.push_mark(),
1168            OpCode::BinPersId => {
1169                let pid = stack.pop()?;
1170                match pid {
1171                    Object::String(s) => {
1172                        stack.push(Object::Persistent(s.into_bytes()));
1173                    }
1174                    Object::Tuple(tuple) => {
1175                        // The persistent ID is a tuple (e.g., ('storage', 'FloatStorage', '0', 'cpu', 4))
1176                        // Store it as a PersistentTuple for proper handling
1177                        stack.push(Object::PersistentTuple(tuple));
1178                    }
1179                    _ => {
1180                        return Err(PickleError::InvalidData(format!(
1181                            "persistent id must be a string or tuple, got {:?}",
1182                            pid
1183                        )));
1184                    }
1185                }
1186            }
1187            OpCode::Reduce => {
1188                let args = stack.pop()?;
1189                let callable = stack.pop()?;
1190
1191                // Check if this is an OrderedDict
1192                if let Object::Class { module_name, name } = &callable {
1193                    if module_name == "collections" && name == "OrderedDict" {
1194                        // OrderedDict is created with empty args, just push an empty dict
1195                        stack.push(Object::Dict(HashMap::new()));
1196                    } else {
1197                        let _obj = Object::Reduce {
1198                            callable: Box::new(callable.clone()),
1199                            args: Box::new(args.clone()),
1200                        };
1201                        let obj = rebuild_from_type_v2(
1202                            Object::Tuple(vec![callable, args]),
1203                            &mut stack.memo,
1204                            &stack.data_source,
1205                        )?;
1206                        stack.push(obj);
1207                    }
1208                } else {
1209                    let _obj = Object::Reduce {
1210                        callable: Box::new(callable.clone()),
1211                        args: Box::new(args.clone()),
1212                    };
1213                    let obj = rebuild_from_type_v2(
1214                        Object::Tuple(vec![callable, args]),
1215                        &mut stack.memo,
1216                        &stack.data_source,
1217                    )?;
1218                    stack.push(obj);
1219                }
1220            }
1221            OpCode::Build => {
1222                let args = stack.pop()?;
1223                let obj = stack.pop()?;
1224                match obj {
1225                    Object::Dict(mut dict) => {
1226                        // For dicts, BUILD updates with the args
1227                        if let Object::Dict(update) = args {
1228                            dict.extend(update);
1229                        }
1230                        stack.push(Object::Dict(dict));
1231                    }
1232                    _ => {
1233                        stack.push(Object::Build {
1234                            callable: Box::new(obj),
1235                            args: Box::new(args),
1236                        });
1237                    }
1238                }
1239            }
1240            OpCode::NewObj => {
1241                let args = stack.pop()?;
1242                let cls = stack.pop()?;
1243                stack.push(Object::Reduce {
1244                    callable: Box::new(cls),
1245                    args: Box::new(args),
1246                });
1247            }
1248            OpCode::Dict => {
1249                let objs = stack.pop_to_marker()?;
1250                let mut dict = HashMap::new();
1251                if objs.len() % 2 != 0 {
1252                    return Err(PickleError::InvalidData(
1253                        "dict requires even number of objects".to_string(),
1254                    ));
1255                }
1256                for chunk in objs.chunks(2) {
1257                    if let Object::String(key) = &chunk[0] {
1258                        dict.insert(key.clone(), chunk[1].clone());
1259                    } else {
1260                        return Err(PickleError::InvalidData(
1261                            "dict key must be a string".to_string(),
1262                        ));
1263                    }
1264                }
1265                stack.push(Object::Dict(dict));
1266            }
1267            OpCode::List => {
1268                let objs = stack.pop_to_marker()?;
1269                stack.push(Object::List(objs));
1270            }
1271            OpCode::Memoize => {
1272                // Store top of stack in memo without popping
1273                // The memo index is the current number of items in the memo
1274                let obj = stack.top()?;
1275                let idx = stack.memo_len() as u32;
1276                stack.memo_put(idx, obj);
1277            }
1278            OpCode::Stop => break,
1279        }
1280    }
1281    stack.pop()
1282}
1283
1284/// Load tensors from a pickle file (PyTorch checkpoint format)
1285pub fn read_pickle_tensors<R: BufRead>(reader: &mut R) -> Result<HashMap<String, TensorSnapshot>> {
1286    let obj = read_pickle(reader)?;
1287
1288    // Extract tensors from the loaded object
1289    let mut tensors = HashMap::new();
1290    let mut path = Vec::new();
1291    extract_tensors(&obj, &mut path, &mut tensors);
1292
1293    Ok(tensors)
1294}
1295
1296fn extract_tensors<'a>(
1297    obj: &'a Object,
1298    path: &mut Vec<&'a str>,
1299    tensors: &mut HashMap<String, TensorSnapshot>,
1300) {
1301    match obj {
1302        Object::Dict(dict) => {
1303            for (key, value) in dict {
1304                path.push(key);
1305                extract_tensors(value, path, tensors);
1306                path.pop();
1307            }
1308        }
1309        Object::TorchParam(snapshot) => {
1310            // Only allocate the string here when we actually insert
1311            tensors.insert(path.join("."), snapshot.clone());
1312        }
1313        _ => {}
1314    }
1315}