candle_core/
pickle.rs

1//! Just enough pickle support to be able to read PyTorch checkpoints.
2// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
3// composable/tensor agnostic at some point.
4use crate::{Context, DType, Error as E, Layout, Result, Tensor};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7use std::io::BufRead;
8
9const VERBOSE: bool = false;
10
11// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
12#[repr(u8)]
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub enum OpCode {
15    // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
16    Proto = 0x80,
17    Global = b'c',
18    BinPut = b'q',
19    LongBinPut = b'r',
20    EmptyTuple = b')',
21    Reduce = b'R',
22    Mark = b'(',
23    BinUnicode = b'X',
24    BinInt = b'J',
25    Tuple = b't',
26    BinPersId = b'Q',
27    BinInt1 = b'K',
28    BinInt2 = b'M',
29    Tuple1 = 0x85,
30    Tuple2 = 0x86,
31    Tuple3 = 0x87,
32    NewTrue = 0x88,
33    NewFalse = 0x89,
34    None = b'N',
35    BinGet = b'h',
36    LongBinGet = b'j',
37    SetItem = b's',
38    SetItems = b'u',
39    EmptyDict = b'}',
40    Dict = b'd',
41    Build = b'b',
42    Stop = b'.',
43    NewObj = 0x81,
44    EmptyList = b']',
45    BinFloat = b'G',
46    Append = b'a',
47    Appends = b'e',
48}
49
50// Avoid using FromPrimitive so as not to drag another dependency.
51impl TryFrom<u8> for OpCode {
52    type Error = u8;
53    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
54        match value {
55            0x80 => Ok(Self::Proto),
56            b'c' => Ok(Self::Global),
57            b'q' => Ok(Self::BinPut),
58            b'r' => Ok(Self::LongBinPut),
59            b')' => Ok(Self::EmptyTuple),
60            b'R' => Ok(Self::Reduce),
61            b'(' => Ok(Self::Mark),
62            b'X' => Ok(Self::BinUnicode),
63            b'J' => Ok(Self::BinInt),
64            b't' => Ok(Self::Tuple),
65            b'Q' => Ok(Self::BinPersId),
66            b'K' => Ok(Self::BinInt1),
67            b'M' => Ok(Self::BinInt2),
68            b'N' => Ok(Self::None),
69            0x85 => Ok(Self::Tuple1),
70            0x86 => Ok(Self::Tuple2),
71            0x87 => Ok(Self::Tuple3),
72            0x88 => Ok(Self::NewTrue),
73            0x89 => Ok(Self::NewFalse),
74            b'h' => Ok(Self::BinGet),
75            b'j' => Ok(Self::LongBinGet),
76            b's' => Ok(Self::SetItem),
77            b'u' => Ok(Self::SetItems),
78            b'}' => Ok(Self::EmptyDict),
79            b'd' => Ok(Self::EmptyDict),
80            b'b' => Ok(Self::Build),
81            b'.' => Ok(Self::Stop),
82            0x81 => Ok(Self::NewObj),
83            b']' => Ok(Self::EmptyList),
84            b'G' => Ok(Self::BinFloat),
85            b'a' => Ok(Self::Append),
86            b'e' => Ok(Self::Appends),
87            value => Err(value),
88        }
89    }
90}
91
92fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
93    let mut data: Vec<u8> = Vec::with_capacity(32);
94    r.read_until(b'\n', &mut data)?;
95    data.pop();
96    if data.last() == Some(&b'\r') {
97        data.pop();
98    }
99    Ok(data)
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum Object {
104    Class {
105        module_name: String,
106        class_name: String,
107    },
108    Int(i32),
109    Float(f64),
110    Unicode(String),
111    Bool(bool),
112    None,
113    Tuple(Vec<Object>),
114    List(Vec<Object>),
115    Mark,
116    Dict(Vec<(Object, Object)>),
117    Reduce {
118        callable: Box<Object>,
119        args: Box<Object>,
120    },
121    Build {
122        callable: Box<Object>,
123        args: Box<Object>,
124    },
125    PersistentLoad(Box<Object>),
126}
127
128type OResult<T> = std::result::Result<T, Object>;
129
130impl Object {
131    pub fn unicode(self) -> OResult<String> {
132        match self {
133            Self::Unicode(t) => Ok(t),
134            _ => Err(self),
135        }
136    }
137
138    pub fn reduce(self) -> OResult<(Self, Self)> {
139        match self {
140            Self::Reduce { callable, args } => Ok((*callable, *args)),
141            _ => Err(self),
142        }
143    }
144
145    pub fn none(self) -> OResult<()> {
146        match self {
147            Self::None => Ok(()),
148            _ => Err(self),
149        }
150    }
151
152    pub fn persistent_load(self) -> OResult<Self> {
153        match self {
154            Self::PersistentLoad(t) => Ok(*t),
155            _ => Err(self),
156        }
157    }
158
159    pub fn bool(self) -> OResult<bool> {
160        match self {
161            Self::Bool(t) => Ok(t),
162            _ => Err(self),
163        }
164    }
165
166    pub fn int(self) -> OResult<i32> {
167        match self {
168            Self::Int(t) => Ok(t),
169            _ => Err(self),
170        }
171    }
172
173    pub fn tuple(self) -> OResult<Vec<Self>> {
174        match self {
175            Self::Tuple(t) => Ok(t),
176            _ => Err(self),
177        }
178    }
179
180    pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
181        match self {
182            Self::Dict(t) => Ok(t),
183            _ => Err(self),
184        }
185    }
186
187    pub fn class(self) -> OResult<(String, String)> {
188        match self {
189            Self::Class {
190                module_name,
191                class_name,
192            } => Ok((module_name, class_name)),
193            _ => Err(self),
194        }
195    }
196
197    pub fn into_tensor_info(
198        self,
199        name: Self,
200        dir_name: &std::path::Path,
201    ) -> Result<Option<TensorInfo>> {
202        let name = match name.unicode() {
203            Ok(name) => name,
204            Err(_) => return Ok(None),
205        };
206        let (callable, args) = match self.reduce() {
207            Ok(callable_args) => callable_args,
208            _ => return Ok(None),
209        };
210        let (callable, args) = match callable {
211            Object::Class {
212                module_name,
213                class_name,
214            } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
215                let mut args = args.tuple()?;
216                let callable = args.remove(0);
217                let args = args.remove(1);
218                (callable, args)
219            }
220            Object::Class {
221                module_name,
222                class_name,
223            } if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
224                let mut args = args.tuple()?;
225                args.remove(0).reduce()?
226            }
227            _ => (callable, args),
228        };
229        match callable {
230            Object::Class {
231                module_name,
232                class_name,
233            } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
234            _ => return Ok(None),
235        };
236        let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
237        Ok(Some(TensorInfo {
238            name,
239            dtype,
240            layout,
241            path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
242            storage_size,
243        }))
244    }
245}
246
247impl TryFrom<Object> for String {
248    type Error = Object;
249    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
250        match value {
251            Object::Unicode(s) => Ok(s),
252            other => Err(other),
253        }
254    }
255}
256
257impl TryFrom<Object> for usize {
258    type Error = Object;
259    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
260        match value {
261            Object::Int(s) if s >= 0 => Ok(s as usize),
262            other => Err(other),
263        }
264    }
265}
266
267impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
268    type Error = Object;
269    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
270        match value {
271            Object::Tuple(values) => {
272                // This does not return the appropriate value in the error case but instead return
273                // the object related to the first error.
274                values
275                    .into_iter()
276                    .map(|v| T::try_from(v))
277                    .collect::<std::result::Result<Vec<T>, Self::Error>>()
278            }
279            other => Err(other),
280        }
281    }
282}
283
284#[derive(Debug)]
285pub struct Stack {
286    stack: Vec<Object>,
287    memo: HashMap<u32, Object>,
288}
289
290impl Stack {
291    pub fn empty() -> Self {
292        Self {
293            stack: Vec::with_capacity(512),
294            memo: HashMap::new(),
295        }
296    }
297
298    pub fn stack(&self) -> &[Object] {
299        self.stack.as_slice()
300    }
301
302    pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
303        loop {
304            if self.read(r)? {
305                break;
306            }
307        }
308        Ok(())
309    }
310
311    pub fn finalize(mut self) -> Result<Object> {
312        self.pop()
313    }
314
315    fn push(&mut self, obj: Object) {
316        self.stack.push(obj)
317    }
318
319    fn pop(&mut self) -> Result<Object> {
320        match self.stack.pop() {
321            None => crate::bail!("unexpected empty stack"),
322            Some(obj) => Ok(obj),
323        }
324    }
325
326    // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
327    fn build(&mut self) -> Result<()> {
328        let args = self.pop()?;
329        let obj = self.pop()?;
330        let obj = match (obj, args) {
331            (Object::Dict(mut obj), Object::Dict(mut args)) => {
332                obj.append(&mut args);
333                Object::Dict(obj)
334            }
335            (obj, args) => Object::Build {
336                callable: Box::new(obj),
337                args: Box::new(args),
338            },
339        };
340        self.push(obj);
341        Ok(())
342    }
343
344    fn reduce(&mut self) -> Result<()> {
345        let args = self.pop()?;
346        let callable = self.pop()?;
347        #[allow(clippy::single_match)]
348        let reduced = match &callable {
349            Object::Class {
350                module_name,
351                class_name,
352            } => {
353                if module_name == "collections"
354                    && (class_name == "OrderedDict" || class_name == "defaultdict")
355                {
356                    // TODO: have a separate ordered dict and a separate default dict.
357                    Some(Object::Dict(vec![]))
358                } else {
359                    None
360                }
361            }
362            _ => None,
363        };
364        let reduced = reduced.unwrap_or_else(|| Object::Reduce {
365            callable: Box::new(callable),
366            args: Box::new(args),
367        });
368        self.push(reduced);
369        Ok(())
370    }
371
372    fn last(&mut self) -> Result<&mut Object> {
373        match self.stack.last_mut() {
374            None => crate::bail!("unexpected empty stack"),
375            Some(obj) => Ok(obj),
376        }
377    }
378
379    fn memo_get(&self, id: u32) -> Result<Object> {
380        match self.memo.get(&id) {
381            None => crate::bail!("missing object in memo {id}"),
382            Some(obj) => {
383                // Maybe we should use refcounting rather than doing potential large clones here.
384                Ok(obj.clone())
385            }
386        }
387    }
388
389    fn memo_put(&mut self, id: u32) -> Result<()> {
390        let obj = self.last()?.clone();
391        self.memo.insert(id, obj);
392        Ok(())
393    }
394
395    fn persistent_load(&self, id: Object) -> Result<Object> {
396        Ok(Object::PersistentLoad(Box::new(id)))
397    }
398
399    fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
400        Ok(Object::Reduce {
401            callable: Box::new(class),
402            args: Box::new(args),
403        })
404    }
405
406    fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
407        let mut mark_idx = None;
408        for (idx, obj) in self.stack.iter().enumerate().rev() {
409            if obj == &Object::Mark {
410                mark_idx = Some(idx);
411                break;
412            }
413        }
414        match mark_idx {
415            Some(mark_idx) => {
416                let objs = self.stack.split_off(mark_idx + 1);
417                self.stack.pop();
418                Ok(objs)
419            }
420            None => {
421                crate::bail!("marker object not found")
422            }
423        }
424    }
425
426    pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
427        let op_code = match OpCode::try_from(r.read_u8()?) {
428            Ok(op_code) => op_code,
429            Err(op_code) => {
430                crate::bail!("unknown op-code {op_code}")
431            }
432        };
433        // println!("op: {op_code:?}");
434        // println!("{:?}", self.stack);
435        match op_code {
436            OpCode::Proto => {
437                let version = r.read_u8()?;
438                if VERBOSE {
439                    println!("proto {version}");
440                }
441            }
442            OpCode::Global => {
443                let module_name = read_to_newline(r)?;
444                let class_name = read_to_newline(r)?;
445                let module_name = String::from_utf8_lossy(&module_name).to_string();
446                let class_name = String::from_utf8_lossy(&class_name).to_string();
447                self.push(Object::Class {
448                    module_name,
449                    class_name,
450                })
451            }
452            OpCode::BinInt1 => {
453                let arg = r.read_u8()?;
454                self.push(Object::Int(arg as i32))
455            }
456            OpCode::BinInt2 => {
457                let arg = r.read_u16::<LittleEndian>()?;
458                self.push(Object::Int(arg as i32))
459            }
460            OpCode::BinInt => {
461                let arg = r.read_i32::<LittleEndian>()?;
462                self.push(Object::Int(arg))
463            }
464            OpCode::BinFloat => {
465                // Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
466                // https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
467                // https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
468                let arg = r.read_f64::<byteorder::BigEndian>()?;
469                self.push(Object::Float(arg))
470            }
471            OpCode::BinUnicode => {
472                let len = r.read_u32::<LittleEndian>()?;
473                let mut data = vec![0u8; len as usize];
474                r.read_exact(&mut data)?;
475                let data = String::from_utf8(data).map_err(E::wrap)?;
476                self.push(Object::Unicode(data))
477            }
478            OpCode::BinPersId => {
479                let id = self.pop()?;
480                let obj = self.persistent_load(id)?;
481                self.push(obj)
482            }
483            OpCode::Tuple => {
484                let objs = self.pop_to_marker()?;
485                self.push(Object::Tuple(objs))
486            }
487            OpCode::Tuple1 => {
488                let obj = self.pop()?;
489                self.push(Object::Tuple(vec![obj]))
490            }
491            OpCode::Tuple2 => {
492                let obj2 = self.pop()?;
493                let obj1 = self.pop()?;
494                self.push(Object::Tuple(vec![obj1, obj2]))
495            }
496            OpCode::Tuple3 => {
497                let obj3 = self.pop()?;
498                let obj2 = self.pop()?;
499                let obj1 = self.pop()?;
500                self.push(Object::Tuple(vec![obj1, obj2, obj3]))
501            }
502            OpCode::NewTrue => self.push(Object::Bool(true)),
503            OpCode::NewFalse => self.push(Object::Bool(false)),
504            OpCode::Append => {
505                let value = self.pop()?;
506                let pylist = self.last()?;
507                if let Object::List(d) = pylist {
508                    d.push(value)
509                } else {
510                    crate::bail!("expected a list, got {pylist:?}")
511                }
512            }
513            OpCode::Appends => {
514                let objs = self.pop_to_marker()?;
515                let pylist = self.last()?;
516                if let Object::List(d) = pylist {
517                    d.extend(objs)
518                } else {
519                    crate::bail!("expected a list, got {pylist:?}")
520                }
521            }
522            OpCode::SetItem => {
523                let value = self.pop()?;
524                let key = self.pop()?;
525                let pydict = self.last()?;
526                if let Object::Dict(d) = pydict {
527                    d.push((key, value))
528                } else {
529                    crate::bail!("expected a dict, got {pydict:?}")
530                }
531            }
532            OpCode::SetItems => {
533                let mut objs = self.pop_to_marker()?;
534                let pydict = self.last()?;
535                if let Object::Dict(d) = pydict {
536                    if objs.len() % 2 != 0 {
537                        crate::bail!("setitems: not an even number of objects")
538                    }
539                    while let Some(value) = objs.pop() {
540                        let key = objs.pop().context("empty objs")?;
541                        d.push((key, value))
542                    }
543                } else {
544                    crate::bail!("expected a dict, got {pydict:?}")
545                }
546            }
547            OpCode::None => self.push(Object::None),
548            OpCode::Stop => {
549                return Ok(true);
550            }
551            OpCode::Build => self.build()?,
552            OpCode::EmptyDict => self.push(Object::Dict(vec![])),
553            OpCode::Dict => {
554                let mut objs = self.pop_to_marker()?;
555                let mut pydict = vec![];
556                if objs.len() % 2 != 0 {
557                    crate::bail!("setitems: not an even number of objects")
558                }
559                while let Some(value) = objs.pop() {
560                    let key = objs.pop().context("empty objs")?;
561                    pydict.push((key, value))
562                }
563                self.push(Object::Dict(pydict))
564            }
565            OpCode::Mark => self.push(Object::Mark),
566            OpCode::Reduce => self.reduce()?,
567            OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
568            OpCode::EmptyList => self.push(Object::List(vec![])),
569            OpCode::BinGet => {
570                let arg = r.read_u8()?;
571                let obj = self.memo_get(arg as u32)?;
572                self.push(obj)
573            }
574            OpCode::LongBinGet => {
575                let arg = r.read_u32::<LittleEndian>()?;
576                let obj = self.memo_get(arg)?;
577                self.push(obj)
578            }
579            OpCode::BinPut => {
580                let arg = r.read_u8()?;
581                self.memo_put(arg as u32)?
582            }
583            OpCode::LongBinPut => {
584                let arg = r.read_u32::<LittleEndian>()?;
585                self.memo_put(arg)?
586            }
587            OpCode::NewObj => {
588                let args = self.pop()?;
589                let class = self.pop()?;
590                let obj = self.new_obj(class, args)?;
591                self.push(obj)
592            }
593        }
594        Ok(false)
595    }
596}
597
598impl From<Object> for E {
599    fn from(value: Object) -> Self {
600        E::Msg(format!("conversion error on {value:?}"))
601    }
602}
603
604// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
605// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
606fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
607    let mut args = args.tuple()?;
608    let stride = Vec::<usize>::try_from(args.remove(3))?;
609    let size = Vec::<usize>::try_from(args.remove(2))?;
610    let offset = args.remove(1).int()? as usize;
611    let storage = args.remove(0).persistent_load()?;
612    let mut storage = storage.tuple()?;
613    let storage_size = storage.remove(4).int()? as usize;
614    let path = storage.remove(2).unicode()?;
615    let (_module_name, class_name) = storage.remove(1).class()?;
616    let dtype = match class_name.as_str() {
617        "FloatStorage" => DType::F32,
618        "DoubleStorage" => DType::F64,
619        "HalfStorage" => DType::F16,
620        "BFloat16Storage" => DType::BF16,
621        "ByteStorage" => DType::U8,
622        "LongStorage" => DType::I64,
623        other => {
624            crate::bail!("unsupported storage type {other}")
625        }
626    };
627    let layout = Layout::new(crate::Shape::from(size), stride, offset);
628    Ok((layout, dtype, path, storage_size))
629}
630
631#[derive(Debug, Clone)]
632pub struct TensorInfo {
633    pub name: String,
634    pub dtype: DType,
635    pub layout: Layout,
636    pub path: String,
637    pub storage_size: usize,
638}
639
640/// Read the tensor info from a .pth file.
641///
642/// # Arguments
643/// * `file` - The path to the .pth file.
644/// * `verbose` - Whether to print debug information.
645/// * `key` - Optional key to retrieve `state_dict` from the pth file.
646pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
647    file: P,
648    verbose: bool,
649    key: Option<&str>,
650) -> Result<Vec<TensorInfo>> {
651    let file = std::fs::File::open(file)?;
652    let zip_reader = std::io::BufReader::new(file);
653    let mut zip = zip::ZipArchive::new(zip_reader)?;
654    let zip_file_names = zip
655        .file_names()
656        .map(|f| f.to_string())
657        .collect::<Vec<String>>();
658
659    let mut tensor_infos = vec![];
660    for file_name in zip_file_names.iter() {
661        if !file_name.ends_with("data.pkl") {
662            continue;
663        }
664        let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
665        let reader = zip.by_name(file_name)?;
666        let mut reader = std::io::BufReader::new(reader);
667        let mut stack = Stack::empty();
668        stack.read_loop(&mut reader)?;
669        let obj = stack.finalize()?;
670        if VERBOSE || verbose {
671            println!("{obj:#?}");
672        }
673
674        let obj = match obj {
675            Object::Build { callable, args } => match *callable {
676                Object::Reduce { callable, args: _ } => match *callable {
677                    Object::Class {
678                        module_name,
679                        class_name,
680                    } if module_name == "__torch__" && class_name == "Module" => *args,
681                    _ => continue,
682                },
683                _ => continue,
684            },
685            obj => obj,
686        };
687
688        // If key is provided, then we need to extract the state_dict from the object.
689        let obj = if let Some(key) = key {
690            if let Object::Dict(key_values) = obj {
691                key_values
692                    .into_iter()
693                    .find(|(k, _)| *k == Object::Unicode(key.to_owned()))
694                    .map(|(_, v)| v)
695                    .ok_or_else(|| E::Msg(format!("key {key} not found")))?
696            } else {
697                obj
698            }
699        } else {
700            obj
701        };
702
703        // If the object is a dict, then we can extract the tensor info from it.
704        // NOTE: We are assuming that the `obj` is state_dict by this stage.
705        if let Object::Dict(key_values) = obj {
706            for (name, value) in key_values.into_iter() {
707                match value.into_tensor_info(name, &dir_name) {
708                    Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
709                    Ok(None) => {}
710                    Err(err) => eprintln!("skipping: {err:?}"),
711                }
712            }
713        }
714    }
715    Ok(tensor_infos)
716}
717
718/// Lazy tensor loader.
719pub struct PthTensors {
720    tensor_infos: HashMap<String, TensorInfo>,
721    path: std::path::PathBuf,
722    // We do not store a zip reader as it needs mutable access to extract data. Instead we
723    // re-create a zip reader for each tensor.
724}
725
726impl PthTensors {
727    pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
728        let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
729        let tensor_infos = tensor_infos
730            .into_iter()
731            .map(|ti| (ti.name.to_string(), ti))
732            .collect();
733        let path = path.as_ref().to_owned();
734        Ok(Self { tensor_infos, path })
735    }
736
737    pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
738        &self.tensor_infos
739    }
740
741    pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
742        use std::io::Read;
743        let tensor_info = match self.tensor_infos.get(name) {
744            None => return Ok(None),
745            Some(tensor_info) => tensor_info,
746        };
747        // We hope that the file has not changed since first reading it.
748        let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
749        let mut zip = zip::ZipArchive::new(zip_reader)?;
750        let mut reader = zip.by_name(&tensor_info.path)?;
751        let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
752        let rank = tensor_info.layout.shape().rank();
753
754        // Reading the data is a bit tricky as it can be strided, for now only support the basic
755        // case and when the tensor is fortran contiguous.
756        if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
757            crate::bail!(
758                "cannot retrieve non-contiguous tensors {:?}",
759                tensor_info.layout
760            )
761        }
762        let start_offset = tensor_info.layout.start_offset();
763        if start_offset > 0 {
764            std::io::copy(
765                &mut reader.by_ref().take(start_offset as u64),
766                &mut std::io::sink(),
767            )?;
768        }
769        let tensor = Tensor::from_reader(
770            tensor_info.layout.shape().clone(),
771            tensor_info.dtype,
772            &mut reader,
773        )?;
774
775        if rank > 1 && is_fortran_contiguous {
776            // Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
777            let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
778            let tensor = tensor.reshape(shape_reversed)?;
779
780            // Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
781            let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
782            let tensor = tensor.permute(dim_indeces_reversed)?;
783            Ok(Some(tensor))
784        } else {
785            Ok(Some(tensor))
786        }
787    }
788}
789
790/// Read all the tensors from a PyTorch pth file with a given key.
791///
792/// # Arguments
793/// * `path` - Path to the pth file.
794/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
795///           contains multiple objects and the state_dict is the one we are interested in.
796pub fn read_all_with_key<P: AsRef<std::path::Path>>(
797    path: P,
798    key: Option<&str>,
799) -> Result<Vec<(String, Tensor)>> {
800    let pth = PthTensors::new(path, key)?;
801    let tensor_names = pth.tensor_infos.keys();
802    let mut tensors = Vec::with_capacity(tensor_names.len());
803    for name in tensor_names {
804        if let Some(tensor) = pth.get(name)? {
805            tensors.push((name.to_string(), tensor))
806        }
807    }
808    Ok(tensors)
809}
810
811/// Read all the tensors from a PyTorch pth file.
812///
813/// # Arguments
814/// * `path` - Path to the pth file.
815pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
816    read_all_with_key(path, None)
817}