Skip to main content

burn_store/burnpack/
base.rs

1//! Core types and constants for the Burnpack file format.
2//!
3//! See the [parent module](crate::burnpack) for the complete file format specification.
4
5use alloc::collections::BTreeMap;
6use alloc::string::String;
7use alloc::vec::Vec;
8use burn_tensor::DType;
9use byteorder::{ByteOrder, LittleEndian};
10use serde::{Deserialize, Serialize};
11
12/// Magic number identifying a Burnpack file: "BURN" in ASCII (0x4255524E)
13/// When written to file in little-endian format, appears as "NRUB" bytes
14pub const MAGIC_NUMBER: u32 = 0x4255524E;
15
16/// Current format version
17pub const FORMAT_VERSION: u16 = 0x0001;
18
19/// Size of the magic number in bytes
20pub const MAGIC_SIZE: usize = 4;
21
22/// Size of the format version in bytes
23pub const VERSION_SIZE: usize = 2;
24
25/// Size of the metadata size field in bytes
26pub const METADATA_SIZE_FIELD_SIZE: usize = 4;
27
28/// Total header size (computed from components)
29pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE;
30
31/// Alignment for tensor data in bytes.
32///
33/// All tensor data is aligned to 256-byte boundaries to enable efficient
34/// memory-mapped (mmap) zero-copy loading. This alignment ensures:
35/// - Proper pointer alignment for all tensor element types (f64 requires 8-byte alignment)
36/// - Cache-line friendly access (most CPUs use 64-byte cache lines)
37/// - GPU memory alignment (CUDA prefers 256-byte for coalesced access)
38/// - Future-proofing for wider SIMD (AVX-512 = 64 bytes, future AVX-1024 = 128 bytes)
39///
40/// Industry alignment choices:
41/// - 256-byte: GGUF, MLX, ncnn, MNN, TNN, vLLM-AWQ, Marlin (15+ formats)
42/// - 64-byte: SafeTensors (minimum for AVX-512)
43/// - 4096-byte: Core ML
44///
45/// 256-byte alignment has negligible overhead for typical tensor sizes while
46/// providing maximum compatibility with current and future hardware.
47pub const TENSOR_ALIGNMENT: u64 = 256;
48
49/// Calculate the byte offset where the tensor data section starts.
50///
51/// The data section is padded to start at a 256-byte aligned position
52/// so that all tensor offsets (which are relative to data section) result
53/// in properly aligned absolute file positions for mmap zero-copy access.
54///
55/// This function must be used consistently by both writer and reader.
56#[inline]
57pub fn aligned_data_section_start(metadata_size: usize) -> usize {
58    let unaligned_start = (HEADER_SIZE + metadata_size) as u64;
59    // Keep multiplication in u64 space to avoid overflow on 32-bit systems
60    (unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize
61}
62
63// Security limits to prevent DoS attacks via resource exhaustion
64// These can be adjusted based on your use case
65
66/// Maximum allowed metadata size (100 MB)
67/// Prevents memory exhaustion attacks via oversized metadata claims
68pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024;
69
70/// Maximum allowed tensor size per tensor
71/// Prevents memory exhaustion attacks via oversized tensor claims
72/// 32-bit platforms: 2 GB limit (to fit within usize range)
73/// 64-bit platforms: 10 GB limit
74#[cfg(target_pointer_width = "32")]
75pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024;
76#[cfg(not(target_pointer_width = "32"))]
77pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024;
78
79/// Maximum allowed number of tensors (100,000)
80/// Prevents resource exhaustion via excessive tensor counts
81pub const MAX_TENSOR_COUNT: usize = 100_000;
82
83/// Maximum CBOR deserialization recursion depth (128 levels)
84/// Prevents stack overflow attacks via deeply nested CBOR structures
85pub const MAX_CBOR_RECURSION_DEPTH: usize = 128;
86
87/// Maximum allowed file size (100 GB)
88/// Prevents resource exhaustion from extremely large files
89/// This limit applies to file-based loading (mmap and buffered)
90#[cfg(feature = "std")]
91pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024;
92
93/// Byte range for magic number in header
94pub const fn magic_range() -> core::ops::Range<usize> {
95    let start = 0;
96    let end = start + MAGIC_SIZE;
97    start..end
98}
99
100/// Byte range for format version in header
101pub const fn version_range() -> core::ops::Range<usize> {
102    let start = MAGIC_SIZE;
103    let end = start + VERSION_SIZE;
104    start..end
105}
106
107/// Byte range for metadata size field in header
108pub const fn metadata_size_range() -> core::ops::Range<usize> {
109    let start = MAGIC_SIZE + VERSION_SIZE;
110    let end = start + METADATA_SIZE_FIELD_SIZE;
111    start..end
112}
113
114// Compile-time validation that ranges are correct
115const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE);
116
117/// Header structure for Burnpack files
118#[derive(Debug, Clone, Copy)]
119pub struct BurnpackHeader {
120    /// Magic number (4 bytes): 0x4255524E ("BURN")
121    pub magic: u32,
122    /// Format version (2 bytes)
123    pub version: u16,
124    /// Size of CBOR metadata in bytes (4 bytes)
125    pub metadata_size: u32,
126}
127
128impl BurnpackHeader {
129    /// Create a new header with the given metadata size
130    #[allow(dead_code)]
131    pub fn new(metadata_size: u32) -> Self {
132        Self {
133            magic: MAGIC_NUMBER,
134            version: FORMAT_VERSION,
135            metadata_size,
136        }
137    }
138
139    /// Serialize header into bytes
140    pub fn into_bytes(self) -> [u8; HEADER_SIZE] {
141        let mut bytes = [0u8; HEADER_SIZE];
142        LittleEndian::write_u32(&mut bytes[magic_range()], self.magic);
143        LittleEndian::write_u16(&mut bytes[version_range()], self.version);
144        LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size);
145        bytes
146    }
147
148    /// Deserialize header from bytes
149    pub fn from_bytes(bytes: &[u8]) -> Result<Self, BurnpackError> {
150        if bytes.len() < HEADER_SIZE {
151            return Err(BurnpackError::InvalidHeader);
152        }
153
154        let magic = LittleEndian::read_u32(&bytes[magic_range()]);
155        if magic != MAGIC_NUMBER {
156            return Err(BurnpackError::InvalidMagicNumber);
157        }
158
159        let version = LittleEndian::read_u16(&bytes[version_range()]);
160        let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]);
161
162        Ok(Self {
163            magic,
164            version,
165            metadata_size,
166        })
167    }
168}
169
170/// Metadata structure serialized with CBOR
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct BurnpackMetadata {
173    /// Tensor descriptors mapped by name for efficient lookup
174    pub tensors: BTreeMap<String, TensorDescriptor>,
175    /// Optional additional metadata
176    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
177    pub metadata: BTreeMap<String, String>,
178}
179
180/// Individual tensor descriptor
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct TensorDescriptor {
183    /// Data type of the tensor
184    pub dtype: DType,
185    /// Tensor shape dimensions
186    pub shape: Vec<u64>,
187    /// Byte offsets in data section (start, end)
188    pub data_offsets: (u64, u64),
189    /// Parameter ID for training state persistence matching.
190    /// Generated automatically if not present during loading.
191    #[serde(default, skip_serializing_if = "Option::is_none")]
192    pub param_id: Option<u64>,
193}
194
195/// Error types for Burnpack operations
196#[derive(Debug)]
197pub enum BurnpackError {
198    InvalidHeader,
199    InvalidMagicNumber,
200    InvalidVersion,
201    MetadataSerializationError(String),
202    MetadataDeserializationError(String),
203    IoError(String),
204    TensorNotFound(String),
205    TensorBytesSizeMismatch(String),
206    ValidationError(String),
207}
208
209impl core::fmt::Display for BurnpackError {
210    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
211        match self {
212            BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"),
213            BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"),
214            BurnpackError::InvalidVersion => write!(f, "Unsupported version"),
215            BurnpackError::MetadataSerializationError(e) => {
216                write!(f, "Metadata serialization error: {}", e)
217            }
218            BurnpackError::MetadataDeserializationError(e) => {
219                write!(f, "Metadata deserialization error: {}", e)
220            }
221            BurnpackError::IoError(e) => write!(f, "I/O error: {}", e),
222            BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
223            BurnpackError::TensorBytesSizeMismatch(e) => {
224                write!(f, "Tensor bytes size mismatch: {}", e)
225            }
226            BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e),
227        }
228    }
229}
230
231impl core::error::Error for BurnpackError {}