#[cfg(feature = "std")]
use super::base::MAX_FILE_SIZE;
use super::base::{
BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
MAX_CBOR_RECURSION_DEPTH, MAX_METADATA_SIZE, MAX_TENSOR_COUNT, MAX_TENSOR_SIZE,
aligned_data_section_start,
};
use crate::TensorSnapshot;
use alloc::format;
use alloc::rc::Rc;
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use burn_core::module::ParamId;
use burn_tensor::{Bytes, Shape, TensorData};
#[cfg(feature = "std")]
use std::cell::RefCell;
#[cfg(feature = "std")]
use std::fs::File;
#[cfg(feature = "std")]
use std::io::{Read, Seek};
#[cfg(feature = "std")]
use std::path::Path;
pub(crate) enum StorageBackend {
Memory(Rc<Bytes>),
#[cfg(feature = "std")]
#[allow(dead_code)]
FileBuffered { file: Rc<RefCell<File>> },
}
impl StorageBackend {
pub(crate) fn read_into(&self, bytes: &mut [u8], offset: usize) -> Result<(), BurnpackError> {
match self {
StorageBackend::Memory(data) => {
let data_bytes = data.as_ref();
let end = offset.checked_add(bytes.len()).ok_or_else(|| {
BurnpackError::IoError(format!(
"Offset overflow: offset {} + length {} exceeds maximum",
offset,
bytes.len()
))
})?;
if end > data_bytes.len() {
return Err(BurnpackError::IoError(format!(
"Read out of bounds: requested {}..{} but data length is {}",
offset,
end,
data_bytes.len()
)));
}
bytes.copy_from_slice(&data_bytes[offset..end]);
Ok(())
}
#[cfg(feature = "std")]
StorageBackend::FileBuffered { file } => {
use std::io::SeekFrom;
let mut file = file.borrow_mut();
file.seek(SeekFrom::Start(offset as u64)).map_err(|e| {
BurnpackError::IoError(format!("Failed to seek in file: {}", e))
})?;
file.read_exact(bytes).map_err(|e| {
BurnpackError::IoError(format!("Failed to read from file: {}", e))
})?;
Ok(())
}
}
}
#[allow(dead_code)]
pub(crate) fn as_bytes(&self) -> Result<&[u8], BurnpackError> {
match self {
StorageBackend::Memory(data) => Ok(data.as_ref()),
#[cfg(feature = "std")]
StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(
"Cannot get full bytes reference for FileBuffered backend".into(),
)),
}
}
pub(crate) fn slice_bytes(&self, start: usize, end: usize) -> Result<Bytes, BurnpackError> {
if end < start {
return Err(BurnpackError::IoError(format!(
"Invalid slice range: end ({}) < start ({})",
end, start
)));
}
match self {
StorageBackend::Memory(data) => {
let cloned = (**data).clone();
let (_, right) = cloned.split(start).map_err(|(_, e)| {
BurnpackError::IoError(format!("Failed to split at start {}: {:?}", start, e))
})?;
let slice_len = end - start;
let (middle, _) = right.split(slice_len).map_err(|(_, e)| {
BurnpackError::IoError(format!(
"Failed to split at length {}: {:?}",
slice_len, e
))
})?;
Ok(middle)
}
#[cfg(feature = "std")]
StorageBackend::FileBuffered { .. } => Err(BurnpackError::IoError(
"Zero-copy not supported for buffered file reading. Use from_file() with memmap feature for zero-copy loading.".into(),
)),
}
}
}
pub struct BurnpackReader {
pub(crate) metadata: BurnpackMetadata,
pub(crate) storage: StorageBackend,
pub(crate) data_offset: usize,
}
impl BurnpackReader {
pub fn from_bytes(bytes: Bytes) -> Result<Self, BurnpackError> {
if bytes.len() < HEADER_SIZE {
return Err(BurnpackError::InvalidHeader);
}
let header = BurnpackHeader::from_bytes(&bytes[..HEADER_SIZE])?;
if header.magic != MAGIC_NUMBER {
return Err(BurnpackError::InvalidMagicNumber);
}
if header.version > FORMAT_VERSION {
return Err(BurnpackError::InvalidVersion);
}
if header.metadata_size > MAX_METADATA_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
header.metadata_size, MAX_METADATA_SIZE
)));
}
let metadata_start = HEADER_SIZE;
let metadata_end = metadata_start
.checked_add(header.metadata_size as usize)
.ok_or_else(|| {
BurnpackError::IoError(format!(
"Metadata size overflow: {} + {}",
metadata_start, header.metadata_size
))
})?;
if bytes.len() < metadata_end {
return Err(BurnpackError::InvalidHeader);
}
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
&bytes[metadata_start..metadata_end],
MAX_CBOR_RECURSION_DEPTH,
)
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
if metadata.tensors.len() > MAX_TENSOR_COUNT {
return Err(BurnpackError::ValidationError(format!(
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
metadata.tensors.len(),
MAX_TENSOR_COUNT
)));
}
if !metadata.tensors.is_empty() {
let max_data_offset = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0);
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Data offset {} exceeds platform maximum",
max_data_offset
))
})?;
let min_file_size =
metadata_end
.checked_add(max_data_offset_usize)
.ok_or_else(|| {
BurnpackError::ValidationError("File size calculation overflow".into())
})?;
if bytes.len() < min_file_size {
return Err(BurnpackError::ValidationError(format!(
"File truncated: expected at least {} bytes, got {} bytes",
min_file_size,
bytes.len()
)));
}
}
Ok(Self {
metadata,
storage: StorageBackend::Memory(Rc::new(bytes)),
data_offset: aligned_data_section_start(header.metadata_size as usize),
})
}
#[cfg(all(feature = "std", feature = "memmap"))]
pub(crate) fn from_file_mmap<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
let file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
let file_size = file
.metadata()
.map_err(|e| BurnpackError::IoError(e.to_string()))?
.len();
if file_size > MAX_FILE_SIZE {
return Err(BurnpackError::ValidationError(format!(
"File size {} bytes exceeds maximum allowed size of {} bytes",
file_size, MAX_FILE_SIZE
)));
}
let mmap = unsafe {
memmap2::MmapOptions::new()
.map(&file)
.map_err(|e| BurnpackError::IoError(e.to_string()))?
};
if mmap.len() < HEADER_SIZE {
return Err(BurnpackError::InvalidHeader);
}
let header = BurnpackHeader::from_bytes(&mmap[..HEADER_SIZE])?;
if header.magic != MAGIC_NUMBER {
return Err(BurnpackError::InvalidMagicNumber);
}
if header.version > FORMAT_VERSION {
return Err(BurnpackError::InvalidVersion);
}
if header.metadata_size > MAX_METADATA_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
header.metadata_size, MAX_METADATA_SIZE
)));
}
let metadata_start = HEADER_SIZE;
let metadata_end = metadata_start
.checked_add(header.metadata_size as usize)
.ok_or_else(|| {
BurnpackError::IoError(format!(
"Metadata size overflow: {} + {}",
metadata_start, header.metadata_size
))
})?;
if mmap.len() < metadata_end {
return Err(BurnpackError::InvalidHeader);
}
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
&mmap[metadata_start..metadata_end],
MAX_CBOR_RECURSION_DEPTH,
)
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
if metadata.tensors.len() > MAX_TENSOR_COUNT {
return Err(BurnpackError::ValidationError(format!(
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
metadata.tensors.len(),
MAX_TENSOR_COUNT
)));
}
if !metadata.tensors.is_empty() {
let max_data_offset = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0);
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Data offset {} exceeds platform maximum",
max_data_offset
))
})?;
let min_file_size =
metadata_end
.checked_add(max_data_offset_usize)
.ok_or_else(|| {
BurnpackError::ValidationError("File size calculation overflow".into())
})?;
if mmap.len() < min_file_size {
return Err(BurnpackError::ValidationError(format!(
"File truncated: expected at least {} bytes, got {} bytes",
min_file_size,
mmap.len()
)));
}
}
let shared_bytes = bytes::Bytes::from_owner(mmap);
let bytes = Bytes::from_shared(shared_bytes, burn_tensor::AllocationProperty::File);
Ok(Self {
metadata,
storage: StorageBackend::Memory(Rc::new(bytes)),
data_offset: aligned_data_section_start(header.metadata_size as usize),
})
}
#[cfg(feature = "std")]
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
#[cfg(feature = "memmap")]
{
Self::from_file_mmap(path)
}
#[cfg(not(feature = "memmap"))]
{
Self::from_file_buffered(path)
}
}
#[cfg(feature = "std")]
#[allow(dead_code)]
pub(crate) fn from_file_buffered<P: AsRef<Path>>(path: P) -> Result<Self, BurnpackError> {
let mut file = File::open(&path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
let file_size = file
.metadata()
.map_err(|e| BurnpackError::IoError(e.to_string()))?
.len();
if file_size > MAX_FILE_SIZE {
return Err(BurnpackError::ValidationError(format!(
"File size {} bytes exceeds maximum allowed size of {} bytes",
file_size, MAX_FILE_SIZE
)));
}
let mut header_bytes = [0u8; HEADER_SIZE];
file.read_exact(&mut header_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
let header = BurnpackHeader::from_bytes(&header_bytes)?;
if header.version > FORMAT_VERSION {
return Err(BurnpackError::InvalidVersion);
}
if header.metadata_size > MAX_METADATA_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Metadata size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
header.metadata_size, MAX_METADATA_SIZE
)));
}
let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
file.read_exact(&mut metadata_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
let metadata: BurnpackMetadata = ciborium::de::from_reader_with_recursion_limit(
metadata_bytes.as_slice(),
MAX_CBOR_RECURSION_DEPTH,
)
.map_err(|e| BurnpackError::MetadataDeserializationError(e.to_string()))?;
if metadata.tensors.len() > MAX_TENSOR_COUNT {
return Err(BurnpackError::ValidationError(format!(
"File contains {} tensors, exceeding maximum of {} (potential DoS attack)",
metadata.tensors.len(),
MAX_TENSOR_COUNT
)));
}
let metadata_end = HEADER_SIZE
.checked_add(header.metadata_size as usize)
.ok_or_else(|| {
BurnpackError::IoError(format!(
"Metadata size overflow: {} + {}",
HEADER_SIZE, header.metadata_size
))
})?;
if !metadata.tensors.is_empty() {
let max_data_offset = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0);
let max_data_offset_usize: usize = max_data_offset.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Data offset {} exceeds platform maximum",
max_data_offset
))
})?;
let min_file_size =
metadata_end
.checked_add(max_data_offset_usize)
.ok_or_else(|| {
BurnpackError::ValidationError("File size calculation overflow".into())
})?;
let file_size = file
.metadata()
.map_err(|e| BurnpackError::IoError(e.to_string()))?
.len() as usize;
if file_size < min_file_size {
return Err(BurnpackError::ValidationError(format!(
"File truncated: expected at least {} bytes, got {} bytes",
min_file_size, file_size
)));
}
}
Ok(Self {
metadata,
storage: StorageBackend::FileBuffered {
file: Rc::new(RefCell::new(file)),
},
data_offset: aligned_data_section_start(header.metadata_size as usize),
})
}
pub fn get_snapshots(&self) -> Result<Vec<TensorSnapshot>, BurnpackError> {
self.get_snapshots_internal(false)
}
pub fn get_snapshots_zero_copy(
&self,
zero_copy: bool,
) -> Result<Vec<TensorSnapshot>, BurnpackError> {
self.get_snapshots_internal(zero_copy)
}
fn get_snapshots_internal(
&self,
zero_copy: bool,
) -> Result<Vec<TensorSnapshot>, BurnpackError> {
let mut snapshots = Vec::new();
for (name, descriptor) in &self.metadata.tensors {
let shape: Shape = Shape::from(descriptor
.shape
.iter()
.map(|&s| {
s.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted shape data: dimension {} exceeds platform maximum",
name, s
))
})
})
.collect::<Result<Vec<usize>, BurnpackError>>()?);
let dtype = descriptor.dtype;
let storage = match &self.storage {
StorageBackend::Memory(data) => StorageBackend::Memory(data.clone()),
#[cfg(feature = "std")]
StorageBackend::FileBuffered { file } => {
StorageBackend::FileBuffered { file: file.clone() }
}
};
let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum",
name, descriptor.data_offsets.0
))
})?;
let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum",
name, descriptor.data_offsets.1
))
})?;
let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: start offset overflow {} + {}",
name, self.data_offset, offset_start
))
})?;
let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {
BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: end offset overflow {} + {}",
name, self.data_offset, offset_end
))
})?;
let shape_for_closure = shape.clone();
if end < start {
return Err(BurnpackError::ValidationError(format!(
"Tensor '{}' has corrupted offset data: end offset {} < start offset {}",
name, end, start
)));
}
let tensor_size = end - start;
if tensor_size > MAX_TENSOR_SIZE {
return Err(BurnpackError::ValidationError(format!(
"Tensor '{}' size {} exceeds maximum allowed size of {} bytes (potential DoS attack)",
name, tensor_size, MAX_TENSOR_SIZE
)));
}
let tensor_id = descriptor
.param_id
.map(ParamId::from)
.unwrap_or_else(ParamId::new);
let data_fn: Rc<dyn Fn() -> Result<TensorData, crate::TensorSnapshotError>> =
if zero_copy {
Rc::new(move || {
let bytes = storage.slice_bytes(start, end).map_err(|e| {
crate::TensorSnapshotError::IoError(format!(
"Zero-copy slice failed: {}",
e
))
})?;
Ok(TensorData::from_bytes(
bytes,
shape_for_closure.clone(),
dtype,
))
})
} else {
Rc::new(move || {
let len = end - start;
let mut data_bytes = vec![0u8; len];
storage.read_into(&mut data_bytes, start).map_err(|e| {
crate::TensorSnapshotError::IoError(format!(
"Failed to read tensor data: {}",
e
))
})?;
Ok(TensorData::from_bytes_vec(
data_bytes,
shape_for_closure.clone(),
dtype,
))
})
};
let snapshot = TensorSnapshot::from_closure(
data_fn,
dtype,
shape,
name.split('.').map(|s| s.to_string()).collect(),
vec![], tensor_id, );
snapshots.push(snapshot);
}
Ok(snapshots)
}
#[allow(dead_code)]
pub(crate) fn get_tensor_snapshot(&self, name: &str) -> Result<TensorSnapshot, BurnpackError> {
let snapshots = self.get_snapshots()?;
snapshots
.into_iter()
.find(|s| s.full_path() == name)
.ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))
}
#[allow(dead_code)]
pub(crate) fn tensor_names(&self) -> Vec<&str> {
self.metadata
.tensors
.keys()
.map(|name| name.as_str())
.collect()
}
#[allow(dead_code)]
pub(crate) fn metadata(&self) -> &BurnpackMetadata {
&self.metadata
}
#[allow(dead_code)]
pub(crate) fn get_tensor_data(&self, name: &str) -> Result<Vec<u8>, BurnpackError> {
let descriptor = self
.metadata
.tensors
.get(name)
.ok_or_else(|| BurnpackError::TensorNotFound(name.to_string()))?;
let offset_start: usize = descriptor.data_offsets.0.try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: start offset {} exceeds platform maximum",
name, descriptor.data_offsets.0
))
})?;
let offset_end: usize = descriptor.data_offsets.1.try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: end offset {} exceeds platform maximum",
name, descriptor.data_offsets.1
))
})?;
let start = self.data_offset.checked_add(offset_start).ok_or_else(|| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: start offset overflow {} + {}",
name, self.data_offset, offset_start
))
})?;
let end = self.data_offset.checked_add(offset_end).ok_or_else(|| {
BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: end offset overflow {} + {}",
name, self.data_offset, offset_end
))
})?;
if end < start {
return Err(BurnpackError::IoError(format!(
"Tensor '{}' has corrupted offset data: end offset {} < start offset {}",
name, end, start
)));
}
let len = end - start;
let mut buffer = vec![0u8; len];
self.storage.read_into(&mut buffer, start)?;
Ok(buffer)
}
}