burn_store/
traits.rs

1use alloc::boxed::Box;
2use alloc::vec::Vec;
3
4use super::applier::{Applier, ApplyResult};
5use crate::collector::Collector;
6use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
7use burn_core::module::Module;
8use burn_tensor::backend::Backend;
9
10/// Extension trait for modules that provides tensor storage functionality.
11///
12/// This trait provides convenient methods to collect and apply tensor snapshots from any Burn module.
13/// Collection operations create lightweight tensor snapshots without immediately copying data.
14/// Apply operations apply tensor data from snapshots to the corresponding tensors in the module.
15pub trait ModuleSnapshot<B: Backend>: Module<B> {
16    /// Collects tensor snapshots for inspection without copying data.
17    ///
18    /// Returns a vector of `TensorSnapshot` objects that can lazily materialize the tensor data.
19    /// Each `TensorSnapshot` contains the full path accessible via `snapshot.full_path()`.
20    ///
21    /// # Arguments
22    ///
23    /// * `filter` - An optional [`PathFilter`] to determine which tensors to collect.
24    ///   When `None`, all tensors are collected.
25    /// * `adapter` - Optional adapter to transform tensors based on container types.
26    ///   Applied to all collected tensors before returning.
27    fn collect(
28        &self,
29        filter: Option<PathFilter>,
30        adapter: Option<Box<dyn ModuleAdapter>>,
31    ) -> Vec<TensorSnapshot> {
32        let mut collector = Collector::new(filter, adapter);
33        self.visit(&mut collector);
34        collector.into_tensors()
35    }
36
37    /// Applies tensor snapshots to the module.
38    ///
39    /// This is the primary apply method that applies tensor data from `TensorSnapshot`s
40    /// to the corresponding tensors in the module. The snapshots are typically obtained
41    /// from `collect()` or loaded from storage.
42    ///
43    /// # Arguments
44    ///
45    /// * `snapshots` - A vector of TensorSnapshot objects
46    /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
47    ///   When `None`, all available tensors are applied.
48    /// * `adapter` - Optional adapter to transform tensors based on container types
49    ///
50    /// # Returns
51    ///
52    /// An [`ApplyResult`] containing information about applied, skipped, missing,
53    /// and unused tensors, as well as any errors encountered.
54    ///
55    /// # Examples
56    ///
57    /// ```rust,ignore
58    /// use burn_store::PathFilter;
59    ///
60    /// // Apply all tensors
61    /// let result = model.apply(snapshots, None, None);
62    ///
63    /// // Apply only encoder tensors
64    /// let filter = PathFilter::new().with_regex(r"^encoder\..*");
65    /// let result = model.apply(snapshots, Some(filter), None);
66    ///
67    /// // Apply with complex filter
68    /// let filter = PathFilter::new()
69    ///     .with_regex(r"^encoder\..*")
70    ///     .with_regex(r"^decoder\..*")
71    ///     .with_full_path("head.weight");
72    /// let result = model.apply(snapshots, Some(filter), None);
73    /// ```
74    fn apply(
75        &mut self,
76        snapshots: Vec<TensorSnapshot>,
77        filter: Option<PathFilter>,
78        adapter: Option<Box<dyn ModuleAdapter>>,
79    ) -> ApplyResult
80    where
81        Self: Sized,
82    {
83        let mut applier = Applier::new(snapshots, filter, adapter);
84
85        // Use unsafe to avoid cloning the entire module, which would double the memory usage
86        // We read the module out, map it, then write it back
87        // See https://github.com/tracel-ai/burn/issues/3754
88        unsafe {
89            // Read the module out of self (moves it, leaving self in undefined state)
90            let module = core::ptr::read(self as *const Self);
91
92            // Map the module to create a new one with updated tensors
93            let new_module = module.map(&mut applier);
94
95            // Write the new module back to self
96            core::ptr::write(self as *mut Self, new_module);
97        }
98
99        applier.into_result()
100    }
101
102    /// Saves tensor snapshots into a [`ModuleStore`].
103    ///
104    /// This method allows using a `ModuleStore` implementation to handle the
105    /// collection and writing logic in a configurable way.
106    ///
107    /// # Arguments
108    ///
109    /// * `store` - A mutable reference to a [`ModuleStore`] that will collect and save the tensors
110    fn save_into<P>(&self, store: &mut P) -> Result<(), P::Error>
111    where
112        P: ModuleStore,
113    {
114        store.collect_from(self)
115    }
116
117    /// Loads tensor data from a [`ModuleStore`].
118    ///
119    /// This method allows using a `ModuleStore` implementation to handle the
120    /// loading and application logic in a configurable way.
121    ///
122    /// # Arguments
123    ///
124    /// * `store` - A mutable reference to a [`ModuleStore`] that will load and apply tensors
125    fn load_from<P>(&mut self, store: &mut P) -> Result<ApplyResult, P::Error>
126    where
127        P: ModuleStore,
128    {
129        store.apply_to(self)
130    }
131}
132
133/// A trait for handling module storage operations.
134///
135/// `ModuleStore` provides a unified interface for saving and loading module
136/// tensor data with support for various storage formats and advanced features like filtering,
137/// remapping, and metadata handling.
138pub trait ModuleStore {
139    /// The error type that can be returned during storage operations.
140    ///
141    /// This should be a format-specific error type that provides detailed
142    /// information about what went wrong (e.g., I/O errors, format violations,
143    /// unsupported tensor types).
144    type Error: core::fmt::Debug + core::fmt::Display;
145
146    /// Collect tensor data from a module and store it to storage.
147    ///
148    /// This method traverses the module structure, collects all tensor data
149    /// according to the store's configuration (filters, remapping, etc.),
150    /// and writes it to the underlying storage.
151    ///
152    /// # Arguments
153    ///
154    /// * `module` - The module to collect tensor data from. The module must
155    ///   implement `ModuleSnapshot` to provide tensor access.
156    ///
157    /// # Returns
158    ///
159    /// * `Ok(())` - If all tensors were successfully collected and stored
160    /// * `Err(Self::Error)` - If an error occurred during collection or writing
161    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
162        &mut self,
163        module: &M,
164    ) -> Result<(), Self::Error>;
165
166    /// Load stored tensor data and apply it to a module.
167    ///
168    /// This method reads tensor data from storage and applies it to the provided
169    /// module. The operation is flexible and can handle partial matches, missing
170    /// tensors, and extra tensors in the storage.
171    ///
172    /// # Arguments
173    ///
174    /// * `module` - The module to apply tensor data to. The module must
175    ///   implement `ModuleSnapshot` to allow tensor updates.
176    ///
177    /// # Returns
178    ///
179    /// * `Ok(ApplyResult)` - Detailed information about the apply operation:
180    ///   - `applied`: List of successfully applied tensor names
181    ///   - `missing`: Tensors expected by the module but not found in storage
182    ///   - `skipped`: Tensors in storage that were not applied (filtered or not needed)
183    ///   - `errors`: Non-critical errors that occurred during apply
184    /// * `Err(Self::Error)` - If a critical error prevented the apply operation
185    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
186        &mut self,
187        module: &mut M,
188    ) -> Result<ApplyResult, Self::Error>;
189}
190
191// Blanket implementation for all modules
192impl<B: Backend, M: Module<B>> ModuleSnapshot<B> for M {}