1use std::io::{Seek, Write};
2use std::marker::PhantomData;
3
4use crate::compression_trait::{CompressionAlgo, DataLoaderTrait, Meta};
5use crate::Endian;
6use crate::{compression_trait::CompressionTrait, CHUNK_BUFF};
7use flate2::write::{DeflateEncoder, GzEncoder, ZlibEncoder};
8use hpt_common::shape::shape::Shape;
9use indicatif::ProgressBar;
10
11pub trait Save {
12 type Meta: for<'a> serde::Deserialize<'a>;
13 fn __save(
14 data: &Self,
15 file: &mut std::fs::File,
16 len_so_far: &mut usize,
17 global_cnt: &mut usize,
18 compression_algo: CompressionAlgo,
19 endian: Endian,
20 level: u32,
21 ) -> std::io::Result<Self::Meta>;
22 fn save(&self, path: &str) -> std::io::Result<()>
23 where
24 <Self as Save>::Meta: serde::Serialize,
25 {
26 let mut file = std::fs::File::create(path)?;
27 let meta = <Self as Save>::__save(
28 &self,
29 &mut file,
30 &mut 0,
31 &mut 0,
32 CompressionAlgo::NoCompression,
33 Endian::Native,
34 9,
35 )?;
36 let serialized = serde_json::to_string(&meta)?;
37 file.write_all(serialized.as_bytes())?;
38 Ok(())
39 }
40}
41
42fn generate_header_compressed(
43 meta: &Meta,
44) -> (
45 (
46 usize, String, Vec<i64>, Vec<i64>, usize, String, CompressionAlgo, Endian, Vec<(usize, usize, usize, usize)>, ),
56 (String, usize, usize, usize, usize),
57) {
58 let info = (
59 0usize,
60 meta.name.clone(),
61 meta.data_saver.shape().to_vec(),
62 meta.data_saver.shape().to_strides().to_vec(),
63 meta.data_saver.size(),
64 meta.data_saver.dtype().to_string(),
65 meta.compression_algo,
66 meta.endian,
67 vec![],
68 );
69 let res = {
70 let x = &meta.data_saver;
71 let outer = x.size() / (*x.shape().last().unwrap() as usize);
72 let inner = (*x.shape().last().unwrap() as usize) * x.mem_size();
73 let num_chunks;
74 let mut num_lines;
75 let mut remain = 0;
76 let mut buffer_size;
77 if x.size() * x.mem_size() < CHUNK_BUFF {
78 num_chunks = 1;
79 num_lines = outer;
80 buffer_size = num_lines * inner;
81 } else {
82 buffer_size = ((CHUNK_BUFF - 1) / inner) * inner;
83 num_lines = buffer_size / inner;
84 if num_lines == 0 {
85 num_lines = 1;
86 buffer_size = inner;
87 }
88 remain = outer % num_lines;
89 num_chunks = outer / num_lines;
90 }
91 (
92 meta.name.clone(),
93 num_chunks,
94 num_lines,
95 remain,
96 buffer_size,
97 )
98 };
99
100 (info, res)
101}
102
103pub fn save(
109 file: &mut std::fs::File,
110 mut meta: Meta,
111 len: &mut usize,
112 global_cnt: usize,
113) -> std::io::Result<(
114 usize, String, Vec<i64>, Vec<i64>, usize, String, CompressionAlgo, Endian, Vec<(usize, usize, usize, usize)>, )> {
124 let total_size: usize = meta.data_saver.size();
127
128 let pb = ProgressBar::new(total_size as u64);
129 pb.set_style(
130 indicatif::ProgressStyle::default_bar()
131 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
132 .unwrap(),
133 );
134 let (mut info, save_config) = generate_header_compressed(&meta);
141 const MAGIC_HEADER: &str = "FASTTENSOR";
146 const PLACEHOLDER: &str = "FASTTENSORHPLACEHOLD";
147 const HEADER_LEN: usize = MAGIC_HEADER.len() + PLACEHOLDER.len();
148 let mut len_so_far = if global_cnt == 0 {
149 file.write_all((MAGIC_HEADER.to_owned() + PLACEHOLDER).as_bytes())?;
150 HEADER_LEN
151 } else {
152 *len
153 };
154 let last_stride: i64 = *meta.data_saver.strides().last().unwrap() as i64;
155 let mut prg: Vec<i64> = vec![0; meta.data_saver.shape().len() - 1];
156 let mut shape: Vec<i64> = meta.data_saver.shape().iter().map(|x| *x as i64).collect();
157 shape.iter_mut().for_each(|x: &mut i64| {
158 *x -= 1;
159 });
160 let inner_loop_size: usize = *meta.data_saver.shape().last().unwrap() as usize;
161 let line_num: usize = save_config.2;
162 let num_chunks: usize = save_config.1;
163 let buffer_size: usize = save_config.4;
164 let mut attributes = vec![];
165 let mut chunk = vec![0u8; buffer_size];
166 let unpack = get_unpack_closure(meta.endian);
167 for k in 0..num_chunks {
168 for j in 0..line_num {
169 for i in 0..inner_loop_size {
170 let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
171 let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
172 unpack(
173 &mut meta.data_saver,
174 ((i as i64) * last_stride) as isize,
175 &mut chunk[start..end],
176 );
177 }
178 pb.inc(inner_loop_size as u64);
179 for h in (0..shape.len() - 1).rev() {
180 if prg[h] < shape[h] {
181 prg[h] += 1;
182 meta.data_saver
183 .offset(meta.data_saver.strides()[h] as isize);
184 break;
185 } else {
186 prg[h] = 0;
187 meta.data_saver
188 .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
189 }
190 }
191 }
192 compress_data(
193 &meta,
194 &chunk,
195 file,
196 &mut attributes,
197 &mut len_so_far,
198 k,
199 line_num,
200 )?;
201 }
202 let remain_outer: usize = save_config.3;
203 let mut remain_chunk = vec![0u8; remain_outer * inner_loop_size * meta.data_saver.mem_size()];
204 for j in 0..remain_outer {
205 for i in 0..inner_loop_size {
206 let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
207 let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
208 unpack(
209 &mut meta.data_saver,
210 ((i as i64) * last_stride) as isize,
211 &mut remain_chunk[start..end],
212 );
213 }
214 pb.inc(inner_loop_size as u64);
215 for h in (0..shape.len() - 1).rev() {
216 if prg[h] < shape[h] {
217 prg[h] += 1;
218 meta.data_saver
219 .offset(meta.data_saver.strides()[h] as isize);
220 break;
221 } else {
222 prg[h] = 0;
223 meta.data_saver
224 .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
225 }
226 }
227 }
228 compress_data(
229 &meta,
230 &remain_chunk,
231 file,
232 &mut attributes,
233 &mut len_so_far,
234 num_chunks,
235 line_num,
236 )?;
237 let current_pos = file.seek(std::io::SeekFrom::Current(0))?;
238 file.seek(std::io::SeekFrom::Start(0))?;
239 file.write_all(format!("FASTTENSOR{:20}", current_pos).as_bytes())?;
240 file.seek(std::io::SeekFrom::Start(current_pos))?;
241 let length = *len;
242 *len = len_so_far;
243 info.8 = attributes;
244 info.0 = length;
245
246 Ok(info)
247}
248
249fn compress_data(
250 meta: &Meta,
251 chunk: &[u8],
252 file: &mut std::fs::File,
253 attributes: &mut Vec<(usize, usize, usize, usize)>,
254 len_so_far: &mut usize,
255 k: usize,
256 line_num: usize,
257) -> std::io::Result<()> {
258 let mut closure = |compressed_data: &[u8]| -> std::io::Result<()> {
259 file.write_all(compressed_data)?;
260 attributes.push((
261 k * line_num,
262 *len_so_far, compressed_data.len(), chunk.len(), ));
266 *len_so_far += compressed_data.len();
267 Ok(())
268 };
269
270 match meta.compression_algo {
271 CompressionAlgo::Gzip => {
272 let mut encoder =
273 GzEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
274 encoder.write_all_data(chunk)?;
275 encoder.flush_all()?;
276 let compressed_data = encoder.finish_all()?;
277 closure(&compressed_data)?
278 }
279 CompressionAlgo::Deflate => {
280 let mut encoder =
281 DeflateEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
282 encoder.write_all_data(chunk)?;
283 encoder.flush_all()?;
284 let compressed_data = encoder.finish_all()?;
285 closure(&compressed_data)?
286 }
287 CompressionAlgo::Zlib => {
288 let mut encoder =
289 ZlibEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
290 encoder.write_all_data(chunk)?;
291 encoder.flush_all()?;
292 let compressed_data = encoder.finish_all()?;
293 closure(&compressed_data)?
294 }
295 CompressionAlgo::NoCompression => closure(chunk)?,
296 }
297 Ok(())
298}
299
300fn get_unpack_closure(endian: Endian) -> impl Fn(&mut Box<dyn DataLoaderTrait>, isize, &mut [u8]) {
301 match endian {
302 Endian::Little => {
303 |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
304 data_saver.fill_le_bytes_slice(offset, data)
305 }
306 }
307 Endian::Big => {
308 |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
309 data_saver.fill_be_bytes_slice(offset, data)
310 }
311 }
312 Endian::Native => {
313 |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
314 data_saver.fill_ne_bytes_slice(offset, data)
315 }
316 }
317 }
318}
319
320macro_rules! impl_save {
321 ($struct:ident) => {
322 impl Save for $struct {
323 type Meta = Self;
324 fn __save(
325 data: &Self,
326 _: &mut std::fs::File,
327 _: &mut usize,
328 _: &mut usize,
329 _: CompressionAlgo,
330 _: Endian,
331 _: u32,
332 ) -> std::io::Result<Self> {
333 Ok(data.clone())
334 }
335 }
336 };
337}
338
339impl_save!(bool);
340impl_save!(i8);
341impl_save!(i16);
342impl_save!(i32);
343impl_save!(i64);
344impl_save!(u8);
345impl_save!(u16);
346impl_save!(u32);
347impl_save!(u64);
348impl_save!(f32);
349impl_save!(f64);
350impl_save!(usize);
351impl_save!(isize);
352impl_save!(String);
353impl_save!(Shape);
354
355impl<T> Save for PhantomData<T> {
356 type Meta = Self;
357 fn __save(
358 data: &Self,
359 _: &mut std::fs::File,
360 _: &mut usize,
361 _: &mut usize,
362 _: CompressionAlgo,
363 _: Endian,
364 _: u32,
365 ) -> std::io::Result<Self> {
366 Ok(*data)
367 }
368}
369
370impl<T: Save> Save for Option<T> {
371 type Meta = Option<T::Meta>;
372 fn __save(
373 data: &Self,
374 file: &mut std::fs::File,
375 len: &mut usize,
376 global_cnt: &mut usize,
377 compression_algo: CompressionAlgo,
378 endian: Endian,
379 level: u32,
380 ) -> std::io::Result<Self::Meta> {
381 match data {
382 Some(x) => Ok(Some(T::__save(
383 x,
384 file,
385 len,
386 global_cnt,
387 compression_algo,
388 endian,
389 level,
390 )?)),
391 None => Ok(None),
392 }
393 }
394}
395
396impl<T: Save> Save for Vec<T> {
397 type Meta = Vec<T::Meta>;
398 fn __save(
399 data: &Self,
400 file: &mut std::fs::File,
401 len: &mut usize,
402 global_cnt: &mut usize,
403 compression_algo: CompressionAlgo,
404 endian: Endian,
405 level: u32,
406 ) -> std::io::Result<Self::Meta> {
407 let mut res = Vec::with_capacity(data.len());
408 for i in 0..data.len() {
409 res.push(T::__save(
410 &data[i],
411 file,
412 len,
413 global_cnt,
414 compression_algo,
415 endian,
416 level,
417 )?);
418 }
419 Ok(res)
420 }
421}
422
423impl<T: Save, const N: usize> Save for [T; N]
424where
425 [T::Meta; N]: for<'a> serde::Deserialize<'a>,
426{
427 type Meta = [T::Meta; N];
428 fn __save(
429 data: &Self,
430 file: &mut std::fs::File,
431 len: &mut usize,
432 global_cnt: &mut usize,
433 compression_algo: CompressionAlgo,
434 endian: Endian,
435 level: u32,
436 ) -> std::io::Result<Self::Meta> {
437 let mut arr: [std::mem::MaybeUninit<T::Meta>; N] =
438 unsafe { std::mem::MaybeUninit::uninit().assume_init() };
439
440 for i in 0..N {
441 arr[i] = std::mem::MaybeUninit::new(T::__save(
442 &data[i],
443 file,
444 len,
445 global_cnt,
446 compression_algo,
447 endian,
448 level,
449 )?);
450 }
451
452 Ok(unsafe {
453 let ptr = &arr as *const _ as *const [T::Meta; N];
454 ptr.read()
455 })
456 }
457}