burn_store/
traits.rs

1use alloc::boxed::Box;
2use alloc::collections::BTreeMap;
3use alloc::string::String;
4use alloc::vec::Vec;
5
6use super::applier::Applier;
7use super::apply_result::ApplyResult;
8use crate::collector::Collector;
9use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
10use burn_core::module::Module;
11use burn_tensor::backend::Backend;
12
13/// Extension trait for modules that provides tensor storage functionality.
14///
15/// This trait provides convenient methods to collect and apply tensor snapshots from any Burn module.
16/// Collection operations create lightweight tensor snapshots without immediately copying data.
17/// Apply operations apply tensor data from snapshots to the corresponding tensors in the module.
18pub trait ModuleSnapshot<B: Backend>: Module<B> {
19    /// Collects tensor snapshots for inspection without copying data.
20    ///
21    /// Returns a vector of `TensorSnapshot` objects that can lazily materialize the tensor data.
22    /// Each `TensorSnapshot` contains the full path accessible via `snapshot.full_path()`.
23    ///
24    /// # Arguments
25    ///
26    /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.
27    ///   When `None`, all tensors are collected.
28    /// * `adapter` - Optional adapter to transform tensors based on container types.
29    ///   Applied to all collected tensors before returning.
30    /// * `skip_enum_variants` - Skip enum variant names when building paths.
31    ///   When true, paths will not include enum variant names (e.g., "feature.weight"
32    ///   instead of "feature.BaseConv.weight"). Useful when exporting to formats
33    ///   like PyTorch/SafeTensors that don't use enum variants.
34    fn collect(
35        &self,
36        filter: Option<PathFilter>,
37        adapter: Option<Box<dyn ModuleAdapter>>,
38        skip_enum_variants: bool,
39    ) -> Vec<TensorSnapshot> {
40        let mut collector = Collector::new(filter, adapter, skip_enum_variants);
41        self.visit(&mut collector);
42        collector.into_tensors()
43    }
44
45    /// Applies tensor snapshots to the module.
46    ///
47    /// This is the primary apply method that applies tensor data from `TensorSnapshot`s
48    /// to the corresponding tensors in the module. The snapshots are typically obtained
49    /// from `collect()` or loaded from storage.
50    ///
51    /// # Arguments
52    ///
53    /// * `snapshots` - A vector of TensorSnapshot objects
54    /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
55    ///   When `None`, all available tensors are applied.
56    /// * `adapter` - Optional adapter to transform tensors based on container types
57    /// * `skip_enum_variants` - Skip enum variant names when matching tensor paths
58    ///
59    /// # Returns
60    ///
61    /// An [`ApplyResult`] containing information about applied, skipped, missing,
62    /// and unused tensors, as well as any errors encountered.
63    ///
64    /// # Examples
65    ///
66    /// ```rust,ignore
67    /// use burn_store::PathFilter;
68    ///
69    /// // Apply all tensors
70    /// let result = model.apply(snapshots, None, None, false);
71    ///
72    /// // Apply only encoder tensors
73    /// let filter = PathFilter::new().with_regex(r"^encoder\..*");
74    /// let result = model.apply(snapshots, Some(filter), None, false);
75    ///
76    /// // Apply with complex filter
77    /// let filter = PathFilter::new()
78    ///     .with_regex(r"^encoder\..*")
79    ///     .with_regex(r"^decoder\..*")
80    ///     .with_full_path("head.weight");
81    /// let result = model.apply(snapshots, Some(filter), None, false);
82    ///
83    /// // Apply with enum variant skipping (for PyTorch models)
84    /// let result = model.apply(snapshots, None, None, true);
85    /// ```
86    fn apply(
87        &mut self,
88        snapshots: Vec<TensorSnapshot>,
89        filter: Option<PathFilter>,
90        adapter: Option<Box<dyn ModuleAdapter>>,
91        skip_enum_variants: bool,
92    ) -> ApplyResult
93    where
94        Self: Sized,
95    {
96        let mut applier = Applier::new(snapshots, filter, adapter, skip_enum_variants);
97
98        // Use unsafe to avoid cloning the entire module, which would double the memory usage
99        // We read the module out, map it, then write it back
100        // See https://github.com/tracel-ai/burn/issues/3754
101        unsafe {
102            // Read the module out of self (moves it, leaving self in undefined state)
103            let module = core::ptr::read(self as *const Self);
104
105            // Map the module to create a new one with updated tensors
106            let new_module = module.map(&mut applier);
107
108            // Write the new module back to self
109            core::ptr::write(self as *mut Self, new_module);
110        }
111
112        applier.into_result()
113    }
114
115    /// Saves tensor snapshots into a [`ModuleStore`].
116    ///
117    /// This method allows using a `ModuleStore` implementation to handle the
118    /// collection and writing logic in a configurable way.
119    ///
120    /// # Arguments
121    ///
122    /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors
123    fn save_into<P>(&self, store: &mut P) -> Result<(), P::Error>
124    where
125        P: ModuleStore,
126    {
127        store.collect_from(self)
128    }
129
130    /// Loads tensor data from a [`ModuleStore`].
131    ///
132    /// This method allows using a `ModuleStore` implementation to handle the
133    /// loading and application logic in a configurable way.
134    ///
135    /// # Arguments
136    ///
137    /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors
138    fn load_from<P>(&mut self, store: &mut P) -> Result<ApplyResult, P::Error>
139    where
140        P: ModuleStore,
141    {
142        store.apply_to(self)
143    }
144}
145
146/// A trait for handling module storage operations.
147///
148/// `ModuleStore` provides a unified interface for saving and loading module
149/// tensor data with support for various storage formats and advanced features like filtering,
150/// remapping, and metadata handling.
151pub trait ModuleStore {
152    /// The error type that can be returned during storage operations.
153    ///
154    /// This should be a format-specific error type that provides detailed
155    /// information about what went wrong (e.g., I/O errors, format violations,
156    /// unsupported tensor types).
157    type Error: core::fmt::Debug + core::fmt::Display;
158
159    /// Collect tensor data from a module and store it to storage.
160    ///
161    /// This method traverses the module structure, collects all tensor data
162    /// according to the store's configuration (filters, remapping, etc.),
163    /// and writes it to the underlying storage.
164    ///
165    /// # Arguments
166    ///
167    /// * `module` - The module to collect tensor data from. The module must
168    ///   implement `ModuleSnapshot` to provide tensor access.
169    ///
170    /// # Returns
171    ///
172    /// * `Ok(())` - If all tensors were successfully collected and stored
173    /// * `Err(Self::Error)` - If an error occurred during collection or writing
174    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
175        &mut self,
176        module: &M,
177    ) -> Result<(), Self::Error>;
178
179    /// Load stored tensor data and apply it to a module.
180    ///
181    /// This method reads tensor data from storage and applies it to the provided
182    /// module. The operation is flexible and can handle partial matches, missing
183    /// tensors, and extra tensors in the storage.
184    ///
185    /// # Arguments
186    ///
187    /// * `module` - The module to apply tensor data to. The module must
188    ///   implement `ModuleSnapshot` to allow tensor updates.
189    ///
190    /// # Returns
191    ///
192    /// * `Ok(ApplyResult)` - Detailed information about the apply operation:
193    ///   - `applied`: List of successfully applied tensor names
194    ///   - `missing`: Tensors expected by the module but not found in storage
195    ///   - `skipped`: Tensors in storage that were not applied (filtered or not needed)
196    ///   - `errors`: Non-critical errors that occurred during apply
197    /// * `Err(Self::Error)` - If a critical error prevented the apply operation
198    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
199        &mut self,
200        module: &mut M,
201    ) -> Result<ApplyResult, Self::Error>;
202
203    /// Get a single tensor snapshot by name.
204    ///
205    /// This method provides direct access to individual tensors in storage without
206    /// requiring a module. The returned `TensorSnapshot` uses lazy loading - tensor
207    /// data is only materialized when `to_data()` is called.
208    ///
209    /// **Note:** Key remapping is applied, so use the remapped name if configured.
210    /// Filters are NOT applied - use `apply_to()` for filtered loading.
211    ///
212    /// Results are cached after the first call for efficient repeated access.
213    ///
214    /// # Arguments
215    ///
216    /// * `name` - The tensor name/path (e.g., "encoder.layer1.weight")
217    ///
218    /// # Returns
219    ///
220    /// * `Ok(Some(&TensorSnapshot))` - Reference to the tensor snapshot if found
221    /// * `Ok(None)` - If no tensor with that name exists
222    /// * `Err(Self::Error)` - If an error occurred accessing storage
223    ///
224    /// # Example
225    ///
226    /// ```rust,ignore
227    /// let mut store = BurnpackStore::from_file("model.bpk");
228    /// if let Some(snapshot) = store.get_snapshot("encoder.weight")? {
229    ///     println!("Shape: {:?}", snapshot.shape);
230    ///     println!("Dtype: {:?}", snapshot.dtype);
231    ///     let data = snapshot.to_data()?;  // Lazy load
232    /// }
233    /// ```
234    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error>;
235
236    /// Get all tensor snapshots from storage as an ordered map.
237    ///
238    /// This method returns all tensors in storage as lazy-loading snapshots,
239    /// organized in a `BTreeMap` for efficient lookup by name. The map preserves
240    /// alphabetical ordering of tensor names.
241    ///
242    /// **Note:** This returns ALL tensors in storage, regardless of any filter
243    /// settings. Filters are only applied during `apply_to()`. Key remapping
244    /// IS applied, so tensor names reflect any configured remapping.
245    ///
246    /// Results are cached after the first call for efficient repeated access.
247    ///
248    /// # Returns
249    ///
250    /// * `Ok(&BTreeMap<String, TensorSnapshot>)` - Reference to all tensor snapshots
251    /// * `Err(Self::Error)` - If an error occurred accessing storage
252    ///
253    /// # Example
254    ///
255    /// ```rust,ignore
256    /// let mut store = SafetensorsStore::from_file("model.safetensors");
257    /// let snapshots = store.get_all_snapshots()?;
258    /// for (name, snapshot) in snapshots {
259    ///     println!("{}: {:?}", name, snapshot.shape);
260    /// }
261    /// ```
262    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error>;
263
264    /// Get all tensor names/keys in storage.
265    ///
266    /// This method returns the names of all tensors in storage.
267    /// Useful for inspecting storage contents or checking if specific tensors exist.
268    ///
269    /// **Note:** Returns ALL tensor names regardless of filter settings.
270    /// Key remapping IS applied, so names reflect any configured remapping.
271    ///
272    /// # Returns
273    ///
274    /// * `Ok(Vec<String>)` - All tensor names in storage
275    /// * `Err(Self::Error)` - If an error occurred accessing storage
276    ///
277    /// # Example
278    ///
279    /// ```rust,ignore
280    /// let mut store = PytorchStore::from_file("model.pth");
281    /// let keys = store.keys()?;
282    /// println!("Tensors in file: {:?}", keys);
283    /// ```
284    fn keys(&mut self) -> Result<Vec<String>, Self::Error>;
285}
286
287// Blanket implementation for all modules
288impl<B: Backend, M: Module<B>> ModuleSnapshot<B> for M {}