1use crate::{Context, DType, Error as E, Layout, Result, Tensor};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7use std::io::BufRead;
8
9const VERBOSE: bool = false;
10
11#[repr(u8)]
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub enum OpCode {
15 Proto = 0x80,
17 Global = b'c',
18 BinPut = b'q',
19 LongBinPut = b'r',
20 EmptyTuple = b')',
21 Reduce = b'R',
22 Mark = b'(',
23 BinUnicode = b'X',
24 BinInt = b'J',
25 Tuple = b't',
26 BinPersId = b'Q',
27 BinInt1 = b'K',
28 BinInt2 = b'M',
29 Tuple1 = 0x85,
30 Tuple2 = 0x86,
31 Tuple3 = 0x87,
32 NewTrue = 0x88,
33 NewFalse = 0x89,
34 None = b'N',
35 BinGet = b'h',
36 LongBinGet = b'j',
37 SetItem = b's',
38 SetItems = b'u',
39 EmptyDict = b'}',
40 Dict = b'd',
41 Build = b'b',
42 Stop = b'.',
43 NewObj = 0x81,
44 EmptyList = b']',
45 BinFloat = b'G',
46 Append = b'a',
47 Appends = b'e',
48}
49
50impl TryFrom<u8> for OpCode {
52 type Error = u8;
53 fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
54 match value {
55 0x80 => Ok(Self::Proto),
56 b'c' => Ok(Self::Global),
57 b'q' => Ok(Self::BinPut),
58 b'r' => Ok(Self::LongBinPut),
59 b')' => Ok(Self::EmptyTuple),
60 b'R' => Ok(Self::Reduce),
61 b'(' => Ok(Self::Mark),
62 b'X' => Ok(Self::BinUnicode),
63 b'J' => Ok(Self::BinInt),
64 b't' => Ok(Self::Tuple),
65 b'Q' => Ok(Self::BinPersId),
66 b'K' => Ok(Self::BinInt1),
67 b'M' => Ok(Self::BinInt2),
68 b'N' => Ok(Self::None),
69 0x85 => Ok(Self::Tuple1),
70 0x86 => Ok(Self::Tuple2),
71 0x87 => Ok(Self::Tuple3),
72 0x88 => Ok(Self::NewTrue),
73 0x89 => Ok(Self::NewFalse),
74 b'h' => Ok(Self::BinGet),
75 b'j' => Ok(Self::LongBinGet),
76 b's' => Ok(Self::SetItem),
77 b'u' => Ok(Self::SetItems),
78 b'}' => Ok(Self::EmptyDict),
79 b'd' => Ok(Self::EmptyDict),
80 b'b' => Ok(Self::Build),
81 b'.' => Ok(Self::Stop),
82 0x81 => Ok(Self::NewObj),
83 b']' => Ok(Self::EmptyList),
84 b'G' => Ok(Self::BinFloat),
85 b'a' => Ok(Self::Append),
86 b'e' => Ok(Self::Appends),
87 value => Err(value),
88 }
89 }
90}
91
92fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
93 let mut data: Vec<u8> = Vec::with_capacity(32);
94 r.read_until(b'\n', &mut data)?;
95 data.pop();
96 if data.last() == Some(&b'\r') {
97 data.pop();
98 }
99 Ok(data)
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum Object {
104 Class {
105 module_name: String,
106 class_name: String,
107 },
108 Int(i32),
109 Float(f64),
110 Unicode(String),
111 Bool(bool),
112 None,
113 Tuple(Vec<Object>),
114 List(Vec<Object>),
115 Mark,
116 Dict(Vec<(Object, Object)>),
117 Reduce {
118 callable: Box<Object>,
119 args: Box<Object>,
120 },
121 Build {
122 callable: Box<Object>,
123 args: Box<Object>,
124 },
125 PersistentLoad(Box<Object>),
126}
127
128type OResult<T> = std::result::Result<T, Object>;
129
130impl Object {
131 pub fn unicode(self) -> OResult<String> {
132 match self {
133 Self::Unicode(t) => Ok(t),
134 _ => Err(self),
135 }
136 }
137
138 pub fn reduce(self) -> OResult<(Self, Self)> {
139 match self {
140 Self::Reduce { callable, args } => Ok((*callable, *args)),
141 _ => Err(self),
142 }
143 }
144
145 pub fn none(self) -> OResult<()> {
146 match self {
147 Self::None => Ok(()),
148 _ => Err(self),
149 }
150 }
151
152 pub fn persistent_load(self) -> OResult<Self> {
153 match self {
154 Self::PersistentLoad(t) => Ok(*t),
155 _ => Err(self),
156 }
157 }
158
159 pub fn bool(self) -> OResult<bool> {
160 match self {
161 Self::Bool(t) => Ok(t),
162 _ => Err(self),
163 }
164 }
165
166 pub fn int(self) -> OResult<i32> {
167 match self {
168 Self::Int(t) => Ok(t),
169 _ => Err(self),
170 }
171 }
172
173 pub fn tuple(self) -> OResult<Vec<Self>> {
174 match self {
175 Self::Tuple(t) => Ok(t),
176 _ => Err(self),
177 }
178 }
179
180 pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
181 match self {
182 Self::Dict(t) => Ok(t),
183 _ => Err(self),
184 }
185 }
186
187 pub fn class(self) -> OResult<(String, String)> {
188 match self {
189 Self::Class {
190 module_name,
191 class_name,
192 } => Ok((module_name, class_name)),
193 _ => Err(self),
194 }
195 }
196
197 pub fn into_tensor_info(
198 self,
199 name: Self,
200 dir_name: &std::path::Path,
201 ) -> Result<Option<TensorInfo>> {
202 let name = match name.unicode() {
203 Ok(name) => name,
204 Err(_) => return Ok(None),
205 };
206 let (callable, args) = match self.reduce() {
207 Ok(callable_args) => callable_args,
208 _ => return Ok(None),
209 };
210 let (callable, args) = match callable {
211 Object::Class {
212 module_name,
213 class_name,
214 } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
215 let mut args = args.tuple()?;
216 let callable = args.remove(0);
217 let args = args.remove(1);
218 (callable, args)
219 }
220 Object::Class {
221 module_name,
222 class_name,
223 } if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
224 let mut args = args.tuple()?;
225 args.remove(0).reduce()?
226 }
227 _ => (callable, args),
228 };
229 match callable {
230 Object::Class {
231 module_name,
232 class_name,
233 } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
234 _ => return Ok(None),
235 };
236 let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
237 Ok(Some(TensorInfo {
238 name,
239 dtype,
240 layout,
241 path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
242 storage_size,
243 }))
244 }
245}
246
247impl TryFrom<Object> for String {
248 type Error = Object;
249 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
250 match value {
251 Object::Unicode(s) => Ok(s),
252 other => Err(other),
253 }
254 }
255}
256
257impl TryFrom<Object> for usize {
258 type Error = Object;
259 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
260 match value {
261 Object::Int(s) if s >= 0 => Ok(s as usize),
262 other => Err(other),
263 }
264 }
265}
266
267impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
268 type Error = Object;
269 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
270 match value {
271 Object::Tuple(values) => {
272 values
275 .into_iter()
276 .map(|v| T::try_from(v))
277 .collect::<std::result::Result<Vec<T>, Self::Error>>()
278 }
279 other => Err(other),
280 }
281 }
282}
283
284#[derive(Debug)]
285pub struct Stack {
286 stack: Vec<Object>,
287 memo: HashMap<u32, Object>,
288}
289
290impl Stack {
291 pub fn empty() -> Self {
292 Self {
293 stack: Vec::with_capacity(512),
294 memo: HashMap::new(),
295 }
296 }
297
298 pub fn stack(&self) -> &[Object] {
299 self.stack.as_slice()
300 }
301
302 pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
303 loop {
304 if self.read(r)? {
305 break;
306 }
307 }
308 Ok(())
309 }
310
311 pub fn finalize(mut self) -> Result<Object> {
312 self.pop()
313 }
314
315 fn push(&mut self, obj: Object) {
316 self.stack.push(obj)
317 }
318
319 fn pop(&mut self) -> Result<Object> {
320 match self.stack.pop() {
321 None => crate::bail!("unexpected empty stack"),
322 Some(obj) => Ok(obj),
323 }
324 }
325
326 fn build(&mut self) -> Result<()> {
328 let args = self.pop()?;
329 let obj = self.pop()?;
330 let obj = match (obj, args) {
331 (Object::Dict(mut obj), Object::Dict(mut args)) => {
332 obj.append(&mut args);
333 Object::Dict(obj)
334 }
335 (obj, args) => Object::Build {
336 callable: Box::new(obj),
337 args: Box::new(args),
338 },
339 };
340 self.push(obj);
341 Ok(())
342 }
343
344 fn reduce(&mut self) -> Result<()> {
345 let args = self.pop()?;
346 let callable = self.pop()?;
347 #[allow(clippy::single_match)]
348 let reduced = match &callable {
349 Object::Class {
350 module_name,
351 class_name,
352 } => {
353 if module_name == "collections"
354 && (class_name == "OrderedDict" || class_name == "defaultdict")
355 {
356 Some(Object::Dict(vec![]))
358 } else {
359 None
360 }
361 }
362 _ => None,
363 };
364 let reduced = reduced.unwrap_or_else(|| Object::Reduce {
365 callable: Box::new(callable),
366 args: Box::new(args),
367 });
368 self.push(reduced);
369 Ok(())
370 }
371
372 fn last(&mut self) -> Result<&mut Object> {
373 match self.stack.last_mut() {
374 None => crate::bail!("unexpected empty stack"),
375 Some(obj) => Ok(obj),
376 }
377 }
378
379 fn memo_get(&self, id: u32) -> Result<Object> {
380 match self.memo.get(&id) {
381 None => crate::bail!("missing object in memo {id}"),
382 Some(obj) => {
383 Ok(obj.clone())
385 }
386 }
387 }
388
389 fn memo_put(&mut self, id: u32) -> Result<()> {
390 let obj = self.last()?.clone();
391 self.memo.insert(id, obj);
392 Ok(())
393 }
394
395 fn persistent_load(&self, id: Object) -> Result<Object> {
396 Ok(Object::PersistentLoad(Box::new(id)))
397 }
398
399 fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
400 Ok(Object::Reduce {
401 callable: Box::new(class),
402 args: Box::new(args),
403 })
404 }
405
406 fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
407 let mut mark_idx = None;
408 for (idx, obj) in self.stack.iter().enumerate().rev() {
409 if obj == &Object::Mark {
410 mark_idx = Some(idx);
411 break;
412 }
413 }
414 match mark_idx {
415 Some(mark_idx) => {
416 let objs = self.stack.split_off(mark_idx + 1);
417 self.stack.pop();
418 Ok(objs)
419 }
420 None => {
421 crate::bail!("marker object not found")
422 }
423 }
424 }
425
426 pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
427 let op_code = match OpCode::try_from(r.read_u8()?) {
428 Ok(op_code) => op_code,
429 Err(op_code) => {
430 crate::bail!("unknown op-code {op_code}")
431 }
432 };
433 match op_code {
436 OpCode::Proto => {
437 let version = r.read_u8()?;
438 if VERBOSE {
439 println!("proto {version}");
440 }
441 }
442 OpCode::Global => {
443 let module_name = read_to_newline(r)?;
444 let class_name = read_to_newline(r)?;
445 let module_name = String::from_utf8_lossy(&module_name).to_string();
446 let class_name = String::from_utf8_lossy(&class_name).to_string();
447 self.push(Object::Class {
448 module_name,
449 class_name,
450 })
451 }
452 OpCode::BinInt1 => {
453 let arg = r.read_u8()?;
454 self.push(Object::Int(arg as i32))
455 }
456 OpCode::BinInt2 => {
457 let arg = r.read_u16::<LittleEndian>()?;
458 self.push(Object::Int(arg as i32))
459 }
460 OpCode::BinInt => {
461 let arg = r.read_i32::<LittleEndian>()?;
462 self.push(Object::Int(arg))
463 }
464 OpCode::BinFloat => {
465 let arg = r.read_f64::<byteorder::BigEndian>()?;
469 self.push(Object::Float(arg))
470 }
471 OpCode::BinUnicode => {
472 let len = r.read_u32::<LittleEndian>()?;
473 let mut data = vec![0u8; len as usize];
474 r.read_exact(&mut data)?;
475 let data = String::from_utf8(data).map_err(E::wrap)?;
476 self.push(Object::Unicode(data))
477 }
478 OpCode::BinPersId => {
479 let id = self.pop()?;
480 let obj = self.persistent_load(id)?;
481 self.push(obj)
482 }
483 OpCode::Tuple => {
484 let objs = self.pop_to_marker()?;
485 self.push(Object::Tuple(objs))
486 }
487 OpCode::Tuple1 => {
488 let obj = self.pop()?;
489 self.push(Object::Tuple(vec![obj]))
490 }
491 OpCode::Tuple2 => {
492 let obj2 = self.pop()?;
493 let obj1 = self.pop()?;
494 self.push(Object::Tuple(vec![obj1, obj2]))
495 }
496 OpCode::Tuple3 => {
497 let obj3 = self.pop()?;
498 let obj2 = self.pop()?;
499 let obj1 = self.pop()?;
500 self.push(Object::Tuple(vec![obj1, obj2, obj3]))
501 }
502 OpCode::NewTrue => self.push(Object::Bool(true)),
503 OpCode::NewFalse => self.push(Object::Bool(false)),
504 OpCode::Append => {
505 let value = self.pop()?;
506 let pylist = self.last()?;
507 if let Object::List(d) = pylist {
508 d.push(value)
509 } else {
510 crate::bail!("expected a list, got {pylist:?}")
511 }
512 }
513 OpCode::Appends => {
514 let objs = self.pop_to_marker()?;
515 let pylist = self.last()?;
516 if let Object::List(d) = pylist {
517 d.extend(objs)
518 } else {
519 crate::bail!("expected a list, got {pylist:?}")
520 }
521 }
522 OpCode::SetItem => {
523 let value = self.pop()?;
524 let key = self.pop()?;
525 let pydict = self.last()?;
526 if let Object::Dict(d) = pydict {
527 d.push((key, value))
528 } else {
529 crate::bail!("expected a dict, got {pydict:?}")
530 }
531 }
532 OpCode::SetItems => {
533 let mut objs = self.pop_to_marker()?;
534 let pydict = self.last()?;
535 if let Object::Dict(d) = pydict {
536 if objs.len() % 2 != 0 {
537 crate::bail!("setitems: not an even number of objects")
538 }
539 while let Some(value) = objs.pop() {
540 let key = objs.pop().context("empty objs")?;
541 d.push((key, value))
542 }
543 } else {
544 crate::bail!("expected a dict, got {pydict:?}")
545 }
546 }
547 OpCode::None => self.push(Object::None),
548 OpCode::Stop => {
549 return Ok(true);
550 }
551 OpCode::Build => self.build()?,
552 OpCode::EmptyDict => self.push(Object::Dict(vec![])),
553 OpCode::Dict => {
554 let mut objs = self.pop_to_marker()?;
555 let mut pydict = vec![];
556 if objs.len() % 2 != 0 {
557 crate::bail!("setitems: not an even number of objects")
558 }
559 while let Some(value) = objs.pop() {
560 let key = objs.pop().context("empty objs")?;
561 pydict.push((key, value))
562 }
563 self.push(Object::Dict(pydict))
564 }
565 OpCode::Mark => self.push(Object::Mark),
566 OpCode::Reduce => self.reduce()?,
567 OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
568 OpCode::EmptyList => self.push(Object::List(vec![])),
569 OpCode::BinGet => {
570 let arg = r.read_u8()?;
571 let obj = self.memo_get(arg as u32)?;
572 self.push(obj)
573 }
574 OpCode::LongBinGet => {
575 let arg = r.read_u32::<LittleEndian>()?;
576 let obj = self.memo_get(arg)?;
577 self.push(obj)
578 }
579 OpCode::BinPut => {
580 let arg = r.read_u8()?;
581 self.memo_put(arg as u32)?
582 }
583 OpCode::LongBinPut => {
584 let arg = r.read_u32::<LittleEndian>()?;
585 self.memo_put(arg)?
586 }
587 OpCode::NewObj => {
588 let args = self.pop()?;
589 let class = self.pop()?;
590 let obj = self.new_obj(class, args)?;
591 self.push(obj)
592 }
593 }
594 Ok(false)
595 }
596}
597
598impl From<Object> for E {
599 fn from(value: Object) -> Self {
600 E::Msg(format!("conversion error on {value:?}"))
601 }
602}
603
604fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
607 let mut args = args.tuple()?;
608 let stride = Vec::<usize>::try_from(args.remove(3))?;
609 let size = Vec::<usize>::try_from(args.remove(2))?;
610 let offset = args.remove(1).int()? as usize;
611 let storage = args.remove(0).persistent_load()?;
612 let mut storage = storage.tuple()?;
613 let storage_size = storage.remove(4).int()? as usize;
614 let path = storage.remove(2).unicode()?;
615 let (_module_name, class_name) = storage.remove(1).class()?;
616 let dtype = match class_name.as_str() {
617 "FloatStorage" => DType::F32,
618 "DoubleStorage" => DType::F64,
619 "HalfStorage" => DType::F16,
620 "BFloat16Storage" => DType::BF16,
621 "ByteStorage" => DType::U8,
622 "LongStorage" => DType::I64,
623 other => {
624 crate::bail!("unsupported storage type {other}")
625 }
626 };
627 let layout = Layout::new(crate::Shape::from(size), stride, offset);
628 Ok((layout, dtype, path, storage_size))
629}
630
631#[derive(Debug, Clone)]
632pub struct TensorInfo {
633 pub name: String,
634 pub dtype: DType,
635 pub layout: Layout,
636 pub path: String,
637 pub storage_size: usize,
638}
639
640pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
647 file: P,
648 verbose: bool,
649 key: Option<&str>,
650) -> Result<Vec<TensorInfo>> {
651 let file = std::fs::File::open(file)?;
652 let zip_reader = std::io::BufReader::new(file);
653 let mut zip = zip::ZipArchive::new(zip_reader)?;
654 let zip_file_names = zip
655 .file_names()
656 .map(|f| f.to_string())
657 .collect::<Vec<String>>();
658
659 let mut tensor_infos = vec![];
660 for file_name in zip_file_names.iter() {
661 if !file_name.ends_with("data.pkl") {
662 continue;
663 }
664 let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
665 let reader = zip.by_name(file_name)?;
666 let mut reader = std::io::BufReader::new(reader);
667 let mut stack = Stack::empty();
668 stack.read_loop(&mut reader)?;
669 let obj = stack.finalize()?;
670 if VERBOSE || verbose {
671 println!("{obj:#?}");
672 }
673
674 let obj = match obj {
675 Object::Build { callable, args } => match *callable {
676 Object::Reduce { callable, args: _ } => match *callable {
677 Object::Class {
678 module_name,
679 class_name,
680 } if module_name == "__torch__" && class_name == "Module" => *args,
681 _ => continue,
682 },
683 _ => continue,
684 },
685 obj => obj,
686 };
687
688 let obj = if let Some(key) = key {
690 if let Object::Dict(key_values) = obj {
691 key_values
692 .into_iter()
693 .find(|(k, _)| *k == Object::Unicode(key.to_owned()))
694 .map(|(_, v)| v)
695 .ok_or_else(|| E::Msg(format!("key {key} not found")))?
696 } else {
697 obj
698 }
699 } else {
700 obj
701 };
702
703 if let Object::Dict(key_values) = obj {
706 for (name, value) in key_values.into_iter() {
707 match value.into_tensor_info(name, &dir_name) {
708 Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
709 Ok(None) => {}
710 Err(err) => eprintln!("skipping: {err:?}"),
711 }
712 }
713 }
714 }
715 Ok(tensor_infos)
716}
717
718pub struct PthTensors {
720 tensor_infos: HashMap<String, TensorInfo>,
721 path: std::path::PathBuf,
722 }
725
726impl PthTensors {
727 pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
728 let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
729 let tensor_infos = tensor_infos
730 .into_iter()
731 .map(|ti| (ti.name.to_string(), ti))
732 .collect();
733 let path = path.as_ref().to_owned();
734 Ok(Self { tensor_infos, path })
735 }
736
737 pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
738 &self.tensor_infos
739 }
740
741 pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
742 use std::io::Read;
743 let tensor_info = match self.tensor_infos.get(name) {
744 None => return Ok(None),
745 Some(tensor_info) => tensor_info,
746 };
747 let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
749 let mut zip = zip::ZipArchive::new(zip_reader)?;
750 let mut reader = zip.by_name(&tensor_info.path)?;
751 let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
752 let rank = tensor_info.layout.shape().rank();
753
754 if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
757 crate::bail!(
758 "cannot retrieve non-contiguous tensors {:?}",
759 tensor_info.layout
760 )
761 }
762 let start_offset = tensor_info.layout.start_offset();
763 if start_offset > 0 {
764 std::io::copy(
765 &mut reader.by_ref().take(start_offset as u64),
766 &mut std::io::sink(),
767 )?;
768 }
769 let tensor = Tensor::from_reader(
770 tensor_info.layout.shape().clone(),
771 tensor_info.dtype,
772 &mut reader,
773 )?;
774
775 if rank > 1 && is_fortran_contiguous {
776 let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
778 let tensor = tensor.reshape(shape_reversed)?;
779
780 let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
782 let tensor = tensor.permute(dim_indeces_reversed)?;
783 Ok(Some(tensor))
784 } else {
785 Ok(Some(tensor))
786 }
787 }
788}
789
790pub fn read_all_with_key<P: AsRef<std::path::Path>>(
797 path: P,
798 key: Option<&str>,
799) -> Result<Vec<(String, Tensor)>> {
800 let pth = PthTensors::new(path, key)?;
801 let tensor_names = pth.tensor_infos.keys();
802 let mut tensors = Vec::with_capacity(tensor_names.len());
803 for name in tensor_names {
804 if let Some(tensor) = pth.get(name)? {
805 tensors.push((name.to_string(), tensor))
806 }
807 }
808 Ok(tensors)
809}
810
811pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
816 read_all_with_key(path, None)
817}