burn_store/traits.rs
1use alloc::boxed::Box;
2use alloc::collections::BTreeMap;
3use alloc::string::String;
4use alloc::vec::Vec;
5
6use super::applier::Applier;
7use super::apply_result::ApplyResult;
8use crate::collector::Collector;
9use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
10use burn_core::module::Module;
11use burn_tensor::backend::Backend;
12
13/// Extension trait for modules that provides tensor storage functionality.
14///
15/// This trait provides convenient methods to collect and apply tensor snapshots from any Burn module.
16/// Collection operations create lightweight tensor snapshots without immediately copying data.
17/// Apply operations apply tensor data from snapshots to the corresponding tensors in the module.
18pub trait ModuleSnapshot<B: Backend>: Module<B> {
19 /// Collects tensor snapshots for inspection without copying data.
20 ///
21 /// Returns a vector of `TensorSnapshot` objects that can lazily materialize the tensor data.
22 /// Each `TensorSnapshot` contains the full path accessible via `snapshot.full_path()`.
23 ///
24 /// # Arguments
25 ///
26 /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.
27 /// When `None`, all tensors are collected.
28 /// * `adapter` - Optional adapter to transform tensors based on container types.
29 /// Applied to all collected tensors before returning.
30 /// * `skip_enum_variants` - Skip enum variant names when building paths.
31 /// When true, paths will not include enum variant names (e.g., "feature.weight"
32 /// instead of "feature.BaseConv.weight"). Useful when exporting to formats
33 /// like PyTorch/SafeTensors that don't use enum variants.
34 fn collect(
35 &self,
36 filter: Option<PathFilter>,
37 adapter: Option<Box<dyn ModuleAdapter>>,
38 skip_enum_variants: bool,
39 ) -> Vec<TensorSnapshot> {
40 let mut collector = Collector::new(filter, adapter, skip_enum_variants);
41 self.visit(&mut collector);
42 collector.into_tensors()
43 }
44
45 /// Applies tensor snapshots to the module.
46 ///
47 /// This is the primary apply method that applies tensor data from `TensorSnapshot`s
48 /// to the corresponding tensors in the module. The snapshots are typically obtained
49 /// from `collect()` or loaded from storage.
50 ///
51 /// # Arguments
52 ///
53 /// * `snapshots` - A vector of TensorSnapshot objects
54 /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
55 /// When `None`, all available tensors are applied.
56 /// * `adapter` - Optional adapter to transform tensors based on container types
57 /// * `skip_enum_variants` - Skip enum variant names when matching tensor paths
58 ///
59 /// # Returns
60 ///
61 /// An [`ApplyResult`] containing information about applied, skipped, missing,
62 /// and unused tensors, as well as any errors encountered.
63 ///
64 /// # Examples
65 ///
66 /// ```rust,ignore
67 /// use burn_store::PathFilter;
68 ///
69 /// // Apply all tensors
70 /// let result = model.apply(snapshots, None, None, false);
71 ///
72 /// // Apply only encoder tensors
73 /// let filter = PathFilter::new().with_regex(r"^encoder\..*");
74 /// let result = model.apply(snapshots, Some(filter), None, false);
75 ///
76 /// // Apply with complex filter
77 /// let filter = PathFilter::new()
78 /// .with_regex(r"^encoder\..*")
79 /// .with_regex(r"^decoder\..*")
80 /// .with_full_path("head.weight");
81 /// let result = model.apply(snapshots, Some(filter), None, false);
82 ///
83 /// // Apply with enum variant skipping (for PyTorch models)
84 /// let result = model.apply(snapshots, None, None, true);
85 /// ```
86 fn apply(
87 &mut self,
88 snapshots: Vec<TensorSnapshot>,
89 filter: Option<PathFilter>,
90 adapter: Option<Box<dyn ModuleAdapter>>,
91 skip_enum_variants: bool,
92 ) -> ApplyResult
93 where
94 Self: Sized,
95 {
96 let mut applier = Applier::new(snapshots, filter, adapter, skip_enum_variants);
97
98 // Use unsafe to avoid cloning the entire module, which would double the memory usage
99 // We read the module out, map it, then write it back
100 // See https://github.com/tracel-ai/burn/issues/3754
101 unsafe {
102 // Read the module out of self (moves it, leaving self in undefined state)
103 let module = core::ptr::read(self as *const Self);
104
105 // Map the module to create a new one with updated tensors
106 let new_module = module.map(&mut applier);
107
108 // Write the new module back to self
109 core::ptr::write(self as *mut Self, new_module);
110 }
111
112 applier.into_result()
113 }
114
115 /// Saves tensor snapshots into a [`ModuleStore`].
116 ///
117 /// This method allows using a `ModuleStore` implementation to handle the
118 /// collection and writing logic in a configurable way.
119 ///
120 /// # Arguments
121 ///
122 /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors
123 fn save_into<P>(&self, store: &mut P) -> Result<(), P::Error>
124 where
125 P: ModuleStore,
126 {
127 store.collect_from(self)
128 }
129
130 /// Loads tensor data from a [`ModuleStore`].
131 ///
132 /// This method allows using a `ModuleStore` implementation to handle the
133 /// loading and application logic in a configurable way.
134 ///
135 /// # Arguments
136 ///
137 /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors
138 fn load_from<P>(&mut self, store: &mut P) -> Result<ApplyResult, P::Error>
139 where
140 P: ModuleStore,
141 {
142 store.apply_to(self)
143 }
144}
145
146/// A trait for handling module storage operations.
147///
148/// `ModuleStore` provides a unified interface for saving and loading module
149/// tensor data with support for various storage formats and advanced features like filtering,
150/// remapping, and metadata handling.
151pub trait ModuleStore {
152 /// The error type that can be returned during storage operations.
153 ///
154 /// This should be a format-specific error type that provides detailed
155 /// information about what went wrong (e.g., I/O errors, format violations,
156 /// unsupported tensor types).
157 type Error: core::fmt::Debug + core::fmt::Display;
158
159 /// Collect tensor data from a module and store it to storage.
160 ///
161 /// This method traverses the module structure, collects all tensor data
162 /// according to the store's configuration (filters, remapping, etc.),
163 /// and writes it to the underlying storage.
164 ///
165 /// # Arguments
166 ///
167 /// * `module` - The module to collect tensor data from. The module must
168 /// implement `ModuleSnapshot` to provide tensor access.
169 ///
170 /// # Returns
171 ///
172 /// * `Ok(())` - If all tensors were successfully collected and stored
173 /// * `Err(Self::Error)` - If an error occurred during collection or writing
174 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
175 &mut self,
176 module: &M,
177 ) -> Result<(), Self::Error>;
178
179 /// Load stored tensor data and apply it to a module.
180 ///
181 /// This method reads tensor data from storage and applies it to the provided
182 /// module. The operation is flexible and can handle partial matches, missing
183 /// tensors, and extra tensors in the storage.
184 ///
185 /// # Arguments
186 ///
187 /// * `module` - The module to apply tensor data to. The module must
188 /// implement `ModuleSnapshot` to allow tensor updates.
189 ///
190 /// # Returns
191 ///
192 /// * `Ok(ApplyResult)` - Detailed information about the apply operation:
193 /// - `applied`: List of successfully applied tensor names
194 /// - `missing`: Tensors expected by the module but not found in storage
195 /// - `skipped`: Tensors in storage that were not applied (filtered or not needed)
196 /// - `errors`: Non-critical errors that occurred during apply
197 /// * `Err(Self::Error)` - If a critical error prevented the apply operation
198 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
199 &mut self,
200 module: &mut M,
201 ) -> Result<ApplyResult, Self::Error>;
202
203 /// Get a single tensor snapshot by name.
204 ///
205 /// This method provides direct access to individual tensors in storage without
206 /// requiring a module. The returned `TensorSnapshot` uses lazy loading - tensor
207 /// data is only materialized when `to_data()` is called.
208 ///
209 /// **Note:** Key remapping is applied, so use the remapped name if configured.
210 /// Filters are NOT applied - use `apply_to()` for filtered loading.
211 ///
212 /// Results are cached after the first call for efficient repeated access.
213 ///
214 /// # Arguments
215 ///
216 /// * `name` - The tensor name/path (e.g., "encoder.layer1.weight")
217 ///
218 /// # Returns
219 ///
220 /// * `Ok(Some(&TensorSnapshot))` - Reference to the tensor snapshot if found
221 /// * `Ok(None)` - If no tensor with that name exists
222 /// * `Err(Self::Error)` - If an error occurred accessing storage
223 ///
224 /// # Example
225 ///
226 /// ```rust,ignore
227 /// let mut store = BurnpackStore::from_file("model.bpk");
228 /// if let Some(snapshot) = store.get_snapshot("encoder.weight")? {
229 /// println!("Shape: {:?}", snapshot.shape);
230 /// println!("Dtype: {:?}", snapshot.dtype);
231 /// let data = snapshot.to_data()?; // Lazy load
232 /// }
233 /// ```
234 fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error>;
235
236 /// Get all tensor snapshots from storage as an ordered map.
237 ///
238 /// This method returns all tensors in storage as lazy-loading snapshots,
239 /// organized in a `BTreeMap` for efficient lookup by name. The map preserves
240 /// alphabetical ordering of tensor names.
241 ///
242 /// **Note:** This returns ALL tensors in storage, regardless of any filter
243 /// settings. Filters are only applied during `apply_to()`. Key remapping
244 /// IS applied, so tensor names reflect any configured remapping.
245 ///
246 /// Results are cached after the first call for efficient repeated access.
247 ///
248 /// # Returns
249 ///
250 /// * `Ok(&BTreeMap<String, TensorSnapshot>)` - Reference to all tensor snapshots
251 /// * `Err(Self::Error)` - If an error occurred accessing storage
252 ///
253 /// # Example
254 ///
255 /// ```rust,ignore
256 /// let mut store = SafetensorsStore::from_file("model.safetensors");
257 /// let snapshots = store.get_all_snapshots()?;
258 /// for (name, snapshot) in snapshots {
259 /// println!("{}: {:?}", name, snapshot.shape);
260 /// }
261 /// ```
262 fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error>;
263
264 /// Get all tensor names/keys in storage.
265 ///
266 /// This method returns the names of all tensors in storage.
267 /// Useful for inspecting storage contents or checking if specific tensors exist.
268 ///
269 /// **Note:** Returns ALL tensor names regardless of filter settings.
270 /// Key remapping IS applied, so names reflect any configured remapping.
271 ///
272 /// # Returns
273 ///
274 /// * `Ok(Vec<String>)` - All tensor names in storage
275 /// * `Err(Self::Error)` - If an error occurred accessing storage
276 ///
277 /// # Example
278 ///
279 /// ```rust,ignore
280 /// let mut store = PytorchStore::from_file("model.pth");
281 /// let keys = store.keys()?;
282 /// println!("Tensors in file: {:?}", keys);
283 /// ```
284 fn keys(&mut self) -> Result<Vec<String>, Self::Error>;
285}
286
287// Blanket implementation for all modules
288impl<B: Backend, M: Module<B>> ModuleSnapshot<B> for M {}