hpt_dataloader/struct_save/
load.rs

1use std::{
2    fmt::Display,
3    fs::File,
4    io::{Read, Seek},
5    marker::PhantomData,
6};
7
8use flate2::read::{DeflateDecoder, GzDecoder, ZlibDecoder};
9use hpt_common::shape::shape::Shape;
10use hpt_traits::tensor::{CommonBounds, TensorInfo};
11
12use crate::{
13    data_loader::{parse_header_compressed, TensorMeta},
14    CPUTensorCreator, CompressionAlgo,
15};
16
17use super::save::Save;
18
19pub trait Load: Sized + Save {
20    fn load<P: Into<std::path::PathBuf>>(path: P) -> std::io::Result<Self>
21    where
22        <Self as Save>::Meta: MetaLoad<Output = Self>,
23    {
24        let path: std::path::PathBuf = path.into();
25        let meta = parse_header_compressed::<Self, _>(&path).expect("failed to parse header");
26        let mut file = File::open(path)?;
27        meta.load(&mut file)
28    }
29}
30
31pub(crate) fn load<'a, T: CommonBounds + bytemuck::AnyBitPattern, B: CPUTensorCreator>(
32    file: &mut File,
33    meta: &TensorMeta<T, B>,
34) -> std::io::Result<<B as CPUTensorCreator>::Output>
35where
36    <B as CPUTensorCreator>::Output: Clone + TensorInfo<T> + Display,
37{
38    if meta.dtype != T::STR {
39        return Err(std::io::Error::new(
40            std::io::ErrorKind::InvalidInput,
41            format!(
42                "the dtype stored is {}, but the dtype requested is {}",
43                meta.dtype,
44                T::STR
45            ),
46        ));
47    }
48    // since the shape is scaled with mem_size, we need to scale it back
49    let shape = meta.shape.clone();
50    // create the tensor
51    let tensor = B::empty(&shape)
52        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
53
54    if tensor.size() == 0 {
55        return Ok(tensor);
56    }
57    #[allow(unused_mut)]
58    let mut res: &mut [T] =
59        unsafe { std::slice::from_raw_parts_mut(tensor.ptr().ptr, tensor.size()) };
60
61    let mut res_idx = 0;
62    for (_, idx, compressed_len, current_block_mem_size) in meta.indices.iter() {
63        let uncompressed_data = uncompress_data(
64            file,
65            *idx as u64,
66            *compressed_len,
67            *current_block_mem_size,
68            meta.compression_algo,
69        )?;
70        for idx in (0..uncompressed_data.len()).step_by(std::mem::size_of::<T>()) {
71            let val = &uncompressed_data[idx..idx + std::mem::size_of::<T>()];
72            res[res_idx] = *bytemuck::from_bytes(val);
73            res_idx += 1;
74        }
75    }
76    Ok(tensor)
77}
78
79fn uncompress_data(
80    file: &mut File,
81    idx: u64,
82    compressed_len: usize,
83    block_mem_size: usize,
84    compression_type: CompressionAlgo,
85) -> std::io::Result<Vec<u8>> {
86    file.seek(std::io::SeekFrom::Start(idx))?;
87    let mut compressed_vec = vec![0u8; compressed_len];
88    file.read_exact(&mut compressed_vec)?;
89    let mut uncompressed_data = vec![0u8; block_mem_size];
90    match compression_type {
91        CompressionAlgo::Gzip => {
92            let mut decoder = GzDecoder::new(compressed_vec.as_slice());
93            decoder.read_exact(&mut uncompressed_data)?;
94        }
95        CompressionAlgo::Deflate => {
96            let mut decoder = DeflateDecoder::new(compressed_vec.as_slice());
97            decoder.read_exact(&mut uncompressed_data)?;
98        }
99        CompressionAlgo::Zlib => {
100            let mut decoder = ZlibDecoder::new(compressed_vec.as_slice());
101            decoder.read_exact(&mut uncompressed_data)?;
102        }
103        CompressionAlgo::NoCompression => {
104            uncompressed_data = compressed_vec;
105        }
106    }
107    Ok(uncompressed_data)
108}
109
110pub trait MetaLoad: Sized {
111    type Output;
112    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output>;
113}
114
115macro_rules! impl_load {
116    ($struct:ty, $default:expr $(, $generics:tt)*) => {
117        impl<$($generics)*> MetaLoad for $struct {
118            type Output = $struct;
119            fn load(&self, _: &mut std::fs::File) -> std::io::Result<Self::Output> {
120                Ok(self.clone())
121            }
122        }
123    };
124}
125
126impl_load!(bool, false);
127impl_load!(u8, 0);
128impl_load!(i8, 0);
129impl_load!(u16, 0);
130impl_load!(i16, 0);
131impl_load!(u32, 0);
132impl_load!(i32, 0);
133impl_load!(u64, 0);
134impl_load!(i64, 0);
135impl_load!(f32, 0.0);
136impl_load!(f64, 0.0);
137impl_load!(usize, 0);
138impl_load!(isize, 0);
139impl_load!(half::f16, half::f16::ZERO);
140impl_load!(half::bf16, half::bf16::ZERO);
141impl_load!(String, String::new());
142impl_load!(Shape, Shape::new([]));
143impl_load!(PhantomData<T>, Self, T);
144
145impl<T: MetaLoad> MetaLoad for Option<T> {
146    type Output = Option<T::Output>;
147    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
148        match self {
149            Some(x) => Ok(Some(T::load(x, file)?)),
150            None => Ok(None),
151        }
152    }
153}
154
155impl<T: MetaLoad> MetaLoad for Vec<T> {
156    type Output = Vec<T::Output>;
157    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
158        let mut res = Vec::with_capacity(self.len());
159        for i in 0..self.len() {
160            res.push(T::load(&self[i], file)?);
161        }
162        Ok(res)
163    }
164}
165
166impl<T: MetaLoad, const N: usize> MetaLoad for [T; N] {
167    type Output = [T::Output; N];
168    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
169        let mut arr: [std::mem::MaybeUninit<T::Output>; N] =
170            unsafe { std::mem::MaybeUninit::uninit().assume_init() };
171        for i in 0..N {
172            arr[i] = std::mem::MaybeUninit::new(T::load(&self[i], file)?);
173        }
174        Ok(unsafe {
175            let ptr = &arr as *const _ as *const [T::Output; N];
176            ptr.read()
177        })
178    }
179}