hpt_dataloader/struct_save/
save.rs

1use std::io::{Seek, Write};
2use std::marker::PhantomData;
3
4use crate::compression_trait::{CompressionAlgo, DataLoaderTrait, Meta};
5use crate::Endian;
6use crate::{compression_trait::CompressionTrait, CHUNK_BUFF};
7use flate2::write::{DeflateEncoder, GzEncoder, ZlibEncoder};
8use hpt_common::shape::shape::Shape;
9use indicatif::ProgressBar;
10
11pub trait Save {
12    type Meta: for<'a> serde::Deserialize<'a>;
13    fn __save(
14        data: &Self,
15        file: &mut std::fs::File,
16        len_so_far: &mut usize,
17        global_cnt: &mut usize,
18        compression_algo: CompressionAlgo,
19        endian: Endian,
20        level: u32,
21    ) -> std::io::Result<Self::Meta>;
22    fn save(&self, path: &str) -> std::io::Result<()>
23    where
24        <Self as Save>::Meta: serde::Serialize,
25    {
26        let mut file = std::fs::File::create(path)?;
27        let meta = <Self as Save>::__save(
28            &self,
29            &mut file,
30            &mut 0,
31            &mut 0,
32            CompressionAlgo::NoCompression,
33            Endian::Native,
34            9,
35        )?;
36        let serialized = serde_json::to_string(&meta)?;
37        file.write_all(serialized.as_bytes())?;
38        Ok(())
39    }
40}
41
42fn generate_header_compressed(
43    meta: &Meta,
44) -> (
45    (
46        usize,                             /*begin */
47        String,                            /* name */
48        Vec<i64>,                          /* shape */
49        Vec<i64>,                          /* strides */
50        usize,                             /* size */
51        String,                            /* dtype */
52        CompressionAlgo,                   /* compression_algo */
53        Endian,                            /* endian */
54        Vec<(usize, usize, usize, usize)>, /* indices */
55    ),
56    (String, usize, usize, usize, usize),
57) {
58    let info = (
59        0usize,
60        meta.name.clone(),
61        meta.data_saver.shape().to_vec(),
62        meta.data_saver.shape().to_strides().to_vec(),
63        meta.data_saver.size(),
64        meta.data_saver.dtype().to_string(),
65        meta.compression_algo,
66        meta.endian,
67        vec![],
68    );
69    let res = {
70        let x = &meta.data_saver;
71        let outer = x.size() / (*x.shape().last().unwrap() as usize);
72        let inner = (*x.shape().last().unwrap() as usize) * x.mem_size();
73        let num_chunks;
74        let mut num_lines;
75        let mut remain = 0;
76        let mut buffer_size;
77        if x.size() * x.mem_size() < CHUNK_BUFF {
78            num_chunks = 1;
79            num_lines = outer;
80            buffer_size = num_lines * inner;
81        } else {
82            buffer_size = ((CHUNK_BUFF - 1) / inner) * inner;
83            num_lines = buffer_size / inner;
84            if num_lines == 0 {
85                num_lines = 1;
86                buffer_size = inner;
87            }
88            remain = outer % num_lines;
89            num_chunks = outer / num_lines;
90        }
91        (
92            meta.name.clone(),
93            num_chunks,
94            num_lines,
95            remain,
96            buffer_size,
97        )
98    };
99
100    (info, res)
101}
102
103/// method to compress the tensor and save to file
104///
105/// `file_name`: name of the file to create
106///
107/// `tensors`: a list of tuples, [(name, tensor), ...]. Name will be used as the key to load the tensor
108pub fn save(
109    file: &mut std::fs::File,
110    mut meta: Meta,
111    len: &mut usize,
112    global_cnt: usize,
113) -> std::io::Result<(
114    usize,                             /*begin */
115    String,                            /* name */
116    Vec<i64>,                          /* shape */
117    Vec<i64>,                          /* strides */
118    usize,                             /* size */
119    String,                            /* dtype */
120    CompressionAlgo,                   /* compression_algo */
121    Endian,                            /* endian */
122    Vec<(usize, usize, usize, usize)>, /* indices */
123)> {
124    // initialize progress bar based on total element size of the data to save
125    ////////////////////////////////////////////initialize progress bar////////////////////////////////////////////
126    let total_size: usize = meta.data_saver.size();
127
128    let pb = ProgressBar::new(total_size as u64);
129    pb.set_style(
130        indicatif::ProgressStyle::default_bar()
131            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
132            .unwrap(),
133    );
134    ////////////////////////////////////////////initialize progress bar////////////////////////////////////////////
135
136    // generate compresse file header
137    // (Vec<String>, list of tensor attributes: {"begin": 0, "name": "a", "shape": [1, 2, 3], ...}, ...),
138    // will be used to deserialize and get the storing information when loading the file
139    // Vec<(String, usize, usize, usize)>:  (tensor name, num_chunks, num_lines of each chunk, remain lines(if it is not divisable by CHUNK_BUFF)))
140    let (mut info, save_config) = generate_header_compressed(&meta);
141    // write FASTTENSOR Magic Header, FASTTENSORHPLACEHOLD is a place holder for the location to the real header
142    // real header is at the end of the file.
143    // For Example: FASTTENSOR210123456789{"a": {"begin": 0, "name": "a", "shape": [1, 2, 3], ...}, ...}
144    // 21 is the location of the real header, 0123456789 is the data we store.
145    const MAGIC_HEADER: &str = "FASTTENSOR";
146    const PLACEHOLDER: &str = "FASTTENSORHPLACEHOLD";
147    const HEADER_LEN: usize = MAGIC_HEADER.len() + PLACEHOLDER.len();
148    let mut len_so_far = if global_cnt == 0 {
149        file.write_all((MAGIC_HEADER.to_owned() + PLACEHOLDER).as_bytes())?;
150        HEADER_LEN
151    } else {
152        *len
153    };
154    let last_stride: i64 = *meta.data_saver.strides().last().unwrap() as i64;
155    let mut prg: Vec<i64> = vec![0; meta.data_saver.shape().len() - 1];
156    let mut shape: Vec<i64> = meta.data_saver.shape().iter().map(|x| *x as i64).collect();
157    shape.iter_mut().for_each(|x: &mut i64| {
158        *x -= 1;
159    });
160    let inner_loop_size: usize = *meta.data_saver.shape().last().unwrap() as usize;
161    let line_num: usize = save_config.2;
162    let num_chunks: usize = save_config.1;
163    let buffer_size: usize = save_config.4;
164    let mut attributes = vec![];
165    let mut chunk = vec![0u8; buffer_size];
166    let unpack = get_unpack_closure(meta.endian);
167    for k in 0..num_chunks {
168        for j in 0..line_num {
169            for i in 0..inner_loop_size {
170                let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
171                let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
172                unpack(
173                    &mut meta.data_saver,
174                    ((i as i64) * last_stride) as isize,
175                    &mut chunk[start..end],
176                );
177            }
178            pb.inc(inner_loop_size as u64);
179            for h in (0..shape.len() - 1).rev() {
180                if prg[h] < shape[h] {
181                    prg[h] += 1;
182                    meta.data_saver
183                        .offset(meta.data_saver.strides()[h] as isize);
184                    break;
185                } else {
186                    prg[h] = 0;
187                    meta.data_saver
188                        .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
189                }
190            }
191        }
192        compress_data(
193            &meta,
194            &chunk,
195            file,
196            &mut attributes,
197            &mut len_so_far,
198            k,
199            line_num,
200        )?;
201    }
202    let remain_outer: usize = save_config.3;
203    let mut remain_chunk = vec![0u8; remain_outer * inner_loop_size * meta.data_saver.mem_size()];
204    for j in 0..remain_outer {
205        for i in 0..inner_loop_size {
206            let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
207            let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
208            unpack(
209                &mut meta.data_saver,
210                ((i as i64) * last_stride) as isize,
211                &mut remain_chunk[start..end],
212            );
213        }
214        pb.inc(inner_loop_size as u64);
215        for h in (0..shape.len() - 1).rev() {
216            if prg[h] < shape[h] {
217                prg[h] += 1;
218                meta.data_saver
219                    .offset(meta.data_saver.strides()[h] as isize);
220                break;
221            } else {
222                prg[h] = 0;
223                meta.data_saver
224                    .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
225            }
226        }
227    }
228    compress_data(
229        &meta,
230        &remain_chunk,
231        file,
232        &mut attributes,
233        &mut len_so_far,
234        num_chunks,
235        line_num,
236    )?;
237    let current_pos = file.seek(std::io::SeekFrom::Current(0))?;
238    file.seek(std::io::SeekFrom::Start(0))?;
239    file.write_all(format!("FASTTENSOR{:20}", current_pos).as_bytes())?;
240    file.seek(std::io::SeekFrom::Start(current_pos))?;
241    let length = *len;
242    *len = len_so_far;
243    info.8 = attributes;
244    info.0 = length;
245
246    Ok(info)
247}
248
249fn compress_data(
250    meta: &Meta,
251    chunk: &[u8],
252    file: &mut std::fs::File,
253    attributes: &mut Vec<(usize, usize, usize, usize)>,
254    len_so_far: &mut usize,
255    k: usize,
256    line_num: usize,
257) -> std::io::Result<()> {
258    let mut closure = |compressed_data: &[u8]| -> std::io::Result<()> {
259        file.write_all(compressed_data)?;
260        attributes.push((
261            k * line_num,
262            *len_so_far,           /* start offset of the compressed data */
263            compressed_data.len(), /* length of the compressed data */
264            chunk.len(),           /* bytes of the data in the chunk */
265        ));
266        *len_so_far += compressed_data.len();
267        Ok(())
268    };
269
270    match meta.compression_algo {
271        CompressionAlgo::Gzip => {
272            let mut encoder =
273                GzEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
274            encoder.write_all_data(chunk)?;
275            encoder.flush_all()?;
276            let compressed_data = encoder.finish_all()?;
277            closure(&compressed_data)?
278        }
279        CompressionAlgo::Deflate => {
280            let mut encoder =
281                DeflateEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
282            encoder.write_all_data(chunk)?;
283            encoder.flush_all()?;
284            let compressed_data = encoder.finish_all()?;
285            closure(&compressed_data)?
286        }
287        CompressionAlgo::Zlib => {
288            let mut encoder =
289                ZlibEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
290            encoder.write_all_data(chunk)?;
291            encoder.flush_all()?;
292            let compressed_data = encoder.finish_all()?;
293            closure(&compressed_data)?
294        }
295        CompressionAlgo::NoCompression => closure(chunk)?,
296    }
297    Ok(())
298}
299
300fn get_unpack_closure(endian: Endian) -> impl Fn(&mut Box<dyn DataLoaderTrait>, isize, &mut [u8]) {
301    match endian {
302        Endian::Little => {
303            |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
304                data_saver.fill_le_bytes_slice(offset, data)
305            }
306        }
307        Endian::Big => {
308            |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
309                data_saver.fill_be_bytes_slice(offset, data)
310            }
311        }
312        Endian::Native => {
313            |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
314                data_saver.fill_ne_bytes_slice(offset, data)
315            }
316        }
317    }
318}
319
320macro_rules! impl_save {
321    ($struct:ident) => {
322        impl Save for $struct {
323            type Meta = Self;
324            fn __save(
325                data: &Self,
326                _: &mut std::fs::File,
327                _: &mut usize,
328                _: &mut usize,
329                _: CompressionAlgo,
330                _: Endian,
331                _: u32,
332            ) -> std::io::Result<Self> {
333                Ok(data.clone())
334            }
335        }
336    };
337}
338
339impl_save!(bool);
340impl_save!(i8);
341impl_save!(i16);
342impl_save!(i32);
343impl_save!(i64);
344impl_save!(u8);
345impl_save!(u16);
346impl_save!(u32);
347impl_save!(u64);
348impl_save!(f32);
349impl_save!(f64);
350impl_save!(usize);
351impl_save!(isize);
352impl_save!(String);
353impl_save!(Shape);
354
355impl<T> Save for PhantomData<T> {
356    type Meta = Self;
357    fn __save(
358        data: &Self,
359        _: &mut std::fs::File,
360        _: &mut usize,
361        _: &mut usize,
362        _: CompressionAlgo,
363        _: Endian,
364        _: u32,
365    ) -> std::io::Result<Self> {
366        Ok(*data)
367    }
368}
369
370impl<T: Save> Save for Option<T> {
371    type Meta = Option<T::Meta>;
372    fn __save(
373        data: &Self,
374        file: &mut std::fs::File,
375        len: &mut usize,
376        global_cnt: &mut usize,
377        compression_algo: CompressionAlgo,
378        endian: Endian,
379        level: u32,
380    ) -> std::io::Result<Self::Meta> {
381        match data {
382            Some(x) => Ok(Some(T::__save(
383                x,
384                file,
385                len,
386                global_cnt,
387                compression_algo,
388                endian,
389                level,
390            )?)),
391            None => Ok(None),
392        }
393    }
394}
395
396impl<T: Save> Save for Vec<T> {
397    type Meta = Vec<T::Meta>;
398    fn __save(
399        data: &Self,
400        file: &mut std::fs::File,
401        len: &mut usize,
402        global_cnt: &mut usize,
403        compression_algo: CompressionAlgo,
404        endian: Endian,
405        level: u32,
406    ) -> std::io::Result<Self::Meta> {
407        let mut res = Vec::with_capacity(data.len());
408        for i in 0..data.len() {
409            res.push(T::__save(
410                &data[i],
411                file,
412                len,
413                global_cnt,
414                compression_algo,
415                endian,
416                level,
417            )?);
418        }
419        Ok(res)
420    }
421}
422
423impl<T: Save, const N: usize> Save for [T; N]
424where
425    [T::Meta; N]: for<'a> serde::Deserialize<'a>,
426{
427    type Meta = [T::Meta; N];
428    fn __save(
429        data: &Self,
430        file: &mut std::fs::File,
431        len: &mut usize,
432        global_cnt: &mut usize,
433        compression_algo: CompressionAlgo,
434        endian: Endian,
435        level: u32,
436    ) -> std::io::Result<Self::Meta> {
437        let mut arr: [std::mem::MaybeUninit<T::Meta>; N] =
438            unsafe { std::mem::MaybeUninit::uninit().assume_init() };
439
440        for i in 0..N {
441            arr[i] = std::mem::MaybeUninit::new(T::__save(
442                &data[i],
443                file,
444                len,
445                global_cnt,
446                compression_algo,
447                endian,
448                level,
449            )?);
450        }
451
452        Ok(unsafe {
453            let ptr = &arr as *const _ as *const [T::Meta; N];
454            ptr.read()
455        })
456    }
457}