hpt_dataloader/struct_save/
load.rs

1use std::{
2    fmt::Display,
3    fs::File,
4    io::{Read, Seek},
5    marker::PhantomData,
6};
7
8use flate2::read::{DeflateDecoder, GzDecoder, ZlibDecoder};
9use hpt_common::shape::shape::Shape;
10use hpt_traits::{CommonBounds, TensorInfo};
11use num::traits::FromBytes;
12
13use crate::{
14    data_loader::{parse_header_compressed, TensorMeta},
15    CPUTensorCreator, CompressionAlgo, Endian,
16};
17
18use super::save::Save;
19
20pub trait Load: Sized + Save {
21    fn load(path: &str) -> std::io::Result<Self>
22    where
23        <Self as Save>::Meta: MetaLoad<Output = Self>,
24    {
25        let meta = parse_header_compressed::<Self>(path).expect("failed to parse header");
26        let mut file = File::open(path)?;
27        meta.load(&mut file)
28    }
29}
30
31pub(crate) fn load<
32    'a,
33    T: CommonBounds + FromBytes<Bytes = [u8; N]>,
34    B: CPUTensorCreator<T>,
35    const N: usize,
36>(
37    file: &mut File,
38    meta: &TensorMeta<T, B>,
39) -> std::io::Result<<B as CPUTensorCreator<T>>::Output>
40where
41    <B as CPUTensorCreator<T>>::Output: Clone + TensorInfo<T> + Display,
42{
43    if meta.dtype != T::STR {
44        return Err(std::io::Error::new(
45            std::io::ErrorKind::InvalidInput,
46            format!(
47                "the dtype stored is {}, but the dtype requested is {}",
48                meta.dtype,
49                T::STR
50            ),
51        ));
52    }
53    // since the shape is scaled with mem_size, we need to scale it back
54    let shape = meta.shape.clone();
55    // create the tensor
56    let tensor = B::empty(&shape)
57        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
58
59    if tensor.size() == 0 {
60        return Ok(tensor);
61    }
62    #[allow(unused_mut)]
63    let mut res: &mut [T] =
64        unsafe { std::slice::from_raw_parts_mut(tensor.ptr().ptr, tensor.size()) };
65
66    let pack = get_pack_closure::<T, N>(meta.endian);
67
68    let mut res_idx = 0;
69    for (_, idx, compressed_len, current_block_mem_size) in meta.indices.iter() {
70        let uncompressed_data = uncompress_data(
71            file,
72            *idx as u64,
73            *compressed_len,
74            *current_block_mem_size,
75            meta.compression_algo,
76        )?;
77        for idx in (0..uncompressed_data.len()).step_by(std::mem::size_of::<T>()) {
78            let val = &uncompressed_data[idx..idx + std::mem::size_of::<T>()];
79            let data = pack(val);
80            res[res_idx] = data;
81            res_idx += 1;
82        }
83    }
84    Ok(tensor)
85}
86
87fn uncompress_data(
88    file: &mut File,
89    idx: u64,
90    compressed_len: usize,
91    block_mem_size: usize,
92    compression_type: CompressionAlgo,
93) -> std::io::Result<Vec<u8>> {
94    file.seek(std::io::SeekFrom::Start(idx))?;
95    let mut compressed_vec = vec![0u8; compressed_len];
96    file.read_exact(&mut compressed_vec)?;
97    let mut uncompressed_data = vec![0u8; block_mem_size];
98    match compression_type {
99        CompressionAlgo::Gzip => {
100            let mut decoder = GzDecoder::new(compressed_vec.as_slice());
101            decoder.read_exact(&mut uncompressed_data)?;
102        }
103        CompressionAlgo::Deflate => {
104            let mut decoder = DeflateDecoder::new(compressed_vec.as_slice());
105            decoder.read_exact(&mut uncompressed_data)?;
106        }
107        CompressionAlgo::Zlib => {
108            let mut decoder = ZlibDecoder::new(compressed_vec.as_slice());
109            decoder.read_exact(&mut uncompressed_data)?;
110        }
111        CompressionAlgo::NoCompression => {
112            uncompressed_data = compressed_vec;
113        }
114    }
115    Ok(uncompressed_data)
116}
117
118fn get_pack_closure<T: CommonBounds + FromBytes<Bytes = [u8; N]>, const N: usize>(
119    endian: Endian,
120) -> impl Fn(&[u8]) -> T {
121    match endian {
122        Endian::Little => |val: &[u8]| T::from_le_bytes(val.try_into().unwrap()),
123        Endian::Big => |val: &[u8]| T::from_be_bytes(val.try_into().unwrap()),
124        Endian::Native => |val: &[u8]| T::from_ne_bytes(val.try_into().unwrap()),
125    }
126}
127
128pub trait MetaLoad: Sized {
129    type Output;
130    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output>;
131}
132
133macro_rules! impl_load {
134    ($struct:ty, $default:expr $(, $generics:tt)*) => {
135        impl<$($generics)*> MetaLoad for $struct {
136            type Output = $struct;
137            fn load(&self, _: &mut std::fs::File) -> std::io::Result<Self::Output> {
138                Ok(self.clone())
139            }
140        }
141    };
142}
143
144impl_load!(bool, false);
145impl_load!(u8, 0);
146impl_load!(i8, 0);
147impl_load!(u16, 0);
148impl_load!(i16, 0);
149impl_load!(u32, 0);
150impl_load!(i32, 0);
151impl_load!(u64, 0);
152impl_load!(i64, 0);
153impl_load!(f32, 0.0);
154impl_load!(f64, 0.0);
155impl_load!(usize, 0);
156impl_load!(isize, 0);
157impl_load!(String, String::new());
158impl_load!(Shape, Shape::new([]));
159impl_load!(PhantomData<T>, Self, T);
160
161impl<T: MetaLoad> MetaLoad for Option<T> {
162    type Output = Option<T::Output>;
163    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
164        match self {
165            Some(x) => Ok(Some(T::load(x, file)?)),
166            None => Ok(None),
167        }
168    }
169}
170
171impl<T: MetaLoad> MetaLoad for Vec<T> {
172    type Output = Vec<T::Output>;
173    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
174        let mut res = Vec::with_capacity(self.len());
175        for i in 0..self.len() {
176            res.push(T::load(&self[i], file)?);
177        }
178        Ok(res)
179    }
180}
181
182impl<T: MetaLoad, const N: usize> MetaLoad for [T; N] {
183    type Output = [T::Output; N];
184    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
185        let mut arr: [std::mem::MaybeUninit<T::Output>; N] =
186            unsafe { std::mem::MaybeUninit::uninit().assume_init() };
187        for i in 0..N {
188            arr[i] = std::mem::MaybeUninit::new(T::load(&self[i], file)?);
189        }
190        Ok(unsafe {
191            let ptr = &arr as *const _ as *const [T::Output; N];
192            ptr.read()
193        })
194    }
195}