1use crate::{DType, Error as E, Layout, Result, Tensor};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7use std::io::BufRead;
8
9const VERBOSE: bool = false;
10
11#[repr(u8)]
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub enum OpCode {
15 Proto = 0x80,
17 Global = b'c',
18 BinPut = b'q',
19 LongBinPut = b'r',
20 EmptyTuple = b')',
21 Reduce = b'R',
22 Mark = b'(',
23 BinUnicode = b'X',
24 BinInt = b'J',
25 Tuple = b't',
26 BinPersId = b'Q',
27 BinInt1 = b'K',
28 BinInt2 = b'M',
29 Tuple1 = 0x85,
30 Tuple2 = 0x86,
31 Tuple3 = 0x87,
32 NewTrue = 0x88,
33 NewFalse = 0x89,
34 None = b'N',
35 BinGet = b'h',
36 LongBinGet = b'j',
37 SetItem = b's',
38 SetItems = b'u',
39 EmptyDict = b'}',
40 Dict = b'd',
41 Build = b'b',
42 Stop = b'.',
43 NewObj = 0x81,
44 EmptyList = b']',
45 BinFloat = b'g',
46 Append = b'a',
47 Appends = b'e',
48}
49
50impl TryFrom<u8> for OpCode {
52 type Error = u8;
53 fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
54 match value {
55 0x80 => Ok(Self::Proto),
56 b'c' => Ok(Self::Global),
57 b'q' => Ok(Self::BinPut),
58 b'r' => Ok(Self::LongBinPut),
59 b')' => Ok(Self::EmptyTuple),
60 b'R' => Ok(Self::Reduce),
61 b'(' => Ok(Self::Mark),
62 b'X' => Ok(Self::BinUnicode),
63 b'J' => Ok(Self::BinInt),
64 b't' => Ok(Self::Tuple),
65 b'Q' => Ok(Self::BinPersId),
66 b'K' => Ok(Self::BinInt1),
67 b'M' => Ok(Self::BinInt2),
68 b'N' => Ok(Self::None),
69 0x85 => Ok(Self::Tuple1),
70 0x86 => Ok(Self::Tuple2),
71 0x87 => Ok(Self::Tuple3),
72 0x88 => Ok(Self::NewTrue),
73 0x89 => Ok(Self::NewFalse),
74 b'h' => Ok(Self::BinGet),
75 b'j' => Ok(Self::LongBinGet),
76 b's' => Ok(Self::SetItem),
77 b'u' => Ok(Self::SetItems),
78 b'}' => Ok(Self::EmptyDict),
79 b'd' => Ok(Self::EmptyDict),
80 b'b' => Ok(Self::Build),
81 b'.' => Ok(Self::Stop),
82 0x81 => Ok(Self::NewObj),
83 b']' => Ok(Self::EmptyList),
84 b'G' => Ok(Self::BinFloat),
85 b'a' => Ok(Self::Append),
86 b'e' => Ok(Self::Appends),
87 value => Err(value),
88 }
89 }
90}
91
92fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
93 let mut data: Vec<u8> = Vec::with_capacity(32);
94 r.read_until(b'\n', &mut data)?;
95 data.pop();
96 if data.last() == Some(&b'\r') {
97 data.pop();
98 }
99 Ok(data)
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum Object {
104 Class {
105 module_name: String,
106 class_name: String,
107 },
108 Int(i32),
109 Float(f64),
110 Unicode(String),
111 Bool(bool),
112 None,
113 Tuple(Vec<Object>),
114 List(Vec<Object>),
115 Mark,
116 Dict(Vec<(Object, Object)>),
117 Reduce {
118 callable: Box<Object>,
119 args: Box<Object>,
120 },
121 Build {
122 callable: Box<Object>,
123 args: Box<Object>,
124 },
125 PersistentLoad(Box<Object>),
126}
127
128type OResult<T> = std::result::Result<T, Object>;
129
130impl Object {
131 pub fn unicode(self) -> OResult<String> {
132 match self {
133 Self::Unicode(t) => Ok(t),
134 _ => Err(self),
135 }
136 }
137
138 pub fn reduce(self) -> OResult<(Self, Self)> {
139 match self {
140 Self::Reduce { callable, args } => Ok((*callable, *args)),
141 _ => Err(self),
142 }
143 }
144
145 pub fn none(self) -> OResult<()> {
146 match self {
147 Self::None => Ok(()),
148 _ => Err(self),
149 }
150 }
151
152 pub fn persistent_load(self) -> OResult<Self> {
153 match self {
154 Self::PersistentLoad(t) => Ok(*t),
155 _ => Err(self),
156 }
157 }
158
159 pub fn bool(self) -> OResult<bool> {
160 match self {
161 Self::Bool(t) => Ok(t),
162 _ => Err(self),
163 }
164 }
165
166 pub fn int(self) -> OResult<i32> {
167 match self {
168 Self::Int(t) => Ok(t),
169 _ => Err(self),
170 }
171 }
172
173 pub fn tuple(self) -> OResult<Vec<Self>> {
174 match self {
175 Self::Tuple(t) => Ok(t),
176 _ => Err(self),
177 }
178 }
179
180 pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
181 match self {
182 Self::Dict(t) => Ok(t),
183 _ => Err(self),
184 }
185 }
186
187 pub fn class(self) -> OResult<(String, String)> {
188 match self {
189 Self::Class {
190 module_name,
191 class_name,
192 } => Ok((module_name, class_name)),
193 _ => Err(self),
194 }
195 }
196
197 pub fn into_tensor_info(
198 self,
199 name: Self,
200 dir_name: &std::path::Path,
201 ) -> Result<Option<TensorInfo>> {
202 let name = match name.unicode() {
203 Ok(name) => name,
204 Err(_) => return Ok(None),
205 };
206 let (callable, args) = match self.reduce() {
207 Ok(callable_args) => callable_args,
208 _ => return Ok(None),
209 };
210 let (callable, args) = match callable {
211 Object::Class {
212 module_name,
213 class_name,
214 } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
215 let mut args = args.tuple()?;
216 let callable = args.remove(0);
217 let args = args.remove(1);
218 (callable, args)
219 }
220 _ => (callable, args),
221 };
222 match callable {
223 Object::Class {
224 module_name,
225 class_name,
226 } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
227 _ => return Ok(None),
228 };
229 let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
230 let mut path = dir_name.to_path_buf();
231 path.push(file_path);
232 Ok(Some(TensorInfo {
233 name,
234 dtype,
235 layout,
236 path: path.to_string_lossy().into_owned(),
237 storage_size,
238 }))
239 }
240}
241
242impl TryFrom<Object> for String {
243 type Error = Object;
244 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
245 match value {
246 Object::Unicode(s) => Ok(s),
247 other => Err(other),
248 }
249 }
250}
251
252impl TryFrom<Object> for usize {
253 type Error = Object;
254 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
255 match value {
256 Object::Int(s) if s >= 0 => Ok(s as usize),
257 other => Err(other),
258 }
259 }
260}
261
262impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
263 type Error = Object;
264 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
265 match value {
266 Object::Tuple(values) => {
267 values
270 .into_iter()
271 .map(|v| T::try_from(v))
272 .collect::<std::result::Result<Vec<T>, Self::Error>>()
273 }
274 other => Err(other),
275 }
276 }
277}
278
279#[derive(Debug)]
280pub struct Stack {
281 stack: Vec<Object>,
282 memo: HashMap<u32, Object>,
283}
284
285impl Stack {
286 pub fn empty() -> Self {
287 Self {
288 stack: Vec::with_capacity(512),
289 memo: HashMap::new(),
290 }
291 }
292
293 pub fn stack(&self) -> &[Object] {
294 self.stack.as_slice()
295 }
296
297 pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
298 loop {
299 if self.read(r)? {
300 break;
301 }
302 }
303 Ok(())
304 }
305
306 pub fn finalize(mut self) -> Result<Object> {
307 self.pop()
308 }
309
310 fn push(&mut self, obj: Object) {
311 self.stack.push(obj)
312 }
313
314 fn pop(&mut self) -> Result<Object> {
315 match self.stack.pop() {
316 None => crate::bail!("unexpected empty stack"),
317 Some(obj) => Ok(obj),
318 }
319 }
320
321 fn build(&mut self) -> Result<()> {
323 let args = self.pop()?;
324 let obj = self.pop()?;
325 let obj = match (obj, args) {
326 (Object::Dict(mut obj), Object::Dict(mut args)) => {
327 obj.append(&mut args);
328 Object::Dict(obj)
329 }
330 (obj, args) => Object::Build {
331 callable: Box::new(obj),
332 args: Box::new(args),
333 },
334 };
335 self.push(obj);
336 Ok(())
337 }
338
339 fn reduce(&mut self) -> Result<()> {
340 let args = self.pop()?;
341 let callable = self.pop()?;
342 #[allow(clippy::single_match)]
343 let reduced = match &callable {
344 Object::Class {
345 module_name,
346 class_name,
347 } => {
348 if module_name == "collections" && class_name == "OrderedDict" {
349 Some(Object::Dict(vec![]))
351 } else {
352 None
353 }
354 }
355 _ => None,
356 };
357 let reduced = reduced.unwrap_or_else(|| Object::Reduce {
358 callable: Box::new(callable),
359 args: Box::new(args),
360 });
361 self.push(reduced);
362 Ok(())
363 }
364
365 fn last(&mut self) -> Result<&mut Object> {
366 match self.stack.last_mut() {
367 None => crate::bail!("unexpected empty stack"),
368 Some(obj) => Ok(obj),
369 }
370 }
371
372 fn memo_get(&self, id: u32) -> Result<Object> {
373 match self.memo.get(&id) {
374 None => crate::bail!("missing object in memo {id}"),
375 Some(obj) => {
376 Ok(obj.clone())
378 }
379 }
380 }
381
382 fn memo_put(&mut self, id: u32) -> Result<()> {
383 let obj = self.last()?.clone();
384 self.memo.insert(id, obj);
385 Ok(())
386 }
387
388 fn persistent_load(&self, id: Object) -> Result<Object> {
389 Ok(Object::PersistentLoad(Box::new(id)))
390 }
391
392 fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
393 Ok(Object::Reduce {
394 callable: Box::new(class),
395 args: Box::new(args),
396 })
397 }
398
399 fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
400 let mut mark_idx = None;
401 for (idx, obj) in self.stack.iter().enumerate().rev() {
402 if obj == &Object::Mark {
403 mark_idx = Some(idx);
404 break;
405 }
406 }
407 match mark_idx {
408 Some(mark_idx) => {
409 let objs = self.stack.split_off(mark_idx + 1);
410 self.stack.pop();
411 Ok(objs)
412 }
413 None => {
414 crate::bail!("marker object not found")
415 }
416 }
417 }
418
419 pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
420 let op_code = match OpCode::try_from(r.read_u8()?) {
421 Ok(op_code) => op_code,
422 Err(op_code) => {
423 crate::bail!("unknown op-code {op_code}")
424 }
425 };
426 match op_code {
429 OpCode::Proto => {
430 let version = r.read_u8()?;
431 if VERBOSE {
432 println!("proto {version}");
433 }
434 }
435 OpCode::Global => {
436 let module_name = read_to_newline(r)?;
437 let class_name = read_to_newline(r)?;
438 let module_name = String::from_utf8_lossy(&module_name).to_string();
439 let class_name = String::from_utf8_lossy(&class_name).to_string();
440 self.push(Object::Class {
441 module_name,
442 class_name,
443 })
444 }
445 OpCode::BinInt1 => {
446 let arg = r.read_u8()?;
447 self.push(Object::Int(arg as i32))
448 }
449 OpCode::BinInt2 => {
450 let arg = r.read_u16::<LittleEndian>()?;
451 self.push(Object::Int(arg as i32))
452 }
453 OpCode::BinInt => {
454 let arg = r.read_i32::<LittleEndian>()?;
455 self.push(Object::Int(arg))
456 }
457 OpCode::BinFloat => {
458 let arg = r.read_f64::<LittleEndian>()?;
459 self.push(Object::Float(arg))
460 }
461 OpCode::BinUnicode => {
462 let len = r.read_u32::<LittleEndian>()?;
463 let mut data = vec![0u8; len as usize];
464 r.read_exact(&mut data)?;
465 let data = String::from_utf8(data).map_err(E::wrap)?;
466 self.push(Object::Unicode(data))
467 }
468 OpCode::BinPersId => {
469 let id = self.pop()?;
470 let obj = self.persistent_load(id)?;
471 self.push(obj)
472 }
473 OpCode::Tuple => {
474 let objs = self.pop_to_marker()?;
475 self.push(Object::Tuple(objs))
476 }
477 OpCode::Tuple1 => {
478 let obj = self.pop()?;
479 self.push(Object::Tuple(vec![obj]))
480 }
481 OpCode::Tuple2 => {
482 let obj2 = self.pop()?;
483 let obj1 = self.pop()?;
484 self.push(Object::Tuple(vec![obj1, obj2]))
485 }
486 OpCode::Tuple3 => {
487 let obj3 = self.pop()?;
488 let obj2 = self.pop()?;
489 let obj1 = self.pop()?;
490 self.push(Object::Tuple(vec![obj1, obj2, obj3]))
491 }
492 OpCode::NewTrue => self.push(Object::Bool(true)),
493 OpCode::NewFalse => self.push(Object::Bool(false)),
494 OpCode::Append => {
495 let value = self.pop()?;
496 let pylist = self.last()?;
497 if let Object::List(d) = pylist {
498 d.push(value)
499 } else {
500 crate::bail!("expected a list, got {pylist:?}")
501 }
502 }
503 OpCode::Appends => {
504 let objs = self.pop_to_marker()?;
505 let pylist = self.last()?;
506 if let Object::List(d) = pylist {
507 d.extend(objs)
508 } else {
509 crate::bail!("expected a list, got {pylist:?}")
510 }
511 }
512 OpCode::SetItem => {
513 let value = self.pop()?;
514 let key = self.pop()?;
515 let pydict = self.last()?;
516 if let Object::Dict(d) = pydict {
517 d.push((key, value))
518 } else {
519 crate::bail!("expected a dict, got {pydict:?}")
520 }
521 }
522 OpCode::SetItems => {
523 let mut objs = self.pop_to_marker()?;
524 let pydict = self.last()?;
525 if let Object::Dict(d) = pydict {
526 if objs.len() % 2 != 0 {
527 crate::bail!("setitems: not an even number of objects")
528 }
529 while let Some(value) = objs.pop() {
530 let key = objs.pop().unwrap();
531 d.push((key, value))
532 }
533 } else {
534 crate::bail!("expected a dict, got {pydict:?}")
535 }
536 }
537 OpCode::None => self.push(Object::None),
538 OpCode::Stop => {
539 return Ok(true);
540 }
541 OpCode::Build => self.build()?,
542 OpCode::EmptyDict => self.push(Object::Dict(vec![])),
543 OpCode::Dict => {
544 let mut objs = self.pop_to_marker()?;
545 let mut pydict = vec![];
546 if objs.len() % 2 != 0 {
547 crate::bail!("setitems: not an even number of objects")
548 }
549 while let Some(value) = objs.pop() {
550 let key = objs.pop().unwrap();
551 pydict.push((key, value))
552 }
553 self.push(Object::Dict(pydict))
554 }
555 OpCode::Mark => self.push(Object::Mark),
556 OpCode::Reduce => self.reduce()?,
557 OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
558 OpCode::EmptyList => self.push(Object::List(vec![])),
559 OpCode::BinGet => {
560 let arg = r.read_u8()?;
561 let obj = self.memo_get(arg as u32)?;
562 self.push(obj)
563 }
564 OpCode::LongBinGet => {
565 let arg = r.read_u32::<LittleEndian>()?;
566 let obj = self.memo_get(arg)?;
567 self.push(obj)
568 }
569 OpCode::BinPut => {
570 let arg = r.read_u8()?;
571 self.memo_put(arg as u32)?
572 }
573 OpCode::LongBinPut => {
574 let arg = r.read_u32::<LittleEndian>()?;
575 self.memo_put(arg)?
576 }
577 OpCode::NewObj => {
578 let args = self.pop()?;
579 let class = self.pop()?;
580 let obj = self.new_obj(class, args)?;
581 self.push(obj)
582 }
583 }
584 Ok(false)
585 }
586}
587
588impl From<Object> for E {
589 fn from(value: Object) -> Self {
590 E::Msg(format!("conversion error on {value:?}"))
591 }
592}
593
594fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
597 let mut args = args.tuple()?;
598 let stride = Vec::<usize>::try_from(args.remove(3))?;
599 let size = Vec::<usize>::try_from(args.remove(2))?;
600 let offset = args.remove(1).int()? as usize;
601 let storage = args.remove(0).persistent_load()?;
602 let mut storage = storage.tuple()?;
603 let storage_size = storage.remove(4).int()? as usize;
604 let path = storage.remove(2).unicode()?;
605 let (_module_name, class_name) = storage.remove(1).class()?;
606 let dtype = match class_name.as_str() {
607 "FloatStorage" => DType::F32,
608 "DoubleStorage" => DType::F64,
609 "HalfStorage" => DType::F16,
610 "BFloat16Storage" => DType::BF16,
611 "ByteStorage" => DType::U8,
612 "LongStorage" => DType::I64,
613 other => {
614 crate::bail!("unsupported storage type {other}")
615 }
616 };
617 let layout = Layout::new(crate::Shape::from(size), stride, offset);
618 Ok((layout, dtype, path, storage_size))
619}
620
621#[derive(Debug, Clone)]
622pub struct TensorInfo {
623 pub name: String,
624 pub dtype: DType,
625 pub layout: Layout,
626 pub path: String,
627 pub storage_size: usize,
628}
629
630pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
631 file: P,
632 verbose: bool,
633) -> Result<Vec<TensorInfo>> {
634 let file = std::fs::File::open(file)?;
635 let zip_reader = std::io::BufReader::new(file);
636 let mut zip = zip::ZipArchive::new(zip_reader)?;
637 let zip_file_names = zip
638 .file_names()
639 .map(|f| f.to_string())
640 .collect::<Vec<String>>();
641
642 let mut tensor_infos = vec![];
643 for file_name in zip_file_names.iter() {
644 if !file_name.ends_with("data.pkl") {
645 continue;
646 }
647 let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
648 let reader = zip.by_name(file_name)?;
649 let mut reader = std::io::BufReader::new(reader);
650 let mut stack = Stack::empty();
651 stack.read_loop(&mut reader)?;
652 let obj = stack.finalize()?;
653 if VERBOSE || verbose {
654 println!("{obj:?}");
655 }
656 let obj = match obj {
657 Object::Build { callable, args } => match *callable {
658 Object::Reduce { callable, args: _ } => match *callable {
659 Object::Class {
660 module_name,
661 class_name,
662 } if module_name == "__torch__" && class_name == "Module" => *args,
663 _ => continue,
664 },
665 _ => continue,
666 },
667 obj => obj,
668 };
669 if let Object::Dict(key_values) = obj {
670 for (name, value) in key_values.into_iter() {
671 match value.into_tensor_info(name, &dir_name) {
672 Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
673 Ok(None) => {}
674 Err(err) => eprintln!("skipping: {err:?}"),
675 }
676 }
677 }
678 }
679 Ok(tensor_infos)
680}
681
682pub struct PthTensors {
684 tensor_infos: HashMap<String, TensorInfo>,
685 path: std::path::PathBuf,
686 }
689
690impl PthTensors {
691 pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
692 let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
693 let tensor_infos = tensor_infos
694 .into_iter()
695 .map(|ti| (ti.name.to_string(), ti))
696 .collect();
697 let path = path.as_ref().to_owned();
698 Ok(Self { tensor_infos, path })
699 }
700
701 pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
702 &self.tensor_infos
703 }
704
705 pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
706 let tensor_info = match self.tensor_infos.get(name) {
707 None => return Ok(None),
708 Some(tensor_info) => tensor_info,
709 };
710 let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
712 let mut zip = zip::ZipArchive::new(zip_reader)?;
713 let mut reader = zip.by_name(&tensor_info.path)?;
714
715 if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
718 crate::bail!(
719 "cannot retrieve non-contiguous tensors {:?}",
720 tensor_info.layout
721 )
722 }
723 let tensor = Tensor::from_reader(
724 tensor_info.layout.shape().clone(),
725 tensor_info.dtype,
726 &mut reader,
727 )?;
728 Ok(Some(tensor))
729 }
730}
731
732pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
734 let pth = PthTensors::new(path)?;
735 let tensor_names = pth.tensor_infos.keys();
736 let mut tensors = Vec::with_capacity(tensor_names.len());
737 for name in tensor_names {
738 if let Some(tensor) = pth.get(name)? {
739 tensors.push((name.to_string(), tensor))
740 }
741 }
742 Ok(tensors)
743}