burn_store/burnpack/
store.rs

1#[cfg(feature = "std")]
2use std::path::PathBuf;
3
4use super::reader::BurnpackReader;
5use super::writer::BurnpackWriter;
6#[cfg(feature = "std")]
7use crate::KeyRemapper;
8use crate::burnpack::base::BurnpackError;
9use crate::{ModuleSnapshot, ModuleStore, PathFilter};
10use alloc::collections::BTreeMap;
11use alloc::format;
12use alloc::string::String;
13use burn_core::prelude::Backend;
14use burn_tensor::Bytes;
15
16/// Store mode for BurnpackStore
17enum StoreMode {
18    #[cfg(feature = "std")]
19    File(PathBuf),
20    Bytes(Option<Bytes>),
21}
22
23/// BurnpackStore - A Burn-specific file format store using CBOR for metadata
24pub struct BurnpackStore {
25    /// Store mode - either file path or bytes
26    mode: StoreMode,
27    /// Optional filter for selective loading/saving
28    filter: Option<PathFilter>,
29    /// Additional metadata
30    metadata: BTreeMap<String, String>,
31    /// Allow partial loading (ignore missing tensors)
32    allow_partial: bool,
33    /// Validate tensors during loading (check shapes and dtypes)
34    validate: bool,
35    /// Allow overwriting existing files (default: false)
36    overwrite: bool,
37    /// Automatically append .bpk extension if not present (default: true)
38    #[cfg(feature = "std")]
39    auto_extension: bool,
40    /// Key remapper for tensor name transformations
41    #[cfg(feature = "std")]
42    remapper: KeyRemapper,
43    /// Writer for saving
44    writer: Option<BurnpackWriter>,
45    /// Reader for loading
46    reader: Option<BurnpackReader>,
47}
48
49impl BurnpackStore {
50    /// Get the default metadata that includes Burn framework information.
51    ///
52    /// This includes:
53    /// - `format`: "burnpack"
54    /// - `producer`: "burn"
55    /// - `version`: The version of burn-store crate (from CARGO_PKG_VERSION)
56    ///
57    /// These metadata fields are automatically added to all saved models.
58    pub fn default_metadata() -> BTreeMap<String, String> {
59        let mut metadata = BTreeMap::new();
60        metadata.insert("format".into(), "burnpack".into());
61        metadata.insert("producer".into(), "burn".into());
62        metadata.insert("version".into(), env!("CARGO_PKG_VERSION").into());
63        metadata
64    }
65    /// Create a new store from a file path
66    ///
67    /// By default, automatically appends `.bpk` extension if the path doesn't have one.
68    /// Use `.auto_extension(false)` to disable this behavior.
69    ///
70    /// # Examples
71    ///
72    /// ```no_run
73    /// # use burn_store::BurnpackStore;
74    /// // Automatically appends .bpk
75    /// let store = BurnpackStore::from_file("model");  // creates "model.bpk"
76    ///
77    /// // Already has extension, no append
78    /// let store = BurnpackStore::from_file("model.bpk");  // uses "model.bpk"
79    /// let store = BurnpackStore::from_file("model.myext");  // uses "model.myext"
80    ///
81    /// // Disable auto-extension
82    /// let store = BurnpackStore::from_file("model").auto_extension(false);  // uses "model"
83    /// ```
84    #[cfg(feature = "std")]
85    pub fn from_file<P: AsRef<std::path::Path>>(path: P) -> Self {
86        Self {
87            mode: StoreMode::File(path.as_ref().to_path_buf()),
88            filter: None,
89            metadata: Self::default_metadata(),
90            allow_partial: false,
91            validate: true,
92            overwrite: false,
93            #[cfg(feature = "std")]
94            auto_extension: true,
95            #[cfg(feature = "std")]
96            remapper: KeyRemapper::new(),
97            writer: None,
98            reader: None,
99        }
100    }
101
102    /// Create a new store from bytes (for reading) or empty (for writing)
103    pub fn from_bytes(bytes: Option<Bytes>) -> Self {
104        Self {
105            mode: StoreMode::Bytes(bytes),
106            filter: None,
107            metadata: Self::default_metadata(),
108            allow_partial: false,
109            validate: true,
110            overwrite: false,
111            #[cfg(feature = "std")]
112            auto_extension: false, // Not used for bytes mode
113            #[cfg(feature = "std")]
114            remapper: KeyRemapper::new(),
115            writer: None,
116            reader: None,
117        }
118    }
119
120    /// Add metadata key-value pair
121    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
122        self.metadata.insert(key.into(), value.into());
123        self
124    }
125
126    /// Clear all metadata (including defaults)
127    ///
128    /// This removes all metadata including the default format, producer, and version fields.
129    /// Use with caution as some tools may expect these fields to be present.
130    pub fn clear_metadata(mut self) -> Self {
131        self.metadata.clear();
132        self
133    }
134
135    /// Allow partial loading (ignore missing tensors)
136    ///
137    /// When set to `true`, the store will not fail if some tensors are missing
138    /// during loading. This is useful when loading a subset of a model's parameters.
139    ///
140    /// Default: `false`
141    pub fn allow_partial(mut self, allow: bool) -> Self {
142        self.allow_partial = allow;
143        self
144    }
145
146    /// Enable or disable validation during loading
147    ///
148    /// When validation is enabled, the store will check that loaded tensors
149    /// match the expected shapes and data types. Disabling validation can
150    /// improve performance but may lead to runtime errors if data is corrupted.
151    ///
152    /// Default: `true`
153    pub fn validate(mut self, validate: bool) -> Self {
154        self.validate = validate;
155        self
156    }
157
158    /// Allow overwriting existing files when saving
159    ///
160    /// When set to `false`, attempting to save to an existing file will result in an error.
161    /// When set to `true`, existing files will be overwritten without warning.
162    ///
163    /// Default: `false`
164    pub fn overwrite(mut self, overwrite: bool) -> Self {
165        self.overwrite = overwrite;
166        self
167    }
168
169    /// Enable or disable automatic .bpk extension appending
170    ///
171    /// When enabled (default), automatically appends `.bpk` to the file path
172    /// if no extension is detected. If an extension is already present, it is preserved.
173    ///
174    /// When disabled, uses the exact path provided without modification.
175    ///
176    /// Default: `true`
177    ///
178    /// # Examples
179    ///
180    /// ```no_run
181    /// # use burn_store::BurnpackStore;
182    /// // With auto_extension enabled (default)
183    /// let store = BurnpackStore::from_file("model");  // -> "model.bpk"
184    ///
185    /// // With auto_extension disabled
186    /// let store = BurnpackStore::from_file("model")
187    ///     .auto_extension(false);  // -> "model"
188    /// ```
189    #[cfg(feature = "std")]
190    pub fn auto_extension(mut self, enable: bool) -> Self {
191        self.auto_extension = enable;
192        self
193    }
194
195    /// Set path filter for selective loading/saving
196    pub fn with_filter(mut self, filter: PathFilter) -> Self {
197        self.filter = Some(filter);
198        self
199    }
200
201    /// Add regex pattern to filter
202    #[cfg(feature = "std")]
203    pub fn with_regex(mut self, pattern: &str) -> Self {
204        let filter = self.filter.unwrap_or_default();
205        self.filter = Some(filter.with_regex(pattern));
206        self
207    }
208
209    /// Add exact path to filter
210    pub fn with_full_path(mut self, path: impl Into<String>) -> Self {
211        let filter = self.filter.unwrap_or_default();
212        self.filter = Some(filter.with_full_path(path));
213        self
214    }
215
216    /// Match all tensors (no filtering)
217    pub fn match_all(mut self) -> Self {
218        self.filter = Some(PathFilter::new().match_all());
219        self
220    }
221
222    /// Set key remapper for tensor name transformations during loading
223    #[cfg(feature = "std")]
224    pub fn remap(mut self, remapper: KeyRemapper) -> Self {
225        self.remapper = remapper;
226        self
227    }
228
229    /// Add a single regex pattern for key remapping
230    #[cfg(feature = "std")]
231    pub fn with_remap_pattern<S1, S2>(mut self, from: S1, to: S2) -> Self
232    where
233        S1: AsRef<str>,
234        S2: Into<String>,
235    {
236        self.remapper = self
237            .remapper
238            .add_pattern(from.as_ref(), to.into())
239            .expect("Invalid regex pattern");
240        self
241    }
242
243    /// Set the path filter
244    pub fn filter(mut self, filter: PathFilter) -> Self {
245        self.filter = Some(filter);
246        self
247    }
248
249    /// Get the bytes after writing (only valid for bytes mode after collecting)
250    pub fn get_bytes(&self) -> Result<Bytes, BurnpackError> {
251        if let Some(writer) = &self.writer {
252            return writer.to_bytes();
253        }
254
255        match &self.mode {
256            StoreMode::Bytes(Some(bytes)) => Ok(bytes.clone()),
257            _ => Err(BurnpackError::IoError("No bytes available".into())),
258        }
259    }
260
261    /// Process the file path with auto-extension logic
262    #[cfg(feature = "std")]
263    fn process_path(&self, path: &std::path::Path) -> PathBuf {
264        if !self.auto_extension {
265            return path.to_path_buf();
266        }
267
268        // Check if path already has an extension
269        if path.extension().is_some() {
270            // Has extension, use as-is
271            return path.to_path_buf();
272        }
273
274        // No extension, append .bpk
275        let mut new_path = path.to_path_buf();
276        new_path.set_extension("bpk");
277        new_path
278    }
279}
280
281impl ModuleStore for BurnpackStore {
282    type Error = BurnpackError;
283
284    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
285        &mut self,
286        module: &M,
287    ) -> Result<(), Self::Error> {
288        // Collect snapshots from module
289        let snapshots = module.collect(self.filter.clone(), None);
290
291        // Initialize writer with snapshots
292        let mut writer = BurnpackWriter::new(snapshots);
293
294        // Add metadata using builder pattern
295        for (key, value) in &self.metadata {
296            writer = writer.with_metadata(key.as_str(), value.as_str());
297        }
298
299        // Store the writer for finalization
300        self.writer = Some(writer);
301
302        // Write to storage based on mode
303        if let Some(writer) = &self.writer {
304            match &self.mode {
305                #[cfg(feature = "std")]
306                StoreMode::File(path) => {
307                    // Process path with auto-extension logic
308                    let final_path = self.process_path(path);
309
310                    // Check if file exists and overwrite is disabled
311                    if final_path.exists() && !self.overwrite {
312                        return Err(BurnpackError::IoError(format!(
313                            "File already exists: {}. Use .overwrite(true) to overwrite.",
314                            final_path.display()
315                        )));
316                    }
317                    writer.write_to_file(&final_path)?;
318                }
319                StoreMode::Bytes(_) => {
320                    // Generate and store the bytes
321                    let bytes_data = writer.to_bytes()?;
322                    // Update mode with bytes - this pattern is irrefutable in no-std mode
323                    #[cfg_attr(not(feature = "std"), allow(irrefutable_let_patterns))]
324                    let StoreMode::Bytes(bytes_ref) = &mut self.mode else {
325                        unreachable!("We just matched Bytes variant");
326                    };
327                    *bytes_ref = Some(bytes_data);
328                }
329            }
330        }
331
332        Ok(())
333    }
334
335    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
336        &mut self,
337        module: &mut M,
338    ) -> Result<crate::ApplyResult, Self::Error> {
339        // Load reader if not already loaded
340        if self.reader.is_none() {
341            let reader = match &self.mode {
342                #[cfg(feature = "std")]
343                StoreMode::File(path) => {
344                    // Process path with auto-extension logic
345                    let final_path = self.process_path(path);
346                    BurnpackReader::from_file(&final_path)?
347                }
348                StoreMode::Bytes(Some(bytes)) => BurnpackReader::from_bytes(bytes.clone())?,
349                StoreMode::Bytes(None) => {
350                    return Err(BurnpackError::IoError("No bytes to read from".into()));
351                }
352            };
353            self.reader = Some(reader);
354        }
355
356        let reader = self
357            .reader
358            .as_ref()
359            .ok_or_else(|| BurnpackError::IoError("Reader not initialized".into()))?;
360
361        // Get all snapshots at once for efficient loading
362        #[cfg(feature = "std")]
363        let snapshots = if !self.remapper.patterns.is_empty() {
364            let (remapped, _remapped_names) = self.remapper.remap(reader.get_snapshots()?);
365            // TODO figure what to do with remapped names
366            remapped
367        } else {
368            reader.get_snapshots()?
369        };
370
371        #[cfg(not(feature = "std"))]
372        let snapshots = reader.get_snapshots()?;
373
374        // Apply all snapshots at once to the module
375        let result = module.apply(snapshots, self.filter.clone(), None);
376
377        // Validate if needed
378        if self.validate && !result.errors.is_empty() {
379            return Err(BurnpackError::ValidationError(format!(
380                "Import errors: {:?}",
381                result.errors
382            )));
383        }
384
385        // Check for missing tensors if partial loading is not allowed
386        if !self.allow_partial && !result.missing.is_empty() {
387            return Err(BurnpackError::ValidationError(format!(
388                "Missing tensors: {:?}",
389                result.missing
390            )));
391        }
392
393        Ok(result)
394    }
395}