1use std::io::{Seek, Write};
2use std::marker::PhantomData;
3
4use crate::compression_trait::{CompressionAlgo, DataLoaderTrait, Meta};
5use crate::Endian;
6use crate::{compression_trait::CompressionTrait, CHUNK_BUFF};
7use flate2::write::{DeflateEncoder, GzEncoder, ZlibEncoder};
8use hpt_common::shape::shape::Shape;
9use indicatif::ProgressBar;
10
11pub trait Save {
12 type Meta: for<'a> serde::Deserialize<'a>;
13 fn __save(
14 data: &Self,
15 file: &mut std::fs::File,
16 len_so_far: &mut usize,
17 global_cnt: &mut usize,
18 compression_algo: CompressionAlgo,
19 level: u32,
20 ) -> std::io::Result<Self::Meta>;
21 fn save<P: Into<std::path::PathBuf>>(&self, path: P) -> std::io::Result<()>
22 where
23 <Self as Save>::Meta: serde::Serialize,
24 {
25 let mut file = std::fs::File::create(path.into())?;
26 let meta = <Self as Save>::__save(
27 &self,
28 &mut file,
29 &mut 0,
30 &mut 0,
31 CompressionAlgo::NoCompression,
32 9,
33 )?;
34 let serialized = serde_json::to_string(&meta)?;
35 file.write_all(serialized.as_bytes())?;
36 Ok(())
37 }
38}
39
40fn generate_header_compressed(
41 meta: &Meta,
42) -> (
43 (
44 usize, String, Vec<i64>, Vec<i64>, usize, String, CompressionAlgo, Endian, Vec<(usize, usize, usize, usize)>, ),
54 (String, usize, usize, usize, usize),
55) {
56 let info = (
57 0usize,
58 meta.name.clone(),
59 meta.data_saver.shape().to_vec(),
60 meta.data_saver.shape().to_strides().to_vec(),
61 meta.data_saver.size(),
62 meta.data_saver.dtype().to_string(),
63 meta.compression_algo,
64 meta.endian,
65 vec![],
66 );
67 let res = {
68 let x = &meta.data_saver;
69 let outer = x.size() / (*x.shape().last().unwrap() as usize);
70 let inner = (*x.shape().last().unwrap() as usize) * x.mem_size();
71 let num_chunks;
72 let mut num_lines;
73 let mut remain = 0;
74 let mut buffer_size;
75 if x.size() * x.mem_size() < CHUNK_BUFF {
76 num_chunks = 1;
77 num_lines = outer;
78 buffer_size = num_lines * inner;
79 } else {
80 buffer_size = ((CHUNK_BUFF - 1) / inner) * inner;
81 num_lines = buffer_size / inner;
82 if num_lines == 0 {
83 num_lines = 1;
84 buffer_size = inner;
85 }
86 remain = outer % num_lines;
87 num_chunks = outer / num_lines;
88 }
89 (
90 meta.name.clone(),
91 num_chunks,
92 num_lines,
93 remain,
94 buffer_size,
95 )
96 };
97
98 (info, res)
99}
100
101pub fn save(
107 file: &mut std::fs::File,
108 mut meta: Meta,
109 len: &mut usize,
110 global_cnt: usize,
111) -> std::io::Result<(
112 usize, String, Vec<i64>, Vec<i64>, usize, String, CompressionAlgo, Endian, Vec<(usize, usize, usize, usize)>, )> {
122 let total_size: usize = meta.data_saver.size();
125
126 let pb = ProgressBar::new(total_size as u64);
127 pb.set_style(
128 indicatif::ProgressStyle::default_bar()
129 .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
130 .unwrap(),
131 );
132 let (mut info, save_config) = generate_header_compressed(&meta);
139 const MAGIC_HEADER: &str = "FASTTENSOR";
144 const PLACEHOLDER: &str = "FASTTENSORHPLACEHOLD";
145 const HEADER_LEN: usize = MAGIC_HEADER.len() + PLACEHOLDER.len();
146 let mut len_so_far = if global_cnt == 0 {
147 file.write_all((MAGIC_HEADER.to_owned() + PLACEHOLDER).as_bytes())?;
148 HEADER_LEN
149 } else {
150 *len
151 };
152 let last_stride: i64 = *meta.data_saver.strides().last().unwrap() as i64;
153 let mut prg: Vec<i64> = vec![0; meta.data_saver.shape().len() - 1];
154 let mut shape: Vec<i64> = meta.data_saver.shape().iter().map(|x| *x as i64).collect();
155 shape.iter_mut().for_each(|x: &mut i64| {
156 *x -= 1;
157 });
158 let inner_loop_size: usize = *meta.data_saver.shape().last().unwrap() as usize;
159 let line_num: usize = save_config.2;
160 let num_chunks: usize = save_config.1;
161 let buffer_size: usize = save_config.4;
162 let mut attributes = vec![];
163 let mut chunk = vec![0u8; buffer_size];
164 let unpack = get_unpack_closure(meta.endian);
165 for k in 0..num_chunks {
166 for j in 0..line_num {
167 for i in 0..inner_loop_size {
168 let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
169 let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
170 unpack(
171 &mut meta.data_saver,
172 ((i as i64) * last_stride) as isize,
173 &mut chunk[start..end],
174 );
175 }
176 pb.inc(inner_loop_size as u64);
177 for h in (0..shape.len() - 1).rev() {
178 if prg[h] < shape[h] {
179 prg[h] += 1;
180 meta.data_saver
181 .offset(meta.data_saver.strides()[h] as isize);
182 break;
183 } else {
184 prg[h] = 0;
185 meta.data_saver
186 .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
187 }
188 }
189 }
190 compress_data(
191 &meta,
192 &chunk,
193 file,
194 &mut attributes,
195 &mut len_so_far,
196 k,
197 line_num,
198 )?;
199 }
200 let remain_outer: usize = save_config.3;
201 let mut remain_chunk = vec![0u8; remain_outer * inner_loop_size * meta.data_saver.mem_size()];
202 for j in 0..remain_outer {
203 for i in 0..inner_loop_size {
204 let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
205 let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
206 unpack(
207 &mut meta.data_saver,
208 ((i as i64) * last_stride) as isize,
209 &mut remain_chunk[start..end],
210 );
211 }
212 pb.inc(inner_loop_size as u64);
213 for h in (0..shape.len() - 1).rev() {
214 if prg[h] < shape[h] {
215 prg[h] += 1;
216 meta.data_saver
217 .offset(meta.data_saver.strides()[h] as isize);
218 break;
219 } else {
220 prg[h] = 0;
221 meta.data_saver
222 .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
223 }
224 }
225 }
226 compress_data(
227 &meta,
228 &remain_chunk,
229 file,
230 &mut attributes,
231 &mut len_so_far,
232 num_chunks,
233 line_num,
234 )?;
235 let current_pos = file.seek(std::io::SeekFrom::Current(0))?;
236 file.seek(std::io::SeekFrom::Start(0))?;
237 file.write_all(format!("FASTTENSOR{:20}", current_pos).as_bytes())?;
238 file.seek(std::io::SeekFrom::Start(current_pos))?;
239 let length = *len;
240 *len = len_so_far;
241 info.8 = attributes;
242 info.0 = length;
243
244 Ok(info)
245}
246
247fn compress_data(
248 meta: &Meta,
249 chunk: &[u8],
250 file: &mut std::fs::File,
251 attributes: &mut Vec<(usize, usize, usize, usize)>,
252 len_so_far: &mut usize,
253 k: usize,
254 line_num: usize,
255) -> std::io::Result<()> {
256 let mut closure = |compressed_data: &[u8]| -> std::io::Result<()> {
257 file.write_all(compressed_data)?;
258 attributes.push((
259 k * line_num,
260 *len_so_far, compressed_data.len(), chunk.len(), ));
264 *len_so_far += compressed_data.len();
265 Ok(())
266 };
267
268 match meta.compression_algo {
269 CompressionAlgo::Gzip => {
270 let mut encoder =
271 GzEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
272 encoder.write_all_data(chunk)?;
273 encoder.flush_all()?;
274 let compressed_data = encoder.finish_all()?;
275 closure(&compressed_data)?
276 }
277 CompressionAlgo::Deflate => {
278 let mut encoder =
279 DeflateEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
280 encoder.write_all_data(chunk)?;
281 encoder.flush_all()?;
282 let compressed_data = encoder.finish_all()?;
283 closure(&compressed_data)?
284 }
285 CompressionAlgo::Zlib => {
286 let mut encoder =
287 ZlibEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
288 encoder.write_all_data(chunk)?;
289 encoder.flush_all()?;
290 let compressed_data = encoder.finish_all()?;
291 closure(&compressed_data)?
292 }
293 CompressionAlgo::NoCompression => closure(chunk)?,
294 }
295 Ok(())
296}
297
298fn get_unpack_closure(endian: Endian) -> impl Fn(&mut Box<dyn DataLoaderTrait>, isize, &mut [u8]) {
299 match endian {
300 Endian::Little => {
301 |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
302 data_saver.fill_le_bytes_slice(offset, data)
303 }
304 }
305 Endian::Big => {
306 |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
307 data_saver.fill_be_bytes_slice(offset, data)
308 }
309 }
310 Endian::Native => {
311 |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
312 data_saver.fill_ne_bytes_slice(offset, data)
313 }
314 }
315 }
316}
317
318macro_rules! impl_save {
319 ($struct:ty) => {
320 impl Save for $struct {
321 type Meta = Self;
322 fn __save(
323 data: &Self,
324 _: &mut std::fs::File,
325 _: &mut usize,
326 _: &mut usize,
327 _: CompressionAlgo,
328 _: u32,
329 ) -> std::io::Result<Self> {
330 Ok(data.clone())
331 }
332 }
333 };
334}
335
336impl_save!(bool);
337impl_save!(i8);
338impl_save!(i16);
339impl_save!(i32);
340impl_save!(i64);
341impl_save!(u8);
342impl_save!(u16);
343impl_save!(u32);
344impl_save!(u64);
345impl_save!(f32);
346impl_save!(f64);
347impl_save!(half::f16);
348impl_save!(half::bf16);
349impl_save!(usize);
350impl_save!(isize);
351impl_save!(String);
352impl_save!(Shape);
353
354impl<T> Save for PhantomData<T> {
355 type Meta = Self;
356 fn __save(
357 data: &Self,
358 _: &mut std::fs::File,
359 _: &mut usize,
360 _: &mut usize,
361 _: CompressionAlgo,
362 _: u32,
363 ) -> std::io::Result<Self> {
364 Ok(*data)
365 }
366}
367
368impl<T: Save> Save for Option<T> {
369 type Meta = Option<T::Meta>;
370 fn __save(
371 data: &Self,
372 file: &mut std::fs::File,
373 len: &mut usize,
374 global_cnt: &mut usize,
375 compression_algo: CompressionAlgo,
376 level: u32,
377 ) -> std::io::Result<Self::Meta> {
378 match data {
379 Some(x) => Ok(Some(T::__save(
380 x,
381 file,
382 len,
383 global_cnt,
384 compression_algo,
385 level,
386 )?)),
387 None => Ok(None),
388 }
389 }
390}
391
392impl<T: Save> Save for Vec<T> {
393 type Meta = Vec<T::Meta>;
394 fn __save(
395 data: &Self,
396 file: &mut std::fs::File,
397 len: &mut usize,
398 global_cnt: &mut usize,
399 compression_algo: CompressionAlgo,
400 level: u32,
401 ) -> std::io::Result<Self::Meta> {
402 let mut res = Vec::with_capacity(data.len());
403 for i in 0..data.len() {
404 res.push(T::__save(
405 &data[i],
406 file,
407 len,
408 global_cnt,
409 compression_algo,
410 level,
411 )?);
412 }
413 Ok(res)
414 }
415}
416
417impl<T: Save, const N: usize> Save for [T; N]
418where
419 [T::Meta; N]: for<'a> serde::Deserialize<'a>,
420{
421 type Meta = [T::Meta; N];
422 fn __save(
423 data: &Self,
424 file: &mut std::fs::File,
425 len: &mut usize,
426 global_cnt: &mut usize,
427 compression_algo: CompressionAlgo,
428 level: u32,
429 ) -> std::io::Result<Self::Meta> {
430 let mut arr: [std::mem::MaybeUninit<T::Meta>; N] =
431 unsafe { std::mem::MaybeUninit::uninit().assume_init() };
432
433 for i in 0..N {
434 arr[i] = std::mem::MaybeUninit::new(T::__save(
435 &data[i],
436 file,
437 len,
438 global_cnt,
439 compression_algo,
440 level,
441 )?);
442 }
443
444 Ok(unsafe {
445 let ptr = &arr as *const _ as *const [T::Meta; N];
446 ptr.read()
447 })
448 }
449}