hpt_dataloader/struct_save/
save.rs

1use std::io::{Seek, Write};
2use std::marker::PhantomData;
3
4use crate::compression_trait::{CompressionAlgo, DataLoaderTrait, Meta};
5use crate::Endian;
6use crate::{compression_trait::CompressionTrait, CHUNK_BUFF};
7use flate2::write::{DeflateEncoder, GzEncoder, ZlibEncoder};
8use hpt_common::shape::shape::Shape;
9use indicatif::ProgressBar;
10
11pub trait Save {
12    type Meta: for<'a> serde::Deserialize<'a>;
13    fn __save(
14        data: &Self,
15        file: &mut std::fs::File,
16        len_so_far: &mut usize,
17        global_cnt: &mut usize,
18        compression_algo: CompressionAlgo,
19        level: u32,
20    ) -> std::io::Result<Self::Meta>;
21    fn save<P: Into<std::path::PathBuf>>(&self, path: P) -> std::io::Result<()>
22    where
23        <Self as Save>::Meta: serde::Serialize,
24    {
25        let mut file = std::fs::File::create(path.into())?;
26        let meta = <Self as Save>::__save(
27            &self,
28            &mut file,
29            &mut 0,
30            &mut 0,
31            CompressionAlgo::NoCompression,
32            9,
33        )?;
34        let serialized = serde_json::to_string(&meta)?;
35        file.write_all(serialized.as_bytes())?;
36        Ok(())
37    }
38}
39
40fn generate_header_compressed(
41    meta: &Meta,
42) -> (
43    (
44        usize,                             /*begin */
45        String,                            /* name */
46        Vec<i64>,                          /* shape */
47        Vec<i64>,                          /* strides */
48        usize,                             /* size */
49        String,                            /* dtype */
50        CompressionAlgo,                   /* compression_algo */
51        Endian,                            /* endian */
52        Vec<(usize, usize, usize, usize)>, /* indices */
53    ),
54    (String, usize, usize, usize, usize),
55) {
56    let info = (
57        0usize,
58        meta.name.clone(),
59        meta.data_saver.shape().to_vec(),
60        meta.data_saver.shape().to_strides().to_vec(),
61        meta.data_saver.size(),
62        meta.data_saver.dtype().to_string(),
63        meta.compression_algo,
64        meta.endian,
65        vec![],
66    );
67    let res = {
68        let x = &meta.data_saver;
69        let outer = x.size() / (*x.shape().last().unwrap() as usize);
70        let inner = (*x.shape().last().unwrap() as usize) * x.mem_size();
71        let num_chunks;
72        let mut num_lines;
73        let mut remain = 0;
74        let mut buffer_size;
75        if x.size() * x.mem_size() < CHUNK_BUFF {
76            num_chunks = 1;
77            num_lines = outer;
78            buffer_size = num_lines * inner;
79        } else {
80            buffer_size = ((CHUNK_BUFF - 1) / inner) * inner;
81            num_lines = buffer_size / inner;
82            if num_lines == 0 {
83                num_lines = 1;
84                buffer_size = inner;
85            }
86            remain = outer % num_lines;
87            num_chunks = outer / num_lines;
88        }
89        (
90            meta.name.clone(),
91            num_chunks,
92            num_lines,
93            remain,
94            buffer_size,
95        )
96    };
97
98    (info, res)
99}
100
101/// method to compress the tensor and save to file
102///
103/// `file_name`: name of the file to create
104///
105/// `tensors`: a list of tuples, [(name, tensor), ...]. Name will be used as the key to load the tensor
106pub fn save(
107    file: &mut std::fs::File,
108    mut meta: Meta,
109    len: &mut usize,
110    global_cnt: usize,
111) -> std::io::Result<(
112    usize,                             /*begin */
113    String,                            /* name */
114    Vec<i64>,                          /* shape */
115    Vec<i64>,                          /* strides */
116    usize,                             /* size */
117    String,                            /* dtype */
118    CompressionAlgo,                   /* compression_algo */
119    Endian,                            /* endian */
120    Vec<(usize, usize, usize, usize)>, /* indices */
121)> {
122    // initialize progress bar based on total element size of the data to save
123    ////////////////////////////////////////////initialize progress bar////////////////////////////////////////////
124    let total_size: usize = meta.data_saver.size();
125
126    let pb = ProgressBar::new(total_size as u64);
127    pb.set_style(
128        indicatif::ProgressStyle::default_bar()
129            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
130            .unwrap(),
131    );
132    ////////////////////////////////////////////initialize progress bar////////////////////////////////////////////
133
134    // generate compresse file header
135    // (Vec<String>, list of tensor attributes: {"begin": 0, "name": "a", "shape": [1, 2, 3], ...}, ...),
136    // will be used to deserialize and get the storing information when loading the file
137    // Vec<(String, usize, usize, usize)>:  (tensor name, num_chunks, num_lines of each chunk, remain lines(if it is not divisable by CHUNK_BUFF)))
138    let (mut info, save_config) = generate_header_compressed(&meta);
139    // write FASTTENSOR Magic Header, FASTTENSORHPLACEHOLD is a place holder for the location to the real header
140    // real header is at the end of the file.
141    // For Example: FASTTENSOR210123456789{"a": {"begin": 0, "name": "a", "shape": [1, 2, 3], ...}, ...}
142    // 21 is the location of the real header, 0123456789 is the data we store.
143    const MAGIC_HEADER: &str = "FASTTENSOR";
144    const PLACEHOLDER: &str = "FASTTENSORHPLACEHOLD";
145    const HEADER_LEN: usize = MAGIC_HEADER.len() + PLACEHOLDER.len();
146    let mut len_so_far = if global_cnt == 0 {
147        file.write_all((MAGIC_HEADER.to_owned() + PLACEHOLDER).as_bytes())?;
148        HEADER_LEN
149    } else {
150        *len
151    };
152    let last_stride: i64 = *meta.data_saver.strides().last().unwrap() as i64;
153    let mut prg: Vec<i64> = vec![0; meta.data_saver.shape().len() - 1];
154    let mut shape: Vec<i64> = meta.data_saver.shape().iter().map(|x| *x as i64).collect();
155    shape.iter_mut().for_each(|x: &mut i64| {
156        *x -= 1;
157    });
158    let inner_loop_size: usize = *meta.data_saver.shape().last().unwrap() as usize;
159    let line_num: usize = save_config.2;
160    let num_chunks: usize = save_config.1;
161    let buffer_size: usize = save_config.4;
162    let mut attributes = vec![];
163    let mut chunk = vec![0u8; buffer_size];
164    let unpack = get_unpack_closure(meta.endian);
165    for k in 0..num_chunks {
166        for j in 0..line_num {
167            for i in 0..inner_loop_size {
168                let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
169                let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
170                unpack(
171                    &mut meta.data_saver,
172                    ((i as i64) * last_stride) as isize,
173                    &mut chunk[start..end],
174                );
175            }
176            pb.inc(inner_loop_size as u64);
177            for h in (0..shape.len() - 1).rev() {
178                if prg[h] < shape[h] {
179                    prg[h] += 1;
180                    meta.data_saver
181                        .offset(meta.data_saver.strides()[h] as isize);
182                    break;
183                } else {
184                    prg[h] = 0;
185                    meta.data_saver
186                        .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
187                }
188            }
189        }
190        compress_data(
191            &meta,
192            &chunk,
193            file,
194            &mut attributes,
195            &mut len_so_far,
196            k,
197            line_num,
198        )?;
199    }
200    let remain_outer: usize = save_config.3;
201    let mut remain_chunk = vec![0u8; remain_outer * inner_loop_size * meta.data_saver.mem_size()];
202    for j in 0..remain_outer {
203        for i in 0..inner_loop_size {
204            let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
205            let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
206            unpack(
207                &mut meta.data_saver,
208                ((i as i64) * last_stride) as isize,
209                &mut remain_chunk[start..end],
210            );
211        }
212        pb.inc(inner_loop_size as u64);
213        for h in (0..shape.len() - 1).rev() {
214            if prg[h] < shape[h] {
215                prg[h] += 1;
216                meta.data_saver
217                    .offset(meta.data_saver.strides()[h] as isize);
218                break;
219            } else {
220                prg[h] = 0;
221                meta.data_saver
222                    .offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
223            }
224        }
225    }
226    compress_data(
227        &meta,
228        &remain_chunk,
229        file,
230        &mut attributes,
231        &mut len_so_far,
232        num_chunks,
233        line_num,
234    )?;
235    let current_pos = file.seek(std::io::SeekFrom::Current(0))?;
236    file.seek(std::io::SeekFrom::Start(0))?;
237    file.write_all(format!("FASTTENSOR{:20}", current_pos).as_bytes())?;
238    file.seek(std::io::SeekFrom::Start(current_pos))?;
239    let length = *len;
240    *len = len_so_far;
241    info.8 = attributes;
242    info.0 = length;
243
244    Ok(info)
245}
246
247fn compress_data(
248    meta: &Meta,
249    chunk: &[u8],
250    file: &mut std::fs::File,
251    attributes: &mut Vec<(usize, usize, usize, usize)>,
252    len_so_far: &mut usize,
253    k: usize,
254    line_num: usize,
255) -> std::io::Result<()> {
256    let mut closure = |compressed_data: &[u8]| -> std::io::Result<()> {
257        file.write_all(compressed_data)?;
258        attributes.push((
259            k * line_num,
260            *len_so_far,           /* start offset of the compressed data */
261            compressed_data.len(), /* length of the compressed data */
262            chunk.len(),           /* bytes of the data in the chunk */
263        ));
264        *len_so_far += compressed_data.len();
265        Ok(())
266    };
267
268    match meta.compression_algo {
269        CompressionAlgo::Gzip => {
270            let mut encoder =
271                GzEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
272            encoder.write_all_data(chunk)?;
273            encoder.flush_all()?;
274            let compressed_data = encoder.finish_all()?;
275            closure(&compressed_data)?
276        }
277        CompressionAlgo::Deflate => {
278            let mut encoder =
279                DeflateEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
280            encoder.write_all_data(chunk)?;
281            encoder.flush_all()?;
282            let compressed_data = encoder.finish_all()?;
283            closure(&compressed_data)?
284        }
285        CompressionAlgo::Zlib => {
286            let mut encoder =
287                ZlibEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
288            encoder.write_all_data(chunk)?;
289            encoder.flush_all()?;
290            let compressed_data = encoder.finish_all()?;
291            closure(&compressed_data)?
292        }
293        CompressionAlgo::NoCompression => closure(chunk)?,
294    }
295    Ok(())
296}
297
298fn get_unpack_closure(endian: Endian) -> impl Fn(&mut Box<dyn DataLoaderTrait>, isize, &mut [u8]) {
299    match endian {
300        Endian::Little => {
301            |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
302                data_saver.fill_le_bytes_slice(offset, data)
303            }
304        }
305        Endian::Big => {
306            |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
307                data_saver.fill_be_bytes_slice(offset, data)
308            }
309        }
310        Endian::Native => {
311            |data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
312                data_saver.fill_ne_bytes_slice(offset, data)
313            }
314        }
315    }
316}
317
318macro_rules! impl_save {
319    ($struct:ty) => {
320        impl Save for $struct {
321            type Meta = Self;
322            fn __save(
323                data: &Self,
324                _: &mut std::fs::File,
325                _: &mut usize,
326                _: &mut usize,
327                _: CompressionAlgo,
328                _: u32,
329            ) -> std::io::Result<Self> {
330                Ok(data.clone())
331            }
332        }
333    };
334}
335
336impl_save!(bool);
337impl_save!(i8);
338impl_save!(i16);
339impl_save!(i32);
340impl_save!(i64);
341impl_save!(u8);
342impl_save!(u16);
343impl_save!(u32);
344impl_save!(u64);
345impl_save!(f32);
346impl_save!(f64);
347impl_save!(half::f16);
348impl_save!(half::bf16);
349impl_save!(usize);
350impl_save!(isize);
351impl_save!(String);
352impl_save!(Shape);
353
354impl<T> Save for PhantomData<T> {
355    type Meta = Self;
356    fn __save(
357        data: &Self,
358        _: &mut std::fs::File,
359        _: &mut usize,
360        _: &mut usize,
361        _: CompressionAlgo,
362        _: u32,
363    ) -> std::io::Result<Self> {
364        Ok(*data)
365    }
366}
367
368impl<T: Save> Save for Option<T> {
369    type Meta = Option<T::Meta>;
370    fn __save(
371        data: &Self,
372        file: &mut std::fs::File,
373        len: &mut usize,
374        global_cnt: &mut usize,
375        compression_algo: CompressionAlgo,
376        level: u32,
377    ) -> std::io::Result<Self::Meta> {
378        match data {
379            Some(x) => Ok(Some(T::__save(
380                x,
381                file,
382                len,
383                global_cnt,
384                compression_algo,
385                level,
386            )?)),
387            None => Ok(None),
388        }
389    }
390}
391
392impl<T: Save> Save for Vec<T> {
393    type Meta = Vec<T::Meta>;
394    fn __save(
395        data: &Self,
396        file: &mut std::fs::File,
397        len: &mut usize,
398        global_cnt: &mut usize,
399        compression_algo: CompressionAlgo,
400        level: u32,
401    ) -> std::io::Result<Self::Meta> {
402        let mut res = Vec::with_capacity(data.len());
403        for i in 0..data.len() {
404            res.push(T::__save(
405                &data[i],
406                file,
407                len,
408                global_cnt,
409                compression_algo,
410                level,
411            )?);
412        }
413        Ok(res)
414    }
415}
416
417impl<T: Save, const N: usize> Save for [T; N]
418where
419    [T::Meta; N]: for<'a> serde::Deserialize<'a>,
420{
421    type Meta = [T::Meta; N];
422    fn __save(
423        data: &Self,
424        file: &mut std::fs::File,
425        len: &mut usize,
426        global_cnt: &mut usize,
427        compression_algo: CompressionAlgo,
428        level: u32,
429    ) -> std::io::Result<Self::Meta> {
430        let mut arr: [std::mem::MaybeUninit<T::Meta>; N] =
431            unsafe { std::mem::MaybeUninit::uninit().assume_init() };
432
433        for i in 0..N {
434            arr[i] = std::mem::MaybeUninit::new(T::__save(
435                &data[i],
436                file,
437                len,
438                global_cnt,
439                compression_algo,
440                level,
441            )?);
442        }
443
444        Ok(unsafe {
445            let ptr = &arr as *const _ as *const [T::Meta; N];
446            ptr.read()
447        })
448    }
449}