use super::base::{
BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start,
};
use crate::TensorSnapshot;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::Bytes;
#[cfg(feature = "std")]
use std::fs::File;
#[cfg(feature = "std")]
use std::io::Write;
#[cfg(feature = "std")]
use std::path::Path;
#[inline]
const fn align_offset(offset: u64, alignment: u64) -> u64 {
offset.div_ceil(alignment) * alignment
}
pub struct BurnpackWriter {
pub(crate) snapshots: Vec<TensorSnapshot>,
pub(crate) metadata: BTreeMap<String, String>,
}
impl BurnpackWriter {
pub fn new(snapshots: Vec<TensorSnapshot>) -> Self {
Self {
snapshots,
metadata: BTreeMap::new(),
}
}
pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
self.metadata.insert(key.to_string(), value.to_string());
self
}
fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec<u8>), BurnpackError> {
let mut tensors = BTreeMap::new();
let mut current_offset = 0u64;
for snapshot in &self.snapshots {
let data_len = snapshot.data_len() as u64;
let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT);
let end = aligned_start.checked_add(data_len).ok_or_else(|| {
BurnpackError::IoError(format!(
"Tensor offset overflow: {} + {} exceeds maximum",
aligned_start, data_len
))
})?;
tensors.insert(
snapshot.full_path(),
TensorDescriptor {
dtype: snapshot.dtype,
shape: snapshot.shape.iter().map(|&s| s as u64).collect(),
data_offsets: (aligned_start, end),
param_id: snapshot.tensor_id.map(|id| id.val()),
},
);
current_offset = end;
}
let metadata = BurnpackMetadata {
tensors,
metadata: self.metadata.clone(),
};
let mut metadata_bytes = Vec::new();
ciborium::ser::into_writer(&metadata, &mut metadata_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
Ok((metadata, metadata_bytes))
}
pub fn size(&self) -> Result<usize, BurnpackError> {
let (metadata, metadata_bytes) = self.build_metadata()?;
let data_section_start = aligned_data_section_start(metadata_bytes.len());
let data_size = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0) as usize;
Ok(data_section_start + data_size)
}
pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> {
let (metadata, metadata_bytes) = self.build_metadata()?;
let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Metadata size {} exceeds maximum of {} bytes",
metadata_bytes.len(),
u32::MAX
))
})?;
let header = BurnpackHeader {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
};
let data_section_start = aligned_data_section_start(metadata_bytes.len());
let data_size = metadata
.tensors
.values()
.map(|t| t.data_offsets.1)
.max()
.unwrap_or(0) as usize;
let total_size = data_section_start + data_size;
if buffer.len() < total_size {
return Err(BurnpackError::IoError(format!(
"Buffer too small: need {} bytes, got {} bytes",
total_size,
buffer.len()
)));
}
let mut offset = 0;
let header_bytes = header.into_bytes();
buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes);
offset += HEADER_SIZE;
buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes);
offset += metadata_bytes.len();
if data_section_start > offset {
buffer[offset..data_section_start].fill(0);
offset = data_section_start;
}
for snapshot in &self.snapshots {
let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
BurnpackError::IoError(format!(
"Internal error: tensor '{}' not found in metadata",
snapshot.full_path()
))
})?;
let aligned_offset = descriptor.data_offsets.0 as usize;
let target_offset = data_section_start + aligned_offset;
if target_offset > offset {
buffer[offset..target_offset].fill(0);
offset = target_offset;
}
let expected_len = snapshot.data_len();
let data = snapshot.to_data().map_err(|e| {
BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
})?;
let actual_len = data.bytes.len();
if actual_len != expected_len {
return Err(BurnpackError::IoError(format!(
"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
snapshot.full_path(),
expected_len,
actual_len
)));
}
buffer[offset..offset + actual_len].copy_from_slice(&data.bytes);
offset += actual_len;
}
Ok(())
}
pub fn to_bytes(&self) -> Result<Bytes, BurnpackError> {
let size = self.size()?;
let mut buffer = vec![0u8; size];
self.write_into(&mut buffer)?;
Ok(Bytes::from_bytes_vec(buffer))
}
#[cfg(feature = "std")]
pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), BurnpackError> {
let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
let (metadata, metadata_bytes) = self.build_metadata()?;
let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
BurnpackError::IoError(format!(
"Metadata size {} exceeds maximum of {} bytes",
metadata_bytes.len(),
u32::MAX
))
})?;
let header = BurnpackHeader {
magic: MAGIC_NUMBER,
version: FORMAT_VERSION,
metadata_size,
};
file.write_all(&header.into_bytes())
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
file.write_all(&metadata_bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
let data_section_start = aligned_data_section_start(metadata_bytes.len());
let current_file_pos = HEADER_SIZE + metadata_bytes.len();
if data_section_start > current_file_pos {
let padding_size = data_section_start - current_file_pos;
let padding = vec![0u8; padding_size];
file.write_all(&padding)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
}
let mut data_offset = 0usize;
for snapshot in &self.snapshots {
let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
BurnpackError::IoError(format!(
"Internal error: tensor '{}' not found in metadata",
snapshot.full_path()
))
})?;
let aligned_offset = descriptor.data_offsets.0 as usize;
if aligned_offset > data_offset {
let padding_size = aligned_offset - data_offset;
let padding = vec![0u8; padding_size];
file.write_all(&padding)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
data_offset = aligned_offset;
}
let expected_len = snapshot.data_len();
let data = snapshot.to_data().map_err(|e| {
BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
})?;
let actual_len = data.bytes.len();
if actual_len != expected_len {
return Err(BurnpackError::IoError(format!(
"Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
snapshot.full_path(),
expected_len,
actual_len
)));
}
file.write_all(&data.bytes)
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
data_offset += actual_len;
}
file.flush()
.map_err(|e| BurnpackError::IoError(e.to_string()))?;
Ok(())
}
}