use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use burn_tensor::DType;
use byteorder::{ByteOrder, LittleEndian};
use serde::{Deserialize, Serialize};
pub const MAGIC_NUMBER: u32 = 0x4255524E;
pub const FORMAT_VERSION: u16 = 0x0001;
pub const MAGIC_SIZE: usize = 4;
pub const VERSION_SIZE: usize = 2;
pub const METADATA_SIZE_FIELD_SIZE: usize = 4;
pub const HEADER_SIZE: usize = MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE;
pub const TENSOR_ALIGNMENT: u64 = 256;
#[inline]
pub fn aligned_data_section_start(metadata_size: usize) -> usize {
let unaligned_start = (HEADER_SIZE + metadata_size) as u64;
(unaligned_start.div_ceil(TENSOR_ALIGNMENT) * TENSOR_ALIGNMENT) as usize
}
pub const MAX_METADATA_SIZE: u32 = 100 * 1024 * 1024;
#[cfg(target_pointer_width = "32")]
pub const MAX_TENSOR_SIZE: usize = 2 * 1024 * 1024 * 1024;
#[cfg(not(target_pointer_width = "32"))]
pub const MAX_TENSOR_SIZE: usize = 10 * 1024 * 1024 * 1024;
pub const MAX_TENSOR_COUNT: usize = 100_000;
pub const MAX_CBOR_RECURSION_DEPTH: usize = 128;
#[cfg(feature = "std")]
pub const MAX_FILE_SIZE: u64 = 100 * 1024 * 1024 * 1024;
pub const fn magic_range() -> core::ops::Range<usize> {
let start = 0;
let end = start + MAGIC_SIZE;
start..end
}
pub const fn version_range() -> core::ops::Range<usize> {
let start = MAGIC_SIZE;
let end = start + VERSION_SIZE;
start..end
}
pub const fn metadata_size_range() -> core::ops::Range<usize> {
let start = MAGIC_SIZE + VERSION_SIZE;
let end = start + METADATA_SIZE_FIELD_SIZE;
start..end
}
const _: () = assert!(MAGIC_SIZE + VERSION_SIZE + METADATA_SIZE_FIELD_SIZE == HEADER_SIZE);
#[derive(Debug, Clone, Copy)]
pub struct BurnpackHeader {
pub magic: u32,
pub version: u16,
pub metadata_size: u32,
}
impl BurnpackHeader {
#[allow(dead_code)]
pub fn new(metadata_size: u32) -> Self {
Self {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
}
}
pub fn into_bytes(self) -> [u8; HEADER_SIZE] {
let mut bytes = [0u8; HEADER_SIZE];
LittleEndian::write_u32(&mut bytes[magic_range()], self.magic);
LittleEndian::write_u16(&mut bytes[version_range()], self.version);
LittleEndian::write_u32(&mut bytes[metadata_size_range()], self.metadata_size);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, BurnpackError> {
if bytes.len() < HEADER_SIZE {
return Err(BurnpackError::InvalidHeader);
}
let magic = LittleEndian::read_u32(&bytes[magic_range()]);
if magic != MAGIC_NUMBER {
return Err(BurnpackError::InvalidMagicNumber);
}
let version = LittleEndian::read_u16(&bytes[version_range()]);
let metadata_size = LittleEndian::read_u32(&bytes[metadata_size_range()]);
Ok(Self {
magic,
version,
metadata_size,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BurnpackMetadata {
pub tensors: BTreeMap<String, TensorDescriptor>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub metadata: BTreeMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorDescriptor {
pub dtype: DType,
pub shape: Vec<u64>,
pub data_offsets: (u64, u64),
#[serde(default, skip_serializing_if = "Option::is_none")]
pub param_id: Option<u64>,
}
#[derive(Debug)]
pub enum BurnpackError {
InvalidHeader,
InvalidMagicNumber,
InvalidVersion,
MetadataSerializationError(String),
MetadataDeserializationError(String),
IoError(String),
TensorNotFound(String),
TensorBytesSizeMismatch(String),
ValidationError(String),
}
impl core::fmt::Display for BurnpackError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
BurnpackError::InvalidHeader => write!(f, "Invalid header: insufficient bytes"),
BurnpackError::InvalidMagicNumber => write!(f, "Invalid magic number"),
BurnpackError::InvalidVersion => write!(f, "Unsupported version"),
BurnpackError::MetadataSerializationError(e) => {
write!(f, "Metadata serialization error: {}", e)
}
BurnpackError::MetadataDeserializationError(e) => {
write!(f, "Metadata deserialization error: {}", e)
}
BurnpackError::IoError(e) => write!(f, "I/O error: {}", e),
BurnpackError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
BurnpackError::TensorBytesSizeMismatch(e) => {
write!(f, "Tensor bytes size mismatch: {}", e)
}
BurnpackError::ValidationError(e) => write!(f, "Validation error: {}", e),
}
}
}
impl core::error::Error for BurnpackError {}