burn_store/burnpack/
writer.rs

1use super::base::{
2    BurnpackError, BurnpackHeader, BurnpackMetadata, FORMAT_VERSION, HEADER_SIZE, MAGIC_NUMBER,
3    TENSOR_ALIGNMENT, TensorDescriptor, aligned_data_section_start,
4};
5use crate::TensorSnapshot;
6use alloc::collections::BTreeMap;
7use alloc::format;
8use alloc::string::{String, ToString};
9use alloc::vec;
10use alloc::vec::Vec;
11use burn_tensor::Bytes;
12
13#[cfg(feature = "std")]
14use std::fs::File;
15#[cfg(feature = "std")]
16use std::io::Write;
17#[cfg(feature = "std")]
18use std::path::Path;
19
20/// Align an offset to the specified alignment boundary.
21///
22/// Returns the smallest value >= `offset` that is a multiple of `alignment`.
23#[inline]
24const fn align_offset(offset: u64, alignment: u64) -> u64 {
25    offset.div_ceil(alignment) * alignment
26}
27
28/// Writer for creating Burnpack files
29pub struct BurnpackWriter {
30    /// Tensors to write
31    pub(crate) snapshots: Vec<TensorSnapshot>,
32    /// Metadata key-value pairs
33    pub(crate) metadata: BTreeMap<String, String>,
34}
35
36impl BurnpackWriter {
37    /// Create a new writer
38    pub fn new(snapshots: Vec<TensorSnapshot>) -> Self {
39        Self {
40            snapshots,
41            metadata: BTreeMap::new(),
42        }
43    }
44
45    /// Builder pattern: add metadata and return self
46    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
47        self.metadata.insert(key.to_string(), value.to_string());
48        self
49    }
50
51    /// Build tensor descriptors and metadata
52    fn build_metadata(&self) -> Result<(BurnpackMetadata, Vec<u8>), BurnpackError> {
53        // Build tensor descriptors and calculate offsets with alignment
54        let mut tensors = BTreeMap::new();
55        let mut current_offset = 0u64;
56
57        for snapshot in &self.snapshots {
58            let data_len = snapshot.data_len() as u64;
59
60            // Align the start offset for mmap zero-copy support
61            let aligned_start = align_offset(current_offset, TENSOR_ALIGNMENT);
62            let end = aligned_start.checked_add(data_len).ok_or_else(|| {
63                BurnpackError::IoError(format!(
64                    "Tensor offset overflow: {} + {} exceeds maximum",
65                    aligned_start, data_len
66                ))
67            })?;
68
69            tensors.insert(
70                snapshot.full_path(),
71                TensorDescriptor {
72                    dtype: snapshot.dtype,
73                    shape: snapshot.shape.iter().map(|&s| s as u64).collect(),
74                    data_offsets: (aligned_start, end),
75                    param_id: snapshot.tensor_id.map(|id| id.val()),
76                },
77            );
78
79            current_offset = end;
80        }
81
82        // Create metadata structure
83        let metadata = BurnpackMetadata {
84            tensors,
85            metadata: self.metadata.clone(),
86        };
87
88        // Serialize metadata with CBOR
89        let mut metadata_bytes = Vec::new();
90        ciborium::ser::into_writer(&metadata, &mut metadata_bytes)
91            .map_err(|e| BurnpackError::IoError(e.to_string()))?;
92
93        Ok((metadata, metadata_bytes))
94    }
95
96    /// Calculate the total size needed for the burnpack data
97    ///
98    /// This is useful when you want to pre-allocate a buffer for `write_into()`.
99    /// The size includes padding bytes for both metadata alignment and tensor alignment.
100    pub fn size(&self) -> Result<usize, BurnpackError> {
101        let (metadata, metadata_bytes) = self.build_metadata()?;
102
103        // Data section starts at aligned position after header + metadata
104        let data_section_start = aligned_data_section_start(metadata_bytes.len());
105
106        // Calculate total data section size from aligned offsets
107        // The last tensor's end offset gives us the total data section size
108        let data_size = metadata
109            .tensors
110            .values()
111            .map(|t| t.data_offsets.1)
112            .max()
113            .unwrap_or(0) as usize;
114
115        Ok(data_section_start + data_size)
116    }
117
118    /// Write burnpack data into a caller-provided buffer
119    ///
120    /// The buffer must be large enough to hold all data. Use `size()` to determine
121    /// the required buffer size. If the buffer is too small, this will return an error.
122    ///
123    /// This allows the caller to control buffer allocation, enabling optimizations like:
124    /// - Buffer reuse across multiple writes
125    /// - Custom allocators
126    /// - Pinned memory for GPU transfers
127    ///
128    /// # Arguments
129    ///
130    /// * `buffer` - Mutable slice to write data into. Must be at least `size()` bytes.
131    pub fn write_into(&self, buffer: &mut [u8]) -> Result<(), BurnpackError> {
132        let (metadata, metadata_bytes) = self.build_metadata()?;
133
134        // Check metadata size fits in u32
135        let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
136            BurnpackError::IoError(format!(
137                "Metadata size {} exceeds maximum of {} bytes",
138                metadata_bytes.len(),
139                u32::MAX
140            ))
141        })?;
142
143        // Create header
144        let header = BurnpackHeader {
145            magic: MAGIC_NUMBER,
146            version: FORMAT_VERSION,
147            metadata_size,
148        };
149
150        // Data section starts at aligned position after header + metadata
151        let data_section_start = aligned_data_section_start(metadata_bytes.len());
152
153        // Calculate required size from aligned offsets
154        let data_size = metadata
155            .tensors
156            .values()
157            .map(|t| t.data_offsets.1)
158            .max()
159            .unwrap_or(0) as usize;
160        let total_size = data_section_start + data_size;
161
162        // Check buffer size
163        if buffer.len() < total_size {
164            return Err(BurnpackError::IoError(format!(
165                "Buffer too small: need {} bytes, got {} bytes",
166                total_size,
167                buffer.len()
168            )));
169        }
170
171        let mut offset = 0;
172
173        // Write header
174        let header_bytes = header.into_bytes();
175        buffer[offset..offset + HEADER_SIZE].copy_from_slice(&header_bytes);
176        offset += HEADER_SIZE;
177
178        // Write metadata
179        buffer[offset..offset + metadata_bytes.len()].copy_from_slice(&metadata_bytes);
180        offset += metadata_bytes.len();
181
182        // Write padding to align data section start
183        if data_section_start > offset {
184            buffer[offset..data_section_start].fill(0);
185            offset = data_section_start;
186        }
187
188        // Write tensor data with alignment padding
189        for snapshot in &self.snapshots {
190            // Get the aligned offset from metadata
191            let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
192                BurnpackError::IoError(format!(
193                    "Internal error: tensor '{}' not found in metadata",
194                    snapshot.full_path()
195                ))
196            })?;
197            let aligned_offset = descriptor.data_offsets.0 as usize;
198            let target_offset = data_section_start + aligned_offset;
199
200            // Write padding zeros if needed
201            if target_offset > offset {
202                buffer[offset..target_offset].fill(0);
203                offset = target_offset;
204            }
205
206            let expected_len = snapshot.data_len();
207            let data = snapshot.to_data().map_err(|e| {
208                BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
209            })?;
210            let actual_len = data.bytes.len();
211
212            // Validate data length consistency
213            if actual_len != expected_len {
214                return Err(BurnpackError::IoError(format!(
215                    "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
216                    snapshot.full_path(),
217                    expected_len,
218                    actual_len
219                )));
220            }
221
222            buffer[offset..offset + actual_len].copy_from_slice(&data.bytes);
223            offset += actual_len;
224        }
225
226        Ok(())
227    }
228
229    /// Write to a byte buffer (convenience method)
230    ///
231    /// This allocates a buffer internally and writes the burnpack data.
232    /// For more control over buffer allocation, use `size()` + `write_into()`.
233    pub fn to_bytes(&self) -> Result<Bytes, BurnpackError> {
234        let size = self.size()?;
235        let mut buffer = vec![0u8; size];
236        self.write_into(&mut buffer)?;
237        Ok(Bytes::from_bytes_vec(buffer))
238    }
239
240    /// Write directly to a file (more memory efficient for large models)
241    #[cfg(feature = "std")]
242    pub fn write_to_file<P: AsRef<Path>>(&self, path: P) -> Result<(), BurnpackError> {
243        let mut file = File::create(path).map_err(|e| BurnpackError::IoError(e.to_string()))?;
244
245        let (metadata, metadata_bytes) = self.build_metadata()?;
246
247        // Check metadata size fits in u32
248        let metadata_size: u32 = metadata_bytes.len().try_into().map_err(|_| {
249            BurnpackError::IoError(format!(
250                "Metadata size {} exceeds maximum of {} bytes",
251                metadata_bytes.len(),
252                u32::MAX
253            ))
254        })?;
255
256        // Create and write header
257        let header = BurnpackHeader {
258            magic: MAGIC_NUMBER,
259            version: FORMAT_VERSION,
260            metadata_size,
261        };
262
263        file.write_all(&header.into_bytes())
264            .map_err(|e| BurnpackError::IoError(e.to_string()))?;
265
266        // Write metadata
267        file.write_all(&metadata_bytes)
268            .map_err(|e| BurnpackError::IoError(e.to_string()))?;
269
270        // Data section starts at aligned position after header + metadata
271        let data_section_start = aligned_data_section_start(metadata_bytes.len());
272        let current_file_pos = HEADER_SIZE + metadata_bytes.len();
273
274        // Write padding to align data section start
275        if data_section_start > current_file_pos {
276            let padding_size = data_section_start - current_file_pos;
277            let padding = vec![0u8; padding_size];
278            file.write_all(&padding)
279                .map_err(|e| BurnpackError::IoError(e.to_string()))?;
280        }
281
282        // Track current position within data section (relative to data_section_start)
283        let mut data_offset = 0usize;
284
285        // Stream tensor data directly to file with alignment padding
286        for snapshot in &self.snapshots {
287            // Get the aligned offset from metadata
288            let descriptor = metadata.tensors.get(&snapshot.full_path()).ok_or_else(|| {
289                BurnpackError::IoError(format!(
290                    "Internal error: tensor '{}' not found in metadata",
291                    snapshot.full_path()
292                ))
293            })?;
294            let aligned_offset = descriptor.data_offsets.0 as usize;
295
296            // Write padding zeros if needed
297            if aligned_offset > data_offset {
298                let padding_size = aligned_offset - data_offset;
299                let padding = vec![0u8; padding_size];
300                file.write_all(&padding)
301                    .map_err(|e| BurnpackError::IoError(e.to_string()))?;
302                data_offset = aligned_offset;
303            }
304
305            let expected_len = snapshot.data_len();
306            let data = snapshot.to_data().map_err(|e| {
307                BurnpackError::IoError(format!("Failed to get tensor data: {:?}", e))
308            })?;
309            let actual_len = data.bytes.len();
310
311            // Validate data length consistency
312            if actual_len != expected_len {
313                return Err(BurnpackError::IoError(format!(
314                    "Data corruption: tensor '{}' has inconsistent length (expected {}, got {})",
315                    snapshot.full_path(),
316                    expected_len,
317                    actual_len
318                )));
319            }
320
321            file.write_all(&data.bytes)
322                .map_err(|e| BurnpackError::IoError(e.to_string()))?;
323            data_offset += actual_len;
324        }
325
326        file.flush()
327            .map_err(|e| BurnpackError::IoError(e.to_string()))?;
328
329        Ok(())
330    }
331}