burn_store/burnpack/store.rs
1#[cfg(feature = "std")]
2use std::path::PathBuf;
3
4use super::reader::BurnpackReader;
5use super::writer::BurnpackWriter;
6#[cfg(feature = "std")]
7use crate::KeyRemapper;
8use crate::burnpack::base::BurnpackError;
9use crate::{
10 IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot,
11};
12use alloc::boxed::Box;
13use alloc::collections::BTreeMap;
14use alloc::format;
15use alloc::string::String;
16use alloc::vec::Vec;
17use burn_core::prelude::Backend;
18use burn_tensor::Bytes;
19
20/// Store mode for BurnpackStore
21enum StoreMode {
22 #[cfg(feature = "std")]
23 File(PathBuf),
24 Bytes(Option<Bytes>),
25}
26
27/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
28pub struct BurnpackStore {
29 /// Store mode - either file path or bytes
30 mode: StoreMode,
31 /// Optional filter for selective loading/saving
32 filter: Option<PathFilter>,
33 /// Additional metadata
34 metadata: BTreeMap<String, String>,
35 /// Allow partial loading (ignore missing tensors)
36 allow_partial: bool,
37 /// Validate tensors during loading (check shapes and dtypes)
38 validate: bool,
39 /// Allow overwriting existing files (default: false)
40 overwrite: bool,
41 /// Enable zero-copy tensor loading (default: false)
42 ///
43 /// When enabled and the backend supports it, tensor data is sliced from
44 /// the source without copying. This requires keeping the source data alive.
45 zero_copy: bool,
46 /// Automatically append .bpk extension if not present (default: true)
47 #[cfg(feature = "std")]
48 auto_extension: bool,
49 /// Key remapper for tensor name transformations
50 #[cfg(feature = "std")]
51 remapper: KeyRemapper,
52 /// Adapter applied when loading (source -> Burn)
53 from_adapter: Box<dyn ModuleAdapter>,
54 /// Adapter applied when saving (Burn -> target)
55 to_adapter: Box<dyn ModuleAdapter>,
56 /// Writer for saving
57 writer: Option<BurnpackWriter>,
58 /// Reader for loading
59 reader: Option<BurnpackReader>,
60 /// Cached tensor snapshots (parsed once, reused)
61 snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
62}
63
64impl BurnpackStore {
65 /// Get the default metadata that includes Burn framework information.
66 ///
67 /// This includes:
68 /// - `format`: "burnpack"
69 /// - `producer`: "burn"
70 /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
71 ///
72 /// These metadata fields are automatically added to all saved models.
73 pub fn default_metadata() -> BTreeMap<String, String> {
74 let mut metadata = BTreeMap::new();
75 metadata.insert("format".into(), "burnpack".into());
76 metadata.insert("producer".into(), "burn".into());
77 metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
78 metadata
79 }
80 /// Create a new store from a file path
81 ///
82 /// By default, automatically appends `.bpk` extension if the path doesn't have one.
83 /// Use `.auto_extension(false)` to disable this behavior.
84 ///
85 /// # Examples
86 ///
87 /// ```no_run
88 /// # use burn_store::BurnpackStore;
89 /// // Automatically appends .bpk
90 /// let store = BurnpackStore::from_file("model"); // creates "model.bpk"
91 ///
92 /// // Already has extension, no append
93 /// let store = BurnpackStore::from_file("model.bpk"); // uses "model.bpk"
94 /// let store = BurnpackStore::from_file("model.myext"); // uses "model.myext"
95 ///
96 /// // Disable auto-extension
97 /// let store = BurnpackStore::from_file("model").auto_extension(false); // uses "model"
98 /// ```
99 #[cfg(feature = "std")]
100 pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
101 Self {
102 mode: StoreMode::File(path.as_ref().to_path_buf()),
103 filter: None,
104 metadata: Self::default_metadata(),
105 allow_partial: false,
106 validate: true,
107 overwrite: false,
108 zero_copy: false,
109 #[cfg(feature = "std")]
110 auto_extension: true,
111 #[cfg(feature = "std")]
112 remapper: KeyRemapper::new(),
113 from_adapter: Box::new(IdentityAdapter),
114 to_adapter: Box::new(IdentityAdapter),
115 writer: None,
116 reader: None,
117 snapshots_cache: None,
118 }
119 }
120
121 /// Create a new store from bytes (for reading) or empty (for writing)
122 pub fn from_bytes(bytes: Option<Bytes>) -> Self {
123 Self {
124 mode: StoreMode::Bytes(bytes),
125 filter: None,
126 metadata: Self::default_metadata(),
127 allow_partial: false,
128 validate: true,
129 overwrite: false,
130 zero_copy: false,
131 #[cfg(feature = "std")]
132 auto_extension: false, // Not used for bytes mode
133 #[cfg(feature = "std")]
134 remapper: KeyRemapper::new(),
135 from_adapter: Box::new(IdentityAdapter),
136 to_adapter: Box::new(IdentityAdapter),
137 writer: None,
138 reader: None,
139 snapshots_cache: None,
140 }
141 }
142
143 /// Create a new store from static bytes with zero-copy loading enabled.
144 ///
145 /// This is optimized for embedded model weights where the data lives in the
146 /// binary's `.rodata` section. Tensor data is sliced without copying, keeping
147 /// the static reference alive.
148 ///
149 /// # Example
150 ///
151 /// ```ignore
152 /// static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
153 /// let store = BurnpackStore::from_static(MODEL_DATA);
154 /// ```
155 pub fn from_static(data: &'static [u8]) -> Self {
156 use burn_tensor::AllocationProperty;
157
158 // Create bytes::Bytes from static data (zero-copy, stays in .rodata)
159 let shared = bytes::Bytes::from_static(data);
160
161 // Wrap in cubecl Bytes with shared-bytes allocation controller
162 let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
163
164 Self {
165 mode: StoreMode::Bytes(Some(bytes)),
166 filter: None,
167 metadata: Self::default_metadata(),
168 allow_partial: false,
169 validate: true,
170 overwrite: false,
171 zero_copy: true, // Enable zero-copy by default for static data
172 #[cfg(feature = "std")]
173 auto_extension: false,
174 #[cfg(feature = "std")]
175 remapper: KeyRemapper::new(),
176 from_adapter: Box::new(IdentityAdapter),
177 to_adapter: Box::new(IdentityAdapter),
178 writer: None,
179 reader: None,
180 snapshots_cache: None,
181 }
182 }
183
184 /// Add metadata key-value pair
185 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
186 self.metadata.insert(key.into(), value.into());
187 self
188 }
189
190 /// Clear all metadata (including defaults)
191 ///
192 /// This removes all metadata including the default format, producer, and version fields.
193 /// Use with caution as some tools may expect these fields to be present.
194 pub fn clear_metadata(mut self) -> Self {
195 self.metadata.clear();
196 self
197 }
198
199 /// Allow partial loading (ignore missing tensors)
200 ///
201 /// When set to `true`, the store will not fail if some tensors are missing
202 /// during loading. This is useful when loading a subset of a model's parameters.
203 ///
204 /// Default: `false`
205 pub fn allow_partial(mut self, allow: bool) -> Self {
206 self.allow_partial = allow;
207 self
208 }
209
210 /// Enable or disable validation during loading
211 ///
212 /// When validation is enabled, the store will check that loaded tensors
213 /// match the expected shapes and data types. Disabling validation can
214 /// improve performance but may lead to runtime errors if data is corrupted.
215 ///
216 /// Default: `true`
217 pub fn validate(mut self, validate: bool) -> Self {
218 self.validate = validate;
219 self
220 }
221
222 /// Allow overwriting existing files when saving
223 ///
224 /// When set to `false`, attempting to save to an existing file will result in an error.
225 /// When set to `true`, existing files will be overwritten without warning.
226 ///
227 /// Default: `false`
228 pub fn overwrite(mut self, overwrite: bool) -> Self {
229 self.overwrite = overwrite;
230 self
231 }
232
233 /// Enable or disable zero-copy tensor loading.
234 ///
235 /// When enabled and the backend supports it (memory-backed with shared bytes),
236 /// tensor data is sliced from the source without copying. This keeps the source
237 /// data alive as long as any tensor holds a reference.
238 ///
239 /// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).
240 /// Use this method to enable it for other memory-backed stores created with
241 /// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.
242 ///
243 /// Default: `false` (except for `from_static` which defaults to `true`)
244 pub fn zero_copy(mut self, enable: bool) -> Self {
245 self.zero_copy = enable;
246 self
247 }
248
249 /// Enable or disable automatic .bpk extension appending
250 ///
251 /// When enabled (default), automatically appends `.bpk` to the file path
252 /// if no extension is detected. If an extension is already present, it is preserved.
253 ///
254 /// When disabled, uses the exact path provided without modification.
255 ///
256 /// Default: `true`
257 ///
258 /// # Examples
259 ///
260 /// ```no_run
261 /// # use burn_store::BurnpackStore;
262 /// // With auto_extension enabled (default)
263 /// let store = BurnpackStore::from_file("model"); // -> "model.bpk"
264 ///
265 /// // With auto_extension disabled
266 /// let store = BurnpackStore::from_file("model")
267 /// .auto_extension(false); // -> "model"
268 /// ```
269 #[cfg(feature = "std")]
270 pub fn auto_extension(mut self, enable: bool) -> Self {
271 self.auto_extension = enable;
272 self
273 }
274
275 /// Set the adapter for loading tensors (converting from source format to Burn).
276 pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {
277 self.from_adapter = Box::new(adapter);
278 self
279 }
280
281 /// Set the adapter for saving tensors (converting from Burn to target format).
282 pub fn with_to_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {
283 self.to_adapter = Box::new(adapter);
284 self
285 }
286
287 /// Set path filter for selective loading/saving
288 pub fn with_filter(mut self, filter: PathFilter) -> Self {
289 self.filter = Some(filter);
290 self
291 }
292
293 /// Add regex pattern to filter
294 #[cfg(feature = "std")]
295 pub fn with_regex(mut self, pattern: &str) -> Self {
296 let filter = self.filter.unwrap_or_default();
297 self.filter = Some(filter.with_regex(pattern));
298 self
299 }
300
301 /// Add exact path to filter
302 pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
303 let filter = self.filter.unwrap_or_default();
304 self.filter = Some(filter.with_full_path(path));
305 self
306 }
307
308 /// Match all tensors (no filtering)
309 pub fn match_all(mut self) -> Self {
310 self.filter = Some(PathFilter::new().match_all());
311 self
312 }
313
314 /// Set key remapper for tensor name transformations during loading
315 #[cfg(feature = "std")]
316 pub fn remap(mut self, remapper: KeyRemapper) -> Self {
317 self.remapper = remapper;
318 self
319 }
320
321 /// Add a single regex pattern for key remapping
322 #[cfg(feature = "std")]
323 pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
324 where
325 S1: AsRef<str>,
326 S2: Into<String>,
327 {
328 self.remapper = self
329 .remapper
330 .add_pattern(from.as_ref(), to.into())
331 .expect("Invalid regex pattern");
332 self
333 }
334
335 /// Set the path filter
336 pub fn filter(mut self, filter: PathFilter) -> Self {
337 self.filter = Some(filter);
338 self
339 }
340
341 /// Get the bytes after writing (only valid for bytes mode after collecting)
342 pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
343 if let Some(writer) = &self.writer {
344 return writer.to_bytes();
345 }
346
347 match &self.mode {
348 StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
349 _ => Err(BurnpackError::IoError("No bytes available".into())),
350 }
351 }
352
353 /// Process the file path with auto-extension logic
354 #[cfg(feature = "std")]
355 fn process_path(&self, path: &std::path::Path) -> PathBuf {
356 if !self.auto_extension {
357 return path.to_path_buf();
358 }
359
360 // Check if path already has an extension
361 if path.extension().is_some() {
362 // Has extension, use as-is
363 return path.to_path_buf();
364 }
365
366 // No extension, append .bpk
367 let mut new_path = path.to_path_buf();
368 new_path.set_extension("bpk");
369 new_path
370 }
371
372 /// Ensure the reader is initialized, loading from storage if needed
373 fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {
374 if self.reader.is_none() {
375 let reader = match &self.mode {
376 #[cfg(feature = "std")]
377 StoreMode::File(path) => {
378 let final_path = self.process_path(path);
379 BurnpackReader::from_file(&final_path)?
380 }
381 StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
382 StoreMode::Bytes(None) => {
383 return Err(BurnpackError::IoError("No bytes to read from".into()));
384 }
385 };
386 self.reader = Some(reader);
387 }
388
389 self.reader
390 .as_ref()
391 .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))
392 }
393}
394
395impl ModuleStore for BurnpackStore {
396 type Error = BurnpackError;
397
398 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
399 &mut self,
400 module: &M,
401 ) -> Result<(), Self::Error> {
402 // Invalidate cache since we're writing new data
403 self.snapshots_cache = None;
404 self.reader = None;
405
406 // Collect snapshots from module with adapter
407 let snapshots = module.collect(self.filter.clone(), Some(self.to_adapter.clone()), false);
408
409 // Initialize writer with snapshots
410 let mut writer = BurnpackWriter::new(snapshots);
411
412 // Add metadata using builder pattern
413 for (key, value) in &self.metadata {
414 writer = writer.with_metadata(key.as_str(), value.as_str());
415 }
416
417 // Store the writer for finalization
418 self.writer = Some(writer);
419
420 // Write to storage based on mode
421 if let Some(writer) = &self.writer {
422 match &self.mode {
423 #[cfg(feature = "std")]
424 StoreMode::File(path) => {
425 // Process path with auto-extension logic
426 let final_path = self.process_path(path);
427
428 // Check if file exists and overwrite is disabled
429 if final_path.exists() && !self.overwrite {
430 return Err(BurnpackError::IoError(format!(
431 "File already exists: {}. Use .overwrite(true) to overwrite.",
432 final_path.display()
433 )));
434 }
435 writer.write_to_file(&final_path)?;
436 }
437 StoreMode::Bytes(_) => {
438 // Generate and store the bytes
439 let bytes_data = writer.to_bytes()?;
440 // Update mode with bytes - this pattern is irrefutable in no-std mode
441 #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
442 let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
443 unreachable!("We just matched Bytes variant");
444 };
445 *bytes_ref = Some(bytes_data);
446 }
447 }
448 }
449
450 Ok(())
451 }
452
453 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
454 &mut self,
455 module: &mut M,
456 ) -> Result<crate::ApplyResult, Self::Error> {
457 // Get all snapshots using the cached method
458 let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
459
460 // Apply all snapshots at once to the module
461 // Burnpack is Burn's native format, so no enum variant skipping needed
462 // Filter is applied here during apply, not during cache population
463 let result = module.apply(
464 snapshots,
465 self.filter.clone(),
466 Some(self.from_adapter.clone()),
467 false,
468 );
469
470 // Validate if needed
471 if self.validate && !result.errors.is_empty() {
472 return Err(BurnpackError::ValidationError(format!(
473 "Import errors: {:?}",
474 result.errors
475 )));
476 }
477
478 // Check for missing tensors if partial loading is not allowed
479 if !self.allow_partial && !result.missing.is_empty() {
480 return Err(BurnpackError::ValidationError(format!(
481 "Missing tensors: {:?}",
482 result.missing
483 )));
484 }
485
486 Ok(result)
487 }
488
489 fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
490 // Ensure cache is populated
491 self.ensure_snapshots_cache()?;
492 Ok(self.snapshots_cache.as_ref().unwrap().get(name))
493 }
494
495 fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
496 // Ensure cache is populated
497 self.ensure_snapshots_cache()?;
498 Ok(self.snapshots_cache.as_ref().unwrap())
499 }
500
501 fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
502 // Always use the cache to ensure remapping is applied consistently
503 Ok(self.get_all_snapshots()?.keys().cloned().collect())
504 }
505}
506
507impl BurnpackStore {
508 /// Ensure the snapshots cache is populated
509 fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {
510 if self.snapshots_cache.is_some() {
511 return Ok(());
512 }
513
514 // Ensure reader is loaded
515 self.ensure_reader()?;
516
517 // Get snapshots from reader with zero-copy if enabled
518 let reader = self.reader.as_ref().unwrap();
519 let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;
520
521 // Apply remapping if configured (but NOT filtering - that's done at apply time)
522 #[cfg(feature = "std")]
523 let snapshots = if !self.remapper.patterns.is_empty() {
524 let (remapped, _remapped_names) = self.remapper.remap(snapshots);
525 remapped
526 } else {
527 snapshots
528 };
529
530 // Build the cache as BTreeMap
531 let cache: BTreeMap<String, TensorSnapshot> =
532 snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
533
534 self.snapshots_cache = Some(cache);
535 Ok(())
536 }
537}