Skip to main content

burn_store/burnpack/
store.rs

1#[cfg(feature = "std")]
2use std::path::PathBuf;
3
4use super::reader::BurnpackReader;
5use super::writer::BurnpackWriter;
6#[cfg(feature = "std")]
7use crate::KeyRemapper;
8use crate::burnpack::base::BurnpackError;
9use crate::{
10    IdentityAdapter, ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter, TensorSnapshot,
11};
12use alloc::boxed::Box;
13use alloc::collections::BTreeMap;
14use alloc::format;
15use alloc::string::String;
16use alloc::vec::Vec;
17use burn_core::prelude::Backend;
18use burn_tensor::Bytes;
19
20/// Store mode for BurnpackStore
21enum StoreMode {
22    #[cfg(feature = "std")]
23    File(PathBuf),
24    Bytes(Option<Bytes>),
25}
26
27/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
28pub struct BurnpackStore {
29    /// Store mode - either file path or bytes
30    mode: StoreMode,
31    /// Optional filter for selective loading/saving
32    filter: Option<PathFilter>,
33    /// Additional metadata
34    metadata: BTreeMap<String, String>,
35    /// Allow partial loading (ignore missing tensors)
36    allow_partial: bool,
37    /// Validate tensors during loading (check shapes and dtypes)
38    validate: bool,
39    /// Allow overwriting existing files (default: false)
40    overwrite: bool,
41    /// Enable zero-copy tensor loading (default: false)
42    ///
43    /// When enabled and the backend supports it, tensor data is sliced from
44    /// the source without copying. This requires keeping the source data alive.
45    zero_copy: bool,
46    /// Automatically append .bpk extension if not present (default: true)
47    #[cfg(feature = "std")]
48    auto_extension: bool,
49    /// Key remapper for tensor name transformations
50    #[cfg(feature = "std")]
51    remapper: KeyRemapper,
52    /// Adapter applied when loading (source -> Burn)
53    from_adapter: Box<dyn ModuleAdapter>,
54    /// Adapter applied when saving (Burn -> target)
55    to_adapter: Box<dyn ModuleAdapter>,
56    /// Writer for saving
57    writer: Option<BurnpackWriter>,
58    /// Reader for loading
59    reader: Option<BurnpackReader>,
60    /// Cached tensor snapshots (parsed once, reused)
61    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
62}
63
64impl BurnpackStore {
65    /// Get the default metadata that includes Burn framework information.
66    ///
67    /// This includes:
68    /// - `format`: "burnpack"
69    /// - `producer`: "burn"
70    /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
71    ///
72    /// These metadata fields are automatically added to all saved models.
73    pub fn default_metadata() -> BTreeMap<String, String> {
74        let mut metadata = BTreeMap::new();
75        metadata.insert("format".into(), "burnpack".into());
76        metadata.insert("producer".into(), "burn".into());
77        metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
78        metadata
79    }
80    /// Create a new store from a file path
81    ///
82    /// By default, automatically appends `.bpk` extension if the path doesn't have one.
83    /// Use `.auto_extension(false)` to disable this behavior.
84    ///
85    /// # Examples
86    ///
87    /// ```no_run
88    /// # use burn_store::BurnpackStore;
89    /// // Automatically appends .bpk
90    /// let store = BurnpackStore::from_file("model");  // creates "model.bpk"
91    ///
92    /// // Already has extension, no append
93    /// let store = BurnpackStore::from_file("model.bpk");  // uses "model.bpk"
94    /// let store = BurnpackStore::from_file("model.myext");  // uses "model.myext"
95    ///
96    /// // Disable auto-extension
97    /// let store = BurnpackStore::from_file("model").auto_extension(false);  // uses "model"
98    /// ```
99    #[cfg(feature = "std")]
100    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
101        Self {
102            mode: StoreMode::File(path.as_ref().to_path_buf()),
103            filter: None,
104            metadata: Self::default_metadata(),
105            allow_partial: false,
106            validate: true,
107            overwrite: false,
108            zero_copy: false,
109            #[cfg(feature = "std")]
110            auto_extension: true,
111            #[cfg(feature = "std")]
112            remapper: KeyRemapper::new(),
113            from_adapter: Box::new(IdentityAdapter),
114            to_adapter: Box::new(IdentityAdapter),
115            writer: None,
116            reader: None,
117            snapshots_cache: None,
118        }
119    }
120
121    /// Create a new store from bytes (for reading) or empty (for writing)
122    pub fn from_bytes(bytes: Option<Bytes>) -> Self {
123        Self {
124            mode: StoreMode::Bytes(bytes),
125            filter: None,
126            metadata: Self::default_metadata(),
127            allow_partial: false,
128            validate: true,
129            overwrite: false,
130            zero_copy: false,
131            #[cfg(feature = "std")]
132            auto_extension: false, // Not used for bytes mode
133            #[cfg(feature = "std")]
134            remapper: KeyRemapper::new(),
135            from_adapter: Box::new(IdentityAdapter),
136            to_adapter: Box::new(IdentityAdapter),
137            writer: None,
138            reader: None,
139            snapshots_cache: None,
140        }
141    }
142
143    /// Create a new store from static bytes with zero-copy loading enabled.
144    ///
145    /// This is optimized for embedded model weights where the data lives in the
146    /// binary's `.rodata` section. Tensor data is sliced without copying, keeping
147    /// the static reference alive.
148    ///
149    /// # Example
150    ///
151    /// ```ignore
152    /// static MODEL_DATA: &[u8] = include_bytes!("model.bpk");
153    /// let store = BurnpackStore::from_static(MODEL_DATA);
154    /// ```
155    pub fn from_static(data: &'static [u8]) -> Self {
156        use burn_tensor::AllocationProperty;
157
158        // Create bytes::Bytes from static data (zero-copy, stays in .rodata)
159        let shared = bytes::Bytes::from_static(data);
160
161        // Wrap in cubecl Bytes with shared-bytes allocation controller
162        let bytes = Bytes::from_shared(shared, AllocationProperty::Other);
163
164        Self {
165            mode: StoreMode::Bytes(Some(bytes)),
166            filter: None,
167            metadata: Self::default_metadata(),
168            allow_partial: false,
169            validate: true,
170            overwrite: false,
171            zero_copy: true, // Enable zero-copy by default for static data
172            #[cfg(feature = "std")]
173            auto_extension: false,
174            #[cfg(feature = "std")]
175            remapper: KeyRemapper::new(),
176            from_adapter: Box::new(IdentityAdapter),
177            to_adapter: Box::new(IdentityAdapter),
178            writer: None,
179            reader: None,
180            snapshots_cache: None,
181        }
182    }
183
184    /// Add metadata key-value pair
185    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
186        self.metadata.insert(key.into(), value.into());
187        self
188    }
189
190    /// Clear all metadata (including defaults)
191    ///
192    /// This removes all metadata including the default format, producer, and version fields.
193    /// Use with caution as some tools may expect these fields to be present.
194    pub fn clear_metadata(mut self) -> Self {
195        self.metadata.clear();
196        self
197    }
198
199    /// Allow partial loading (ignore missing tensors)
200    ///
201    /// When set to `true`, the store will not fail if some tensors are missing
202    /// during loading. This is useful when loading a subset of a model's parameters.
203    ///
204    /// Default: `false`
205    pub fn allow_partial(mut self, allow: bool) -> Self {
206        self.allow_partial = allow;
207        self
208    }
209
210    /// Enable or disable validation during loading
211    ///
212    /// When validation is enabled, the store will check that loaded tensors
213    /// match the expected shapes and data types. Disabling validation can
214    /// improve performance but may lead to runtime errors if data is corrupted.
215    ///
216    /// Default: `true`
217    pub fn validate(mut self, validate: bool) -> Self {
218        self.validate = validate;
219        self
220    }
221
222    /// Allow overwriting existing files when saving
223    ///
224    /// When set to `false`, attempting to save to an existing file will result in an error.
225    /// When set to `true`, existing files will be overwritten without warning.
226    ///
227    /// Default: `false`
228    pub fn overwrite(mut self, overwrite: bool) -> Self {
229        self.overwrite = overwrite;
230        self
231    }
232
233    /// Enable or disable zero-copy tensor loading.
234    ///
235    /// When enabled and the backend supports it (memory-backed with shared bytes),
236    /// tensor data is sliced from the source without copying. This keeps the source
237    /// data alive as long as any tensor holds a reference.
238    ///
239    /// Zero-copy is automatically enabled when using [`from_static`](Self::from_static).
240    /// Use this method to enable it for other memory-backed stores created with
241    /// [`from_bytes`](Self::from_bytes) when using `Bytes::from_shared()`.
242    ///
243    /// Default: `false` (except for `from_static` which defaults to `true`)
244    pub fn zero_copy(mut self, enable: bool) -> Self {
245        self.zero_copy = enable;
246        self
247    }
248
249    /// Enable or disable automatic .bpk extension appending
250    ///
251    /// When enabled (default), automatically appends `.bpk` to the file path
252    /// if no extension is detected. If an extension is already present, it is preserved.
253    ///
254    /// When disabled, uses the exact path provided without modification.
255    ///
256    /// Default: `true`
257    ///
258    /// # Examples
259    ///
260    /// ```no_run
261    /// # use burn_store::BurnpackStore;
262    /// // With auto_extension enabled (default)
263    /// let store = BurnpackStore::from_file("model");  // -> "model.bpk"
264    ///
265    /// // With auto_extension disabled
266    /// let store = BurnpackStore::from_file("model")
267    ///     .auto_extension(false);  // -> "model"
268    /// ```
269    #[cfg(feature = "std")]
270    pub fn auto_extension(mut self, enable: bool) -> Self {
271        self.auto_extension = enable;
272        self
273    }
274
275    /// Set the adapter for loading tensors (converting from source format to Burn).
276    pub fn with_from_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {
277        self.from_adapter = Box::new(adapter);
278        self
279    }
280
281    /// Set the adapter for saving tensors (converting from Burn to target format).
282    pub fn with_to_adapter(mut self, adapter: impl ModuleAdapter + 'static) -> Self {
283        self.to_adapter = Box::new(adapter);
284        self
285    }
286
287    /// Set path filter for selective loading/saving
288    pub fn with_filter(mut self, filter: PathFilter) -> Self {
289        self.filter = Some(filter);
290        self
291    }
292
293    /// Add regex pattern to filter
294    #[cfg(feature = "std")]
295    pub fn with_regex(mut self, pattern: &str) -> Self {
296        let filter = self.filter.unwrap_or_default();
297        self.filter = Some(filter.with_regex(pattern));
298        self
299    }
300
301    /// Add exact path to filter
302    pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
303        let filter = self.filter.unwrap_or_default();
304        self.filter = Some(filter.with_full_path(path));
305        self
306    }
307
308    /// Match all tensors (no filtering)
309    pub fn match_all(mut self) -> Self {
310        self.filter = Some(PathFilter::new().match_all());
311        self
312    }
313
314    /// Set key remapper for tensor name transformations during loading
315    #[cfg(feature = "std")]
316    pub fn remap(mut self, remapper: KeyRemapper) -> Self {
317        self.remapper = remapper;
318        self
319    }
320
321    /// Add a single regex pattern for key remapping
322    #[cfg(feature = "std")]
323    pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
324    where
325        S1: AsRef<str>,
326        S2: Into<String>,
327    {
328        self.remapper = self
329            .remapper
330            .add_pattern(from.as_ref(), to.into())
331            .expect("Invalid regex pattern");
332        self
333    }
334
335    /// Set the path filter
336    pub fn filter(mut self, filter: PathFilter) -> Self {
337        self.filter = Some(filter);
338        self
339    }
340
341    /// Get the bytes after writing (only valid for bytes mode after collecting)
342    pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
343        if let Some(writer) = &self.writer {
344            return writer.to_bytes();
345        }
346
347        match &self.mode {
348            StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
349            _ => Err(BurnpackError::IoError("No bytes available".into())),
350        }
351    }
352
353    /// Process the file path with auto-extension logic
354    #[cfg(feature = "std")]
355    fn process_path(&self, path: &std::path::Path) -> PathBuf {
356        if !self.auto_extension {
357            return path.to_path_buf();
358        }
359
360        // Check if path already has an extension
361        if path.extension().is_some() {
362            // Has extension, use as-is
363            return path.to_path_buf();
364        }
365
366        // No extension, append .bpk
367        let mut new_path = path.to_path_buf();
368        new_path.set_extension("bpk");
369        new_path
370    }
371
372    /// Ensure the reader is initialized, loading from storage if needed
373    fn ensure_reader(&mut self) -> Result<&BurnpackReader, BurnpackError> {
374        if self.reader.is_none() {
375            let reader = match &self.mode {
376                #[cfg(feature = "std")]
377                StoreMode::File(path) => {
378                    let final_path = self.process_path(path);
379                    BurnpackReader::from_file(&final_path)?
380                }
381                StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
382                StoreMode::Bytes(None) => {
383                    return Err(BurnpackError::IoError("No bytes to read from".into()));
384                }
385            };
386            self.reader = Some(reader);
387        }
388
389        self.reader
390            .as_ref()
391            .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))
392    }
393}
394
395impl ModuleStore for BurnpackStore {
396    type Error = BurnpackError;
397
398    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
399        &mut self,
400        module: &M,
401    ) -> Result<(), Self::Error> {
402        // Invalidate cache since we're writing new data
403        self.snapshots_cache = None;
404        self.reader = None;
405
406        // Collect snapshots from module with adapter
407        let snapshots = module.collect(self.filter.clone(), Some(self.to_adapter.clone()), false);
408
409        // Initialize writer with snapshots
410        let mut writer = BurnpackWriter::new(snapshots);
411
412        // Add metadata using builder pattern
413        for (key, value) in &self.metadata {
414            writer = writer.with_metadata(key.as_str(), value.as_str());
415        }
416
417        // Store the writer for finalization
418        self.writer = Some(writer);
419
420        // Write to storage based on mode
421        if let Some(writer) = &self.writer {
422            match &self.mode {
423                #[cfg(feature = "std")]
424                StoreMode::File(path) => {
425                    // Process path with auto-extension logic
426                    let final_path = self.process_path(path);
427
428                    // Check if file exists and overwrite is disabled
429                    if final_path.exists() && !self.overwrite {
430                        return Err(BurnpackError::IoError(format!(
431                            "File already exists: {}. Use .overwrite(true) to overwrite.",
432                            final_path.display()
433                        )));
434                    }
435                    writer.write_to_file(&final_path)?;
436                }
437                StoreMode::Bytes(_) => {
438                    // Generate and store the bytes
439                    let bytes_data = writer.to_bytes()?;
440                    // Update mode with bytes - this pattern is irrefutable in no-std mode
441                    #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
442                    let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
443                        unreachable!("We just matched Bytes variant");
444                    };
445                    *bytes_ref = Some(bytes_data);
446                }
447            }
448        }
449
450        Ok(())
451    }
452
453    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
454        &mut self,
455        module: &mut M,
456    ) -> Result<crate::ApplyResult, Self::Error> {
457        // Get all snapshots using the cached method
458        let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
459
460        // Apply all snapshots at once to the module
461        // Burnpack is Burn's native format, so no enum variant skipping needed
462        // Filter is applied here during apply, not during cache population
463        let result = module.apply(
464            snapshots,
465            self.filter.clone(),
466            Some(self.from_adapter.clone()),
467            false,
468        );
469
470        // Validate if needed
471        if self.validate && !result.errors.is_empty() {
472            return Err(BurnpackError::ValidationError(format!(
473                "Import errors: {:?}",
474                result.errors
475            )));
476        }
477
478        // Check for missing tensors if partial loading is not allowed
479        if !self.allow_partial && !result.missing.is_empty() {
480            return Err(BurnpackError::ValidationError(format!(
481                "Missing tensors: {:?}",
482                result.missing
483            )));
484        }
485
486        Ok(result)
487    }
488
489    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
490        // Ensure cache is populated
491        self.ensure_snapshots_cache()?;
492        Ok(self.snapshots_cache.as_ref().unwrap().get(name))
493    }
494
495    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
496        // Ensure cache is populated
497        self.ensure_snapshots_cache()?;
498        Ok(self.snapshots_cache.as_ref().unwrap())
499    }
500
501    fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
502        // Always use the cache to ensure remapping is applied consistently
503        Ok(self.get_all_snapshots()?.keys().cloned().collect())
504    }
505}
506
507impl BurnpackStore {
508    /// Ensure the snapshots cache is populated
509    fn ensure_snapshots_cache(&mut self) -> Result<(), BurnpackError> {
510        if self.snapshots_cache.is_some() {
511            return Ok(());
512        }
513
514        // Ensure reader is loaded
515        self.ensure_reader()?;
516
517        // Get snapshots from reader with zero-copy if enabled
518        let reader = self.reader.as_ref().unwrap();
519        let snapshots = reader.get_snapshots_zero_copy(self.zero_copy)?;
520
521        // Apply remapping if configured (but NOT filtering - that's done at apply time)
522        #[cfg(feature = "std")]
523        let snapshots = if !self.remapper.patterns.is_empty() {
524            let (remapped, _remapped_names) = self.remapper.remap(snapshots);
525            remapped
526        } else {
527            snapshots
528        };
529
530        // Build the cache as BTreeMap
531        let cache: BTreeMap<String, TensorSnapshot> =
532            snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
533
534        self.snapshots_cache = Some(cache);
535        Ok(())
536    }
537}