burn_store/burnpack/
base.rs1use alloc::collections::BTreeMap;
6use alloc::string::String;
7use alloc::vec::Vec;
8use burn_tensor::DType;
9use byteorder::{ByteOrder, LittleEndian};
10use serde::{Deserialize, Serialize};
11
12pub const MAGIC_NUMBER: u32 = 0x4255524E;
15
16pub const FORMAT_VERSION: u16 = 0x0001;
18
19pub const MAGIC_SIZE: usize = 4;
21
22pub const VERSION_SIZE: usize = 2;
24
25pub const METADATA_SIZE_FIELD_SIZE: usize = 4;
27
28pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE;
30
31pub const TENSOR_ALIGNMENT: u64 = 256;
48
49#[inline]
57pub fn aligned_data_section_start(metadata_size: usize) -> usize {
58 let unaligned_start = (HEADER_SIZE + metadata_size) as u64;
59 (unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize
61}
62
63pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024;
69
70#[cfg(target_pointer_width = "32")]
75pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024;
76#[cfg(not(target_pointer_width = "32"))]
77pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024;
78
79pub const MAX_TENSOR_COUNT: usize = 100_000;
82
83pub const MAX_CBOR_RECURSION_DEPTH: usize = 128;
86
87#[cfg(feature = "std")]
91pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024;
92
93pub const fn magic_range() -> core::ops::Range<usize> {
95 let start = 0;
96 let end = start + MAGIC_SIZE;
97 start..end
98}
99
100pub const fn version_range() -> core::ops::Range<usize> {
102 let start = MAGIC_SIZE;
103 let end = start + VERSION_SIZE;
104 start..end
105}
106
107pub const fn metadata_size_range() -> core::ops::Range<usize> {
109 let start = MAGIC_SIZE + VERSION_SIZE;
110 let end = start + METADATA_SIZE_FIELD_SIZE;
111 start..end
112}
113
114const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE);
116
117#[derive(Debug, Clone, Copy)]
119pub struct BurnpackHeader {
120 pub magic: u32,
122 pub version: u16,
124 pub metadata_size: u32,
126}
127
128impl BurnpackHeader {
129 #[allow(dead_code)]
131 pub fn new(metadata_size: u32) -> Self {
132 Self {
133 magic: MAGIC_NUMBER,
134 version: FORMAT_VERSION,
135 metadata_size,
136 }
137 }
138
139 pub fn into_bytes(self) -> [u8; HEADER_SIZE] {
141 let mut bytes = [0u8; HEADER_SIZE];
142 LittleEndian::write_u32(&mut bytes[magic_range()], self.magic);
143 LittleEndian::write_u16(&mut bytes[version_range()], self.version);
144 LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size);
145 bytes
146 }
147
148 pub fn from_bytes(bytes: &[u8]) -> Result<Self, BurnpackError> {
150 if bytes.len() < HEADER_SIZE {
151 return Err(BurnpackError::InvalidHeader);
152 }
153
154 let magic = LittleEndian::read_u32(&bytes[magic_range()]);
155 if magic != MAGIC_NUMBER {
156 return Err(BurnpackError::InvalidMagicNumber);
157 }
158
159 let version = LittleEndian::read_u16(&bytes[version_range()]);
160 let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]);
161
162 Ok(Self {
163 magic,
164 version,
165 metadata_size,
166 })
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct BurnpackMetadata {
173 pub tensors: BTreeMap<String, TensorDescriptor>,
175 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
177 pub metadata: BTreeMap<String, String>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct TensorDescriptor {
183 pub dtype: DType,
185 pub shape: Vec<u64>,
187 pub data_offsets: (u64, u64),
189 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub param_id: Option<u64>,
193}
194
195#[derive(Debug)]
197pub enum BurnpackError {
198 InvalidHeader,
199 InvalidMagicNumber,
200 InvalidVersion,
201 MetadataSerializationError(String),
202 MetadataDeserializationError(String),
203 IoError(String),
204 TensorNotFound(String),
205 TensorBytesSizeMismatch(String),
206 ValidationError(String),
207}
208
209impl core::fmt::Display for BurnpackError {
210 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
211 match self {
212 BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"),
213 BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"),
214 BurnpackError::InvalidVersion => write!(f, "Unsupported version"),
215 BurnpackError::MetadataSerializationError(e) => {
216 write!(f, "Metadata serialization error: {}", e)
217 }
218 BurnpackError::MetadataDeserializationError(e) => {
219 write!(f, "Metadata deserialization error: {}", e)
220 }
221 BurnpackError::IoError(e) => write!(f, "I/O error: {}", e),
222 BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
223 BurnpackError::TensorBytesSizeMismatch(e) => {
224 write!(f, "Tensor bytes size mismatch: {}", e)
225 }
226 BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e),
227 }
228 }
229}
230
231impl core::error::Error for BurnpackError {}