1use crate::TensorSnapshot;
14use crate::pytorch::lazy_data::LazyDataSource;
15use alloc::rc::Rc;
16use alloc::string::{String, ToString};
17use alloc::vec::Vec;
18use burn_core::module::ParamId;
19use burn_tensor::{DType, TensorData};
20use byteorder::{LittleEndian, ReadBytesExt};
21use half::{bf16, f16};
22use std::collections::HashMap;
23use std::io::{self, BufRead};
24use std::sync::Arc;
25
26#[derive(Debug)]
28pub enum PickleError {
29 Io(io::Error),
30 InvalidOpCode(u8),
31 InvalidProtocol(u8),
32 UnexpectedOpCode(OpCode),
33 UnsupportedType(String),
34 InvalidData(String),
35 StackUnderflow,
36 MemoNotFound(u32),
37 InvalidShapeOrType,
38}
39
40impl From<io::Error> for PickleError {
41 fn from(e: io::Error) -> Self {
42 PickleError::Io(e)
43 }
44}
45
46impl std::fmt::Display for PickleError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 PickleError::Io(e) => write!(f, "IO error: {}", e),
50 PickleError::InvalidOpCode(code) => write!(
51 f,
52 "Invalid pickle opcode: 0x{:02x}. The file may be corrupted or use an unsupported pickle protocol.",
53 code
54 ),
55 PickleError::InvalidProtocol(proto) => write!(
56 f,
57 "Invalid or unsupported pickle protocol version: {}. Supported versions are 2-5.",
58 proto
59 ),
60 PickleError::UnexpectedOpCode(op) => {
61 write!(f, "Unexpected pickle opcode {:?} in current context", op)
62 }
63 PickleError::UnsupportedType(ty) => write!(
64 f,
65 "Unsupported Python type '{}'. This may indicate a full model save rather than a state_dict.",
66 ty
67 ),
68 PickleError::InvalidData(msg) => write!(f, "Invalid data in pickle file: {}", msg),
69 PickleError::StackUnderflow => {
70 write!(f, "Pickle stack underflow - the file may be corrupted")
71 }
72 PickleError::MemoNotFound(idx) => write!(
73 f,
74 "Pickle memo reference {} not found - the file may be corrupted",
75 idx
76 ),
77 PickleError::InvalidShapeOrType => {
78 write!(f, "Invalid tensor shape or data type in PyTorch file")
79 }
80 }
81 }
82}
83
84impl std::error::Error for PickleError {}
85
86type Result<T> = std::result::Result<T, PickleError>;
87
88#[repr(u8)]
90#[derive(Debug, Eq, PartialEq, Clone)]
91pub enum OpCode {
92 Proto = 0x80,
94 Global = b'c',
95 BinPut = b'q',
96 LongBinPut = b'r',
97 EmptyTuple = b')',
98 Reduce = b'R',
99 Mark = b'(',
100 BinUnicode = b'X',
101 ShortBinString = b'U',
102 BinInt = b'J',
103 Int = b'I',
104 Tuple = b't',
105 BinPersId = b'Q',
106 BinInt1 = b'K',
107 BinInt2 = b'M',
108 Tuple1 = 0x85,
109 Tuple2 = 0x86,
110 Tuple3 = 0x87,
111 NewTrue = 0x88,
112 NewFalse = 0x89,
113 None = b'N',
114 BinGet = b'h',
115 LongBinGet = b'j',
116 SetItem = b's',
117 SetItems = b'u',
118 EmptyDict = b'}',
119 Dict = b'd',
120 Build = b'b',
121 Stop = b'.',
122 NewObj = 0x81,
123 EmptyList = b']',
124 List = b'l',
125 BinFloat = b'G',
126 Append = b'a',
127 Appends = b'e',
128 Long1 = 0x8a,
129 Memoize = 0x94,
130}
131
132impl TryFrom<u8> for OpCode {
134 type Error = u8;
135 fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
136 match value {
137 0x80 => Ok(Self::Proto),
138 b'c' => Ok(Self::Global),
139 b'q' => Ok(Self::BinPut),
140 b'r' => Ok(Self::LongBinPut),
141 b')' => Ok(Self::EmptyTuple),
142 b'R' => Ok(Self::Reduce),
143 b'(' => Ok(Self::Mark),
144 b'X' => Ok(Self::BinUnicode),
145 b'U' => Ok(Self::ShortBinString),
146 b'J' => Ok(Self::BinInt),
147 b'I' => Ok(Self::Int),
148 b't' => Ok(Self::Tuple),
149 b'Q' => Ok(Self::BinPersId),
150 b'K' => Ok(Self::BinInt1),
151 b'M' => Ok(Self::BinInt2),
152 b'N' => Ok(Self::None),
153 0x85 => Ok(Self::Tuple1),
154 0x86 => Ok(Self::Tuple2),
155 0x87 => Ok(Self::Tuple3),
156 0x88 => Ok(Self::NewTrue),
157 0x89 => Ok(Self::NewFalse),
158 b'h' => Ok(Self::BinGet),
159 b'j' => Ok(Self::LongBinGet),
160 b's' => Ok(Self::SetItem),
161 b'u' => Ok(Self::SetItems),
162 b'}' => Ok(Self::EmptyDict),
163 b'd' => Ok(Self::Dict),
164 b'b' => Ok(Self::Build),
165 b'.' => Ok(Self::Stop),
166 0x81 => Ok(Self::NewObj),
167 b']' => Ok(Self::EmptyList),
168 b'l' => Ok(Self::List),
169 b'G' => Ok(Self::BinFloat),
170 b'a' => Ok(Self::Append),
171 b'e' => Ok(Self::Appends),
172 0x8a => Ok(Self::Long1),
173 0x94 => Ok(Self::Memoize),
174 value => Err(value),
175 }
176 }
177}
178
179fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
180 let mut data: Vec<u8> = Vec::with_capacity(32);
181 r.read_until(b'\n', &mut data)?;
182 data.pop();
183 if data.last() == Some(&b'\r') {
184 data.pop();
185 }
186 Ok(data)
187}
188
189fn buf_to_str(buf: &[u8]) -> Result<String> {
190 String::from_utf8(buf.to_vec())
191 .map_err(|e| PickleError::InvalidData(format!("Invalid UTF-8: {}", e)))
192}
193
194#[derive(Debug, Clone)]
195pub enum Object {
196 Class {
197 module_name: String,
198 name: String,
199 },
200 String(String),
201 Int(i64),
202 Float(f64),
203 Bool(bool),
204 None,
205 Tuple(Vec<Object>),
206 List(Vec<Object>),
207 Dict(HashMap<String, Object>),
208 Persistent(Vec<u8>),
209 PersistentTuple(Vec<Object>),
210 Reduce {
211 callable: Box<Object>,
212 args: Box<Object>,
213 },
214 Build {
215 callable: Box<Object>,
216 args: Box<Object>,
217 },
218 TorchParam(TensorSnapshot),
219}
220
221fn rebuild_from_type_v2(
222 o: Object,
223 memo: &mut HashMap<u32, Object>,
224 data_source: &Option<Arc<LazyDataSource>>,
225) -> Result<Object> {
226 let args = if let Object::Tuple(args) = o {
227 if args.is_empty() {
228 return Err(PickleError::InvalidData(
229 "rebuild_from_type_v2: empty args".to_string(),
230 ));
231 }
232 args
233 } else {
234 return Err(PickleError::InvalidData(format!(
235 "rebuild_from_type_v2: expected tuple got {:?}",
236 o
237 )));
238 };
239 let func = &args[0];
240 match func {
241 Object::Class { module_name, name } => {
242 let module_name = module_name.as_str();
243 let name = name.as_str();
244 let actual_args = if args.len() == 2 && matches!(&args[1], Object::Tuple(_)) {
246 args[1].clone()
248 } else {
249 Object::Tuple(args[1..].to_vec())
251 };
252 if module_name == "torch._utils" && name == "_rebuild_tensor_v2" {
253 rebuild_tensor_v2(actual_args, memo, data_source)
254 } else if module_name == "torch._tensor" && name == "_rebuild_from_type_v2" {
255 rebuild_from_type_v2(actual_args, memo, data_source)
256 } else if module_name == "torch._utils" && name == "_rebuild_parameter" {
257 rebuild_parameter(actual_args, memo, data_source)
258 } else if module_name == "collections" && name == "OrderedDict" {
259 Ok(Object::Dict(HashMap::new()))
261 } else {
262 Err(PickleError::UnsupportedType(format!(
263 "{}.{}",
264 module_name, name
265 )))
266 }
267 }
268 _ => Err(PickleError::InvalidData(format!(
269 "rebuild_from_type_v2: expected class got {:?}",
270 func
271 ))),
272 }
273}
274
275fn rebuild_parameter(
276 args: Object,
277 memo: &mut HashMap<u32, Object>,
278 data_source: &Option<Arc<LazyDataSource>>,
279) -> Result<Object> {
280 let args = if let Object::Tuple(args) = args {
281 if args.is_empty() {
282 return Err(PickleError::InvalidData(
283 "rebuild_parameter: empty args".to_string(),
284 ));
285 }
286 args
287 } else {
288 return Err(PickleError::InvalidData(format!(
289 "rebuild_parameter: expected tuple got {:?}",
290 args
291 )));
292 };
293 let data = &args[0];
294 let tensor = match data {
295 Object::Reduce {
296 callable: _,
297 args: _,
298 } => rebuild_from_type_v2(data.clone(), memo, data_source)?,
299 _ => data.clone(),
300 };
301 Ok(tensor)
302}
303
304fn rebuild_tensor_v2(
305 args: Object,
306 _memo: &mut HashMap<u32, Object>,
307 data_source: &Option<Arc<LazyDataSource>>,
308) -> Result<Object> {
309 let args = if let Object::Tuple(args) = args {
311 args
312 } else {
313 return Err(PickleError::InvalidData(format!(
314 "rebuild_tensor_v2: expected tuple got {:?}",
315 args
316 )));
317 };
318
319 if args.len() < 5 {
320 return Err(PickleError::InvalidData(format!(
321 "rebuild_tensor_v2: expected at least 5 args, got {}",
322 args.len()
323 )));
324 }
325
326 let (storage_info, storage_tuple) = match &args[0] {
328 Object::Persistent(data) => (data.clone(), None),
329 Object::PersistentTuple(tuple) => (vec![], Some(tuple.clone())),
330 _ => {
331 return Err(PickleError::InvalidData(format!(
332 "rebuild_tensor_v2: expected persistent id got {:?}",
333 args[0]
334 )));
335 }
336 };
337
338 let storage_offset = match &args[1] {
340 Object::Int(offset) => *offset as usize,
341 _ => 0,
342 };
343
344 let shape = match &args[2] {
346 Object::Tuple(shape) => shape
347 .iter()
348 .map(|x| match x {
349 Object::Int(i) => Ok(*i as usize),
350 _ => Err(PickleError::InvalidData(
351 "shape must contain ints".to_string(),
352 )),
353 })
354 .collect::<Result<Vec<_>>>()?,
355 _ => {
356 return Err(PickleError::InvalidData(format!(
357 "rebuild_tensor_v2: expected shape tuple got {:?}",
358 args[2]
359 )));
360 }
361 };
362
363 let _stride = matches!(&args[3], Object::Tuple(_));
365
366 let (dtype, storage_key) = if let Some(tuple) = storage_tuple {
369 if tuple.len() >= 3 {
371 let storage_type = match &tuple[1] {
372 Object::String(s) => s.as_str(),
373 Object::Class {
374 module_name: _,
375 name,
376 } => name.as_str(),
377 _ => "FloatStorage",
378 };
379 let dtype = match storage_type {
380 "FloatStorage" => DType::F32,
381 "DoubleStorage" => DType::F64,
382 "HalfStorage" => DType::F16,
383 "BFloat16Storage" => DType::BF16,
384 "LongStorage" => DType::I64,
385 "IntStorage" => DType::I32,
386 "ShortStorage" => DType::I16,
387 "CharStorage" => DType::I8,
388 "ByteStorage" => DType::U8,
389 "BoolStorage" => DType::Bool,
390 _ => DType::F32, };
392 let key = match &tuple[2] {
393 Object::String(s) => s.clone(),
394 _ => "0".to_string(),
395 };
396 (dtype, key)
397 } else {
398 (DType::F32, "0".to_string())
399 }
400 } else if !storage_info.is_empty() {
401 let storage_str = String::from_utf8_lossy(&storage_info);
403 if storage_str.starts_with("Tuple(") {
404 let parts: Vec<&str> = storage_str
406 .trim_start_matches("Tuple(")
407 .trim_end_matches(")")
408 .split(", ")
409 .map(|s| {
410 let trimmed = s.trim_matches('"');
411 if let Some(inner) = trimmed
412 .strip_prefix("Object::String(\"")
413 .and_then(|s| s.strip_suffix("\")"))
414 {
415 inner
416 } else {
417 trimmed
418 }
419 })
420 .collect();
421
422 if parts.len() >= 3 {
423 let dtype = match parts[1] {
424 "FloatStorage" => DType::F32,
425 "DoubleStorage" => DType::F64,
426 "HalfStorage" => DType::F16,
427 "BFloat16Storage" => DType::BF16,
428 "LongStorage" => DType::I64,
429 "IntStorage" => DType::I32,
430 "ShortStorage" => DType::I16,
431 "CharStorage" => DType::I8,
432 "ByteStorage" => DType::U8,
433 _ => DType::F32, };
435 (dtype, parts[2].to_string())
436 } else {
437 (DType::F32, "0".to_string())
438 }
439 } else {
440 (DType::F32, "0".to_string())
441 }
442 } else {
443 (DType::F32, "0".to_string())
444 };
445
446 let data_source = match data_source {
448 Some(ds) => ds.clone(),
449 None => {
450 return Err(PickleError::InvalidData(
451 "Cannot load tensor data without a data source".to_string(),
452 ));
453 }
454 };
455
456 let data_source_clone = data_source.clone();
458 let shape_clone = shape.clone();
459
460 let data_file_key = {
462 let exact_key = format!("data/{}", storage_key);
463 if data_source.contains(&exact_key) {
464 exact_key
465 } else {
466 data_source
468 .keys()
469 .into_iter()
470 .find(|key| {
471 key.ends_with(&format!("/data/{}", storage_key))
472 || (key.contains("/data/") && key.rsplit('/').next() == Some(&storage_key))
473 })
474 .unwrap_or_else(|| format!("data/{}", storage_key))
475 }
476 };
477
478 if let LazyDataSource::LegacyMultiStorage(ref source) = *data_source {
481 let source = source
482 .lock()
483 .unwrap_or_else(|poisoned| poisoned.into_inner());
484 let num_elements: usize = shape.iter().product();
485 let bytes_needed = storage_offset * dtype.size() + num_elements * dtype.size();
486 source.track_storage_usage(&storage_key, 0, bytes_needed);
487 }
488
489 Ok(Object::TorchParam(TensorSnapshot::from_closure(
492 Rc::new(move || {
493 if let Ok(data) = data_source_clone.read(&data_file_key) {
495 let num_elements = shape_clone.iter().product::<usize>().max(1);
497
498 let element_size = dtype.size();
500
501 let offset_bytes = storage_offset * element_size;
503 if offset_bytes >= data.len() {
504 return Ok(TensorData::new(
505 vec![0.0f32; num_elements],
506 shape_clone.clone(),
507 ));
508 }
509
510 let data_slice = &data[offset_bytes..];
511 let available_elements = data_slice.len() / element_size;
512 let elements_to_read = num_elements.min(available_elements);
513
514 match dtype {
516 DType::F32 => {
517 let mut values = Vec::with_capacity(num_elements);
518 for i in 0..elements_to_read {
519 let bytes = [
520 data_slice[i * element_size],
521 data_slice[i * element_size + 1],
522 data_slice[i * element_size + 2],
523 data_slice[i * element_size + 3],
524 ];
525 values.push(f32::from_le_bytes(bytes));
526 }
527 values.resize(num_elements, 0.0);
529 Ok(TensorData::new(values, shape_clone.clone()))
530 }
531 DType::F64 => {
532 let mut values = Vec::with_capacity(num_elements);
533 for i in 0..elements_to_read {
534 let mut bytes = [0u8; 8];
535 bytes.copy_from_slice(
536 &data_slice[i * element_size..(i + 1) * element_size],
537 );
538 values.push(f64::from_le_bytes(bytes));
539 }
540 values.resize(num_elements, 0.0);
541 Ok(TensorData::new(values, shape_clone.clone()))
542 }
543 DType::I64 => {
544 let mut values = Vec::with_capacity(num_elements);
545 for i in 0..elements_to_read {
546 let mut bytes = [0u8; 8];
547 bytes.copy_from_slice(
548 &data_slice[i * element_size..(i + 1) * element_size],
549 );
550 values.push(i64::from_le_bytes(bytes));
551 }
552 values.resize(num_elements, 0);
553 Ok(TensorData::new(values, shape_clone.clone()))
554 }
555 DType::I32 => {
556 let mut values = Vec::with_capacity(num_elements);
557 for i in 0..elements_to_read {
558 let mut bytes = [0u8; 4];
559 bytes.copy_from_slice(
560 &data_slice[i * element_size..(i + 1) * element_size],
561 );
562 values.push(i32::from_le_bytes(bytes));
563 }
564 values.resize(num_elements, 0);
565 Ok(TensorData::new(values, shape_clone.clone()))
566 }
567 DType::I16 => {
568 let mut values = Vec::with_capacity(num_elements);
569 for i in 0..elements_to_read {
570 let mut bytes = [0u8; 2];
571 bytes.copy_from_slice(
572 &data_slice[i * element_size..(i + 1) * element_size],
573 );
574 values.push(i16::from_le_bytes(bytes));
575 }
576 values.resize(num_elements, 0);
577 Ok(TensorData::new(values, shape_clone.clone()))
578 }
579 DType::I8 => {
580 let mut values = Vec::with_capacity(num_elements);
581 for &byte in data_slice.iter().take(elements_to_read) {
582 values.push(byte as i8);
583 }
584 values.resize(num_elements, 0);
585 Ok(TensorData::new(values, shape_clone.clone()))
586 }
587 DType::Bool => {
588 let mut values = Vec::with_capacity(num_elements);
589 for &byte in data_slice.iter().take(elements_to_read) {
590 values.push(byte != 0);
591 }
592 values.resize(num_elements, false);
593 Ok(TensorData::new(values, shape_clone.clone()))
594 }
595 DType::F16 => {
596 let mut values = Vec::with_capacity(num_elements);
597 for i in 0..elements_to_read {
598 let mut bytes = [0u8; 2];
599 bytes.copy_from_slice(
600 &data_slice[i * element_size..(i + 1) * element_size],
601 );
602 values.push(f16::from_le_bytes(bytes));
603 }
604 values.resize(num_elements, f16::ZERO);
605 Ok(TensorData::new(values, shape_clone.clone()))
606 }
607 DType::BF16 => {
608 let mut values = Vec::with_capacity(num_elements);
609 for i in 0..elements_to_read {
610 let mut bytes = [0u8; 2];
611 bytes.copy_from_slice(
612 &data_slice[i * element_size..(i + 1) * element_size],
613 );
614 values.push(bf16::from_le_bytes(bytes));
615 }
616 values.resize(num_elements, bf16::ZERO);
617 Ok(TensorData::new(values, shape_clone.clone()))
618 }
619 DType::U8 => {
620 let mut values = Vec::with_capacity(num_elements);
621 for &byte in data_slice.iter().take(elements_to_read) {
622 values.push(byte);
623 }
624 values.resize(num_elements, 0);
625 Ok(TensorData::new(values, shape_clone.clone()))
626 }
627 DType::U16 => {
628 let mut values = Vec::with_capacity(num_elements);
629 for i in 0..elements_to_read {
630 let mut bytes = [0u8; 2];
631 bytes.copy_from_slice(
632 &data_slice[i * element_size..(i + 1) * element_size],
633 );
634 values.push(u16::from_le_bytes(bytes));
635 }
636 values.resize(num_elements, 0);
637 Ok(TensorData::new(values, shape_clone.clone()))
638 }
639 DType::U32 => {
640 let mut values = Vec::with_capacity(num_elements);
641 for i in 0..elements_to_read {
642 let mut bytes = [0u8; 4];
643 bytes.copy_from_slice(
644 &data_slice[i * element_size..(i + 1) * element_size],
645 );
646 values.push(u32::from_le_bytes(bytes));
647 }
648 values.resize(num_elements, 0);
649 Ok(TensorData::new(values, shape_clone.clone()))
650 }
651 DType::U64 => {
652 let mut values = Vec::with_capacity(num_elements);
653 for i in 0..elements_to_read {
654 let mut bytes = [0u8; 8];
655 bytes.copy_from_slice(
656 &data_slice[i * element_size..(i + 1) * element_size],
657 );
658 values.push(u64::from_le_bytes(bytes));
659 }
660 values.resize(num_elements, 0);
661 Ok(TensorData::new(values, shape_clone.clone()))
662 }
663 _ => {
664 Err(crate::TensorSnapshotError::DataError(format!(
666 "Unsupported dtype for tensor data reading: {:?}",
667 dtype
668 )))
669 }
670 }
671 } else {
672 let num_elements = shape_clone.iter().product::<usize>().max(1);
674 match dtype {
675 DType::F32 => Ok(TensorData::new(
676 vec![0.0f32; num_elements],
677 shape_clone.clone(),
678 )),
679 DType::F64 => Ok(TensorData::new(
680 vec![0.0f64; num_elements],
681 shape_clone.clone(),
682 )),
683 DType::F16 => Ok(TensorData::new(
684 vec![f16::ZERO; num_elements],
685 shape_clone.clone(),
686 )),
687 DType::BF16 => Ok(TensorData::new(
688 vec![bf16::ZERO; num_elements],
689 shape_clone.clone(),
690 )),
691 DType::I64 => Ok(TensorData::new(
692 vec![0i64; num_elements],
693 shape_clone.clone(),
694 )),
695 DType::I32 => Ok(TensorData::new(
696 vec![0i32; num_elements],
697 shape_clone.clone(),
698 )),
699 DType::I16 => Ok(TensorData::new(
700 vec![0i16; num_elements],
701 shape_clone.clone(),
702 )),
703 DType::I8 => Ok(TensorData::new(
704 vec![0i8; num_elements],
705 shape_clone.clone(),
706 )),
707 DType::U8 => Ok(TensorData::new(
708 vec![0u8; num_elements],
709 shape_clone.clone(),
710 )),
711 DType::U16 => Ok(TensorData::new(
712 vec![0u16; num_elements],
713 shape_clone.clone(),
714 )),
715 DType::U32 => Ok(TensorData::new(
716 vec![0u32; num_elements],
717 shape_clone.clone(),
718 )),
719 DType::U64 => Ok(TensorData::new(
720 vec![0u64; num_elements],
721 shape_clone.clone(),
722 )),
723 DType::Bool => Ok(TensorData::new(
724 vec![false; num_elements],
725 shape_clone.clone(),
726 )),
727 _ => {
728 Err(crate::TensorSnapshotError::DataError(format!(
730 "Unsupported dtype for tensor data reading: {:?}",
731 dtype
732 )))
733 }
734 }
735 }
736 }),
737 dtype,
738 shape,
739 vec![], vec![], ParamId::new(), )))
743}
744
745pub struct Stack {
746 stack: Vec<Object>,
747 memo: HashMap<u32, Object>,
748 data_source: Option<Arc<LazyDataSource>>,
749}
750
751impl Default for Stack {
752 fn default() -> Self {
753 Self::new()
754 }
755}
756
757impl Stack {
758 pub fn new() -> Self {
759 Self {
761 stack: Vec::new(),
762 memo: HashMap::new(),
763 data_source: None,
764 }
765 }
766
767 pub fn with_data_source(data_source: Arc<LazyDataSource>) -> Self {
768 Self {
769 stack: Vec::new(),
770 memo: HashMap::new(),
771 data_source: Some(data_source),
772 }
773 }
774
775 fn push(&mut self, o: Object) {
776 self.stack.push(o)
777 }
778
779 fn pop(&mut self) -> Result<Object> {
780 match self.stack.pop() {
781 None => Err(PickleError::StackUnderflow),
782 Some(o) => Ok(o),
783 }
784 }
785
786 fn top(&self) -> Result<Object> {
787 match self.stack.last() {
788 None => Err(PickleError::StackUnderflow),
789 Some(o) => Ok(o.clone()),
790 }
791 }
792
793 fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
794 let marker_pos = self
795 .stack
796 .iter()
797 .rposition(|o| {
798 matches!(o, Object::Class { module_name, name }
799 if module_name == "mark" && name == "mark")
800 })
801 .ok_or(PickleError::InvalidData("marker not found".to_string()))?;
802
803 let result = self.stack.split_off(marker_pos + 1);
804 self.stack.pop(); Ok(result)
806 }
807
808 fn last_mut(&mut self) -> Result<&mut Object> {
809 match self.stack.last_mut() {
810 None => Err(PickleError::StackUnderflow),
811 Some(o) => Ok(o),
812 }
813 }
814
815 fn push_mark(&mut self) {
816 self.stack.push(Object::Class {
817 module_name: "mark".to_string(),
818 name: "mark".to_string(),
819 });
820 }
821
822 fn memo_get(&self, idx: u32) -> Result<Object> {
823 self.memo
824 .get(&idx)
825 .cloned()
826 .ok_or(PickleError::MemoNotFound(idx))
827 }
828
829 fn memo_put(&mut self, idx: u32, obj: Object) {
830 self.memo.insert(idx, obj);
831 }
832
833 fn memo_len(&self) -> usize {
834 self.memo.len()
835 }
836}
837
838fn read_global<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
839 let module_name = buf_to_str(&read_to_newline(r)?)?;
840 let name = buf_to_str(&read_to_newline(r)?)?;
841 stack.push(Object::Class { module_name, name });
842 Ok(())
843}
844
845fn read_long1<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
846 let len = r.read_u8()? as usize;
847 let mut data = vec![0u8; len];
848 r.read_exact(&mut data)?;
849 let mut value = 0i64;
851 for (i, &byte) in data.iter().enumerate().take(8) {
852 value |= (byte as i64).wrapping_shl((i as u32) * 8);
854 }
855 if len < 8 && data.last().is_some_and(|&b| b & 0x80 != 0) {
857 for i in len..8 {
859 value |= 0xffi64.wrapping_shl((i as u32) * 8);
860 }
861 }
862 stack.push(Object::Int(value));
863 Ok(())
864}
865
866fn read_string<R: BufRead>(r: &mut R, stack: &mut Stack, len: usize) -> Result<()> {
867 let mut data = vec![0u8; len];
868 r.read_exact(&mut data)?;
869 let s = buf_to_str(&data)?;
870 stack.push(Object::String(s));
871 Ok(())
872}
873
874fn read_bin_int<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
875 let v = r.read_i32::<LittleEndian>()?;
876 stack.push(Object::Int(v as i64));
877 Ok(())
878}
879
880fn read_int<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
881 let line = read_to_newline(r)?;
883 let s = buf_to_str(&line)?;
884 let v = s
885 .parse::<i64>()
886 .map_err(|e| PickleError::InvalidData(format!("Invalid INT value '{}': {}", s, e)))?;
887 stack.push(Object::Int(v));
888 Ok(())
889}
890
891fn read_bin_int1<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
892 let v = r.read_u8()?;
893 stack.push(Object::Int(v as i64));
894 Ok(())
895}
896
897fn read_bin_int2<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
898 let v = r.read_u16::<LittleEndian>()?;
899 stack.push(Object::Int(v as i64));
900 Ok(())
901}
902
903fn read_bin_float<R: BufRead>(r: &mut R, stack: &mut Stack) -> Result<()> {
904 let v = r.read_f64::<byteorder::BigEndian>()?;
906 stack.push(Object::Float(v));
907 Ok(())
908}
909
910pub fn read_pickle<R: BufRead>(r: &mut R) -> Result<Object> {
911 read_pickle_with_optional_data(r, None)
913}
914
915pub fn skip_pickle<R: BufRead>(r: &mut R) -> Result<()> {
919 let mut first_byte = [0u8; 1];
921 r.read_exact(&mut first_byte)?;
922
923 if first_byte[0] == 0x80 {
924 let mut proto_version = [0u8; 1];
926 r.read_exact(&mut proto_version)?;
927 } else {
928 }
932
933 loop {
935 let mut byte = [0u8; 1];
936 r.read_exact(&mut byte)?;
937
938 match byte[0] {
939 0x2e => {
940 break;
942 }
943 0x58 | 0x42 | 0x43 | 0x54 | 0x55 | 0x56 | 0x8c | 0x8d | 0x8e => {
944 let length = match byte[0] {
946 0x43 | 0x55 | 0x8c => {
947 let mut len_byte = [0u8; 1];
949 r.read_exact(&mut len_byte)?;
950 len_byte[0] as usize
951 }
952 0x42 | 0x54 | 0x58 | 0x56 => {
953 let mut len_bytes = [0u8; 4];
955 r.read_exact(&mut len_bytes)?;
956 u32::from_le_bytes(len_bytes) as usize
957 }
958 0x8d | 0x8e => {
959 let mut len_bytes = [0u8; 8];
961 r.read_exact(&mut len_bytes)?;
962 u64::from_le_bytes(len_bytes) as usize
963 }
964 _ => 0,
965 };
966
967 let mut skip_buf = vec![0u8; length.min(8192)];
969 let mut skipped = 0;
970 while skipped < length {
971 let to_skip = (length - skipped).min(skip_buf.len());
972 r.read_exact(&mut skip_buf[..to_skip])?;
973 skipped += to_skip;
974 }
975 }
976 0x4b | 0x4d | 0x4e => {
977 let skip_count = match byte[0] {
979 0x4b => 1,
980 0x4d => 2,
981 0x4e => 4,
982 _ => 0,
983 };
984 let mut skip_buf = vec![0u8; skip_count];
985 r.read_exact(&mut skip_buf)?;
986 }
987 0x47 => {
988 let mut skip_buf = [0u8; 8];
990 r.read_exact(&mut skip_buf)?;
991 }
992 0x4a => {
993 let mut skip_buf = [0u8; 4];
995 r.read_exact(&mut skip_buf)?;
996 }
997 0x8a => {
998 let mut len_byte = [0u8; 1];
1000 r.read_exact(&mut len_byte)?;
1001 let length = len_byte[0] as usize;
1002 let mut skip_buf = vec![0u8; length];
1003 r.read_exact(&mut skip_buf)?;
1004 }
1005 0x8b => {
1006 let mut len_bytes = [0u8; 4];
1008 r.read_exact(&mut len_bytes)?;
1009 let length = u32::from_le_bytes(len_bytes) as usize;
1010 let mut skip_buf = vec![0u8; length.min(8192)];
1011 let mut skipped = 0;
1012 while skipped < length {
1013 let to_skip = (length - skipped).min(skip_buf.len());
1014 r.read_exact(&mut skip_buf[..to_skip])?;
1015 skipped += to_skip;
1016 }
1017 }
1018 _ => {
1019 }
1022 }
1023 }
1024
1025 Ok(())
1026}
1027
1028pub fn read_pickle_with_data<R: BufRead>(
1029 r: &mut R,
1030 data_source: Arc<LazyDataSource>,
1031) -> Result<Object> {
1032 read_pickle_with_optional_data(r, Some(data_source))
1033}
1034
1035pub fn read_pickle_with_optional_data<R: BufRead>(
1036 r: &mut R,
1037 data_source: Option<Arc<LazyDataSource>>,
1038) -> Result<Object> {
1039 let mut stack = match data_source {
1040 Some(ds) => Stack::with_data_source(ds),
1041 None => Stack::new(),
1042 };
1043 loop {
1044 let op_code = r.read_u8()?;
1045 let op_code = OpCode::try_from(op_code).map_err(PickleError::InvalidOpCode)?;
1046 match op_code {
1047 OpCode::Proto => {
1048 let version = r.read_u8()?;
1049 if version > 5 {
1050 return Err(PickleError::InvalidProtocol(version));
1051 }
1052 }
1053 OpCode::Global => read_global(r, &mut stack)?,
1054 OpCode::BinInt => read_bin_int(r, &mut stack)?,
1055 OpCode::Int => read_int(r, &mut stack)?,
1056 OpCode::BinInt1 => read_bin_int1(r, &mut stack)?,
1057 OpCode::BinInt2 => read_bin_int2(r, &mut stack)?,
1058 OpCode::BinFloat => read_bin_float(r, &mut stack)?,
1059 OpCode::BinUnicode => {
1060 let len = r.read_u32::<LittleEndian>()? as usize;
1061 read_string(r, &mut stack, len)?
1062 }
1063 OpCode::ShortBinString => {
1064 let len = r.read_u8()? as usize;
1065 read_string(r, &mut stack, len)?
1066 }
1067 OpCode::Long1 => read_long1(r, &mut stack)?,
1068 OpCode::None => stack.push(Object::None),
1069 OpCode::NewTrue => stack.push(Object::Bool(true)),
1070 OpCode::NewFalse => stack.push(Object::Bool(false)),
1071 OpCode::EmptyTuple => stack.push(Object::Tuple(Vec::new())),
1072 OpCode::EmptyList => stack.push(Object::List(Vec::new())),
1073 OpCode::EmptyDict => stack.push(Object::Dict(HashMap::new())),
1074 OpCode::Tuple => {
1075 let objs = stack.pop_to_marker()?;
1076 stack.push(Object::Tuple(objs))
1077 }
1078 OpCode::Tuple1 => {
1079 let obj = stack.pop()?;
1080 stack.push(Object::Tuple(vec![obj]))
1081 }
1082 OpCode::Tuple2 => {
1083 let obj2 = stack.pop()?;
1084 let obj1 = stack.pop()?;
1085 stack.push(Object::Tuple(vec![obj1, obj2]))
1086 }
1087 OpCode::Tuple3 => {
1088 let obj3 = stack.pop()?;
1089 let obj2 = stack.pop()?;
1090 let obj1 = stack.pop()?;
1091 stack.push(Object::Tuple(vec![obj1, obj2, obj3]))
1092 }
1093 OpCode::Append => {
1094 let value = stack.pop()?;
1095 match stack.last_mut()? {
1096 Object::List(list) => list.push(value),
1097 _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1098 }
1099 }
1100 OpCode::Appends => {
1101 let objs = stack.pop_to_marker()?;
1102 match stack.last_mut()? {
1103 Object::List(list) => list.extend(objs),
1104 _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1105 }
1106 }
1107 OpCode::SetItem => {
1108 let value = stack.pop()?;
1109 let key = stack.pop()?;
1110 match stack.last_mut()? {
1111 Object::Dict(dict) => {
1112 if let Object::String(key) = key {
1113 dict.insert(key, value);
1114 } else {
1115 return Err(PickleError::InvalidData(
1116 "dict key must be a string".to_string(),
1117 ));
1118 }
1119 }
1120 _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1121 }
1122 }
1123 OpCode::SetItems => {
1124 let mut objs = stack.pop_to_marker()?;
1125 if objs.len() % 2 != 0 {
1126 return Err(PickleError::InvalidData(
1127 "setitems requires even number of objects".to_string(),
1128 ));
1129 }
1130 match stack.last_mut()? {
1131 Object::Dict(dict) => {
1132 while !objs.is_empty() {
1133 let key = objs.remove(0);
1134 let value = objs.remove(0);
1135 if let Object::String(key) = key {
1136 dict.insert(key, value);
1137 } else {
1138 return Err(PickleError::InvalidData(
1139 "dict key must be a string".to_string(),
1140 ));
1141 }
1142 }
1143 }
1144 _ => return Err(PickleError::UnexpectedOpCode(op_code)),
1145 }
1146 }
1147 OpCode::BinPut => {
1148 let idx = r.read_u8()? as u32;
1149 let obj = stack.top()?;
1150 stack.memo_put(idx, obj);
1151 }
1152 OpCode::LongBinPut => {
1153 let idx = r.read_u32::<LittleEndian>()?;
1154 let obj = stack.top()?;
1155 stack.memo_put(idx, obj);
1156 }
1157 OpCode::BinGet => {
1158 let idx = r.read_u8()? as u32;
1159 let obj = stack.memo_get(idx)?;
1160 stack.push(obj);
1161 }
1162 OpCode::LongBinGet => {
1163 let idx = r.read_u32::<LittleEndian>()?;
1164 let obj = stack.memo_get(idx)?;
1165 stack.push(obj);
1166 }
1167 OpCode::Mark => stack.push_mark(),
1168 OpCode::BinPersId => {
1169 let pid = stack.pop()?;
1170 match pid {
1171 Object::String(s) => {
1172 stack.push(Object::Persistent(s.into_bytes()));
1173 }
1174 Object::Tuple(tuple) => {
1175 stack.push(Object::PersistentTuple(tuple));
1178 }
1179 _ => {
1180 return Err(PickleError::InvalidData(format!(
1181 "persistent id must be a string or tuple, got {:?}",
1182 pid
1183 )));
1184 }
1185 }
1186 }
1187 OpCode::Reduce => {
1188 let args = stack.pop()?;
1189 let callable = stack.pop()?;
1190
1191 if let Object::Class { module_name, name } = &callable {
1193 if module_name == "collections" && name == "OrderedDict" {
1194 stack.push(Object::Dict(HashMap::new()));
1196 } else {
1197 let _obj = Object::Reduce {
1198 callable: Box::new(callable.clone()),
1199 args: Box::new(args.clone()),
1200 };
1201 let obj = rebuild_from_type_v2(
1202 Object::Tuple(vec![callable, args]),
1203 &mut stack.memo,
1204 &stack.data_source,
1205 )?;
1206 stack.push(obj);
1207 }
1208 } else {
1209 let _obj = Object::Reduce {
1210 callable: Box::new(callable.clone()),
1211 args: Box::new(args.clone()),
1212 };
1213 let obj = rebuild_from_type_v2(
1214 Object::Tuple(vec![callable, args]),
1215 &mut stack.memo,
1216 &stack.data_source,
1217 )?;
1218 stack.push(obj);
1219 }
1220 }
1221 OpCode::Build => {
1222 let args = stack.pop()?;
1223 let obj = stack.pop()?;
1224 match obj {
1225 Object::Dict(mut dict) => {
1226 if let Object::Dict(update) = args {
1228 dict.extend(update);
1229 }
1230 stack.push(Object::Dict(dict));
1231 }
1232 _ => {
1233 stack.push(Object::Build {
1234 callable: Box::new(obj),
1235 args: Box::new(args),
1236 });
1237 }
1238 }
1239 }
1240 OpCode::NewObj => {
1241 let args = stack.pop()?;
1242 let cls = stack.pop()?;
1243 stack.push(Object::Reduce {
1244 callable: Box::new(cls),
1245 args: Box::new(args),
1246 });
1247 }
1248 OpCode::Dict => {
1249 let objs = stack.pop_to_marker()?;
1250 let mut dict = HashMap::new();
1251 if objs.len() % 2 != 0 {
1252 return Err(PickleError::InvalidData(
1253 "dict requires even number of objects".to_string(),
1254 ));
1255 }
1256 for chunk in objs.chunks(2) {
1257 if let Object::String(key) = &chunk[0] {
1258 dict.insert(key.clone(), chunk[1].clone());
1259 } else {
1260 return Err(PickleError::InvalidData(
1261 "dict key must be a string".to_string(),
1262 ));
1263 }
1264 }
1265 stack.push(Object::Dict(dict));
1266 }
1267 OpCode::List => {
1268 let objs = stack.pop_to_marker()?;
1269 stack.push(Object::List(objs));
1270 }
1271 OpCode::Memoize => {
1272 let obj = stack.top()?;
1275 let idx = stack.memo_len() as u32;
1276 stack.memo_put(idx, obj);
1277 }
1278 OpCode::Stop => break,
1279 }
1280 }
1281 stack.pop()
1282}
1283
1284pub fn read_pickle_tensors<R: BufRead>(reader: &mut R) -> Result<HashMap<String, TensorSnapshot>> {
1286 let obj = read_pickle(reader)?;
1287
1288 let mut tensors = HashMap::new();
1290 let mut path = Vec::new();
1291 extract_tensors(&obj, &mut path, &mut tensors);
1292
1293 Ok(tensors)
1294}
1295
1296fn extract_tensors<'a>(
1297 obj: &'a Object,
1298 path: &mut Vec<&'a str>,
1299 tensors: &mut HashMap<String, TensorSnapshot>,
1300) {
1301 match obj {
1302 Object::Dict(dict) => {
1303 for (key, value) in dict {
1304 path.push(key);
1305 extract_tensors(value, path, tensors);
1306 path.pop();
1307 }
1308 }
1309 Object::TorchParam(snapshot) => {
1310 tensors.insert(path.join("."), snapshot.clone());
1312 }
1313 _ => {}
1314 }
1315}