burn_store/traits.rs
1use alloc::boxed::Box;
2use alloc::vec::Vec;
3
4use super::applier::{Applier, ApplyResult};
5use crate::collector::Collector;
6use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
7use burn_core::module::Module;
8use burn_tensor::backend::Backend;
9
10/// Extension trait for modules that provides tensor storage functionality.
11///
12/// This trait provides convenient methods to collect and apply tensor snapshots from any Burn module.
13/// Collection operations create lightweight tensor snapshots without immediately copying data.
14/// Apply operations apply tensor data from snapshots to the corresponding tensors in the module.
15pub trait ModuleSnapshot<B: Backend>: Module<B> {
16 /// Collects tensor snapshots for inspection without copying data.
17 ///
18 /// Returns a vector of `TensorSnapshot` objects that can lazily materialize the tensor data.
19 /// Each `TensorSnapshot` contains the full path accessible via `snapshot.full_path()`.
20 ///
21 /// # Arguments
22 ///
23 /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.
24 /// When `None`, all tensors are collected.
25 /// * `adapter` - Optional adapter to transform tensors based on container types.
26 /// Applied to all collected tensors before returning.
27 fn collect(
28 &self,
29 filter: Option<PathFilter>,
30 adapter: Option<Box<dyn ModuleAdapter>>,
31 ) -> Vec<TensorSnapshot> {
32 let mut collector = Collector::new(filter, adapter);
33 self.visit(&mut collector);
34 collector.into_tensors()
35 }
36
37 /// Applies tensor snapshots to the module.
38 ///
39 /// This is the primary apply method that applies tensor data from `TensorSnapshot`s
40 /// to the corresponding tensors in the module. The snapshots are typically obtained
41 /// from `collect()` or loaded from storage.
42 ///
43 /// # Arguments
44 ///
45 /// * `snapshots` - A vector of TensorSnapshot objects
46 /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
47 /// When `None`, all available tensors are applied.
48 /// * `adapter` - Optional adapter to transform tensors based on container types
49 ///
50 /// # Returns
51 ///
52 /// An [`ApplyResult`] containing information about applied, skipped, missing,
53 /// and unused tensors, as well as any errors encountered.
54 ///
55 /// # Examples
56 ///
57 /// ```rust,ignore
58 /// use burn_store::PathFilter;
59 ///
60 /// // Apply all tensors
61 /// let result = model.apply(snapshots, None, None);
62 ///
63 /// // Apply only encoder tensors
64 /// let filter = PathFilter::new().with_regex(r"^encoder\..*");
65 /// let result = model.apply(snapshots, Some(filter), None);
66 ///
67 /// // Apply with complex filter
68 /// let filter = PathFilter::new()
69 /// .with_regex(r"^encoder\..*")
70 /// .with_regex(r"^decoder\..*")
71 /// .with_full_path("head.weight");
72 /// let result = model.apply(snapshots, Some(filter), None);
73 /// ```
74 fn apply(
75 &mut self,
76 snapshots: Vec<TensorSnapshot>,
77 filter: Option<PathFilter>,
78 adapter: Option<Box<dyn ModuleAdapter>>,
79 ) -> ApplyResult
80 where
81 Self: Sized,
82 {
83 let mut applier = Applier::new(snapshots, filter, adapter);
84
85 // Use unsafe to avoid cloning the entire module, which would double the memory usage
86 // We read the module out, map it, then write it back
87 // See https://github.com/tracel-ai/burn/issues/3754
88 unsafe {
89 // Read the module out of self (moves it, leaving self in undefined state)
90 let module = core::ptr::read(self as *const Self);
91
92 // Map the module to create a new one with updated tensors
93 let new_module = module.map(&mut applier);
94
95 // Write the new module back to self
96 core::ptr::write(self as *mut Self, new_module);
97 }
98
99 applier.into_result()
100 }
101
102 /// Saves tensor snapshots into a [`ModuleStore`].
103 ///
104 /// This method allows using a `ModuleStore` implementation to handle the
105 /// collection and writing logic in a configurable way.
106 ///
107 /// # Arguments
108 ///
109 /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors
110 fn save_into<P>(&self, store: &mut P) -> Result<(), P::Error>
111 where
112 P: ModuleStore,
113 {
114 store.collect_from(self)
115 }
116
117 /// Loads tensor data from a [`ModuleStore`].
118 ///
119 /// This method allows using a `ModuleStore` implementation to handle the
120 /// loading and application logic in a configurable way.
121 ///
122 /// # Arguments
123 ///
124 /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors
125 fn load_from<P>(&mut self, store: &mut P) -> Result<ApplyResult, P::Error>
126 where
127 P: ModuleStore,
128 {
129 store.apply_to(self)
130 }
131}
132
133/// A trait for handling module storage operations.
134///
135/// `ModuleStore` provides a unified interface for saving and loading module
136/// tensor data with support for various storage formats and advanced features like filtering,
137/// remapping, and metadata handling.
138pub trait ModuleStore {
139 /// The error type that can be returned during storage operations.
140 ///
141 /// This should be a format-specific error type that provides detailed
142 /// information about what went wrong (e.g., I/O errors, format violations,
143 /// unsupported tensor types).
144 type Error: core::fmt::Debug + core::fmt::Display;
145
146 /// Collect tensor data from a module and store it to storage.
147 ///
148 /// This method traverses the module structure, collects all tensor data
149 /// according to the store's configuration (filters, remapping, etc.),
150 /// and writes it to the underlying storage.
151 ///
152 /// # Arguments
153 ///
154 /// * `module` - The module to collect tensor data from. The module must
155 /// implement `ModuleSnapshot` to provide tensor access.
156 ///
157 /// # Returns
158 ///
159 /// * `Ok(())` - If all tensors were successfully collected and stored
160 /// * `Err(Self::Error)` - If an error occurred during collection or writing
161 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
162 &mut self,
163 module: &M,
164 ) -> Result<(), Self::Error>;
165
166 /// Load stored tensor data and apply it to a module.
167 ///
168 /// This method reads tensor data from storage and applies it to the provided
169 /// module. The operation is flexible and can handle partial matches, missing
170 /// tensors, and extra tensors in the storage.
171 ///
172 /// # Arguments
173 ///
174 /// * `module` - The module to apply tensor data to. The module must
175 /// implement `ModuleSnapshot` to allow tensor updates.
176 ///
177 /// # Returns
178 ///
179 /// * `Ok(ApplyResult)` - Detailed information about the apply operation:
180 /// - `applied`: List of successfully applied tensor names
181 /// - `missing`: Tensors expected by the module but not found in storage
182 /// - `skipped`: Tensors in storage that were not applied (filtered or not needed)
183 /// - `errors`: Non-critical errors that occurred during apply
184 /// * `Err(Self::Error)` - If a critical error prevented the apply operation
185 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
186 &mut self,
187 module: &mut M,
188 ) -> Result<ApplyResult, Self::Error>;
189}
190
191// Blanket implementation for all modules
192impl<B: Backend, M: Module<B>> ModuleSnapshot<B> for M {}