hpt_dataloader/struct_save/
load.rs1use std::{
2 fmt::Display,
3 fs::File,
4 io::{Read, Seek},
5 marker::PhantomData,
6};
7
8use flate2::read::{DeflateDecoder, GzDecoder, ZlibDecoder};
9use hpt_common::shape::shape::Shape;
10use hpt_traits::tensor::{CommonBounds, TensorInfo};
11
12use crate::{
13 data_loader::{parse_header_compressed, TensorMeta},
14 CPUTensorCreator, CompressionAlgo,
15};
16
17use super::save::Save;
18
19pub trait Load: Sized + Save {
20 fn load<P: Into<std::path::PathBuf>>(path: P) -> std::io::Result<Self>
21 where
22 <Self as Save>::Meta: MetaLoad<Output = Self>,
23 {
24 let path: std::path::PathBuf = path.into();
25 let meta = parse_header_compressed::<Self, _>(&path).expect("failed to parse header");
26 let mut file = File::open(path)?;
27 meta.load(&mut file)
28 }
29}
30
31pub(crate) fn load<'a, T: CommonBounds + bytemuck::AnyBitPattern, B: CPUTensorCreator>(
32 file: &mut File,
33 meta: &TensorMeta<T, B>,
34) -> std::io::Result<<B as CPUTensorCreator>::Output>
35where
36 <B as CPUTensorCreator>::Output: Clone + TensorInfo<T> + Display,
37{
38 if meta.dtype != T::STR {
39 return Err(std::io::Error::new(
40 std::io::ErrorKind::InvalidInput,
41 format!(
42 "the dtype stored is {}, but the dtype requested is {}",
43 meta.dtype,
44 T::STR
45 ),
46 ));
47 }
48 let shape = meta.shape.clone();
50 let tensor = B::empty(&shape)
52 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
53
54 if tensor.size() == 0 {
55 return Ok(tensor);
56 }
57 #[allow(unused_mut)]
58 let mut res: &mut [T] =
59 unsafe { std::slice::from_raw_parts_mut(tensor.ptr().ptr, tensor.size()) };
60
61 let mut res_idx = 0;
62 for (_, idx, compressed_len, current_block_mem_size) in meta.indices.iter() {
63 let uncompressed_data = uncompress_data(
64 file,
65 *idx as u64,
66 *compressed_len,
67 *current_block_mem_size,
68 meta.compression_algo,
69 )?;
70 for idx in (0..uncompressed_data.len()).step_by(std::mem::size_of::<T>()) {
71 let val = &uncompressed_data[idx..idx + std::mem::size_of::<T>()];
72 res[res_idx] = *bytemuck::from_bytes(val);
73 res_idx += 1;
74 }
75 }
76 Ok(tensor)
77}
78
79fn uncompress_data(
80 file: &mut File,
81 idx: u64,
82 compressed_len: usize,
83 block_mem_size: usize,
84 compression_type: CompressionAlgo,
85) -> std::io::Result<Vec<u8>> {
86 file.seek(std::io::SeekFrom::Start(idx))?;
87 let mut compressed_vec = vec![0u8; compressed_len];
88 file.read_exact(&mut compressed_vec)?;
89 let mut uncompressed_data = vec![0u8; block_mem_size];
90 match compression_type {
91 CompressionAlgo::Gzip => {
92 let mut decoder = GzDecoder::new(compressed_vec.as_slice());
93 decoder.read_exact(&mut uncompressed_data)?;
94 }
95 CompressionAlgo::Deflate => {
96 let mut decoder = DeflateDecoder::new(compressed_vec.as_slice());
97 decoder.read_exact(&mut uncompressed_data)?;
98 }
99 CompressionAlgo::Zlib => {
100 let mut decoder = ZlibDecoder::new(compressed_vec.as_slice());
101 decoder.read_exact(&mut uncompressed_data)?;
102 }
103 CompressionAlgo::NoCompression => {
104 uncompressed_data = compressed_vec;
105 }
106 }
107 Ok(uncompressed_data)
108}
109
110pub trait MetaLoad: Sized {
111 type Output;
112 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output>;
113}
114
115macro_rules! impl_load {
116 ($struct:ty, $default:expr $(, $generics:tt)*) => {
117 impl<$($generics)*> MetaLoad for $struct {
118 type Output = $struct;
119 fn load(&self, _: &mut std::fs::File) -> std::io::Result<Self::Output> {
120 Ok(self.clone())
121 }
122 }
123 };
124}
125
126impl_load!(bool, false);
127impl_load!(u8, 0);
128impl_load!(i8, 0);
129impl_load!(u16, 0);
130impl_load!(i16, 0);
131impl_load!(u32, 0);
132impl_load!(i32, 0);
133impl_load!(u64, 0);
134impl_load!(i64, 0);
135impl_load!(f32, 0.0);
136impl_load!(f64, 0.0);
137impl_load!(usize, 0);
138impl_load!(isize, 0);
139impl_load!(half::f16, half::f16::ZERO);
140impl_load!(half::bf16, half::bf16::ZERO);
141impl_load!(String, String::new());
142impl_load!(Shape, Shape::new([]));
143impl_load!(PhantomData<T>, Self, T);
144
145impl<T: MetaLoad> MetaLoad for Option<T> {
146 type Output = Option<T::Output>;
147 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
148 match self {
149 Some(x) => Ok(Some(T::load(x, file)?)),
150 None => Ok(None),
151 }
152 }
153}
154
155impl<T: MetaLoad> MetaLoad for Vec<T> {
156 type Output = Vec<T::Output>;
157 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
158 let mut res = Vec::with_capacity(self.len());
159 for i in 0..self.len() {
160 res.push(T::load(&self[i], file)?);
161 }
162 Ok(res)
163 }
164}
165
166impl<T: MetaLoad, const N: usize> MetaLoad for [T; N] {
167 type Output = [T::Output; N];
168 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
169 let mut arr: [std::mem::MaybeUninit<T::Output>; N] =
170 unsafe { std::mem::MaybeUninit::uninit().assume_init() };
171 for i in 0..N {
172 arr[i] = std::mem::MaybeUninit::new(T::load(&self[i], file)?);
173 }
174 Ok(unsafe {
175 let ptr = &arr as *const _ as *const [T::Output; N];
176 ptr.read()
177 })
178 }
179}