hpt_dataloader/struct_save/
load.rs1use std::{
2 fmt::Display,
3 fs::File,
4 io::{Read, Seek},
5 marker::PhantomData,
6};
7
8use flate2::read::{DeflateDecoder, GzDecoder, ZlibDecoder};
9use hpt_common::shape::shape::Shape;
10use hpt_traits::{CommonBounds, TensorCreator, TensorInfo};
11use num::traits::FromBytes;
12
13use crate::{
14 data_loader::{parse_header_compressed, TensorMeta},
15 CompressionAlgo, Endian,
16};
17
18use super::save::Save;
19
20pub trait Load: Sized + Save {
21 fn load(path: &str) -> std::io::Result<Self>
22 where
23 <Self as Save>::Meta: MetaLoad<Output = Self>,
24 {
25 let meta = parse_header_compressed::<Self>(path).expect("failed to parse header");
26 let mut file = File::open(path)?;
27 meta.load(&mut file)
28 }
29}
30
31pub(crate) fn load<
32 'a,
33 T: CommonBounds + FromBytes<Bytes = [u8; N]>,
34 B: TensorCreator<T, Output = B> + Clone + TensorInfo<T> + Display,
35 const N: usize,
36>(
37 file: &mut File,
38 meta: &TensorMeta<T, B>,
39) -> std::io::Result<B> {
40 if meta.dtype != T::STR {
41 return Err(std::io::Error::new(
42 std::io::ErrorKind::InvalidInput,
43 format!(
44 "the dtype stored is {}, but the dtype requested is {}",
45 meta.dtype,
46 T::STR
47 ),
48 ));
49 }
50 let shape = meta.shape.clone();
52 let tensor: B = B::empty(&shape)
54 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
55
56 if tensor.size() == 0 {
57 return Ok(tensor);
58 }
59 #[allow(unused_mut)]
60 let mut res: &mut [T] =
61 unsafe { std::slice::from_raw_parts_mut(tensor.ptr().ptr, tensor.size()) };
62
63 let pack = get_pack_closure::<T, N>(meta.endian);
64
65 let mut res_idx = 0;
66 for (_, idx, compressed_len, current_block_mem_size) in meta.indices.iter() {
67 let uncompressed_data = uncompress_data(
68 file,
69 *idx as u64,
70 *compressed_len,
71 *current_block_mem_size,
72 meta.compression_algo,
73 )?;
74 for idx in (0..uncompressed_data.len()).step_by(std::mem::size_of::<T>()) {
75 let val = &uncompressed_data[idx..idx + std::mem::size_of::<T>()];
76 let data = pack(val);
77 res[res_idx] = data;
78 res_idx += 1;
79 }
80 }
81 Ok(tensor)
82}
83
84fn uncompress_data(
85 file: &mut File,
86 idx: u64,
87 compressed_len: usize,
88 block_mem_size: usize,
89 compression_type: CompressionAlgo,
90) -> std::io::Result<Vec<u8>> {
91 file.seek(std::io::SeekFrom::Start(idx))?;
92 let mut compressed_vec = vec![0u8; compressed_len];
93 file.read_exact(&mut compressed_vec)?;
94 let mut uncompressed_data = vec![0u8; block_mem_size];
95 match compression_type {
96 CompressionAlgo::Gzip => {
97 let mut decoder = GzDecoder::new(compressed_vec.as_slice());
98 decoder.read_exact(&mut uncompressed_data)?;
99 }
100 CompressionAlgo::Deflate => {
101 let mut decoder = DeflateDecoder::new(compressed_vec.as_slice());
102 decoder.read_exact(&mut uncompressed_data)?;
103 }
104 CompressionAlgo::Zlib => {
105 let mut decoder = ZlibDecoder::new(compressed_vec.as_slice());
106 decoder.read_exact(&mut uncompressed_data)?;
107 }
108 CompressionAlgo::NoCompression => {
109 uncompressed_data = compressed_vec;
110 }
111 }
112 Ok(uncompressed_data)
113}
114
115fn get_pack_closure<T: CommonBounds + FromBytes<Bytes = [u8; N]>, const N: usize>(
116 endian: Endian,
117) -> impl Fn(&[u8]) -> T {
118 match endian {
119 Endian::Little => |val: &[u8]| T::from_le_bytes(val.try_into().unwrap()),
120 Endian::Big => |val: &[u8]| T::from_be_bytes(val.try_into().unwrap()),
121 Endian::Native => |val: &[u8]| T::from_ne_bytes(val.try_into().unwrap()),
122 }
123}
124
125pub trait MetaLoad: Sized {
126 type Output;
127 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output>;
128}
129
130macro_rules! impl_load {
131 ($struct:ty, $default:expr $(, $generics:tt)*) => {
132 impl<$($generics)*> MetaLoad for $struct {
133 type Output = $struct;
134 fn load(&self, _: &mut std::fs::File) -> std::io::Result<Self::Output> {
135 Ok(self.clone())
136 }
137 }
138 };
139}
140
141impl_load!(bool, false);
142impl_load!(u8, 0);
143impl_load!(i8, 0);
144impl_load!(u16, 0);
145impl_load!(i16, 0);
146impl_load!(u32, 0);
147impl_load!(i32, 0);
148impl_load!(u64, 0);
149impl_load!(i64, 0);
150impl_load!(f32, 0.0);
151impl_load!(f64, 0.0);
152impl_load!(usize, 0);
153impl_load!(isize, 0);
154impl_load!(String, String::new());
155impl_load!(Shape, Shape::new([]));
156impl_load!(PhantomData<T>, Self, T);
157
158impl<T: MetaLoad> MetaLoad for Option<T> {
159 type Output = Option<T::Output>;
160 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
161 match self {
162 Some(x) => Ok(Some(T::load(x, file)?)),
163 None => Ok(None),
164 }
165 }
166}
167
168impl<T: MetaLoad> MetaLoad for Vec<T> {
169 type Output = Vec<T::Output>;
170 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
171 let mut res = Vec::with_capacity(self.len());
172 for i in 0..self.len() {
173 res.push(T::load(&self[i], file)?);
174 }
175 Ok(res)
176 }
177}
178
179impl<T: MetaLoad, const N: usize> MetaLoad for [T; N] {
180 type Output = [T::Output; N];
181 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
182 let mut arr: [std::mem::MaybeUninit<T::Output>; N] =
183 unsafe { std::mem::MaybeUninit::uninit().assume_init() };
184 for i in 0..N {
185 arr[i] = std::mem::MaybeUninit::new(T::load(&self[i], file)?);
186 }
187 Ok(unsafe {
188 let ptr = &arr as *const _ as *const [T::Output; N];
189 ptr.read()
190 })
191 }
192}