hpt_dataloader/
compression_trait.rs

1use std::{collections::HashMap, io::Write};
2
3use hpt_common::{shape::shape::Shape, slice::Slice, strides::strides::Strides};
4use hpt_traits::{CommonBounds, TensorInfo};
5use num::traits::{FromBytes, ToBytes};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    data_loader::{Endian, HeaderInfo},
10    load::load_compressed_slice,
11    save::save,
12    CPUTensorCreator,
13};
14
15pub trait CompressionTrait {
16    fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()>;
17    fn flush_all(&mut self) -> std::io::Result<()>;
18    fn finish_all(self) -> std::io::Result<Vec<u8>>;
19}
20
21macro_rules! impl_compression_trait {
22    ($([$name:expr, $($t:ty),*]),*) => {
23        $(
24            impl CompressionTrait for $($t)* {
25                fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()> {
26                    self.write_all(buf)
27                }
28                fn flush_all(&mut self) -> std::io::Result<()> {
29                    self.flush()
30                }
31                fn finish_all(self) -> std::io::Result<Vec<u8>> {
32                    self.finish()
33                }
34            }
35        )*
36    };
37}
38
39impl_compression_trait!(
40    ["gzip", flate2::write::GzEncoder<Vec<u8>>],
41    ["deflate", flate2::write::DeflateEncoder<Vec<u8>>],
42    ["zlib", flate2::write::ZlibEncoder<Vec<u8>>]
43);
44
45pub trait DataLoaderTrait {
46    fn shape(&self) -> &Shape;
47    fn strides(&self) -> &Strides;
48    fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]);
49    fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]);
50    fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]);
51    fn offset(&mut self, offset: isize);
52    fn size(&self) -> usize;
53    fn dtype(&self) -> &'static str;
54    fn mem_size(&self) -> usize;
55}
56
57pub struct DataLoader<T> {
58    pub(crate) shape: Shape,
59    pub(crate) strides: Strides,
60    pub(crate) data: *const T,
61}
62
63impl<T> DataLoader<T> {
64    pub fn new(shape: Shape, strides: Strides, data: *const T) -> Self {
65        Self {
66            shape,
67            strides,
68            data,
69        }
70    }
71}
72
73impl<T, const N: usize> DataLoaderTrait for DataLoader<T>
74where
75    T: CommonBounds + ToBytes<Bytes = [u8; N]>,
76{
77    fn shape(&self) -> &Shape {
78        &self.shape
79    }
80
81    fn strides(&self) -> &Strides {
82        &self.strides
83    }
84
85    fn offset(&mut self, offset: isize) {
86        self.data = unsafe { self.data.offset(offset) };
87    }
88
89    fn size(&self) -> usize {
90        self.shape.size() as usize
91    }
92
93    fn mem_size(&self) -> usize {
94        std::mem::size_of::<T>()
95    }
96
97    fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
98        let val = unsafe { self.data.offset(offset).read() };
99        writer.copy_from_slice(&val.to_ne_bytes());
100    }
101
102    fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
103        let val = unsafe { self.data.offset(offset).read() };
104        writer.copy_from_slice(&val.to_be_bytes());
105    }
106
107    fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
108        let val = unsafe { self.data.offset(offset).read() };
109        writer.copy_from_slice(&val.to_le_bytes());
110    }
111
112    fn dtype(&self) -> &'static str {
113        T::STR
114    }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118pub enum CompressionAlgo {
119    Gzip,
120    Deflate,
121    Zlib,
122    NoCompression,
123}
124
125pub struct Meta {
126    pub name: String,
127    pub compression_algo: CompressionAlgo,
128    pub endian: Endian,
129    pub data_saver: Box<dyn DataLoaderTrait>,
130    pub compression_level: u32,
131}
132
133pub struct TensorSaver {
134    file_path: std::path::PathBuf,
135    to_saves: Option<Vec<Meta>>,
136}
137
138impl TensorSaver {
139    pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
140        Self {
141            file_path: file_path.into(),
142            to_saves: None,
143        }
144    }
145
146    pub fn push<
147        const N: usize,
148        T: ToBytes<Bytes = [u8; N]> + CommonBounds,
149        A: Into<DataLoader<T>>,
150    >(
151        mut self,
152        name: &str,
153        tensor: A,
154        compression_algo: CompressionAlgo,
155        endian: Endian,
156        compression_level: u32,
157    ) -> Self {
158        let data_loader = tensor.into();
159        let meta = Meta {
160            name: name.to_string(),
161            compression_algo,
162            endian,
163            data_saver: Box::new(data_loader),
164            compression_level,
165        };
166        if let Some(to_saves) = &mut self.to_saves {
167            to_saves.push(meta);
168        } else {
169            self.to_saves = Some(vec![meta]);
170        }
171        self
172    }
173
174    pub fn save(self) -> std::io::Result<()> {
175        save(
176            self.file_path.to_str().unwrap().into(),
177            self.to_saves.unwrap(),
178        )
179    }
180}
181
182pub struct TensorLoader {
183    file_path: std::path::PathBuf,
184    to_loads: Option<Vec<(String, Vec<Slice>)>>,
185}
186
187impl TensorLoader {
188    pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
189        Self {
190            file_path: file_path.into(),
191            to_loads: None,
192        }
193    }
194
195    pub fn push(mut self, name: &str, slices: &[Slice]) -> Self {
196        if let Some(to_loads) = &mut self.to_loads {
197            to_loads.push((name.to_string(), slices.to_vec()));
198        } else {
199            self.to_loads = Some(vec![(name.to_string(), slices.to_vec())]);
200        }
201        self
202    }
203
204    pub fn load<T, B, const N: usize>(self) -> std::io::Result<HashMap<String, B>>
205    where
206        T: CommonBounds + FromBytes<Bytes = [u8; N]>,
207        B: CPUTensorCreator<T>,
208        <B as CPUTensorCreator<T>>::Output: Into<B> + TensorInfo<T>,
209    {
210        let res = load_compressed_slice::<T, B, N>(
211            self.file_path.to_str().unwrap().into(),
212            self.to_loads.expect("no tensors to load"),
213        )
214        .expect("failed to load tensor");
215        Ok(res)
216    }
217
218    pub fn load_all<T, B, const N: usize>(self) -> std::io::Result<HashMap<String, B>>
219    where
220        T: CommonBounds + FromBytes<Bytes = [u8; N]>,
221        B: CPUTensorCreator<T>,
222        <B as CPUTensorCreator<T>>::Output: Into<B> + TensorInfo<T>,
223    {
224        let res = HeaderInfo::parse_header_compressed(self.file_path.to_str().unwrap().into())
225            .expect("failed to parse header");
226        let res = load_compressed_slice::<T, B, N>(
227            self.file_path.to_str().unwrap().into(),
228            res.into_values().map(|x| (x.name, vec![])).collect(),
229        )
230        .expect("failed to load tensor");
231        Ok(res)
232    }
233}