hpt_dataloader/
compression_trait.rs

1use std::{collections::HashMap, io::Write};
2
3use hpt_common::{shape::shape::Shape, slice::Slice, strides::strides::Strides};
4use hpt_traits::{CommonBounds, TensorCreator, TensorInfo};
5use num::traits::{FromBytes, ToBytes};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    data_loader::{Endian, HeaderInfo},
10    load::load_compressed_slice,
11    save::save,
12};
13
14pub trait CompressionTrait {
15    fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()>;
16    fn flush_all(&mut self) -> std::io::Result<()>;
17    fn finish_all(self) -> std::io::Result<Vec<u8>>;
18}
19
20macro_rules! impl_compression_trait {
21    ($([$name:expr, $($t:ty),*]),*) => {
22        $(
23            impl CompressionTrait for $($t)* {
24                fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()> {
25                    self.write_all(buf)
26                }
27                fn flush_all(&mut self) -> std::io::Result<()> {
28                    self.flush()
29                }
30                fn finish_all(self) -> std::io::Result<Vec<u8>> {
31                    self.finish()
32                }
33            }
34        )*
35    };
36}
37
38impl_compression_trait!(
39    ["gzip", flate2::write::GzEncoder<Vec<u8>>],
40    ["deflate", flate2::write::DeflateEncoder<Vec<u8>>],
41    ["zlib", flate2::write::ZlibEncoder<Vec<u8>>]
42);
43
44pub trait DataLoaderTrait {
45    fn shape(&self) -> &Shape;
46    fn strides(&self) -> &Strides;
47    fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]);
48    fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]);
49    fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]);
50    fn offset(&mut self, offset: isize);
51    fn size(&self) -> usize;
52    fn dtype(&self) -> &'static str;
53    fn mem_size(&self) -> usize;
54}
55
56pub struct DataLoader<T> {
57    pub(crate) shape: Shape,
58    pub(crate) strides: Strides,
59    pub(crate) data: *const T,
60}
61
62impl<T> DataLoader<T> {
63    pub fn new(shape: Shape, strides: Strides, data: *const T) -> Self {
64        Self {
65            shape,
66            strides,
67            data,
68        }
69    }
70}
71
72impl<T, const N: usize> DataLoaderTrait for DataLoader<T>
73where
74    T: CommonBounds + ToBytes<Bytes = [u8; N]>,
75{
76    fn shape(&self) -> &Shape {
77        &self.shape
78    }
79
80    fn strides(&self) -> &Strides {
81        &self.strides
82    }
83
84    fn offset(&mut self, offset: isize) {
85        self.data = unsafe { self.data.offset(offset) };
86    }
87
88    fn size(&self) -> usize {
89        self.shape.size() as usize
90    }
91
92    fn mem_size(&self) -> usize {
93        std::mem::size_of::<T>()
94    }
95
96    fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
97        let val = unsafe { self.data.offset(offset).read() };
98        writer.copy_from_slice(&val.to_ne_bytes());
99    }
100
101    fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
102        let val = unsafe { self.data.offset(offset).read() };
103        writer.copy_from_slice(&val.to_be_bytes());
104    }
105
106    fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
107        let val = unsafe { self.data.offset(offset).read() };
108        writer.copy_from_slice(&val.to_le_bytes());
109    }
110
111    fn dtype(&self) -> &'static str {
112        T::STR
113    }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117pub enum CompressionAlgo {
118    Gzip,
119    Deflate,
120    Zlib,
121    NoCompression,
122}
123
124pub struct Meta {
125    pub name: String,
126    pub compression_algo: CompressionAlgo,
127    pub endian: Endian,
128    pub data_saver: Box<dyn DataLoaderTrait>,
129    pub compression_level: u32,
130}
131
132pub struct TensorSaver {
133    file_path: std::path::PathBuf,
134    to_saves: Option<Vec<Meta>>,
135}
136
137impl TensorSaver {
138    pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
139        Self {
140            file_path: file_path.into(),
141            to_saves: None,
142        }
143    }
144
145    pub fn push<
146        const N: usize,
147        T: ToBytes<Bytes = [u8; N]> + CommonBounds,
148        A: Into<DataLoader<T>>,
149    >(
150        mut self,
151        name: &str,
152        tensor: A,
153        compression_algo: CompressionAlgo,
154        endian: Endian,
155        compression_level: u32,
156    ) -> Self {
157        let data_loader = tensor.into();
158        let meta = Meta {
159            name: name.to_string(),
160            compression_algo,
161            endian,
162            data_saver: Box::new(data_loader),
163            compression_level,
164        };
165        if let Some(to_saves) = &mut self.to_saves {
166            to_saves.push(meta);
167        } else {
168            self.to_saves = Some(vec![meta]);
169        }
170        self
171    }
172
173    pub fn save(self) -> std::io::Result<()> {
174        save(
175            self.file_path.to_str().unwrap().into(),
176            self.to_saves.unwrap(),
177        )
178    }
179}
180
181pub struct TensorLoader {
182    file_path: std::path::PathBuf,
183    to_loads: Option<Vec<(String, Vec<Slice>)>>,
184}
185
186impl TensorLoader {
187    pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
188        Self {
189            file_path: file_path.into(),
190            to_loads: None,
191        }
192    }
193
194    pub fn push(mut self, name: &str, slices: &[Slice]) -> Self {
195        if let Some(to_loads) = &mut self.to_loads {
196            to_loads.push((name.to_string(), slices.to_vec()));
197        } else {
198            self.to_loads = Some(vec![(name.to_string(), slices.to_vec())]);
199        }
200        self
201    }
202
203    pub fn load<T, B, const N: usize>(self) -> std::io::Result<HashMap<String, B>>
204    where
205        T: CommonBounds + FromBytes<Bytes = [u8; N]>,
206        B: TensorCreator<T, Output = B> + Clone + TensorInfo<T>,
207    {
208        let res = load_compressed_slice::<T, B, N>(
209            self.file_path.to_str().unwrap().into(),
210            self.to_loads.expect("no tensors to load"),
211        )
212        .expect("failed to load tensor");
213        Ok(res)
214    }
215
216    pub fn load_all<T, B, const N: usize>(self) -> std::io::Result<HashMap<String, B>>
217    where
218        T: CommonBounds + FromBytes<Bytes = [u8; N]>,
219        B: TensorCreator<T, Output = B> + Clone + TensorInfo<T>,
220    {
221        let res = HeaderInfo::parse_header_compressed(self.file_path.to_str().unwrap().into())
222            .expect("failed to parse header");
223        let res = load_compressed_slice::<T, B, N>(
224            self.file_path.to_str().unwrap().into(),
225            res.into_values().map(|x| (x.name, vec![])).collect(),
226        )
227        .expect("failed to load tensor");
228        Ok(res)
229    }
230}