1use std::{collections::HashMap, io::Write};
2
3use hpt_common::{shape::shape::Shape, strides::strides::Strides};
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 data_loader::{Endian, HeaderInfo},
9 load::load_compressed_slice,
10 save::save,
11 utils::ToDataLoader,
12 CPUTensorCreator,
13};
14
15pub trait CompressionTrait {
16 fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()>;
17 fn flush_all(&mut self) -> std::io::Result<()>;
18 fn finish_all(self) -> std::io::Result<Vec<u8>>;
19}
20
21macro_rules! impl_compression_trait {
22 ($([$name:expr, $($t:ty),*]),*) => {
23 $(
24 impl CompressionTrait for $($t)* {
25 fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()> {
26 self.write_all(buf)
27 }
28 fn flush_all(&mut self) -> std::io::Result<()> {
29 self.flush()
30 }
31 fn finish_all(self) -> std::io::Result<Vec<u8>> {
32 self.finish()
33 }
34 }
35 )*
36 };
37}
38
39impl_compression_trait!(
40 ["gzip", flate2::write::GzEncoder<Vec<u8>>],
41 ["deflate", flate2::write::DeflateEncoder<Vec<u8>>],
42 ["zlib", flate2::write::ZlibEncoder<Vec<u8>>]
43);
44
45pub trait DataLoaderTrait {
46 fn shape(&self) -> &Shape;
47 fn strides(&self) -> &Strides;
48 fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]);
49 fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]);
50 fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]);
51 fn offset(&mut self, offset: isize);
52 fn size(&self) -> usize;
53 fn dtype(&self) -> &'static str;
54 fn mem_size(&self) -> usize;
55}
56
57pub struct DataLoader<T, B> {
58 pub(crate) shape: Shape,
59 pub(crate) strides: Strides,
60 #[allow(dead_code)]
61 pub(crate) tensor: B, pub(crate) data: *const T,
63}
64
65impl<T, B> DataLoader<T, B>
66where
67 B: TensorInfo<T>,
68{
69 pub fn new(shape: Shape, strides: Strides, tensor: B) -> Self {
70 let ptr = tensor.ptr();
71 Self {
72 shape,
73 strides,
74 tensor,
75 data: ptr.ptr as *const T,
76 }
77 }
78}
79
80impl<B, T> DataLoaderTrait for DataLoader<T, B>
81where
82 B: TensorInfo<T>,
83 T: CommonBounds + bytemuck::NoUninit,
84{
85 fn shape(&self) -> &Shape {
86 &self.shape
87 }
88
89 fn strides(&self) -> &Strides {
90 &self.strides
91 }
92
93 fn offset(&mut self, offset: isize) {
94 self.data = unsafe { self.data.offset(offset) };
95 }
96
97 fn size(&self) -> usize {
98 self.shape.size() as usize
99 }
100
101 fn mem_size(&self) -> usize {
102 std::mem::size_of::<T>()
103 }
104
105 fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
106 let val = unsafe { self.data.offset(offset).read() };
107 writer.copy_from_slice(bytemuck::bytes_of(&val));
108 }
109
110 fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
111 let val = unsafe { self.data.offset(offset).read() };
112 writer.copy_from_slice(bytemuck::bytes_of(&val));
113 }
114
115 fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
116 let val = unsafe { self.data.offset(offset).read() };
117 writer.copy_from_slice(bytemuck::bytes_of(&val));
118 }
119
120 fn dtype(&self) -> &'static str {
121 T::STR
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126pub enum CompressionAlgo {
127 Gzip,
128 Deflate,
129 Zlib,
130 NoCompression,
131}
132
133pub struct Meta {
134 pub name: String,
135 pub compression_algo: CompressionAlgo,
136 pub endian: Endian,
137 pub data_saver: Box<dyn DataLoaderTrait>,
138 pub compression_level: u32,
139}
140
141pub struct TensorSaver {
142 file_path: std::path::PathBuf,
143 to_saves: Option<Vec<Meta>>,
144}
145
146impl TensorSaver {
147 pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
148 Self {
149 file_path: file_path.into(),
150 to_saves: None,
151 }
152 }
153
154 pub fn push<T, A>(
155 mut self,
156 name: &str,
157 tensor: A,
158 compression_algo: CompressionAlgo,
159 compression_level: u32,
160 ) -> Self
161 where
162 T: CommonBounds + bytemuck::NoUninit,
163 A: TensorInfo<T> + 'static + ToDataLoader,
164 <A as ToDataLoader>::Output: DataLoaderTrait,
165 {
166 let data_loader = tensor.to_dataloader();
167 let meta = Meta {
168 name: name.to_string(),
169 compression_algo,
170 endian: Endian::Native,
171 data_saver: Box::new(data_loader),
172 compression_level,
173 };
174 if let Some(to_saves) = &mut self.to_saves {
175 to_saves.push(meta);
176 } else {
177 self.to_saves = Some(vec![meta]);
178 }
179 self
180 }
181
182 pub fn save(self) -> std::io::Result<()> {
183 save(
184 self.file_path.to_str().unwrap().into(),
185 self.to_saves.unwrap(),
186 )
187 }
188}
189
190pub struct TensorLoader {
191 file_path: std::path::PathBuf,
192 to_loads: Option<Vec<(String, Vec<(i64, i64, i64)>)>>,
193}
194
195impl TensorLoader {
196 pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
197 Self {
198 file_path: file_path.into(),
199 to_loads: None,
200 }
201 }
202
203 pub fn push(mut self, name: &str, slices: &[(i64, i64, i64)]) -> Self {
204 if let Some(to_loads) = &mut self.to_loads {
205 to_loads.push((name.to_string(), slices.to_vec()));
206 } else {
207 self.to_loads = Some(vec![(name.to_string(), slices.to_vec())]);
208 }
209 self
210 }
211
212 pub fn load<B>(self) -> std::io::Result<HashMap<String, B>>
213 where
214 B: CPUTensorCreator,
215 <B as CPUTensorCreator>::Output: Into<B> + TensorInfo<<B as CPUTensorCreator>::Meta>,
216 <B as CPUTensorCreator>::Meta: CommonBounds + bytemuck::AnyBitPattern,
217 {
218 let res = load_compressed_slice::<B>(
219 self.file_path.to_str().unwrap().into(),
220 self.to_loads.expect("no tensors to load"),
221 )
222 .expect("failed to load tensor");
223 Ok(res)
224 }
225
226 pub fn load_all<B>(self) -> std::io::Result<HashMap<String, B>>
227 where
228 B: CPUTensorCreator,
229 <B as CPUTensorCreator>::Output: Into<B> + TensorInfo<<B as CPUTensorCreator>::Meta>,
230 <B as CPUTensorCreator>::Meta: CommonBounds + bytemuck::AnyBitPattern,
231 {
232 let res = HeaderInfo::parse_header_compressed(self.file_path.to_str().unwrap().into())
233 .expect("failed to parse header");
234 let res = load_compressed_slice::<B>(
235 self.file_path.to_str().unwrap().into(),
236 res.into_values().map(|x| (x.name, vec![])).collect(),
237 )
238 .expect("failed to load tensor");
239 Ok(res)
240 }
241}