burn_store/burnpack/
writer.rs1use super::base::{
2 BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
3 TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start,
4};
5use crate::TensorSnapshot;
6use alloc::collections::BTreeMap;
7use alloc::format;
8use alloc::string::{String, ToString};
9use alloc::vec;
10use alloc::vec::Vec;
11use burn_tensor::Bytes;
12
13#[cfg(feature = "std")]
14use std::fs::File;
15#[cfg(feature = "std")]
16use std::io::Write;
17#[cfg(feature = "std")]
18use std::path::Path;
19
20#[inline]
24const fn align_offset(offset: u64, alignment: u64) -> u64 {
25 offset.div_ceil(alignment) * alignment
26}
27
28pub struct BurnpackWriter {
30 pub(crate) snapshots: Vec<TensorSnapshot>,
32 pub(crate) metadata: BTreeMap<String, String>,
34}
35
36impl BurnpackWriter {
37 pub fn new(snapshots: Vec<TensorSnapshot>) -> Self {
39 Self {
40 snapshots,
41 metadata: BTreeMap::new(),
42 }
43 }
44
45 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
47 self.metadata.insert(key.to_string(), value.to_string());
48 self
49 }
50
51 fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec<u8>), BurnpackError> {
53 let mut tensors = BTreeMap::new();
55 let mut current_offset = 0u64;
56
57 for snapshot in &self.snapshots {
58 let data_len = snapshot.data_len() as u64;
59
60 let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT);
62 let end = aligned_start.checked_add(data_len).ok_or_else(|| {
63 BurnpackError::IoError(format!(
64 "Tensor offset overflow: {} + {} exceeds maximum",
65 aligned_start, data_len
66 ))
67 })?;
68
69 tensors.insert(
70 snapshot.full_path(),
71 TensorDescriptor {
72 dtype: snapshot.dtype,
73 shape: snapshot.shape.iter().map(|&s| s as u64).collect(),
74 data_offsets: (aligned_start, end),
75 param_id: snapshot.tensor_id.map(|id| id.val()),
76 },
77 );
78
79 current_offset = end;
80 }
81
82 let metadata = BurnpackMetadata {
84 tensors,
85 metadata: self.metadata.clone(),
86 };
87
88 let mut metadata_bytes = Vec::new();
90 ciborium::ser::into_writer(&metadata, &mut metadata_bytes)
91 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
92
93 Ok((metadata, metadata_bytes))
94 }
95
96 pub fn size(&self) -> Result<usize, BurnpackError> {
101 let (metadata, metadata_bytes) = self.build_metadata()?;
102
103 let data_section_start = aligned_data_section_start(metadata_bytes.len());
105
106 let data_size = metadata
109 .tensors
110 .values()
111 .map(|t| t.data_offsets.1)
112 .max()
113 .unwrap_or(0) as usize;
114
115 Ok(data_section_start + data_size)
116 }
117
118 pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> {
132 let (metadata, metadata_bytes) = self.build_metadata()?;
133
134 let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
136 BurnpackError::IoError(format!(
137 "Metadata size {} exceeds maximum of {} bytes",
138 metadata_bytes.len(),
139 u32::MAX
140 ))
141 })?;
142
143 let header = BurnpackHeader {
145 magic: MAGIC_NUMBER,
146 version: FORMAT_VERSION,
147 metadata_size,
148 };
149
150 let data_section_start = aligned_data_section_start(metadata_bytes.len());
152
153 let data_size = metadata
155 .tensors
156 .values()
157 .map(|t| t.data_offsets.1)
158 .max()
159 .unwrap_or(0) as usize;
160 let total_size = data_section_start + data_size;
161
162 if buffer.len() < total_size {
164 return Err(BurnpackError::IoError(format!(
165 "Buffer too small: need {} bytes, got {} bytes",
166 total_size,
167 buffer.len()
168 )));
169 }
170
171 let mut offset = 0;
172
173 let header_bytes = header.into_bytes();
175 buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes);
176 offset += HEADER_SIZE;
177
178 buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes);
180 offset += metadata_bytes.len();
181
182 if data_section_start > offset {
184 buffer[offset..data_section_start].fill(0);
185 offset = data_section_start;
186 }
187
188 for snapshot in &self.snapshots {
190 let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
192 BurnpackError::IoError(format!(
193 "Internal error: tensor '{}' not found in metadata",
194 snapshot.full_path()
195 ))
196 })?;
197 let aligned_offset = descriptor.data_offsets.0 as usize;
198 let target_offset = data_section_start + aligned_offset;
199
200 if target_offset > offset {
202 buffer[offset..target_offset].fill(0);
203 offset = target_offset;
204 }
205
206 let expected_len = snapshot.data_len();
207 let data = snapshot.to_data().map_err(|e| {
208 BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
209 })?;
210 let actual_len = data.bytes.len();
211
212 if actual_len != expected_len {
214 return Err(BurnpackError::IoError(format!(
215 "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
216 snapshot.full_path(),
217 expected_len,
218 actual_len
219 )));
220 }
221
222 buffer[offset..offset + actual_len].copy_from_slice(&data.bytes);
223 offset += actual_len;
224 }
225
226 Ok(())
227 }
228
229 pub fn to_bytes(&self) -> Result<Bytes, BurnpackError> {
234 let size = self.size()?;
235 let mut buffer = vec![0u8; size];
236 self.write_into(&mut buffer)?;
237 Ok(Bytes::from_bytes_vec(buffer))
238 }
239
240 #[cfg(feature = "std")]
242 pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), BurnpackError> {
243 let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
244
245 let (metadata, metadata_bytes) = self.build_metadata()?;
246
247 let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
249 BurnpackError::IoError(format!(
250 "Metadata size {} exceeds maximum of {} bytes",
251 metadata_bytes.len(),
252 u32::MAX
253 ))
254 })?;
255
256 let header = BurnpackHeader {
258 magic: MAGIC_NUMBER,
259 version: FORMAT_VERSION,
260 metadata_size,
261 };
262
263 file.write_all(&header.into_bytes())
264 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
265
266 file.write_all(&metadata_bytes)
268 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
269
270 let data_section_start = aligned_data_section_start(metadata_bytes.len());
272 let current_file_pos = HEADER_SIZE + metadata_bytes.len();
273
274 if data_section_start > current_file_pos {
276 let padding_size = data_section_start - current_file_pos;
277 let padding = vec![0u8; padding_size];
278 file.write_all(&padding)
279 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
280 }
281
282 let mut data_offset = 0usize;
284
285 for snapshot in &self.snapshots {
287 let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
289 BurnpackError::IoError(format!(
290 "Internal error: tensor '{}' not found in metadata",
291 snapshot.full_path()
292 ))
293 })?;
294 let aligned_offset = descriptor.data_offsets.0 as usize;
295
296 if aligned_offset > data_offset {
298 let padding_size = aligned_offset - data_offset;
299 let padding = vec![0u8; padding_size];
300 file.write_all(&padding)
301 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
302 data_offset = aligned_offset;
303 }
304
305 let expected_len = snapshot.data_len();
306 let data = snapshot.to_data().map_err(|e| {
307 BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
308 })?;
309 let actual_len = data.bytes.len();
310
311 if actual_len != expected_len {
313 return Err(BurnpackError::IoError(format!(
314 "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
315 snapshot.full_path(),
316 expected_len,
317 actual_len
318 )));
319 }
320
321 file.write_all(&data.bytes)
322 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
323 data_offset += actual_len;
324 }
325
326 file.flush()
327 .map_err(|e| BurnpackError::IoError(e.to_string()))?;
328
329 Ok(())
330 }
331}