hpt_dataloader/
data_loader.rs1use hpt_traits::tensor::{CommonBounds, TensorInfo};
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4use std::marker::PhantomData;
5use std::{
6 collections::HashMap,
7 fs::File,
8 io::{Read, Seek},
9};
10
11use crate::struct_save::load::load;
12use crate::struct_save::load::MetaLoad;
13use crate::struct_save::save::Save;
14use crate::CPUTensorCreator;
15use crate::CompressionAlgo;
16
17#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
18pub enum Endian {
19 Little,
20 Big,
21 Native,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25pub(crate) struct HeaderInfos {
26 pub(crate) begin: u64,
27 pub(crate) infos: Vec<HeaderInfo>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31pub(crate) struct HeaderInfo {
32 pub(crate) begin: u64,
33 pub(crate) name: String,
34 pub(crate) shape: Vec<i64>,
35 pub(crate) strides: Vec<i64>,
36 pub(crate) size: usize,
37 pub(crate) indices: Vec<(usize, usize, usize, usize)>,
38 pub(crate) compress_algo: CompressionAlgo,
39 pub(crate) dtype: String,
40 pub(crate) endian: Endian,
41}
42#[derive(Serialize, Deserialize)]
44#[must_use]
45pub struct TensorMeta<T: CommonBounds, B: CPUTensorCreator>
46where
47 <B as CPUTensorCreator>::Output: Clone + TensorInfo<T>,
48{
49 pub begin: usize,
50 pub shape: Vec<i64>,
51 pub strides: Vec<i64>,
52 pub size: usize,
53 pub dtype: String,
54 pub compression_algo: CompressionAlgo,
55 pub endian: Endian,
56 pub indices: Vec<(usize, usize, usize, usize)>,
57 pub phantom: PhantomData<(T, B)>,
58}
59
60pub fn parse_header_compressed<M: Save, P: Into<std::path::PathBuf>>(
61 file: P,
62) -> Result<<M as Save>::Meta, Box<dyn std::error::Error>> {
63 let mut file = File::open(file.into())?;
64 file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
65 let mut header_infos = [0u8; 20];
66 file.read_exact(&mut header_infos)?;
67 let header = std::str::from_utf8(&header_infos)?;
68 let header_int = header.trim().parse::<u64>()?;
69 file.seek(std::io::SeekFrom::Start(header_int))?; let mut buffer3 = vec![];
71 file.read_to_end(&mut buffer3)?;
72 let info = std::str::from_utf8(&buffer3)?;
73 let ret = serde_json::from_str::<M::Meta>(info)?;
74 Ok(ret)
75}
76
77impl<T, B: CPUTensorCreator> MetaLoad for TensorMeta<T, B>
78where
79 T: CommonBounds + bytemuck::AnyBitPattern,
80 <B as CPUTensorCreator>::Output: Clone + TensorInfo<T> + Display + Into<B>,
81{
82 type Output = B;
83 fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
84 Ok(load::<T, B>(file, self)?.into())
85 }
86}
87
88impl HeaderInfo {
89 pub(crate) fn parse_header_compressed(
90 file: &str,
91 ) -> Result<HashMap<String, HeaderInfo>, Box<dyn std::error::Error>> {
92 let mut file = File::open(file)?;
93 file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
94 let mut header_infos = [0u8; 20];
95 file.read_exact(&mut header_infos)?;
96 let header = std::str::from_utf8(&header_infos)?;
97 let header_int = header.trim().parse::<u64>()?;
98 file.seek(std::io::SeekFrom::Start(header_int))?; let mut buffer3 = vec![];
100 file.read_to_end(&mut buffer3)?;
101 let info = std::str::from_utf8(&buffer3)?;
102 let ret: HashMap<String, HeaderInfo> =
103 serde_json::from_str::<HashMap<String, Self>>(&info)?;
104 Ok(ret)
105 }
106}