burn_store/burnpack/store.rs
1#[cfg(feature = "std")]
2use std::path::PathBuf;
3
4use super::reader::BurnpackReader;
5use super::writer::BurnpackWriter;
6#[cfg(feature = "std")]
7use crate::KeyRemapper;
8use crate::burnpack::base::BurnpackError;
9use crate::{ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot};
10use alloc::collections::BTreeMap;
11use alloc::format;
12use alloc::string::String;
13use alloc::vec::Vec;
14use burn_core::prelude::Backend;
15use burn_tensor::Bytes;
16
17/// Store mode for BurnpackStore
18enum StoreMode {
19 #[cfg(feature = "std")]
20 File(PathBuf),
21 Bytes(Option<Bytes>),
22}
23
24/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
25pub struct BurnpackStore {
26 /// Store mode - either file path or bytes
27 mode: StoreMode,
28 /// Optional filter for selective loading/saving
29 filter: Option<PathFilter>,
30 /// Additional metadata
31 metadata: BTreeMap<String, String>,
32 /// Allow partial loading (ignore missing tensors)
33 allow_partial: bool,
34 /// Validate tensors during loading (check shapes and dtypes)
35 validate: bool,
36 /// Allow overwriting existing files (default: false)
37 overwrite: bool,
38 /// Enable zero-copy tensor loading (default: false)
39 ///
40 /// When enabled and the backend supports it, tensor data is sliced from
41 /// the source without copying. This requires keeping the source data alive.
42 zero_copy: bool,
43 /// Automatically append .bpk extension if not present (default: true)
44 #[cfg(feature = "std")]
45 auto_extension: bool,
46 /// Key remapper for tensor name transformations
47 #[cfg(feature = "std")]
48 remapper: KeyRemapper,
49 /// Writer for saving
50 writer: Option<BurnpackWriter>,
51 /// Reader for loading
52 reader: Option<BurnpackReader>,
53 /// Cached tensor snapshots (parsed once, reused)
54 snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
55}
56
57impl BurnpackStore {
58 /// Get the default metadata that includes Burn framework information.
59 ///
60 /// This includes:
61 /// - `format`: "burnpack"
62 /// - `producer`: "burn"
63 /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
64 ///
65 /// These metadata fields are automatically added to all saved models.
66 pub fn default_metadata() -> BTreeMap<String, String> {
67 let mut metadata = BTreeMap::new();
68 metadata.insert("format".into(), "burnpack".into());
69 metadata.insert("producer".into(), "burn".into());
70 metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
71 metadata
72 }
73 /// Create a new store from a file path
74 ///
75 /// By default, automatically appends `.bpk` extension if the path doesn't have one.
76 /// Use `.auto_extension(false)` to disable this behavior.
77 ///
78 /// # Examples
79 ///
80 /// ```no_run
81 /// # use burn_store::BurnpackStore;
82 /// // Automatically appends .bpk
83 /// let store = BurnpackStore::from_file("model"); // creates "model.bpk"
84 ///
85 /// // Already has extension, no append
86 /// let store = BurnpackStore::from_file("model.bpk"); // uses "model.bpk"
87 /// let store = BurnpackStore::from_file("model.myext"); // uses "model.myext"
88 ///
89 /// // Disable auto-extension
90 /// let store = BurnpackStore::from_file("model").auto_extension(false); // uses "model"
91 /// ```
92 #[cfg(feature = "std")]
93 pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
94 Self {
95 mode: StoreMode::File(path.as_ref().to_path_buf()),
96 filter: None,
97 metadata: Self::default_metadata(),
98 allow_partial: false,
99 validate: true,
100 overwrite: false,
101 zero_copy: false,
102 #[cfg(feature = "std")]
103 auto_extension: true,
104 #[cfg(feature = "std")]
105 remapper: KeyRemapper::new(),
106 writer: None,
107 reader: None,
108 snapshots_cache: None,
109 }
110 }
111
112 /// Create a new store from bytes (for reading) or empty (for writing)
113 pub fn from_bytes(bytes: Option<Bytes>) -> Self {
114 Self {
115 mode: StoreMode::Bytes(bytes),
116 filter: None,
117 metadata: Self::default_metadata(),
118 allow_partial: false,
119 validate: true,
120 overwrite: false,
121 zero_copy: false,
122 #[cfg(feature = "std")]
123 auto_extension: false, // Not used for bytes mode
124 #[cfg(feature = "std")]
125 remapper: KeyRemapper::new(),
126 writer: None,
127 reader: None,
128 snapshots_cache: None,
129 }
130 }
131
132 /// Create a new store from static bytes with zero-copy loading enabled.
133 ///
134 /// This is optimized for embedded model weights where the data lives in the
135 /// binary's `.rodata` section. Tensor data is sliced without copying, keeping
136 /// the static reference alive.
137 ///
138 /// # Example
139 ///
140 /// ```ignore
141 /// static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
142 /// let store = BurnpackStore::from_static(MODEL_DATA);
143 /// ```
144 pub fn from_static(data: &'static [u8]) -> Self {
145 use burn_tensor::AllocationProperty;
146
147 // Create bytes::Bytes from static data (zero-copy, stays in .rodata)
148 let shared = bytes::Bytes::from_static(data);
149
150 // Wrap in cubecl Bytes with shared-bytes allocation controller
151 let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
152
153 Self {
154 mode: StoreMode::Bytes(Some(bytes)),
155 filter: None,
156 metadata: Self::default_metadata(),
157 allow_partial: false,
158 validate: true,
159 overwrite: false,
160 zero_copy: true, // Enable zero-copy by default for static data
161 #[cfg(feature = "std")]
162 auto_extension: false,
163 #[cfg(feature = "std")]
164 remapper: KeyRemapper::new(),
165 writer: None,
166 reader: None,
167 snapshots_cache: None,
168 }
169 }
170
171 /// Add metadata key-value pair
172 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
173 self.metadata.insert(key.into(), value.into());
174 self
175 }
176
177 /// Clear all metadata (including defaults)
178 ///
179 /// This removes all metadata including the default format, producer, and version fields.
180 /// Use with caution as some tools may expect these fields to be present.
181 pub fn clear_metadata(mut self) -> Self {
182 self.metadata.clear();
183 self
184 }
185
186 /// Allow partial loading (ignore missing tensors)
187 ///
188 /// When set to `true`, the store will not fail if some tensors are missing
189 /// during loading. This is useful when loading a subset of a model's parameters.
190 ///
191 /// Default: `false`
192 pub fn allow_partial(mut self, allow: bool) -> Self {
193 self.allow_partial = allow;
194 self
195 }
196
197 /// Enable or disable validation during loading
198 ///
199 /// When validation is enabled, the store will check that loaded tensors
200 /// match the expected shapes and data types. Disabling validation can
201 /// improve performance but may lead to runtime errors if data is corrupted.
202 ///
203 /// Default: `true`
204 pub fn validate(mut self, validate: bool) -> Self {
205 self.validate = validate;
206 self
207 }
208
209 /// Allow overwriting existing files when saving
210 ///
211 /// When set to `false`, attempting to save to an existing file will result in an error.
212 /// When set to `true`, existing files will be overwritten without warning.
213 ///
214 /// Default: `false`
215 pub fn overwrite(mut self, overwrite: bool) -> Self {
216 self.overwrite = overwrite;
217 self
218 }
219
220 /// Enable or disable zero-copy tensor loading.
221 ///
222 /// When enabled and the backend supports it (memory-backed with shared bytes),
223 /// tensor data is sliced from the source without copying. This keeps the source
224 /// data alive as long as any tensor holds a reference.
225 ///
226 /// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).
227 /// Use this method to enable it for other memory-backed stores created with
228 /// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.
229 ///
230 /// Default: `false` (except for `from_static` which defaults to `true`)
231 pub fn zero_copy(mut self, enable: bool) -> Self {
232 self.zero_copy = enable;
233 self
234 }
235
236 /// Enable or disable automatic .bpk extension appending
237 ///
238 /// When enabled (default), automatically appends `.bpk` to the file path
239 /// if no extension is detected. If an extension is already present, it is preserved.
240 ///
241 /// When disabled, uses the exact path provided without modification.
242 ///
243 /// Default: `true`
244 ///
245 /// # Examples
246 ///
247 /// ```no_run
248 /// # use burn_store::BurnpackStore;
249 /// // With auto_extension enabled (default)
250 /// let store = BurnpackStore::from_file("model"); // -> "model.bpk"
251 ///
252 /// // With auto_extension disabled
253 /// let store = BurnpackStore::from_file("model")
254 /// .auto_extension(false); // -> "model"
255 /// ```
256 #[cfg(feature = "std")]
257 pub fn auto_extension(mut self, enable: bool) -> Self {
258 self.auto_extension = enable;
259 self
260 }
261
262 /// Set path filter for selective loading/saving
263 pub fn with_filter(mut self, filter: PathFilter) -> Self {
264 self.filter = Some(filter);
265 self
266 }
267
268 /// Add regex pattern to filter
269 #[cfg(feature = "std")]
270 pub fn with_regex(mut self, pattern: &str) -> Self {
271 let filter = self.filter.unwrap_or_default();
272 self.filter = Some(filter.with_regex(pattern));
273 self
274 }
275
276 /// Add exact path to filter
277 pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
278 let filter = self.filter.unwrap_or_default();
279 self.filter = Some(filter.with_full_path(path));
280 self
281 }
282
283 /// Match all tensors (no filtering)
284 pub fn match_all(mut self) -> Self {
285 self.filter = Some(PathFilter::new().match_all());
286 self
287 }
288
289 /// Set key remapper for tensor name transformations during loading
290 #[cfg(feature = "std")]
291 pub fn remap(mut self, remapper: KeyRemapper) -> Self {
292 self.remapper = remapper;
293 self
294 }
295
296 /// Add a single regex pattern for key remapping
297 #[cfg(feature = "std")]
298 pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
299 where
300 S1: AsRef<str>,
301 S2: Into<String>,
302 {
303 self.remapper = self
304 .remapper
305 .add_pattern(from.as_ref(), to.into())
306 .expect("Invalid regex pattern");
307 self
308 }
309
310 /// Set the path filter
311 pub fn filter(mut self, filter: PathFilter) -> Self {
312 self.filter = Some(filter);
313 self
314 }
315
316 /// Get the bytes after writing (only valid for bytes mode after collecting)
317 pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
318 if let Some(writer) = &self.writer {
319 return writer.to_bytes();
320 }
321
322 match &self.mode {
323 StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
324 _ => Err(BurnpackError::IoError("No bytes available".into())),
325 }
326 }
327
328 /// Process the file path with auto-extension logic
329 #[cfg(feature = "std")]
330 fn process_path(&self, path: &std::path::Path) -> PathBuf {
331 if !self.auto_extension {
332 return path.to_path_buf();
333 }
334
335 // Check if path already has an extension
336 if path.extension().is_some() {
337 // Has extension, use as-is
338 return path.to_path_buf();
339 }
340
341 // No extension, append .bpk
342 let mut new_path = path.to_path_buf();
343 new_path.set_extension("bpk");
344 new_path
345 }
346
347 /// Ensure the reader is initialized, loading from storage if needed
348 fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {
349 if self.reader.is_none() {
350 let reader = match &self.mode {
351 #[cfg(feature = "std")]
352 StoreMode::File(path) => {
353 let final_path = self.process_path(path);
354 BurnpackReader::from_file(&final_path)?
355 }
356 StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
357 StoreMode::Bytes(None) => {
358 return Err(BurnpackError::IoError("No bytes to read from".into()));
359 }
360 };
361 self.reader = Some(reader);
362 }
363
364 self.reader
365 .as_ref()
366 .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))
367 }
368}
369
370impl ModuleStore for BurnpackStore {
371 type Error = BurnpackError;
372
373 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
374 &mut self,
375 module: &M,
376 ) -> Result<(), Self::Error> {
377 // Invalidate cache since we're writing new data
378 self.snapshots_cache = None;
379 self.reader = None;
380
381 // Collect snapshots from module
382 let snapshots = module.collect(self.filter.clone(), None, false);
383
384 // Initialize writer with snapshots
385 let mut writer = BurnpackWriter::new(snapshots);
386
387 // Add metadata using builder pattern
388 for (key, value) in &self.metadata {
389 writer = writer.with_metadata(key.as_str(), value.as_str());
390 }
391
392 // Store the writer for finalization
393 self.writer = Some(writer);
394
395 // Write to storage based on mode
396 if let Some(writer) = &self.writer {
397 match &self.mode {
398 #[cfg(feature = "std")]
399 StoreMode::File(path) => {
400 // Process path with auto-extension logic
401 let final_path = self.process_path(path);
402
403 // Check if file exists and overwrite is disabled
404 if final_path.exists() && !self.overwrite {
405 return Err(BurnpackError::IoError(format!(
406 "File already exists: {}. Use .overwrite(true) to overwrite.",
407 final_path.display()
408 )));
409 }
410 writer.write_to_file(&final_path)?;
411 }
412 StoreMode::Bytes(_) => {
413 // Generate and store the bytes
414 let bytes_data = writer.to_bytes()?;
415 // Update mode with bytes - this pattern is irrefutable in no-std mode
416 #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
417 let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
418 unreachable!("We just matched Bytes variant");
419 };
420 *bytes_ref = Some(bytes_data);
421 }
422 }
423 }
424
425 Ok(())
426 }
427
428 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
429 &mut self,
430 module: &mut M,
431 ) -> Result<crate::ApplyResult, Self::Error> {
432 // Get all snapshots using the cached method
433 let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
434
435 // Apply all snapshots at once to the module
436 // Burnpack is Burn's native format, so no enum variant skipping needed
437 // Filter is applied here during apply, not during cache population
438 let result = module.apply(snapshots, self.filter.clone(), None, false);
439
440 // Validate if needed
441 if self.validate && !result.errors.is_empty() {
442 return Err(BurnpackError::ValidationError(format!(
443 "Import errors: {:?}",
444 result.errors
445 )));
446 }
447
448 // Check for missing tensors if partial loading is not allowed
449 if !self.allow_partial && !result.missing.is_empty() {
450 return Err(BurnpackError::ValidationError(format!(
451 "Missing tensors: {:?}",
452 result.missing
453 )));
454 }
455
456 Ok(result)
457 }
458
459 fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
460 // Ensure cache is populated
461 self.ensure_snapshots_cache()?;
462 Ok(self.snapshots_cache.as_ref().unwrap().get(name))
463 }
464
465 fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
466 // Ensure cache is populated
467 self.ensure_snapshots_cache()?;
468 Ok(self.snapshots_cache.as_ref().unwrap())
469 }
470
471 fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
472 // Always use the cache to ensure remapping is applied consistently
473 Ok(self.get_all_snapshots()?.keys().cloned().collect())
474 }
475}
476
477impl BurnpackStore {
478 /// Ensure the snapshots cache is populated
479 fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {
480 if self.snapshots_cache.is_some() {
481 return Ok(());
482 }
483
484 // Ensure reader is loaded
485 self.ensure_reader()?;
486
487 // Get snapshots from reader with zero-copy if enabled
488 let reader = self.reader.as_ref().unwrap();
489 let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;
490
491 // Apply remapping if configured (but NOT filtering - that's done at apply time)
492 #[cfg(feature = "std")]
493 let snapshots = if !self.remapper.patterns.is_empty() {
494 let (remapped, _remapped_names) = self.remapper.remap(snapshots);
495 remapped
496 } else {
497 snapshots
498 };
499
500 // Build the cache as BTreeMap
501 let cache: BTreeMap<String, TensorSnapshot> =
502 snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
503
504 self.snapshots_cache = Some(cache);
505 Ok(())
506 }
507}