hpt_dataloader/
data_loader.rs

1use anyhow::Result;
2use hpt_traits::CommonBounds;
3use hpt_traits::TensorInfo;
4use num::traits::FromBytes;
5use serde::{Deserialize, Serialize};
6use std::fmt::Display;
7use std::marker::PhantomData;
8use std::{
9    collections::HashMap,
10    fs::File,
11    io::{Read, Seek},
12};
13
14use crate::struct_save::load::load;
15use crate::struct_save::load::MetaLoad;
16use crate::struct_save::save::Save;
17use crate::CPUTensorCreator;
18use crate::CompressionAlgo;
19
20#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
21pub enum Endian {
22    Little,
23    Big,
24    Native,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28pub(crate) struct HeaderInfos {
29    pub(crate) begin: u64,
30    pub(crate) infos: Vec<HeaderInfo>,
31}
32
33#[derive(Debug, Serialize, Deserialize)]
34pub(crate) struct HeaderInfo {
35    pub(crate) begin: u64,
36    pub(crate) name: String,
37    pub(crate) shape: Vec<i64>,
38    pub(crate) strides: Vec<i64>,
39    pub(crate) size: usize,
40    pub(crate) indices: Vec<(usize, usize, usize, usize)>,
41    pub(crate) compress_algo: CompressionAlgo,
42    pub(crate) dtype: String,
43    pub(crate) endian: Endian,
44}
45/// the meta data of the tensor
46#[derive(Serialize, Deserialize)]
47#[must_use]
48pub struct TensorMeta<T: CommonBounds, B: CPUTensorCreator<T>>
49where
50    <B as CPUTensorCreator<T>>::Output: Clone + TensorInfo<T>,
51{
52    pub begin: usize,
53    pub shape: Vec<i64>,
54    pub strides: Vec<i64>,
55    pub size: usize,
56    pub dtype: String,
57    pub compression_algo: CompressionAlgo,
58    pub endian: Endian,
59    pub indices: Vec<(usize, usize, usize, usize)>,
60    pub phantom: PhantomData<(T, B)>,
61}
62
63pub fn parse_header_compressed<M: Save>(file: &str) -> anyhow::Result<<M as Save>::Meta> {
64    let mut file = File::open(file)?;
65    file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
66    let mut header_infos = [0u8; 20];
67    file.read_exact(&mut header_infos)?;
68    let header = std::str::from_utf8(&header_infos)?;
69    let header_int = header.trim().parse::<u64>()?;
70    file.seek(std::io::SeekFrom::Start(header_int))?; // offset for header
71    let mut buffer3 = vec![];
72    file.read_to_end(&mut buffer3)?;
73    let info = std::str::from_utf8(&buffer3)?;
74    let ret = serde_json::from_str::<M::Meta>(info)?;
75    Ok(ret)
76}
77
78impl<T: CommonBounds + FromBytes<Bytes = [u8; N]>, B: CPUTensorCreator<T>, const N: usize> MetaLoad
79    for TensorMeta<T, B>
80where
81    <B as CPUTensorCreator<T>>::Output: Clone + TensorInfo<T> + Display + Into<B>,
82{
83    type Output = B;
84    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
85        Ok(load::<T, B, N>(file, self)?.into())
86    }
87}
88
89impl HeaderInfo {
90    pub(crate) fn parse_header_compressed(file: &str) -> Result<HashMap<String, HeaderInfo>> {
91        let mut file = File::open(file)?;
92        file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
93        let mut header_infos = [0u8; 20];
94        file.read_exact(&mut header_infos)?;
95        let header = std::str::from_utf8(&header_infos)?;
96        let header_int = header.trim().parse::<u64>()?;
97        file.seek(std::io::SeekFrom::Start(header_int))?; // offset for header
98        let mut buffer3 = vec![];
99        file.read_to_end(&mut buffer3)?;
100        let info = std::str::from_utf8(&buffer3)?;
101        let ret: HashMap<String, HeaderInfo> =
102            serde_json::from_str::<HashMap<String, Self>>(&info)?;
103        Ok(ret)
104    }
105}