hpt_dataloader/
compression_trait.rs

1use std::{collections::HashMap, io::Write};
2
3use hpt_common::{shape::shape::Shape, strides::strides::Strides};
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    data_loader::{Endian, HeaderInfo},
9    load::load_compressed_slice,
10    save::save,
11    utils::ToDataLoader,
12    CPUTensorCreator,
13};
14
15pub trait CompressionTrait {
16    fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()>;
17    fn flush_all(&mut self) -> std::io::Result<()>;
18    fn finish_all(self) -> std::io::Result<Vec<u8>>;
19}
20
21macro_rules! impl_compression_trait {
22    ($([$name:expr, $($t:ty),*]),*) => {
23        $(
24            impl CompressionTrait for $($t)* {
25                fn write_all_data(&mut self, buf: &[u8]) -> std::io::Result<()> {
26                    self.write_all(buf)
27                }
28                fn flush_all(&mut self) -> std::io::Result<()> {
29                    self.flush()
30                }
31                fn finish_all(self) -> std::io::Result<Vec<u8>> {
32                    self.finish()
33                }
34            }
35        )*
36    };
37}
38
39impl_compression_trait!(
40    ["gzip", flate2::write::GzEncoder<Vec<u8>>],
41    ["deflate", flate2::write::DeflateEncoder<Vec<u8>>],
42    ["zlib", flate2::write::ZlibEncoder<Vec<u8>>]
43);
44
45pub trait DataLoaderTrait {
46    fn shape(&self) -> &Shape;
47    fn strides(&self) -> &Strides;
48    fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]);
49    fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]);
50    fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]);
51    fn offset(&mut self, offset: isize);
52    fn size(&self) -> usize;
53    fn dtype(&self) -> &'static str;
54    fn mem_size(&self) -> usize;
55}
56
57pub struct DataLoader<T, B> {
58    pub(crate) shape: Shape,
59    pub(crate) strides: Strides,
60    #[allow(dead_code)]
61    pub(crate) tensor: B, // this is just used to let rust hold the data not to drop
62    pub(crate) data: *const T,
63}
64
65impl<T, B> DataLoader<T, B>
66where
67    B: TensorInfo<T>,
68{
69    pub fn new(shape: Shape, strides: Strides, tensor: B) -> Self {
70        let ptr = tensor.ptr();
71        Self {
72            shape,
73            strides,
74            tensor,
75            data: ptr.ptr as *const T,
76        }
77    }
78}
79
80impl<B, T> DataLoaderTrait for DataLoader<T, B>
81where
82    B: TensorInfo<T>,
83    T: CommonBounds + bytemuck::NoUninit,
84{
85    fn shape(&self) -> &Shape {
86        &self.shape
87    }
88
89    fn strides(&self) -> &Strides {
90        &self.strides
91    }
92
93    fn offset(&mut self, offset: isize) {
94        self.data = unsafe { self.data.offset(offset) };
95    }
96
97    fn size(&self) -> usize {
98        self.shape.size() as usize
99    }
100
101    fn mem_size(&self) -> usize {
102        std::mem::size_of::<T>()
103    }
104
105    fn fill_ne_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
106        let val = unsafe { self.data.offset(offset).read() };
107        writer.copy_from_slice(bytemuck::bytes_of(&val));
108    }
109
110    fn fill_be_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
111        let val = unsafe { self.data.offset(offset).read() };
112        writer.copy_from_slice(bytemuck::bytes_of(&val));
113    }
114
115    fn fill_le_bytes_slice(&self, offset: isize, writer: &mut [u8]) {
116        let val = unsafe { self.data.offset(offset).read() };
117        writer.copy_from_slice(bytemuck::bytes_of(&val));
118    }
119
120    fn dtype(&self) -> &'static str {
121        T::STR
122    }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
126pub enum CompressionAlgo {
127    Gzip,
128    Deflate,
129    Zlib,
130    NoCompression,
131}
132
133pub struct Meta {
134    pub name: String,
135    pub compression_algo: CompressionAlgo,
136    pub endian: Endian,
137    pub data_saver: Box<dyn DataLoaderTrait>,
138    pub compression_level: u32,
139}
140
141pub struct TensorSaver {
142    file_path: std::path::PathBuf,
143    to_saves: Option<Vec<Meta>>,
144}
145
146impl TensorSaver {
147    pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
148        Self {
149            file_path: file_path.into(),
150            to_saves: None,
151        }
152    }
153
154    pub fn push<T, A>(
155        mut self,
156        name: &str,
157        tensor: A,
158        compression_algo: CompressionAlgo,
159        compression_level: u32,
160    ) -> Self
161    where
162        T: CommonBounds + bytemuck::NoUninit,
163        A: TensorInfo<T> + 'static + ToDataLoader,
164        <A as ToDataLoader>::Output: DataLoaderTrait,
165    {
166        let data_loader = tensor.to_dataloader();
167        let meta = Meta {
168            name: name.to_string(),
169            compression_algo,
170            endian: Endian::Native,
171            data_saver: Box::new(data_loader),
172            compression_level,
173        };
174        if let Some(to_saves) = &mut self.to_saves {
175            to_saves.push(meta);
176        } else {
177            self.to_saves = Some(vec![meta]);
178        }
179        self
180    }
181
182    pub fn save(self) -> std::io::Result<()> {
183        save(
184            self.file_path.to_str().unwrap().into(),
185            self.to_saves.unwrap(),
186        )
187    }
188}
189
190pub struct TensorLoader {
191    file_path: std::path::PathBuf,
192    to_loads: Option<Vec<(String, Vec<(i64, i64, i64)>)>>,
193}
194
195impl TensorLoader {
196    pub fn new<P: Into<std::path::PathBuf>>(file_path: P) -> Self {
197        Self {
198            file_path: file_path.into(),
199            to_loads: None,
200        }
201    }
202
203    pub fn push(mut self, name: &str, slices: &[(i64, i64, i64)]) -> Self {
204        if let Some(to_loads) = &mut self.to_loads {
205            to_loads.push((name.to_string(), slices.to_vec()));
206        } else {
207            self.to_loads = Some(vec![(name.to_string(), slices.to_vec())]);
208        }
209        self
210    }
211
212    pub fn load<B>(self) -> std::io::Result<HashMap<String, B>>
213    where
214        B: CPUTensorCreator,
215        <B as CPUTensorCreator>::Output: Into<B> + TensorInfo<<B as CPUTensorCreator>::Meta>,
216        <B as CPUTensorCreator>::Meta: CommonBounds + bytemuck::AnyBitPattern,
217    {
218        let res = load_compressed_slice::<B>(
219            self.file_path.to_str().unwrap().into(),
220            self.to_loads.expect("no tensors to load"),
221        )
222        .expect("failed to load tensor");
223        Ok(res)
224    }
225
226    pub fn load_all<B>(self) -> std::io::Result<HashMap<String, B>>
227    where
228        B: CPUTensorCreator,
229        <B as CPUTensorCreator>::Output: Into<B> + TensorInfo<<B as CPUTensorCreator>::Meta>,
230        <B as CPUTensorCreator>::Meta: CommonBounds + bytemuck::AnyBitPattern,
231    {
232        let res = HeaderInfo::parse_header_compressed(self.file_path.to_str().unwrap().into())
233            .expect("failed to parse header");
234        let res = load_compressed_slice::<B>(
235            self.file_path.to_str().unwrap().into(),
236            res.into_values().map(|x| (x.name, vec![])).collect(),
237        )
238        .expect("failed to load tensor");
239        Ok(res)
240    }
241}