hpt_dataloader/
data_loader.rs1use anyhow::Ok;
2use anyhow::Result;
3use hpt_traits::CommonBounds;
4use hpt_traits::TensorCreator;
5use hpt_traits::TensorInfo;
6use num::traits::FromBytes;
7use serde::{Deserialize, Serialize};
8use std::fmt::Display;
9use std::marker::PhantomData;
10use std::{
11    collections::HashMap,
12    fs::File,
13    io::{Read, Seek},
14};
15
16use crate::struct_save::load::load;
17use crate::struct_save::load::MetaLoad;
18use crate::struct_save::save::Save;
19use crate::CompressionAlgo;
20
21#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
22pub enum Endian {
23    Little,
24    Big,
25    Native,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29pub(crate) struct HeaderInfos {
30    pub(crate) begin: u64,
31    pub(crate) infos: Vec<HeaderInfo>,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35pub(crate) struct HeaderInfo {
36    pub(crate) begin: u64,
37    pub(crate) name: String,
38    pub(crate) shape: Vec<i64>,
39    pub(crate) strides: Vec<i64>,
40    pub(crate) size: usize,
41    pub(crate) indices: Vec<(usize, usize, usize, usize)>,
42    pub(crate) compress_algo: CompressionAlgo,
43    pub(crate) dtype: String,
44    pub(crate) endian: Endian,
45}
46#[derive(Serialize, Deserialize)]
48#[must_use]
49pub struct TensorMeta<T: CommonBounds, B: TensorCreator<T, Output = B> + Clone + TensorInfo<T>> {
50    pub begin: usize,
51    pub shape: Vec<i64>,
52    pub strides: Vec<i64>,
53    pub size: usize,
54    pub dtype: String,
55    pub compression_algo: CompressionAlgo,
56    pub endian: Endian,
57    pub indices: Vec<(usize, usize, usize, usize)>,
58    pub phantom: PhantomData<(T, B)>,
59}
60
61pub fn parse_header_compressed<M: Save>(file: &str) -> anyhow::Result<<M as Save>::Meta> {
62    let mut file = File::open(file)?;
63    file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
64    let mut header_infos = [0u8; 20];
65    file.read_exact(&mut header_infos)?;
66    let header = std::str::from_utf8(&header_infos)?;
67    let header_int = header.trim().parse::<u64>()?;
68    file.seek(std::io::SeekFrom::Start(header_int))?; let mut buffer3 = vec![];
70    file.read_to_end(&mut buffer3)?;
71    let info = std::str::from_utf8(&buffer3)?;
72    let ret = serde_json::from_str::<M::Meta>(info)?;
73    Ok(ret)
74}
75
76impl<
77        T: CommonBounds + FromBytes<Bytes = [u8; N]>,
78        B: TensorCreator<T, Output = B> + Clone + TensorInfo<T> + Display,
79        const N: usize,
80    > MetaLoad for TensorMeta<T, B>
81{
82    type Output = B;
83    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
84        load::<T, B, N>(file, self)
85    }
86}
87
88impl HeaderInfo {
89    pub(crate) fn parse_header_compressed(file: &str) -> Result<HashMap<String, HeaderInfo>> {
90        let mut file = File::open(file)?;
91        file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
92        let mut header_infos = [0u8; 20];
93        file.read_exact(&mut header_infos)?;
94        let header = std::str::from_utf8(&header_infos)?;
95        let header_int = header.trim().parse::<u64>()?;
96        file.seek(std::io::SeekFrom::Start(header_int))?; let mut buffer3 = vec![];
98        file.read_to_end(&mut buffer3)?;
99        let info = std::str::from_utf8(&buffer3)?;
100        let ret: HashMap<String, HeaderInfo> =
101            serde_json::from_str::<HashMap<String, Self>>(&info)?;
102        Ok(ret)
103    }
104}