1use crate::{Context, DType, Error as E, Layout, Result, Tensor};
5use byteorder::{LittleEndian, ReadBytesExt};
6use std::collections::HashMap;
7use std::io::BufRead;
8
9const VERBOSE: bool = false;
10
11#[repr(u8)]
13#[derive(Debug, Eq, PartialEq, Clone)]
14pub enum OpCode {
15 Proto = 0x80,
17 Global = b'c',
18 BinPut = b'q',
19 LongBinPut = b'r',
20 EmptyTuple = b')',
21 Reduce = b'R',
22 Mark = b'(',
23 BinUnicode = b'X',
24 BinInt = b'J',
25 Tuple = b't',
26 BinPersId = b'Q',
27 BinInt1 = b'K',
28 BinInt2 = b'M',
29 Tuple1 = 0x85,
30 Tuple2 = 0x86,
31 Tuple3 = 0x87,
32 NewTrue = 0x88,
33 NewFalse = 0x89,
34 None = b'N',
35 BinGet = b'h',
36 LongBinGet = b'j',
37 SetItem = b's',
38 SetItems = b'u',
39 EmptyDict = b'}',
40 Dict = b'd',
41 Build = b'b',
42 Stop = b'.',
43 NewObj = 0x81,
44 EmptyList = b']',
45 BinFloat = b'G',
46 Append = b'a',
47 Appends = b'e',
48 Long1 = 0x8a,
49}
50
51impl TryFrom<u8> for OpCode {
53 type Error = u8;
54 fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
55 match value {
56 0x80 => Ok(Self::Proto),
57 b'c' => Ok(Self::Global),
58 b'q' => Ok(Self::BinPut),
59 b'r' => Ok(Self::LongBinPut),
60 b')' => Ok(Self::EmptyTuple),
61 b'R' => Ok(Self::Reduce),
62 b'(' => Ok(Self::Mark),
63 b'X' => Ok(Self::BinUnicode),
64 b'J' => Ok(Self::BinInt),
65 b't' => Ok(Self::Tuple),
66 b'Q' => Ok(Self::BinPersId),
67 b'K' => Ok(Self::BinInt1),
68 b'M' => Ok(Self::BinInt2),
69 b'N' => Ok(Self::None),
70 0x85 => Ok(Self::Tuple1),
71 0x86 => Ok(Self::Tuple2),
72 0x87 => Ok(Self::Tuple3),
73 0x88 => Ok(Self::NewTrue),
74 0x89 => Ok(Self::NewFalse),
75 b'h' => Ok(Self::BinGet),
76 b'j' => Ok(Self::LongBinGet),
77 b's' => Ok(Self::SetItem),
78 b'u' => Ok(Self::SetItems),
79 b'}' => Ok(Self::EmptyDict),
80 b'd' => Ok(Self::EmptyDict),
81 b'b' => Ok(Self::Build),
82 b'.' => Ok(Self::Stop),
83 0x81 => Ok(Self::NewObj),
84 b']' => Ok(Self::EmptyList),
85 b'G' => Ok(Self::BinFloat),
86 b'a' => Ok(Self::Append),
87 b'e' => Ok(Self::Appends),
88 0x8a => Ok(Self::Long1),
89 value => Err(value),
90 }
91 }
92}
93
94fn read_to_newline<R: BufRead>(r: &mut R) -> Result<Vec<u8>> {
95 let mut data: Vec<u8> = Vec::with_capacity(32);
96 r.read_until(b'\n', &mut data)?;
97 data.pop();
98 if data.last() == Some(&b'\r') {
99 data.pop();
100 }
101 Ok(data)
102}
103
104#[derive(Debug, Clone, PartialEq)]
105pub enum Object {
106 Class {
107 module_name: String,
108 class_name: String,
109 },
110 Int(i32),
111 Long(i64),
112 Float(f64),
113 Unicode(String),
114 Bool(bool),
115 None,
116 Tuple(Vec<Object>),
117 List(Vec<Object>),
118 Mark,
119 Dict(Vec<(Object, Object)>),
120 Reduce {
121 callable: Box<Object>,
122 args: Box<Object>,
123 },
124 Build {
125 callable: Box<Object>,
126 args: Box<Object>,
127 },
128 PersistentLoad(Box<Object>),
129}
130
131type OResult<T> = std::result::Result<T, Object>;
132
133impl Object {
134 pub fn unicode(self) -> OResult<String> {
135 match self {
136 Self::Unicode(t) => Ok(t),
137 _ => Err(self),
138 }
139 }
140
141 pub fn reduce(self) -> OResult<(Self, Self)> {
142 match self {
143 Self::Reduce { callable, args } => Ok((*callable, *args)),
144 _ => Err(self),
145 }
146 }
147
148 pub fn none(self) -> OResult<()> {
149 match self {
150 Self::None => Ok(()),
151 _ => Err(self),
152 }
153 }
154
155 pub fn persistent_load(self) -> OResult<Self> {
156 match self {
157 Self::PersistentLoad(t) => Ok(*t),
158 _ => Err(self),
159 }
160 }
161
162 pub fn bool(self) -> OResult<bool> {
163 match self {
164 Self::Bool(t) => Ok(t),
165 _ => Err(self),
166 }
167 }
168
169 pub fn int(self) -> OResult<i32> {
170 match self {
171 Self::Int(t) => Ok(t),
172 _ => Err(self),
173 }
174 }
175
176 pub fn int_or_long(self) -> OResult<i64> {
177 match self {
178 Self::Int(t) => Ok(t as i64),
179 Self::Long(t) => Ok(t),
180 _ => Err(self),
181 }
182 }
183
184 pub fn tuple(self) -> OResult<Vec<Self>> {
185 match self {
186 Self::Tuple(t) => Ok(t),
187 _ => Err(self),
188 }
189 }
190
191 pub fn dict(self) -> OResult<Vec<(Self, Self)>> {
192 match self {
193 Self::Dict(t) => Ok(t),
194 _ => Err(self),
195 }
196 }
197
198 pub fn class(self) -> OResult<(String, String)> {
199 match self {
200 Self::Class {
201 module_name,
202 class_name,
203 } => Ok((module_name, class_name)),
204 _ => Err(self),
205 }
206 }
207
208 pub fn into_tensor_info(
209 self,
210 name: Self,
211 dir_name: &std::path::Path,
212 ) -> Result<Option<TensorInfo>> {
213 let name = match name.unicode() {
214 Ok(name) => name,
215 Err(_) => return Ok(None),
216 };
217 let (callable, args) = match self.reduce() {
218 Ok(callable_args) => callable_args,
219 _ => return Ok(None),
220 };
221 let (callable, args) = match callable {
222 Object::Class {
223 module_name,
224 class_name,
225 } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
226 let mut args = args.tuple()?;
227 let callable = args.remove(0);
228 let args = args.remove(1);
229 (callable, args)
230 }
231 Object::Class {
232 module_name,
233 class_name,
234 } if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
235 let mut args = args.tuple()?;
236 args.remove(0).reduce()?
237 }
238 _ => (callable, args),
239 };
240 match callable {
241 Object::Class {
242 module_name,
243 class_name,
244 } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
245 _ => return Ok(None),
246 };
247 let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
248 Ok(Some(TensorInfo {
249 name,
250 dtype,
251 layout,
252 path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
253 storage_size,
254 }))
255 }
256}
257
258impl TryFrom<Object> for String {
259 type Error = Object;
260 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
261 match value {
262 Object::Unicode(s) => Ok(s),
263 other => Err(other),
264 }
265 }
266}
267
268impl TryFrom<Object> for usize {
269 type Error = Object;
270 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
271 match value {
272 Object::Int(s) if s >= 0 => Ok(s as usize),
273 other => Err(other),
274 }
275 }
276}
277
278impl<T: TryFrom<Object, Error = Object>> TryFrom<Object> for Vec<T> {
279 type Error = Object;
280 fn try_from(value: Object) -> std::result::Result<Self, Self::Error> {
281 match value {
282 Object::Tuple(values) => {
283 values
286 .into_iter()
287 .map(|v| T::try_from(v))
288 .collect::<std::result::Result<Vec<T>, Self::Error>>()
289 }
290 other => Err(other),
291 }
292 }
293}
294
295#[derive(Debug)]
296pub struct Stack {
297 stack: Vec<Object>,
298 memo: HashMap<u32, Object>,
299}
300
301impl Stack {
302 pub fn empty() -> Self {
303 Self {
304 stack: Vec::with_capacity(512),
305 memo: HashMap::new(),
306 }
307 }
308
309 pub fn stack(&self) -> &[Object] {
310 self.stack.as_slice()
311 }
312
313 pub fn read_loop<R: BufRead>(&mut self, r: &mut R) -> Result<()> {
314 loop {
315 if self.read(r)? {
316 break;
317 }
318 }
319 Ok(())
320 }
321
322 pub fn finalize(mut self) -> Result<Object> {
323 self.pop()
324 }
325
326 fn push(&mut self, obj: Object) {
327 self.stack.push(obj)
328 }
329
330 fn pop(&mut self) -> Result<Object> {
331 match self.stack.pop() {
332 None => crate::bail!("unexpected empty stack"),
333 Some(obj) => Ok(obj),
334 }
335 }
336
337 fn build(&mut self) -> Result<()> {
339 let args = self.pop()?;
340 let obj = self.pop()?;
341 let obj = match (obj, args) {
342 (Object::Dict(mut obj), Object::Dict(mut args)) => {
343 obj.append(&mut args);
344 Object::Dict(obj)
345 }
346 (obj, args) => Object::Build {
347 callable: Box::new(obj),
348 args: Box::new(args),
349 },
350 };
351 self.push(obj);
352 Ok(())
353 }
354
355 fn reduce(&mut self) -> Result<()> {
356 let args = self.pop()?;
357 let callable = self.pop()?;
358 #[allow(clippy::single_match)]
359 let reduced = match &callable {
360 Object::Class {
361 module_name,
362 class_name,
363 } => {
364 if module_name == "collections"
365 && (class_name == "OrderedDict" || class_name == "defaultdict")
366 {
367 Some(Object::Dict(vec![]))
369 } else {
370 None
371 }
372 }
373 _ => None,
374 };
375 let reduced = reduced.unwrap_or_else(|| Object::Reduce {
376 callable: Box::new(callable),
377 args: Box::new(args),
378 });
379 self.push(reduced);
380 Ok(())
381 }
382
383 fn last(&mut self) -> Result<&mut Object> {
384 match self.stack.last_mut() {
385 None => crate::bail!("unexpected empty stack"),
386 Some(obj) => Ok(obj),
387 }
388 }
389
390 fn memo_get(&self, id: u32) -> Result<Object> {
391 match self.memo.get(&id) {
392 None => crate::bail!("missing object in memo {id}"),
393 Some(obj) => {
394 Ok(obj.clone())
396 }
397 }
398 }
399
400 fn memo_put(&mut self, id: u32) -> Result<()> {
401 let obj = self.last()?.clone();
402 self.memo.insert(id, obj);
403 Ok(())
404 }
405
406 fn persistent_load(&self, id: Object) -> Result<Object> {
407 Ok(Object::PersistentLoad(Box::new(id)))
408 }
409
410 fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
411 Ok(Object::Reduce {
412 callable: Box::new(class),
413 args: Box::new(args),
414 })
415 }
416
417 fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
418 let mut mark_idx = None;
419 for (idx, obj) in self.stack.iter().enumerate().rev() {
420 if obj == &Object::Mark {
421 mark_idx = Some(idx);
422 break;
423 }
424 }
425 match mark_idx {
426 Some(mark_idx) => {
427 let objs = self.stack.split_off(mark_idx + 1);
428 self.stack.pop();
429 Ok(objs)
430 }
431 None => {
432 crate::bail!("marker object not found")
433 }
434 }
435 }
436
437 pub fn read<R: BufRead>(&mut self, r: &mut R) -> Result<bool> {
438 let op_code = match OpCode::try_from(r.read_u8()?) {
439 Ok(op_code) => op_code,
440 Err(op_code) => {
441 crate::bail!("unknown op-code {op_code}")
442 }
443 };
444 match op_code {
447 OpCode::Proto => {
448 let version = r.read_u8()?;
449 if VERBOSE {
450 println!("proto {version}");
451 }
452 }
453 OpCode::Global => {
454 let module_name = read_to_newline(r)?;
455 let class_name = read_to_newline(r)?;
456 let module_name = String::from_utf8_lossy(&module_name).to_string();
457 let class_name = String::from_utf8_lossy(&class_name).to_string();
458 self.push(Object::Class {
459 module_name,
460 class_name,
461 })
462 }
463 OpCode::BinInt1 => {
464 let arg = r.read_u8()?;
465 self.push(Object::Int(arg as i32))
466 }
467 OpCode::BinInt2 => {
468 let arg = r.read_u16::<LittleEndian>()?;
469 self.push(Object::Int(arg as i32))
470 }
471 OpCode::BinInt => {
472 let arg = r.read_i32::<LittleEndian>()?;
473 self.push(Object::Int(arg))
474 }
475 OpCode::BinFloat => {
476 let arg = r.read_f64::<byteorder::BigEndian>()?;
480 self.push(Object::Float(arg))
481 }
482 OpCode::BinUnicode => {
483 let len = r.read_u32::<LittleEndian>()?;
484 let mut data = vec![0u8; len as usize];
485 r.read_exact(&mut data)?;
486 let data = String::from_utf8(data).map_err(E::wrap)?;
487 self.push(Object::Unicode(data))
488 }
489 OpCode::BinPersId => {
490 let id = self.pop()?;
491 let obj = self.persistent_load(id)?;
492 self.push(obj)
493 }
494 OpCode::Tuple => {
495 let objs = self.pop_to_marker()?;
496 self.push(Object::Tuple(objs))
497 }
498 OpCode::Tuple1 => {
499 let obj = self.pop()?;
500 self.push(Object::Tuple(vec![obj]))
501 }
502 OpCode::Tuple2 => {
503 let obj2 = self.pop()?;
504 let obj1 = self.pop()?;
505 self.push(Object::Tuple(vec![obj1, obj2]))
506 }
507 OpCode::Tuple3 => {
508 let obj3 = self.pop()?;
509 let obj2 = self.pop()?;
510 let obj1 = self.pop()?;
511 self.push(Object::Tuple(vec![obj1, obj2, obj3]))
512 }
513 OpCode::NewTrue => self.push(Object::Bool(true)),
514 OpCode::NewFalse => self.push(Object::Bool(false)),
515 OpCode::Append => {
516 let value = self.pop()?;
517 let pylist = self.last()?;
518 if let Object::List(d) = pylist {
519 d.push(value)
520 } else {
521 crate::bail!("expected a list, got {pylist:?}")
522 }
523 }
524 OpCode::Appends => {
525 let objs = self.pop_to_marker()?;
526 let pylist = self.last()?;
527 if let Object::List(d) = pylist {
528 d.extend(objs)
529 } else {
530 crate::bail!("expected a list, got {pylist:?}")
531 }
532 }
533 OpCode::SetItem => {
534 let value = self.pop()?;
535 let key = self.pop()?;
536 let pydict = self.last()?;
537 if let Object::Dict(d) = pydict {
538 d.push((key, value))
539 } else {
540 crate::bail!("expected a dict, got {pydict:?}")
541 }
542 }
543 OpCode::SetItems => {
544 let mut objs = self.pop_to_marker()?;
545 let pydict = self.last()?;
546 if let Object::Dict(d) = pydict {
547 if objs.len() % 2 != 0 {
548 crate::bail!("setitems: not an even number of objects")
549 }
550 while let Some(value) = objs.pop() {
551 let key = objs.pop().context("empty objs")?;
552 d.push((key, value))
553 }
554 } else {
555 crate::bail!("expected a dict, got {pydict:?}")
556 }
557 }
558 OpCode::None => self.push(Object::None),
559 OpCode::Stop => {
560 return Ok(true);
561 }
562 OpCode::Build => self.build()?,
563 OpCode::EmptyDict => self.push(Object::Dict(vec![])),
564 OpCode::Dict => {
565 let mut objs = self.pop_to_marker()?;
566 let mut pydict = vec![];
567 if objs.len() % 2 != 0 {
568 crate::bail!("setitems: not an even number of objects")
569 }
570 while let Some(value) = objs.pop() {
571 let key = objs.pop().context("empty objs")?;
572 pydict.push((key, value))
573 }
574 self.push(Object::Dict(pydict))
575 }
576 OpCode::Mark => self.push(Object::Mark),
577 OpCode::Reduce => self.reduce()?,
578 OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
579 OpCode::EmptyList => self.push(Object::List(vec![])),
580 OpCode::BinGet => {
581 let arg = r.read_u8()?;
582 let obj = self.memo_get(arg as u32)?;
583 self.push(obj)
584 }
585 OpCode::LongBinGet => {
586 let arg = r.read_u32::<LittleEndian>()?;
587 let obj = self.memo_get(arg)?;
588 self.push(obj)
589 }
590 OpCode::BinPut => {
591 let arg = r.read_u8()?;
592 self.memo_put(arg as u32)?
593 }
594 OpCode::LongBinPut => {
595 let arg = r.read_u32::<LittleEndian>()?;
596 self.memo_put(arg)?
597 }
598 OpCode::NewObj => {
599 let args = self.pop()?;
600 let class = self.pop()?;
601 let obj = self.new_obj(class, args)?;
602 self.push(obj)
603 }
604 OpCode::Long1 => {
605 let n_bytes = r.read_u8()?;
606 let mut v = 0;
607 for i in 0..n_bytes {
609 v |= (r.read_u8()? as i64) << (i * 8);
610 }
611 self.push(Object::Long(v))
612 }
613 }
614 Ok(false)
615 }
616}
617
618impl From<Object> for E {
619 fn from(value: Object) -> Self {
620 E::Msg(format!("conversion error on {value:?}"))
621 }
622}
623
624fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
627 let mut args = args.tuple()?;
628 let stride = Vec::<usize>::try_from(args.remove(3))?;
629 let size = Vec::<usize>::try_from(args.remove(2))?;
630 let offset = args.remove(1).int_or_long()? as usize;
631 let storage = args.remove(0).persistent_load()?;
632 let mut storage = storage.tuple()?;
633 let storage_size = storage.remove(4).int_or_long()? as usize;
634 let path = storage.remove(2).unicode()?;
635 let (_module_name, class_name) = storage.remove(1).class()?;
636 let dtype = match class_name.as_str() {
637 "FloatStorage" => DType::F32,
638 "DoubleStorage" => DType::F64,
639 "HalfStorage" => DType::F16,
640 "BFloat16Storage" => DType::BF16,
641 "ByteStorage" => DType::U8,
642 "LongStorage" => DType::I64,
643 other => {
644 crate::bail!("unsupported storage type {other}")
645 }
646 };
647 let layout = Layout::new(
648 crate::Shape::from(size),
649 stride,
650 offset * dtype.size_in_bytes(),
651 );
652 Ok((layout, dtype, path, storage_size))
653}
654
655#[derive(Debug, Clone)]
656pub struct TensorInfo {
657 pub name: String,
658 pub dtype: DType,
659 pub layout: Layout,
660 pub path: String,
661 pub storage_size: usize,
662}
663
664pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
671 file: P,
672 verbose: bool,
673 key: Option<&str>,
674) -> Result<Vec<TensorInfo>> {
675 let file = std::fs::File::open(file)?;
676 let zip_reader = std::io::BufReader::new(file);
677 let mut zip = zip::ZipArchive::new(zip_reader)?;
678 let zip_file_names = zip
679 .file_names()
680 .map(|f| f.to_string())
681 .collect::<Vec<String>>();
682
683 let mut tensor_infos = vec![];
684 for file_name in zip_file_names.iter() {
685 if !file_name.ends_with("data.pkl") {
686 continue;
687 }
688 let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").context("no .pkl")?);
689 let reader = zip.by_name(file_name)?;
690 let mut reader = std::io::BufReader::new(reader);
691 let mut stack = Stack::empty();
692 stack.read_loop(&mut reader)?;
693 let obj = stack.finalize()?;
694 if VERBOSE || verbose {
695 println!("{obj:#?}");
696 }
697
698 let obj = match obj {
699 Object::Build { callable, args } => match *callable {
700 Object::Reduce { callable, args: _ } => match *callable {
701 Object::Class {
702 module_name,
703 class_name,
704 } if module_name == "__torch__" && class_name == "Module" => *args,
705 _ => continue,
706 },
707 _ => continue,
708 },
709 obj => obj,
710 };
711
712 let obj = if let Some(key) = key {
714 if let Object::Dict(key_values) = obj {
715 key_values
716 .into_iter()
717 .find(|(k, _)| *k == Object::Unicode(key.to_owned()))
718 .map(|(_, v)| v)
719 .ok_or_else(|| E::Msg(format!("key {key} not found")))?
720 } else {
721 obj
722 }
723 } else {
724 obj
725 };
726
727 if let Object::Dict(key_values) = obj {
730 for (name, value) in key_values.into_iter() {
731 match value.into_tensor_info(name, &dir_name) {
732 Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
733 Ok(None) => {}
734 Err(err) => eprintln!("skipping: {err:?}"),
735 }
736 }
737 }
738 }
739 Ok(tensor_infos)
740}
741
742pub struct PthTensors {
744 tensor_infos: HashMap<String, TensorInfo>,
745 path: std::path::PathBuf,
746 }
749
750impl PthTensors {
751 pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
752 let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
753 let tensor_infos = tensor_infos
754 .into_iter()
755 .map(|ti| (ti.name.to_string(), ti))
756 .collect();
757 let path = path.as_ref().to_owned();
758 Ok(Self { tensor_infos, path })
759 }
760
761 pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
762 &self.tensor_infos
763 }
764
765 pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
766 use std::io::Read;
767 let tensor_info = match self.tensor_infos.get(name) {
768 None => return Ok(None),
769 Some(tensor_info) => tensor_info,
770 };
771 let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
773 let mut zip = zip::ZipArchive::new(zip_reader)?;
774 let mut reader = zip.by_name(&tensor_info.path)?;
775 let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
776 let rank = tensor_info.layout.shape().rank();
777
778 if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
781 crate::bail!(
782 "cannot retrieve non-contiguous tensors {:?}",
783 tensor_info.layout
784 )
785 }
786 let start_offset = tensor_info.layout.start_offset();
787 if start_offset > 0 {
788 std::io::copy(
789 &mut reader.by_ref().take(start_offset as u64),
790 &mut std::io::sink(),
791 )?;
792 }
793 let tensor = Tensor::from_reader(
794 tensor_info.layout.shape().clone(),
795 tensor_info.dtype,
796 &mut reader,
797 )?;
798
799 if rank > 1 && is_fortran_contiguous {
800 let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
802 let tensor = tensor.reshape(shape_reversed)?;
803
804 let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
806 let tensor = tensor.permute(dim_indeces_reversed)?;
807 Ok(Some(tensor))
808 } else {
809 Ok(Some(tensor))
810 }
811 }
812}
813
814pub fn read_all_with_key<P: AsRef<std::path::Path>>(
821 path: P,
822 key: Option<&str>,
823) -> Result<Vec<(String, Tensor)>> {
824 let pth = PthTensors::new(path, key)?;
825 let tensor_names = pth.tensor_infos.keys();
826 let mut tensors = Vec::with_capacity(tensor_names.len());
827 for name in tensor_names {
828 if let Some(tensor) = pth.get(name)? {
829 tensors.push((name.to_string(), tensor))
830 }
831 }
832 Ok(tensors)
833}
834
835pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
840 read_all_with_key(path, None)
841}