hpt_dataloader/
data_loader.rs

1use hpt_traits::tensor::{CommonBounds, TensorInfo};
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4use std::marker::PhantomData;
5use std::{
6    collections::HashMap,
7    fs::File,
8    io::{Read, Seek},
9};
10
11use crate::struct_save::load::load;
12use crate::struct_save::load::MetaLoad;
13use crate::struct_save::save::Save;
14use crate::CPUTensorCreator;
15use crate::CompressionAlgo;
16
17#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
18pub enum Endian {
19    Little,
20    Big,
21    Native,
22}
23
24#[derive(Debug, Serialize, Deserialize)]
25pub(crate) struct HeaderInfos {
26    pub(crate) begin: u64,
27    pub(crate) infos: Vec<HeaderInfo>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31pub(crate) struct HeaderInfo {
32    pub(crate) begin: u64,
33    pub(crate) name: String,
34    pub(crate) shape: Vec<i64>,
35    pub(crate) strides: Vec<i64>,
36    pub(crate) size: usize,
37    pub(crate) indices: Vec<(usize, usize, usize, usize)>,
38    pub(crate) compress_algo: CompressionAlgo,
39    pub(crate) dtype: String,
40    pub(crate) endian: Endian,
41}
42/// the meta data of the tensor
43#[derive(Serialize, Deserialize)]
44#[must_use]
45pub struct TensorMeta<T: CommonBounds, B: CPUTensorCreator>
46where
47    <B as CPUTensorCreator>::Output: Clone + TensorInfo<T>,
48{
49    pub begin: usize,
50    pub shape: Vec<i64>,
51    pub strides: Vec<i64>,
52    pub size: usize,
53    pub dtype: String,
54    pub compression_algo: CompressionAlgo,
55    pub endian: Endian,
56    pub indices: Vec<(usize, usize, usize, usize)>,
57    pub phantom: PhantomData<(T, B)>,
58}
59
60pub fn parse_header_compressed<M: Save, P: Into<std::path::PathBuf>>(
61    file: P,
62) -> Result<<M as Save>::Meta, Box<dyn std::error::Error>> {
63    let mut file = File::open(file.into())?;
64    file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
65    let mut header_infos = [0u8; 20];
66    file.read_exact(&mut header_infos)?;
67    let header = std::str::from_utf8(&header_infos)?;
68    let header_int = header.trim().parse::<u64>()?;
69    file.seek(std::io::SeekFrom::Start(header_int))?; // offset for header
70    let mut buffer3 = vec![];
71    file.read_to_end(&mut buffer3)?;
72    let info = std::str::from_utf8(&buffer3)?;
73    let ret = serde_json::from_str::<M::Meta>(info)?;
74    Ok(ret)
75}
76
77impl<T, B: CPUTensorCreator> MetaLoad for TensorMeta<T, B>
78where
79    T: CommonBounds + bytemuck::AnyBitPattern,
80    <B as CPUTensorCreator>::Output: Clone + TensorInfo<T> + Display + Into<B>,
81{
82    type Output = B;
83    fn load(&self, file: &mut std::fs::File) -> std::io::Result<Self::Output> {
84        Ok(load::<T, B>(file, self)?.into())
85    }
86}
87
88impl HeaderInfo {
89    pub(crate) fn parse_header_compressed(
90        file: &str,
91    ) -> Result<HashMap<String, HeaderInfo>, Box<dyn std::error::Error>> {
92        let mut file = File::open(file)?;
93        file.read_exact(&mut [0u8; "FASTTENSOR".len()])?;
94        let mut header_infos = [0u8; 20];
95        file.read_exact(&mut header_infos)?;
96        let header = std::str::from_utf8(&header_infos)?;
97        let header_int = header.trim().parse::<u64>()?;
98        file.seek(std::io::SeekFrom::Start(header_int))?; // offset for header
99        let mut buffer3 = vec![];
100        file.read_to_end(&mut buffer3)?;
101        let info = std::str::from_utf8(&buffer3)?;
102        let ret: HashMap<String, HeaderInfo> =
103            serde_json::from_str::<HashMap<String, Self>>(&info)?;
104        Ok(ret)
105    }
106}