candle_core/
pickle.rs

1//! Just enough pickle support to be able to read PyTorch checkpoints.
2// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
3// composable/tensor agnostic at some point.
4use crate::{Context, DType, Error as E, Layout, Result, Tensor};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7use std::io::BufRead;
8
9const VERBOSE: bool = false;
10
11// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
12#[repr(u8)]
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub enum OpCode {
15    // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
16    Proto = 0x80,
17    Global = b'c',
18    BinPut = b'q',
19    LongBinPut = b'r',
20    EmptyTuple = b')',
21    Reduce = b'R',
22    Mark = b'(',
23    BinUnicode = b'X',
24    BinInt = b'J',
25    Tuple = b't',
26    BinPersId = b'Q',
27    BinInt1 = b'K',
28    BinInt2 = b'M',
29    Tuple1 = 0x85,
30    Tuple2 = 0x86,
31    Tuple3 = 0x87,
32    NewTrue = 0x88,
33    NewFalse = 0x89,
34    None = b'N',
35    BinGet = b'h',
36    LongBinGet = b'j',
37    SetItem = b's',
38    SetItems = b'u',
39    EmptyDict = b'}',
40    Dict = b'd',
41    Build = b'b',
42    Stop = b'.',
43    NewObj = 0x81,
44    EmptyList = b']',
45    BinFloat = b'G',
46    Append = b'a',
47    Appends = b'e',
48    Long1 = 0x8a,
49}
50
51// Avoid using FromPrimitive so as not to drag another dependency.
52impl TryFrom<u8> for OpCode {
53    type Error = u8;
54    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
55        match value {
56            0x80 => Ok(Self::Proto),
57            b'c' => Ok(Self::Global),
58            b'q' => Ok(Self::BinPut),
59            b'r' => Ok(Self::LongBinPut),
60            b')' => Ok(Self::EmptyTuple),
61            b'R' => Ok(Self::Reduce),
62            b'(' => Ok(Self::Mark),
63            b'X' => Ok(Self::BinUnicode),
64            b'J' => Ok(Self::BinInt),
65            b't' => Ok(Self::Tuple),
66            b'Q' => Ok(Self::BinPersId),
67            b'K' => Ok(Self::BinInt1),
68            b'M' => Ok(Self::BinInt2),
69            b'N' => Ok(Self::None),
70            0x85 => Ok(Self::Tuple1),
71            0x86 => Ok(Self::Tuple2),
72            0x87 => Ok(Self::Tuple3),
73            0x88 => Ok(Self::NewTrue),
74            0x89 => Ok(Self::NewFalse),
75            b'h' => Ok(Self::BinGet),
76            b'j' => Ok(Self::LongBinGet),
77            b's' => Ok(Self::SetItem),
78            b'u' => Ok(Self::SetItems),
79            b'}' => Ok(Self::EmptyDict),
80            b'd' => Ok(Self::EmptyDict),
81            b'b' => Ok(Self::Build),
82            b'.' => Ok(Self::Stop),
83            0x81 => Ok(Self::NewObj),
84            b']' => Ok(Self::EmptyList),
85            b'G' => Ok(Self::BinFloat),
86            b'a' => Ok(Self::Append),
87            b'e' => Ok(Self::Appends),
88            0x8a => Ok(Self::Long1),
89            value => Err(value),
90        }
91    }
92}
93
94fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
95    let mut data: Vec<u8> = Vec::with_capacity(32);
96    r.read_until(b'\n', &mut data)?;
97    data.pop();
98    if data.last() == Some(&b'\r') {
99        data.pop();
100    }
101    Ok(data)
102}
103
104#[derive(Debug, Clone, PartialEq)]
105pub enum Object {
106    Class {
107        module_name: String,
108        class_name: String,
109    },
110    Int(i32),
111    Long(i64),
112    Float(f64),
113    Unicode(String),
114    Bool(bool),
115    None,
116    Tuple(Vec<Object>),
117    List(Vec<Object>),
118    Mark,
119    Dict(Vec<(Object, Object)>),
120    Reduce {
121        callable: Box<Object>,
122        args: Box<Object>,
123    },
124    Build {
125        callable: Box<Object>,
126        args: Box<Object>,
127    },
128    PersistentLoad(Box<Object>),
129}
130
131type OResult<T> = std::result::Result<T, Object>;
132
133impl Object {
134    pub fn unicode(self) -> OResult<String> {
135        match self {
136            Self::Unicode(t) => Ok(t),
137            _ => Err(self),
138        }
139    }
140
141    pub fn reduce(self) -> OResult<(Self, Self)> {
142        match self {
143            Self::Reduce { callable, args } => Ok((*callable, *args)),
144            _ => Err(self),
145        }
146    }
147
148    pub fn none(self) -> OResult<()> {
149        match self {
150            Self::None => Ok(()),
151            _ => Err(self),
152        }
153    }
154
155    pub fn persistent_load(self) -> OResult<Self> {
156        match self {
157            Self::PersistentLoad(t) => Ok(*t),
158            _ => Err(self),
159        }
160    }
161
162    pub fn bool(self) -> OResult<bool> {
163        match self {
164            Self::Bool(t) => Ok(t),
165            _ => Err(self),
166        }
167    }
168
169    pub fn int(self) -> OResult<i32> {
170        match self {
171            Self::Int(t) => Ok(t),
172            _ => Err(self),
173        }
174    }
175
176    pub fn int_or_long(self) -> OResult<i64> {
177        match self {
178            Self::Int(t) => Ok(t as i64),
179            Self::Long(t) => Ok(t),
180            _ => Err(self),
181        }
182    }
183
184    pub fn tuple(self) -> OResult<Vec<Self>> {
185        match self {
186            Self::Tuple(t) => Ok(t),
187            _ => Err(self),
188        }
189    }
190
191    pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
192        match self {
193            Self::Dict(t) => Ok(t),
194            _ => Err(self),
195        }
196    }
197
198    pub fn class(self) -> OResult<(String, String)> {
199        match self {
200            Self::Class {
201                module_name,
202                class_name,
203            } => Ok((module_name, class_name)),
204            _ => Err(self),
205        }
206    }
207
208    pub fn into_tensor_info(
209        self,
210        name: Self,
211        dir_name: &std::path::Path,
212    ) -> Result<Option<TensorInfo>> {
213        let name = match name.unicode() {
214            Ok(name) => name,
215            Err(_) => return Ok(None),
216        };
217        let (callable, args) = match self.reduce() {
218            Ok(callable_args) => callable_args,
219            _ => return Ok(None),
220        };
221        let (callable, args) = match callable {
222            Object::Class {
223                module_name,
224                class_name,
225            } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
226                let mut args = args.tuple()?;
227                let callable = args.remove(0);
228                let args = args.remove(1);
229                (callable, args)
230            }
231            Object::Class {
232                module_name,
233                class_name,
234            } if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
235                let mut args = args.tuple()?;
236                args.remove(0).reduce()?
237            }
238            _ => (callable, args),
239        };
240        match callable {
241            Object::Class {
242                module_name,
243                class_name,
244            } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
245            _ => return Ok(None),
246        };
247        let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
248        Ok(Some(TensorInfo {
249            name,
250            dtype,
251            layout,
252            path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
253            storage_size,
254        }))
255    }
256}
257
258impl TryFrom<Object> for String {
259    type Error = Object;
260    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
261        match value {
262            Object::Unicode(s) => Ok(s),
263            other => Err(other),
264        }
265    }
266}
267
268impl TryFrom<Object> for usize {
269    type Error = Object;
270    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
271        match value {
272            Object::Int(s) if s >= 0 => Ok(s as usize),
273            other => Err(other),
274        }
275    }
276}
277
278impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
279    type Error = Object;
280    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
281        match value {
282            Object::Tuple(values) => {
283                // This does not return the appropriate value in the error case but instead return
284                // the object related to the first error.
285                values
286                    .into_iter()
287                    .map(|v| T::try_from(v))
288                    .collect::<std::result::Result<Vec<T>, Self::Error>>()
289            }
290            other => Err(other),
291        }
292    }
293}
294
295#[derive(Debug)]
296pub struct Stack {
297    stack: Vec<Object>,
298    memo: HashMap<u32, Object>,
299}
300
301impl Stack {
302    pub fn empty() -> Self {
303        Self {
304            stack: Vec::with_capacity(512),
305            memo: HashMap::new(),
306        }
307    }
308
309    pub fn stack(&self) -> &[Object] {
310        self.stack.as_slice()
311    }
312
313    pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
314        loop {
315            if self.read(r)? {
316                break;
317            }
318        }
319        Ok(())
320    }
321
322    pub fn finalize(mut self) -> Result<Object> {
323        self.pop()
324    }
325
326    fn push(&mut self, obj: Object) {
327        self.stack.push(obj)
328    }
329
330    fn pop(&mut self) -> Result<Object> {
331        match self.stack.pop() {
332            None => crate::bail!("unexpected empty stack"),
333            Some(obj) => Ok(obj),
334        }
335    }
336
337    // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
338    fn build(&mut self) -> Result<()> {
339        let args = self.pop()?;
340        let obj = self.pop()?;
341        let obj = match (obj, args) {
342            (Object::Dict(mut obj), Object::Dict(mut args)) => {
343                obj.append(&mut args);
344                Object::Dict(obj)
345            }
346            (obj, args) => Object::Build {
347                callable: Box::new(obj),
348                args: Box::new(args),
349            },
350        };
351        self.push(obj);
352        Ok(())
353    }
354
355    fn reduce(&mut self) -> Result<()> {
356        let args = self.pop()?;
357        let callable = self.pop()?;
358        #[allow(clippy::single_match)]
359        let reduced = match &callable {
360            Object::Class {
361                module_name,
362                class_name,
363            } => {
364                if module_name == "collections"
365                    && (class_name == "OrderedDict" || class_name == "defaultdict")
366                {
367                    // TODO: have a separate ordered dict and a separate default dict.
368                    Some(Object::Dict(vec![]))
369                } else {
370                    None
371                }
372            }
373            _ => None,
374        };
375        let reduced = reduced.unwrap_or_else(|| Object::Reduce {
376            callable: Box::new(callable),
377            args: Box::new(args),
378        });
379        self.push(reduced);
380        Ok(())
381    }
382
383    fn last(&mut self) -> Result<&mut Object> {
384        match self.stack.last_mut() {
385            None => crate::bail!("unexpected empty stack"),
386            Some(obj) => Ok(obj),
387        }
388    }
389
390    fn memo_get(&self, id: u32) -> Result<Object> {
391        match self.memo.get(&id) {
392            None => crate::bail!("missing object in memo {id}"),
393            Some(obj) => {
394                // Maybe we should use refcounting rather than doing potential large clones here.
395                Ok(obj.clone())
396            }
397        }
398    }
399
400    fn memo_put(&mut self, id: u32) -> Result<()> {
401        let obj = self.last()?.clone();
402        self.memo.insert(id, obj);
403        Ok(())
404    }
405
406    fn persistent_load(&self, id: Object) -> Result<Object> {
407        Ok(Object::PersistentLoad(Box::new(id)))
408    }
409
410    fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
411        Ok(Object::Reduce {
412            callable: Box::new(class),
413            args: Box::new(args),
414        })
415    }
416
417    fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
418        let mut mark_idx = None;
419        for (idx, obj) in self.stack.iter().enumerate().rev() {
420            if obj == &Object::Mark {
421                mark_idx = Some(idx);
422                break;
423            }
424        }
425        match mark_idx {
426            Some(mark_idx) => {
427                let objs = self.stack.split_off(mark_idx + 1);
428                self.stack.pop();
429                Ok(objs)
430            }
431            None => {
432                crate::bail!("marker object not found")
433            }
434        }
435    }
436
437    pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
438        let op_code = match OpCode::try_from(r.read_u8()?) {
439            Ok(op_code) => op_code,
440            Err(op_code) => {
441                crate::bail!("unknown op-code {op_code}")
442            }
443        };
444        // println!("op: {op_code:?}");
445        // println!("{:?}", self.stack);
446        match op_code {
447            OpCode::Proto => {
448                let version = r.read_u8()?;
449                if VERBOSE {
450                    println!("proto {version}");
451                }
452            }
453            OpCode::Global => {
454                let module_name = read_to_newline(r)?;
455                let class_name = read_to_newline(r)?;
456                let module_name = String::from_utf8_lossy(&module_name).to_string();
457                let class_name = String::from_utf8_lossy(&class_name).to_string();
458                self.push(Object::Class {
459                    module_name,
460                    class_name,
461                })
462            }
463            OpCode::BinInt1 => {
464                let arg = r.read_u8()?;
465                self.push(Object::Int(arg as i32))
466            }
467            OpCode::BinInt2 => {
468                let arg = r.read_u16::<LittleEndian>()?;
469                self.push(Object::Int(arg as i32))
470            }
471            OpCode::BinInt => {
472                let arg = r.read_i32::<LittleEndian>()?;
473                self.push(Object::Int(arg))
474            }
475            OpCode::BinFloat => {
476                // Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
477                // https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
478                // https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
479                let arg = r.read_f64::<byteorder::BigEndian>()?;
480                self.push(Object::Float(arg))
481            }
482            OpCode::BinUnicode => {
483                let len = r.read_u32::<LittleEndian>()?;
484                let mut data = vec![0u8; len as usize];
485                r.read_exact(&mut data)?;
486                let data = String::from_utf8(data).map_err(E::wrap)?;
487                self.push(Object::Unicode(data))
488            }
489            OpCode::BinPersId => {
490                let id = self.pop()?;
491                let obj = self.persistent_load(id)?;
492                self.push(obj)
493            }
494            OpCode::Tuple => {
495                let objs = self.pop_to_marker()?;
496                self.push(Object::Tuple(objs))
497            }
498            OpCode::Tuple1 => {
499                let obj = self.pop()?;
500                self.push(Object::Tuple(vec![obj]))
501            }
502            OpCode::Tuple2 => {
503                let obj2 = self.pop()?;
504                let obj1 = self.pop()?;
505                self.push(Object::Tuple(vec![obj1, obj2]))
506            }
507            OpCode::Tuple3 => {
508                let obj3 = self.pop()?;
509                let obj2 = self.pop()?;
510                let obj1 = self.pop()?;
511                self.push(Object::Tuple(vec![obj1, obj2, obj3]))
512            }
513            OpCode::NewTrue => self.push(Object::Bool(true)),
514            OpCode::NewFalse => self.push(Object::Bool(false)),
515            OpCode::Append => {
516                let value = self.pop()?;
517                let pylist = self.last()?;
518                if let Object::List(d) = pylist {
519                    d.push(value)
520                } else {
521                    crate::bail!("expected a list, got {pylist:?}")
522                }
523            }
524            OpCode::Appends => {
525                let objs = self.pop_to_marker()?;
526                let pylist = self.last()?;
527                if let Object::List(d) = pylist {
528                    d.extend(objs)
529                } else {
530                    crate::bail!("expected a list, got {pylist:?}")
531                }
532            }
533            OpCode::SetItem => {
534                let value = self.pop()?;
535                let key = self.pop()?;
536                let pydict = self.last()?;
537                if let Object::Dict(d) = pydict {
538                    d.push((key, value))
539                } else {
540                    crate::bail!("expected a dict, got {pydict:?}")
541                }
542            }
543            OpCode::SetItems => {
544                let mut objs = self.pop_to_marker()?;
545                let pydict = self.last()?;
546                if let Object::Dict(d) = pydict {
547                    if objs.len() % 2 != 0 {
548                        crate::bail!("setitems: not an even number of objects")
549                    }
550                    while let Some(value) = objs.pop() {
551                        let key = objs.pop().context("empty objs")?;
552                        d.push((key, value))
553                    }
554                } else {
555                    crate::bail!("expected a dict, got {pydict:?}")
556                }
557            }
558            OpCode::None => self.push(Object::None),
559            OpCode::Stop => {
560                return Ok(true);
561            }
562            OpCode::Build => self.build()?,
563            OpCode::EmptyDict => self.push(Object::Dict(vec![])),
564            OpCode::Dict => {
565                let mut objs = self.pop_to_marker()?;
566                let mut pydict = vec![];
567                if objs.len() % 2 != 0 {
568                    crate::bail!("setitems: not an even number of objects")
569                }
570                while let Some(value) = objs.pop() {
571                    let key = objs.pop().context("empty objs")?;
572                    pydict.push((key, value))
573                }
574                self.push(Object::Dict(pydict))
575            }
576            OpCode::Mark => self.push(Object::Mark),
577            OpCode::Reduce => self.reduce()?,
578            OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
579            OpCode::EmptyList => self.push(Object::List(vec![])),
580            OpCode::BinGet => {
581                let arg = r.read_u8()?;
582                let obj = self.memo_get(arg as u32)?;
583                self.push(obj)
584            }
585            OpCode::LongBinGet => {
586                let arg = r.read_u32::<LittleEndian>()?;
587                let obj = self.memo_get(arg)?;
588                self.push(obj)
589            }
590            OpCode::BinPut => {
591                let arg = r.read_u8()?;
592                self.memo_put(arg as u32)?
593            }
594            OpCode::LongBinPut => {
595                let arg = r.read_u32::<LittleEndian>()?;
596                self.memo_put(arg)?
597            }
598            OpCode::NewObj => {
599                let args = self.pop()?;
600                let class = self.pop()?;
601                let obj = self.new_obj(class, args)?;
602                self.push(obj)
603            }
604            OpCode::Long1 => {
605                let n_bytes = r.read_u8()?;
606                let mut v = 0;
607                // Decode the next n bytes in little endian
608                for i in 0..n_bytes {
609                    v |= (r.read_u8()? as i64) << (i * 8);
610                }
611                self.push(Object::Long(v))
612            }
613        }
614        Ok(false)
615    }
616}
617
618impl From<Object> for E {
619    fn from(value: Object) -> Self {
620        E::Msg(format!("conversion error on {value:?}"))
621    }
622}
623
624// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
625// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
626fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
627    let mut args = args.tuple()?;
628    let stride = Vec::<usize>::try_from(args.remove(3))?;
629    let size = Vec::<usize>::try_from(args.remove(2))?;
630    let offset = args.remove(1).int_or_long()? as usize;
631    let storage = args.remove(0).persistent_load()?;
632    let mut storage = storage.tuple()?;
633    let storage_size = storage.remove(4).int_or_long()? as usize;
634    let path = storage.remove(2).unicode()?;
635    let (_module_name, class_name) = storage.remove(1).class()?;
636    let dtype = match class_name.as_str() {
637        "FloatStorage" => DType::F32,
638        "DoubleStorage" => DType::F64,
639        "HalfStorage" => DType::F16,
640        "BFloat16Storage" => DType::BF16,
641        "ByteStorage" => DType::U8,
642        "LongStorage" => DType::I64,
643        other => {
644            crate::bail!("unsupported storage type {other}")
645        }
646    };
647    let layout = Layout::new(
648        crate::Shape::from(size),
649        stride,
650        offset * dtype.size_in_bytes(),
651    );
652    Ok((layout, dtype, path, storage_size))
653}
654
655#[derive(Debug, Clone)]
656pub struct TensorInfo {
657    pub name: String,
658    pub dtype: DType,
659    pub layout: Layout,
660    pub path: String,
661    pub storage_size: usize,
662}
663
664/// Read the tensor info from a .pth file.
665///
666/// # Arguments
667/// * `file` - The path to the .pth file.
668/// * `verbose` - Whether to print debug information.
669/// * `key` - Optional key to retrieve `state_dict` from the pth file.
670pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
671    file: P,
672    verbose: bool,
673    key: Option<&str>,
674) -> Result<Vec<TensorInfo>> {
675    let file = std::fs::File::open(file)?;
676    let zip_reader = std::io::BufReader::new(file);
677    let mut zip = zip::ZipArchive::new(zip_reader)?;
678    let zip_file_names = zip
679        .file_names()
680        .map(|f| f.to_string())
681        .collect::<Vec<String>>();
682
683    let mut tensor_infos = vec![];
684    for file_name in zip_file_names.iter() {
685        if !file_name.ends_with("data.pkl") {
686            continue;
687        }
688        let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
689        let reader = zip.by_name(file_name)?;
690        let mut reader = std::io::BufReader::new(reader);
691        let mut stack = Stack::empty();
692        stack.read_loop(&mut reader)?;
693        let obj = stack.finalize()?;
694        if VERBOSE || verbose {
695            println!("{obj:#?}");
696        }
697
698        let obj = match obj {
699            Object::Build { callable, args } => match *callable {
700                Object::Reduce { callable, args: _ } => match *callable {
701                    Object::Class {
702                        module_name,
703                        class_name,
704                    } if module_name == "__torch__" && class_name == "Module" => *args,
705                    _ => continue,
706                },
707                _ => continue,
708            },
709            obj => obj,
710        };
711
712        // If key is provided, then we need to extract the state_dict from the object.
713        let obj = if let Some(key) = key {
714            if let Object::Dict(key_values) = obj {
715                key_values
716                    .into_iter()
717                    .find(|(k, _)| *k == Object::Unicode(key.to_owned()))
718                    .map(|(_, v)| v)
719                    .ok_or_else(|| E::Msg(format!("key {key} not found")))?
720            } else {
721                obj
722            }
723        } else {
724            obj
725        };
726
727        // If the object is a dict, then we can extract the tensor info from it.
728        // NOTE: We are assuming that the `obj` is state_dict by this stage.
729        if let Object::Dict(key_values) = obj {
730            for (name, value) in key_values.into_iter() {
731                match value.into_tensor_info(name, &dir_name) {
732                    Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
733                    Ok(None) => {}
734                    Err(err) => eprintln!("skipping: {err:?}"),
735                }
736            }
737        }
738    }
739    Ok(tensor_infos)
740}
741
742/// Lazy tensor loader.
743pub struct PthTensors {
744    tensor_infos: HashMap<String, TensorInfo>,
745    path: std::path::PathBuf,
746    // We do not store a zip reader as it needs mutable access to extract data. Instead we
747    // re-create a zip reader for each tensor.
748}
749
750impl PthTensors {
751    pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
752        let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
753        let tensor_infos = tensor_infos
754            .into_iter()
755            .map(|ti| (ti.name.to_string(), ti))
756            .collect();
757        let path = path.as_ref().to_owned();
758        Ok(Self { tensor_infos, path })
759    }
760
761    pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
762        &self.tensor_infos
763    }
764
765    pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
766        use std::io::Read;
767        let tensor_info = match self.tensor_infos.get(name) {
768            None => return Ok(None),
769            Some(tensor_info) => tensor_info,
770        };
771        // We hope that the file has not changed since first reading it.
772        let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
773        let mut zip = zip::ZipArchive::new(zip_reader)?;
774        let mut reader = zip.by_name(&tensor_info.path)?;
775        let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
776        let rank = tensor_info.layout.shape().rank();
777
778        // Reading the data is a bit tricky as it can be strided, for now only support the basic
779        // case and when the tensor is fortran contiguous.
780        if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
781            crate::bail!(
782                "cannot retrieve non-contiguous tensors {:?}",
783                tensor_info.layout
784            )
785        }
786        let start_offset = tensor_info.layout.start_offset();
787        if start_offset > 0 {
788            std::io::copy(
789                &mut reader.by_ref().take(start_offset as u64),
790                &mut std::io::sink(),
791            )?;
792        }
793        let tensor = Tensor::from_reader(
794            tensor_info.layout.shape().clone(),
795            tensor_info.dtype,
796            &mut reader,
797        )?;
798
799        if rank > 1 && is_fortran_contiguous {
800            // Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
801            let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
802            let tensor = tensor.reshape(shape_reversed)?;
803
804            // Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
805            let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
806            let tensor = tensor.permute(dim_indeces_reversed)?;
807            Ok(Some(tensor))
808        } else {
809            Ok(Some(tensor))
810        }
811    }
812}
813
814/// Read all the tensors from a PyTorch pth file with a given key.
815///
816/// # Arguments
817/// * `path` - Path to the pth file.
818/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
819///   contains multiple objects and the state_dict is the one we are interested in.
820pub fn read_all_with_key<P: AsRef<std::path::Path>>(
821    path: P,
822    key: Option<&str>,
823) -> Result<Vec<(String, Tensor)>> {
824    let pth = PthTensors::new(path, key)?;
825    let tensor_names = pth.tensor_infos.keys();
826    let mut tensors = Vec::with_capacity(tensor_names.len());
827    for name in tensor_names {
828        if let Some(tensor) = pth.get(name)? {
829            tensors.push((name.to_string(), tensor))
830        }
831    }
832    Ok(tensors)
833}
834
835/// Read all the tensors from a PyTorch pth file.
836///
837/// # Arguments
838/// * `path` - Path to the pth file.
839pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
840    read_all_with_key(path, None)
841}