burn_store/burnpack/
store.rs

1#[cfg(feature = "std")]
2use std::path::PathBuf;
3
4use super::reader::BurnpackReader;
5use super::writer::BurnpackWriter;
6#[cfg(feature = "std")]
7use crate::KeyRemapper;
8use crate::burnpack::base::BurnpackError;
9use crate::{ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot};
10use alloc::collections::BTreeMap;
11use alloc::format;
12use alloc::string::String;
13use alloc::vec::Vec;
14use burn_core::prelude::Backend;
15use burn_tensor::Bytes;
16
17/// Store mode for BurnpackStore
18enum StoreMode {
19    #[cfg(feature = "std")]
20    File(PathBuf),
21    Bytes(Option<Bytes>),
22}
23
24/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
25pub struct BurnpackStore {
26    /// Store mode - either file path or bytes
27    mode: StoreMode,
28    /// Optional filter for selective loading/saving
29    filter: Option<PathFilter>,
30    /// Additional metadata
31    metadata: BTreeMap<String, String>,
32    /// Allow partial loading (ignore missing tensors)
33    allow_partial: bool,
34    /// Validate tensors during loading (check shapes and dtypes)
35    validate: bool,
36    /// Allow overwriting existing files (default: false)
37    overwrite: bool,
38    /// Enable zero-copy tensor loading (default: false)
39    ///
40    /// When enabled and the backend supports it, tensor data is sliced from
41    /// the source without copying. This requires keeping the source data alive.
42    zero_copy: bool,
43    /// Automatically append .bpk extension if not present (default: true)
44    #[cfg(feature = "std")]
45    auto_extension: bool,
46    /// Key remapper for tensor name transformations
47    #[cfg(feature = "std")]
48    remapper: KeyRemapper,
49    /// Writer for saving
50    writer: Option<BurnpackWriter>,
51    /// Reader for loading
52    reader: Option<BurnpackReader>,
53    /// Cached tensor snapshots (parsed once, reused)
54    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
55}
56
57impl BurnpackStore {
58    /// Get the default metadata that includes Burn framework information.
59    ///
60    /// This includes:
61    /// - `format`: "burnpack"
62    /// - `producer`: "burn"
63    /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
64    ///
65    /// These metadata fields are automatically added to all saved models.
66    pub fn default_metadata() -> BTreeMap<String, String> {
67        let mut metadata = BTreeMap::new();
68        metadata.insert("format".into(), "burnpack".into());
69        metadata.insert("producer".into(), "burn".into());
70        metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
71        metadata
72    }
73    /// Create a new store from a file path
74    ///
75    /// By default, automatically appends `.bpk` extension if the path doesn't have one.
76    /// Use `.auto_extension(false)` to disable this behavior.
77    ///
78    /// # Examples
79    ///
80    /// ```no_run
81    /// # use burn_store::BurnpackStore;
82    /// // Automatically appends .bpk
83    /// let store = BurnpackStore::from_file("model");  // creates "model.bpk"
84    ///
85    /// // Already has extension, no append
86    /// let store = BurnpackStore::from_file("model.bpk");  // uses "model.bpk"
87    /// let store = BurnpackStore::from_file("model.myext");  // uses "model.myext"
88    ///
89    /// // Disable auto-extension
90    /// let store = BurnpackStore::from_file("model").auto_extension(false);  // uses "model"
91    /// ```
92    #[cfg(feature = "std")]
93    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
94        Self {
95            mode: StoreMode::File(path.as_ref().to_path_buf()),
96            filter: None,
97            metadata: Self::default_metadata(),
98            allow_partial: false,
99            validate: true,
100            overwrite: false,
101            zero_copy: false,
102            #[cfg(feature = "std")]
103            auto_extension: true,
104            #[cfg(feature = "std")]
105            remapper: KeyRemapper::new(),
106            writer: None,
107            reader: None,
108            snapshots_cache: None,
109        }
110    }
111
112    /// Create a new store from bytes (for reading) or empty (for writing)
113    pub fn from_bytes(bytes: Option<Bytes>) -> Self {
114        Self {
115            mode: StoreMode::Bytes(bytes),
116            filter: None,
117            metadata: Self::default_metadata(),
118            allow_partial: false,
119            validate: true,
120            overwrite: false,
121            zero_copy: false,
122            #[cfg(feature = "std")]
123            auto_extension: false, // Not used for bytes mode
124            #[cfg(feature = "std")]
125            remapper: KeyRemapper::new(),
126            writer: None,
127            reader: None,
128            snapshots_cache: None,
129        }
130    }
131
132    /// Create a new store from static bytes with zero-copy loading enabled.
133    ///
134    /// This is optimized for embedded model weights where the data lives in the
135    /// binary's `.rodata` section. Tensor data is sliced without copying, keeping
136    /// the static reference alive.
137    ///
138    /// # Example
139    ///
140    /// ```ignore
141    /// static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
142    /// let store = BurnpackStore::from_static(MODEL_DATA);
143    /// ```
144    pub fn from_static(data: &'static [u8]) -> Self {
145        use burn_tensor::AllocationProperty;
146
147        // Create bytes::Bytes from static data (zero-copy, stays in .rodata)
148        let shared = bytes::Bytes::from_static(data);
149
150        // Wrap in cubecl Bytes with shared-bytes allocation controller
151        let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
152
153        Self {
154            mode: StoreMode::Bytes(Some(bytes)),
155            filter: None,
156            metadata: Self::default_metadata(),
157            allow_partial: false,
158            validate: true,
159            overwrite: false,
160            zero_copy: true, // Enable zero-copy by default for static data
161            #[cfg(feature = "std")]
162            auto_extension: false,
163            #[cfg(feature = "std")]
164            remapper: KeyRemapper::new(),
165            writer: None,
166            reader: None,
167            snapshots_cache: None,
168        }
169    }
170
171    /// Add metadata key-value pair
172    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
173        self.metadata.insert(key.into(), value.into());
174        self
175    }
176
177    /// Clear all metadata (including defaults)
178    ///
179    /// This removes all metadata including the default format, producer, and version fields.
180    /// Use with caution as some tools may expect these fields to be present.
181    pub fn clear_metadata(mut self) -> Self {
182        self.metadata.clear();
183        self
184    }
185
186    /// Allow partial loading (ignore missing tensors)
187    ///
188    /// When set to `true`, the store will not fail if some tensors are missing
189    /// during loading. This is useful when loading a subset of a model's parameters.
190    ///
191    /// Default: `false`
192    pub fn allow_partial(mut self, allow: bool) -> Self {
193        self.allow_partial = allow;
194        self
195    }
196
197    /// Enable or disable validation during loading
198    ///
199    /// When validation is enabled, the store will check that loaded tensors
200    /// match the expected shapes and data types. Disabling validation can
201    /// improve performance but may lead to runtime errors if data is corrupted.
202    ///
203    /// Default: `true`
204    pub fn validate(mut self, validate: bool) -> Self {
205        self.validate = validate;
206        self
207    }
208
209    /// Allow overwriting existing files when saving
210    ///
211    /// When set to `false`, attempting to save to an existing file will result in an error.
212    /// When set to `true`, existing files will be overwritten without warning.
213    ///
214    /// Default: `false`
215    pub fn overwrite(mut self, overwrite: bool) -> Self {
216        self.overwrite = overwrite;
217        self
218    }
219
220    /// Enable or disable zero-copy tensor loading.
221    ///
222    /// When enabled and the backend supports it (memory-backed with shared bytes),
223    /// tensor data is sliced from the source without copying. This keeps the source
224    /// data alive as long as any tensor holds a reference.
225    ///
226    /// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).
227    /// Use this method to enable it for other memory-backed stores created with
228    /// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.
229    ///
230    /// Default: `false` (except for `from_static` which defaults to `true`)
231    pub fn zero_copy(mut self, enable: bool) -> Self {
232        self.zero_copy = enable;
233        self
234    }
235
236    /// Enable or disable automatic .bpk extension appending
237    ///
238    /// When enabled (default), automatically appends `.bpk` to the file path
239    /// if no extension is detected. If an extension is already present, it is preserved.
240    ///
241    /// When disabled, uses the exact path provided without modification.
242    ///
243    /// Default: `true`
244    ///
245    /// # Examples
246    ///
247    /// ```no_run
248    /// # use burn_store::BurnpackStore;
249    /// // With auto_extension enabled (default)
250    /// let store = BurnpackStore::from_file("model");  // -> "model.bpk"
251    ///
252    /// // With auto_extension disabled
253    /// let store = BurnpackStore::from_file("model")
254    ///     .auto_extension(false);  // -> "model"
255    /// ```
256    #[cfg(feature = "std")]
257    pub fn auto_extension(mut self, enable: bool) -> Self {
258        self.auto_extension = enable;
259        self
260    }
261
262    /// Set path filter for selective loading/saving
263    pub fn with_filter(mut self, filter: PathFilter) -> Self {
264        self.filter = Some(filter);
265        self
266    }
267
268    /// Add regex pattern to filter
269    #[cfg(feature = "std")]
270    pub fn with_regex(mut self, pattern: &str) -> Self {
271        let filter = self.filter.unwrap_or_default();
272        self.filter = Some(filter.with_regex(pattern));
273        self
274    }
275
276    /// Add exact path to filter
277    pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
278        let filter = self.filter.unwrap_or_default();
279        self.filter = Some(filter.with_full_path(path));
280        self
281    }
282
283    /// Match all tensors (no filtering)
284    pub fn match_all(mut self) -> Self {
285        self.filter = Some(PathFilter::new().match_all());
286        self
287    }
288
289    /// Set key remapper for tensor name transformations during loading
290    #[cfg(feature = "std")]
291    pub fn remap(mut self, remapper: KeyRemapper) -> Self {
292        self.remapper = remapper;
293        self
294    }
295
296    /// Add a single regex pattern for key remapping
297    #[cfg(feature = "std")]
298    pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
299    where
300        S1: AsRef<str>,
301        S2: Into<String>,
302    {
303        self.remapper = self
304            .remapper
305            .add_pattern(from.as_ref(), to.into())
306            .expect("Invalid regex pattern");
307        self
308    }
309
310    /// Set the path filter
311    pub fn filter(mut self, filter: PathFilter) -> Self {
312        self.filter = Some(filter);
313        self
314    }
315
316    /// Get the bytes after writing (only valid for bytes mode after collecting)
317    pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
318        if let Some(writer) = &self.writer {
319            return writer.to_bytes();
320        }
321
322        match &self.mode {
323            StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
324            _ => Err(BurnpackError::IoError("No bytes available".into())),
325        }
326    }
327
328    /// Process the file path with auto-extension logic
329    #[cfg(feature = "std")]
330    fn process_path(&self, path: &std::path::Path) -> PathBuf {
331        if !self.auto_extension {
332            return path.to_path_buf();
333        }
334
335        // Check if path already has an extension
336        if path.extension().is_some() {
337            // Has extension, use as-is
338            return path.to_path_buf();
339        }
340
341        // No extension, append .bpk
342        let mut new_path = path.to_path_buf();
343        new_path.set_extension("bpk");
344        new_path
345    }
346
347    /// Ensure the reader is initialized, loading from storage if needed
348    fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {
349        if self.reader.is_none() {
350            let reader = match &self.mode {
351                #[cfg(feature = "std")]
352                StoreMode::File(path) => {
353                    let final_path = self.process_path(path);
354                    BurnpackReader::from_file(&final_path)?
355                }
356                StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
357                StoreMode::Bytes(None) => {
358                    return Err(BurnpackError::IoError("No bytes to read from".into()));
359                }
360            };
361            self.reader = Some(reader);
362        }
363
364        self.reader
365            .as_ref()
366            .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))
367    }
368}
369
370impl ModuleStore for BurnpackStore {
371    type Error = BurnpackError;
372
373    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
374        &mut self,
375        module: &M,
376    ) -> Result<(), Self::Error> {
377        // Invalidate cache since we're writing new data
378        self.snapshots_cache = None;
379        self.reader = None;
380
381        // Collect snapshots from module
382        let snapshots = module.collect(self.filter.clone(), None, false);
383
384        // Initialize writer with snapshots
385        let mut writer = BurnpackWriter::new(snapshots);
386
387        // Add metadata using builder pattern
388        for (key, value) in &self.metadata {
389            writer = writer.with_metadata(key.as_str(), value.as_str());
390        }
391
392        // Store the writer for finalization
393        self.writer = Some(writer);
394
395        // Write to storage based on mode
396        if let Some(writer) = &self.writer {
397            match &self.mode {
398                #[cfg(feature = "std")]
399                StoreMode::File(path) => {
400                    // Process path with auto-extension logic
401                    let final_path = self.process_path(path);
402
403                    // Check if file exists and overwrite is disabled
404                    if final_path.exists() && !self.overwrite {
405                        return Err(BurnpackError::IoError(format!(
406                            "File already exists: {}. Use .overwrite(true) to overwrite.",
407                            final_path.display()
408                        )));
409                    }
410                    writer.write_to_file(&final_path)?;
411                }
412                StoreMode::Bytes(_) => {
413                    // Generate and store the bytes
414                    let bytes_data = writer.to_bytes()?;
415                    // Update mode with bytes - this pattern is irrefutable in no-std mode
416                    #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
417                    let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
418                        unreachable!("We just matched Bytes variant");
419                    };
420                    *bytes_ref = Some(bytes_data);
421                }
422            }
423        }
424
425        Ok(())
426    }
427
428    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
429        &mut self,
430        module: &mut M,
431    ) -> Result<crate::ApplyResult, Self::Error> {
432        // Get all snapshots using the cached method
433        let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
434
435        // Apply all snapshots at once to the module
436        // Burnpack is Burn's native format, so no enum variant skipping needed
437        // Filter is applied here during apply, not during cache population
438        let result = module.apply(snapshots, self.filter.clone(), None, false);
439
440        // Validate if needed
441        if self.validate && !result.errors.is_empty() {
442            return Err(BurnpackError::ValidationError(format!(
443                "Import errors: {:?}",
444                result.errors
445            )));
446        }
447
448        // Check for missing tensors if partial loading is not allowed
449        if !self.allow_partial && !result.missing.is_empty() {
450            return Err(BurnpackError::ValidationError(format!(
451                "Missing tensors: {:?}",
452                result.missing
453            )));
454        }
455
456        Ok(result)
457    }
458
459    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
460        // Ensure cache is populated
461        self.ensure_snapshots_cache()?;
462        Ok(self.snapshots_cache.as_ref().unwrap().get(name))
463    }
464
465    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
466        // Ensure cache is populated
467        self.ensure_snapshots_cache()?;
468        Ok(self.snapshots_cache.as_ref().unwrap())
469    }
470
471    fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
472        // Always use the cache to ensure remapping is applied consistently
473        Ok(self.get_all_snapshots()?.keys().cloned().collect())
474    }
475}
476
477impl BurnpackStore {
478    /// Ensure the snapshots cache is populated
479    fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {
480        if self.snapshots_cache.is_some() {
481            return Ok(());
482        }
483
484        // Ensure reader is loaded
485        self.ensure_reader()?;
486
487        // Get snapshots from reader with zero-copy if enabled
488        let reader = self.reader.as_ref().unwrap();
489        let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;
490
491        // Apply remapping if configured (but NOT filtering - that's done at apply time)
492        #[cfg(feature = "std")]
493        let snapshots = if !self.remapper.patterns.is_empty() {
494            let (remapped, _remapped_names) = self.remapper.remap(snapshots);
495            remapped
496        } else {
497            snapshots
498        };
499
500        // Build the cache as BTreeMap
501        let cache: BTreeMap<String, TensorSnapshot> =
502            snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
503
504        self.snapshots_cache = Some(cache);
505        Ok(())
506    }
507}