candle_core_temp/
pickle.rs

1// Just enough pickle support to be able to read PyTorch checkpoints.
2// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
3// composable/tensor agnostic at some point.
4use crate::{DType, Error as E, Layout, Result, Tensor};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7use std::io::BufRead;
8
9const VERBOSE: bool = false;
10
11// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
12#[repr(u8)]
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub enum OpCode {
15    // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123
16    Proto = 0x80,
17    Global = b'c',
18    BinPut = b'q',
19    LongBinPut = b'r',
20    EmptyTuple = b')',
21    Reduce = b'R',
22    Mark = b'(',
23    BinUnicode = b'X',
24    BinInt = b'J',
25    Tuple = b't',
26    BinPersId = b'Q',
27    BinInt1 = b'K',
28    BinInt2 = b'M',
29    Tuple1 = 0x85,
30    Tuple2 = 0x86,
31    Tuple3 = 0x87,
32    NewTrue = 0x88,
33    NewFalse = 0x89,
34    None = b'N',
35    BinGet = b'h',
36    LongBinGet = b'j',
37    SetItem = b's',
38    SetItems = b'u',
39    EmptyDict = b'}',
40    Dict = b'd',
41    Build = b'b',
42    Stop = b'.',
43    NewObj = 0x81,
44    EmptyList = b']',
45    BinFloat = b'g',
46    Append = b'a',
47    Appends = b'e',
48}
49
50// Avoid using FromPrimitive so as not to drag another dependency.
51impl TryFrom<u8> for OpCode {
52    type Error = u8;
53    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
54        match value {
55            0x80 => Ok(Self::Proto),
56            b'c' => Ok(Self::Global),
57            b'q' => Ok(Self::BinPut),
58            b'r' => Ok(Self::LongBinPut),
59            b')' => Ok(Self::EmptyTuple),
60            b'R' => Ok(Self::Reduce),
61            b'(' => Ok(Self::Mark),
62            b'X' => Ok(Self::BinUnicode),
63            b'J' => Ok(Self::BinInt),
64            b't' => Ok(Self::Tuple),
65            b'Q' => Ok(Self::BinPersId),
66            b'K' => Ok(Self::BinInt1),
67            b'M' => Ok(Self::BinInt2),
68            b'N' => Ok(Self::None),
69            0x85 => Ok(Self::Tuple1),
70            0x86 => Ok(Self::Tuple2),
71            0x87 => Ok(Self::Tuple3),
72            0x88 => Ok(Self::NewTrue),
73            0x89 => Ok(Self::NewFalse),
74            b'h' => Ok(Self::BinGet),
75            b'j' => Ok(Self::LongBinGet),
76            b's' => Ok(Self::SetItem),
77            b'u' => Ok(Self::SetItems),
78            b'}' => Ok(Self::EmptyDict),
79            b'd' => Ok(Self::EmptyDict),
80            b'b' => Ok(Self::Build),
81            b'.' => Ok(Self::Stop),
82            0x81 => Ok(Self::NewObj),
83            b']' => Ok(Self::EmptyList),
84            b'G' => Ok(Self::BinFloat),
85            b'a' => Ok(Self::Append),
86            b'e' => Ok(Self::Appends),
87            value => Err(value),
88        }
89    }
90}
91
92fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
93    let mut data: Vec<u8> = Vec::with_capacity(32);
94    r.read_until(b'\n', &mut data)?;
95    data.pop();
96    if data.last() == Some(&b'\r') {
97        data.pop();
98    }
99    Ok(data)
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum Object {
104    Class {
105        module_name: String,
106        class_name: String,
107    },
108    Int(i32),
109    Float(f64),
110    Unicode(String),
111    Bool(bool),
112    None,
113    Tuple(Vec<Object>),
114    List(Vec<Object>),
115    Mark,
116    Dict(Vec<(Object, Object)>),
117    Reduce {
118        callable: Box<Object>,
119        args: Box<Object>,
120    },
121    Build {
122        callable: Box<Object>,
123        args: Box<Object>,
124    },
125    PersistentLoad(Box<Object>),
126}
127
128type OResult<T> = std::result::Result<T, Object>;
129
130impl Object {
131    pub fn unicode(self) -> OResult<String> {
132        match self {
133            Self::Unicode(t) => Ok(t),
134            _ => Err(self),
135        }
136    }
137
138    pub fn reduce(self) -> OResult<(Self, Self)> {
139        match self {
140            Self::Reduce { callable, args } => Ok((*callable, *args)),
141            _ => Err(self),
142        }
143    }
144
145    pub fn none(self) -> OResult<()> {
146        match self {
147            Self::None => Ok(()),
148            _ => Err(self),
149        }
150    }
151
152    pub fn persistent_load(self) -> OResult<Self> {
153        match self {
154            Self::PersistentLoad(t) => Ok(*t),
155            _ => Err(self),
156        }
157    }
158
159    pub fn bool(self) -> OResult<bool> {
160        match self {
161            Self::Bool(t) => Ok(t),
162            _ => Err(self),
163        }
164    }
165
166    pub fn int(self) -> OResult<i32> {
167        match self {
168            Self::Int(t) => Ok(t),
169            _ => Err(self),
170        }
171    }
172
173    pub fn tuple(self) -> OResult<Vec<Self>> {
174        match self {
175            Self::Tuple(t) => Ok(t),
176            _ => Err(self),
177        }
178    }
179
180    pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
181        match self {
182            Self::Dict(t) => Ok(t),
183            _ => Err(self),
184        }
185    }
186
187    pub fn class(self) -> OResult<(String, String)> {
188        match self {
189            Self::Class {
190                module_name,
191                class_name,
192            } => Ok((module_name, class_name)),
193            _ => Err(self),
194        }
195    }
196
197    pub fn into_tensor_info(
198        self,
199        name: Self,
200        dir_name: &std::path::Path,
201    ) -> Result<Option<TensorInfo>> {
202        let name = match name.unicode() {
203            Ok(name) => name,
204            Err(_) => return Ok(None),
205        };
206        let (callable, args) = match self.reduce() {
207            Ok(callable_args) => callable_args,
208            _ => return Ok(None),
209        };
210        let (callable, args) = match callable {
211            Object::Class {
212                module_name,
213                class_name,
214            } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
215                let mut args = args.tuple()?;
216                let callable = args.remove(0);
217                let args = args.remove(1);
218                (callable, args)
219            }
220            _ => (callable, args),
221        };
222        match callable {
223            Object::Class {
224                module_name,
225                class_name,
226            } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
227            _ => return Ok(None),
228        };
229        let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
230        let mut path = dir_name.to_path_buf();
231        path.push(file_path);
232        Ok(Some(TensorInfo {
233            name,
234            dtype,
235            layout,
236            path: path.to_string_lossy().into_owned(),
237            storage_size,
238        }))
239    }
240}
241
242impl TryFrom<Object> for String {
243    type Error = Object;
244    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
245        match value {
246            Object::Unicode(s) => Ok(s),
247            other => Err(other),
248        }
249    }
250}
251
252impl TryFrom<Object> for usize {
253    type Error = Object;
254    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
255        match value {
256            Object::Int(s) if s >= 0 => Ok(s as usize),
257            other => Err(other),
258        }
259    }
260}
261
262impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
263    type Error = Object;
264    fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
265        match value {
266            Object::Tuple(values) => {
267                // This does not return the appropriate value in the error case but instead return
268                // the object related to the first error.
269                values
270                    .into_iter()
271                    .map(|v| T::try_from(v))
272                    .collect::<std::result::Result<Vec<T>, Self::Error>>()
273            }
274            other => Err(other),
275        }
276    }
277}
278
279#[derive(Debug)]
280pub struct Stack {
281    stack: Vec<Object>,
282    memo: HashMap<u32, Object>,
283}
284
285impl Stack {
286    pub fn empty() -> Self {
287        Self {
288            stack: Vec::with_capacity(512),
289            memo: HashMap::new(),
290        }
291    }
292
293    pub fn stack(&self) -> &[Object] {
294        self.stack.as_slice()
295    }
296
297    pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
298        loop {
299            if self.read(r)? {
300                break;
301            }
302        }
303        Ok(())
304    }
305
306    pub fn finalize(mut self) -> Result<Object> {
307        self.pop()
308    }
309
310    fn push(&mut self, obj: Object) {
311        self.stack.push(obj)
312    }
313
314    fn pop(&mut self) -> Result<Object> {
315        match self.stack.pop() {
316            None => crate::bail!("unexpected empty stack"),
317            Some(obj) => Ok(obj),
318        }
319    }
320
321    // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
322    fn build(&mut self) -> Result<()> {
323        let args = self.pop()?;
324        let obj = self.pop()?;
325        let obj = match (obj, args) {
326            (Object::Dict(mut obj), Object::Dict(mut args)) => {
327                obj.append(&mut args);
328                Object::Dict(obj)
329            }
330            (obj, args) => Object::Build {
331                callable: Box::new(obj),
332                args: Box::new(args),
333            },
334        };
335        self.push(obj);
336        Ok(())
337    }
338
339    fn reduce(&mut self) -> Result<()> {
340        let args = self.pop()?;
341        let callable = self.pop()?;
342        #[allow(clippy::single_match)]
343        let reduced = match &callable {
344            Object::Class {
345                module_name,
346                class_name,
347            } => {
348                if module_name == "collections" && class_name == "OrderedDict" {
349                    // TODO: have a separate ordered dict.
350                    Some(Object::Dict(vec![]))
351                } else {
352                    None
353                }
354            }
355            _ => None,
356        };
357        let reduced = reduced.unwrap_or_else(|| Object::Reduce {
358            callable: Box::new(callable),
359            args: Box::new(args),
360        });
361        self.push(reduced);
362        Ok(())
363    }
364
365    fn last(&mut self) -> Result<&mut Object> {
366        match self.stack.last_mut() {
367            None => crate::bail!("unexpected empty stack"),
368            Some(obj) => Ok(obj),
369        }
370    }
371
372    fn memo_get(&self, id: u32) -> Result<Object> {
373        match self.memo.get(&id) {
374            None => crate::bail!("missing object in memo {id}"),
375            Some(obj) => {
376                // Maybe we should use refcounting rather than doing potential large clones here.
377                Ok(obj.clone())
378            }
379        }
380    }
381
382    fn memo_put(&mut self, id: u32) -> Result<()> {
383        let obj = self.last()?.clone();
384        self.memo.insert(id, obj);
385        Ok(())
386    }
387
388    fn persistent_load(&self, id: Object) -> Result<Object> {
389        Ok(Object::PersistentLoad(Box::new(id)))
390    }
391
392    fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
393        Ok(Object::Reduce {
394            callable: Box::new(class),
395            args: Box::new(args),
396        })
397    }
398
399    fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
400        let mut mark_idx = None;
401        for (idx, obj) in self.stack.iter().enumerate().rev() {
402            if obj == &Object::Mark {
403                mark_idx = Some(idx);
404                break;
405            }
406        }
407        match mark_idx {
408            Some(mark_idx) => {
409                let objs = self.stack.split_off(mark_idx + 1);
410                self.stack.pop();
411                Ok(objs)
412            }
413            None => {
414                crate::bail!("marker object not found")
415            }
416        }
417    }
418
419    pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
420        let op_code = match OpCode::try_from(r.read_u8()?) {
421            Ok(op_code) => op_code,
422            Err(op_code) => {
423                crate::bail!("unknown op-code {op_code}")
424            }
425        };
426        // println!("op: {op_code:?}");
427        // println!("{:?}", self.stack);
428        match op_code {
429            OpCode::Proto => {
430                let version = r.read_u8()?;
431                if VERBOSE {
432                    println!("proto {version}");
433                }
434            }
435            OpCode::Global => {
436                let module_name = read_to_newline(r)?;
437                let class_name = read_to_newline(r)?;
438                let module_name = String::from_utf8_lossy(&module_name).to_string();
439                let class_name = String::from_utf8_lossy(&class_name).to_string();
440                self.push(Object::Class {
441                    module_name,
442                    class_name,
443                })
444            }
445            OpCode::BinInt1 => {
446                let arg = r.read_u8()?;
447                self.push(Object::Int(arg as i32))
448            }
449            OpCode::BinInt2 => {
450                let arg = r.read_u16::<LittleEndian>()?;
451                self.push(Object::Int(arg as i32))
452            }
453            OpCode::BinInt => {
454                let arg = r.read_i32::<LittleEndian>()?;
455                self.push(Object::Int(arg))
456            }
457            OpCode::BinFloat => {
458                let arg = r.read_f64::<LittleEndian>()?;
459                self.push(Object::Float(arg))
460            }
461            OpCode::BinUnicode => {
462                let len = r.read_u32::<LittleEndian>()?;
463                let mut data = vec![0u8; len as usize];
464                r.read_exact(&mut data)?;
465                let data = String::from_utf8(data).map_err(E::wrap)?;
466                self.push(Object::Unicode(data))
467            }
468            OpCode::BinPersId => {
469                let id = self.pop()?;
470                let obj = self.persistent_load(id)?;
471                self.push(obj)
472            }
473            OpCode::Tuple => {
474                let objs = self.pop_to_marker()?;
475                self.push(Object::Tuple(objs))
476            }
477            OpCode::Tuple1 => {
478                let obj = self.pop()?;
479                self.push(Object::Tuple(vec![obj]))
480            }
481            OpCode::Tuple2 => {
482                let obj2 = self.pop()?;
483                let obj1 = self.pop()?;
484                self.push(Object::Tuple(vec![obj1, obj2]))
485            }
486            OpCode::Tuple3 => {
487                let obj3 = self.pop()?;
488                let obj2 = self.pop()?;
489                let obj1 = self.pop()?;
490                self.push(Object::Tuple(vec![obj1, obj2, obj3]))
491            }
492            OpCode::NewTrue => self.push(Object::Bool(true)),
493            OpCode::NewFalse => self.push(Object::Bool(false)),
494            OpCode::Append => {
495                let value = self.pop()?;
496                let pylist = self.last()?;
497                if let Object::List(d) = pylist {
498                    d.push(value)
499                } else {
500                    crate::bail!("expected a list, got {pylist:?}")
501                }
502            }
503            OpCode::Appends => {
504                let objs = self.pop_to_marker()?;
505                let pylist = self.last()?;
506                if let Object::List(d) = pylist {
507                    d.extend(objs)
508                } else {
509                    crate::bail!("expected a list, got {pylist:?}")
510                }
511            }
512            OpCode::SetItem => {
513                let value = self.pop()?;
514                let key = self.pop()?;
515                let pydict = self.last()?;
516                if let Object::Dict(d) = pydict {
517                    d.push((key, value))
518                } else {
519                    crate::bail!("expected a dict, got {pydict:?}")
520                }
521            }
522            OpCode::SetItems => {
523                let mut objs = self.pop_to_marker()?;
524                let pydict = self.last()?;
525                if let Object::Dict(d) = pydict {
526                    if objs.len() % 2 != 0 {
527                        crate::bail!("setitems: not an even number of objects")
528                    }
529                    while let Some(value) = objs.pop() {
530                        let key = objs.pop().unwrap();
531                        d.push((key, value))
532                    }
533                } else {
534                    crate::bail!("expected a dict, got {pydict:?}")
535                }
536            }
537            OpCode::None => self.push(Object::None),
538            OpCode::Stop => {
539                return Ok(true);
540            }
541            OpCode::Build => self.build()?,
542            OpCode::EmptyDict => self.push(Object::Dict(vec![])),
543            OpCode::Dict => {
544                let mut objs = self.pop_to_marker()?;
545                let mut pydict = vec![];
546                if objs.len() % 2 != 0 {
547                    crate::bail!("setitems: not an even number of objects")
548                }
549                while let Some(value) = objs.pop() {
550                    let key = objs.pop().unwrap();
551                    pydict.push((key, value))
552                }
553                self.push(Object::Dict(pydict))
554            }
555            OpCode::Mark => self.push(Object::Mark),
556            OpCode::Reduce => self.reduce()?,
557            OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
558            OpCode::EmptyList => self.push(Object::List(vec![])),
559            OpCode::BinGet => {
560                let arg = r.read_u8()?;
561                let obj = self.memo_get(arg as u32)?;
562                self.push(obj)
563            }
564            OpCode::LongBinGet => {
565                let arg = r.read_u32::<LittleEndian>()?;
566                let obj = self.memo_get(arg)?;
567                self.push(obj)
568            }
569            OpCode::BinPut => {
570                let arg = r.read_u8()?;
571                self.memo_put(arg as u32)?
572            }
573            OpCode::LongBinPut => {
574                let arg = r.read_u32::<LittleEndian>()?;
575                self.memo_put(arg)?
576            }
577            OpCode::NewObj => {
578                let args = self.pop()?;
579                let class = self.pop()?;
580                let obj = self.new_obj(class, args)?;
581                self.push(obj)
582            }
583        }
584        Ok(false)
585    }
586}
587
588impl From<Object> for E {
589    fn from(value: Object) -> Self {
590        E::Msg(format!("conversion error on {value:?}"))
591    }
592}
593
594// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
595// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
596fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
597    let mut args = args.tuple()?;
598    let stride = Vec::<usize>::try_from(args.remove(3))?;
599    let size = Vec::<usize>::try_from(args.remove(2))?;
600    let offset = args.remove(1).int()? as usize;
601    let storage = args.remove(0).persistent_load()?;
602    let mut storage = storage.tuple()?;
603    let storage_size = storage.remove(4).int()? as usize;
604    let path = storage.remove(2).unicode()?;
605    let (_module_name, class_name) = storage.remove(1).class()?;
606    let dtype = match class_name.as_str() {
607        "FloatStorage" => DType::F32,
608        "DoubleStorage" => DType::F64,
609        "HalfStorage" => DType::F16,
610        "BFloat16Storage" => DType::BF16,
611        "ByteStorage" => DType::U8,
612        "LongStorage" => DType::I64,
613        other => {
614            crate::bail!("unsupported storage type {other}")
615        }
616    };
617    let layout = Layout::new(crate::Shape::from(size), stride, offset);
618    Ok((layout, dtype, path, storage_size))
619}
620
621#[derive(Debug, Clone)]
622pub struct TensorInfo {
623    pub name: String,
624    pub dtype: DType,
625    pub layout: Layout,
626    pub path: String,
627    pub storage_size: usize,
628}
629
630pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
631    file: P,
632    verbose: bool,
633) -> Result<Vec<TensorInfo>> {
634    let file = std::fs::File::open(file)?;
635    let zip_reader = std::io::BufReader::new(file);
636    let mut zip = zip::ZipArchive::new(zip_reader)?;
637    let zip_file_names = zip
638        .file_names()
639        .map(|f| f.to_string())
640        .collect::<Vec<String>>();
641
642    let mut tensor_infos = vec![];
643    for file_name in zip_file_names.iter() {
644        if !file_name.ends_with("data.pkl") {
645            continue;
646        }
647        let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
648        let reader = zip.by_name(file_name)?;
649        let mut reader = std::io::BufReader::new(reader);
650        let mut stack = Stack::empty();
651        stack.read_loop(&mut reader)?;
652        let obj = stack.finalize()?;
653        if VERBOSE || verbose {
654            println!("{obj:?}");
655        }
656        let obj = match obj {
657            Object::Build { callable, args } => match *callable {
658                Object::Reduce { callable, args: _ } => match *callable {
659                    Object::Class {
660                        module_name,
661                        class_name,
662                    } if module_name == "__torch__" && class_name == "Module" => *args,
663                    _ => continue,
664                },
665                _ => continue,
666            },
667            obj => obj,
668        };
669        if let Object::Dict(key_values) = obj {
670            for (name, value) in key_values.into_iter() {
671                match value.into_tensor_info(name, &dir_name) {
672                    Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
673                    Ok(None) => {}
674                    Err(err) => eprintln!("skipping: {err:?}"),
675                }
676            }
677        }
678    }
679    Ok(tensor_infos)
680}
681
682/// Lazy tensor loader.
683pub struct PthTensors {
684    tensor_infos: HashMap<String, TensorInfo>,
685    path: std::path::PathBuf,
686    // We do not store a zip reader as it needs mutable access to extract data. Instead we
687    // re-create a zip reader for each tensor.
688}
689
690impl PthTensors {
691    pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
692        let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
693        let tensor_infos = tensor_infos
694            .into_iter()
695            .map(|ti| (ti.name.to_string(), ti))
696            .collect();
697        let path = path.as_ref().to_owned();
698        Ok(Self { tensor_infos, path })
699    }
700
701    pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
702        &self.tensor_infos
703    }
704
705    pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
706        let tensor_info = match self.tensor_infos.get(name) {
707            None => return Ok(None),
708            Some(tensor_info) => tensor_info,
709        };
710        // We hope that the file has not changed since first reading it.
711        let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
712        let mut zip = zip::ZipArchive::new(zip_reader)?;
713        let mut reader = zip.by_name(&tensor_info.path)?;
714
715        // Reading the data is a bit tricky as it can be strided, use an offset, etc.
716        // For now only support the basic case.
717        if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
718            crate::bail!(
719                "cannot retrieve non-contiguous tensors {:?}",
720                tensor_info.layout
721            )
722        }
723        let tensor = Tensor::from_reader(
724            tensor_info.layout.shape().clone(),
725            tensor_info.dtype,
726            &mut reader,
727        )?;
728        Ok(Some(tensor))
729    }
730}
731
732/// Read all the tensors from a PyTorch pth file.
733pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
734    let pth = PthTensors::new(path)?;
735    let tensor_names = pth.tensor_infos.keys();
736    let mut tensors = Vec::with_capacity(tensor_names.len());
737    for name in tensor_names {
738        if let Some(tensor) = pth.get(name)? {
739            tensors.push((name.to_string(), tensor))
740        }
741    }
742    Ok(tensors)
743}