1use std::{collections::HashMap, io::Write};
2
3use hpt_common::{shape::shape::Shape, slice::Slice, strides::strides::Strides};
4use hpt_traits::{CommonBounds, TensorInfo};
5use num::traits::{FromBytes, ToBytes};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 data_loader::{Endian, HeaderInfo},
10 load::load_compressed_slice,
11 save::save,
12 CPUTensorCreator,
13};
14
15pub trait CompressionTrait {
16 fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()>;
17 fn flush_all(&mut self) -> std::io::Result<()>;
18 fn finish_all(self) -> std::io::Result<Vec<u8>>;
19}
20
21macro_rules! impl_compression_trait {
22 ($([$name:expr, $($t:ty),*]),*) => {
23 $(
24 impl CompressionTrait for $($t)* {
25 fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()> {
26 self.write_all(buf)
27 }
28 fn flush_all(&mut self) -> std::io::Result<()> {
29 self.flush()
30 }
31 fn finish_all(self) -> std::io::Result<Vec<u8>> {
32 self.finish()
33 }
34 }
35 )*
36 };
37}
38
39impl_compression_trait!(
40 ["gzip", flate2::write::GzEncoder<Vec<u8>>],
41 ["deflate", flate2::write::DeflateEncoder<Vec<u8>>],
42 ["zlib", flate2::write::ZlibEncoder<Vec<u8>>]
43);
44
45pub trait DataLoaderTrait {
46 fn shape(&self) -> &Shape;
47 fn strides(&self) -> &Strides;
48 fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]);
49 fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]);
50 fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]);
51 fn offset(&mut self, offset: isize);
52 fn size(&self) -> usize;
53 fn dtype(&self) -> &'static str;
54 fn mem_size(&self) -> usize;
55}
56
57pub struct DataLoader<T> {
58 pub(crate) shape: Shape,
59 pub(crate) strides: Strides,
60 pub(crate) data: *const T,
61}
62
63impl<T> DataLoader<T> {
64 pub fn new(shape: Shape, strides: Strides, data: *const T) -> Self {
65 Self {
66 shape,
67 strides,
68 data,
69 }
70 }
71}
72
73impl<T, const N: usize> DataLoaderTrait for DataLoader<T>
74where
75 T: CommonBounds + ToBytes<Bytes = [u8; N]>,
76{
77 fn shape(&self) -> &Shape {
78 &self.shape
79 }
80
81 fn strides(&self) -> &Strides {
82 &self.strides
83 }
84
85 fn offset(&mut self, offset: isize) {
86 self.data = unsafe { self.data.offset(offset) };
87 }
88
89 fn size(&self) -> usize {
90 self.shape.size() as usize
91 }
92
93 fn mem_size(&self) -> usize {
94 std::mem::size_of::<T>()
95 }
96
97 fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
98 let val = unsafe { self.data.offset(offset).read() };
99 writer.copy_from_slice(&val.to_ne_bytes());
100 }
101
102 fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
103 let val = unsafe { self.data.offset(offset).read() };
104 writer.copy_from_slice(&val.to_be_bytes());
105 }
106
107 fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
108 let val = unsafe { self.data.offset(offset).read() };
109 writer.copy_from_slice(&val.to_le_bytes());
110 }
111
112 fn dtype(&self) -> &'static str {
113 T::STR
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118pub enum CompressionAlgo {
119 Gzip,
120 Deflate,
121 Zlib,
122 NoCompression,
123}
124
125pub struct Meta {
126 pub name: String,
127 pub compression_algo: CompressionAlgo,
128 pub endian: Endian,
129 pub data_saver: Box<dyn DataLoaderTrait>,
130 pub compression_level: u32,
131}
132
133pub struct TensorSaver {
134 file_path: std::path::PathBuf,
135 to_saves: Option<Vec<Meta>>,
136}
137
138impl TensorSaver {
139 pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
140 Self {
141 file_path: file_path.into(),
142 to_saves: None,
143 }
144 }
145
146 pub fn push<
147 const N: usize,
148 T: ToBytes<Bytes = [u8; N]> + CommonBounds,
149 A: Into<DataLoader<T>>,
150 >(
151 mut self,
152 name: &str,
153 tensor: A,
154 compression_algo: CompressionAlgo,
155 endian: Endian,
156 compression_level: u32,
157 ) -> Self {
158 let data_loader = tensor.into();
159 let meta = Meta {
160 name: name.to_string(),
161 compression_algo,
162 endian,
163 data_saver: Box::new(data_loader),
164 compression_level,
165 };
166 if let Some(to_saves) = &mut self.to_saves {
167 to_saves.push(meta);
168 } else {
169 self.to_saves = Some(vec![meta]);
170 }
171 self
172 }
173
174 pub fn save(self) -> std::io::Result<()> {
175 save(
176 self.file_path.to_str().unwrap().into(),
177 self.to_saves.unwrap(),
178 )
179 }
180}
181
182pub struct TensorLoader {
183 file_path: std::path::PathBuf,
184 to_loads: Option<Vec<(String, Vec<Slice>)>>,
185}
186
187impl TensorLoader {
188 pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
189 Self {
190 file_path: file_path.into(),
191 to_loads: None,
192 }
193 }
194
195 pub fn push(mut self, name: &str, slices: &[Slice]) -> Self {
196 if let Some(to_loads) = &mut self.to_loads {
197 to_loads.push((name.to_string(), slices.to_vec()));
198 } else {
199 self.to_loads = Some(vec![(name.to_string(), slices.to_vec())]);
200 }
201 self
202 }
203
204 pub fn load<T, B, const N: usize>(self) -> std::io::Result<HashMap<String, B>>
205 where
206 T: CommonBounds + FromBytes<Bytes = [u8; N]>,
207 B: CPUTensorCreator<T>,
208 <B as CPUTensorCreator<T>>::Output: Into<B> + TensorInfo<T>,
209 {
210 let res = load_compressed_slice::<T, B, N>(
211 self.file_path.to_str().unwrap().into(),
212 self.to_loads.expect("no tensors to load"),
213 )
214 .expect("failed to load tensor");
215 Ok(res)
216 }
217
218 pub fn load_all<T, B, const N: usize>(self) -> std::io::Result<HashMap<String, B>>
219 where
220 T: CommonBounds + FromBytes<Bytes = [u8; N]>,
221 B: CPUTensorCreator<T>,
222 <B as CPUTensorCreator<T>>::Output: Into<B> + TensorInfo<T>,
223 {
224 let res = HeaderInfo::parse_header_compressed(self.file_path.to_str().unwrap().into())
225 .expect("failed to parse header");
226 let res = load_compressed_slice::<T, B, N>(
227 self.file_path.to_str().unwrap().into(),
228 res.into_values().map(|x| (x.name, vec![])).collect(),
229 )
230 .expect("failed to load tensor");
231 Ok(res)
232 }
233}