burn_store/
applier.rs

1//! Applier that correctly applies tensor snapshots with adapter support
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::{String, ToString};
6use alloc::vec::Vec;
7
8use hashbrown::{HashMap, HashSet};
9
10use burn_core::module::{ModuleMapper, Param};
11use burn_tensor::{Bool, DType, Int, Shape, Tensor, backend::Backend};
12
13use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
14
15/// Error types that can occur during tensor application
16#[derive(Debug, Clone)]
17pub enum ApplyError {
18    /// Shape mismatch between expected and actual tensor
19    ShapeMismatch {
20        /// Path of the tensor
21        path: String,
22        /// Expected shape
23        expected: Vec<usize>,
24        /// Found shape
25        found: Vec<usize>,
26    },
27    /// Data type mismatch between expected and actual tensor
28    DTypeMismatch {
29        /// Path of the tensor
30        path: String,
31        /// Expected data type
32        expected: DType,
33        /// Found data type
34        found: DType,
35    },
36    /// Error from adapter transformation
37    AdapterError {
38        /// Path of the tensor
39        path: String,
40        /// Error message
41        message: String,
42    },
43    /// Error loading tensor data
44    LoadError {
45        /// Path of the tensor
46        path: String,
47        /// Error message
48        message: String,
49    },
50}
51
52impl core::fmt::Display for ApplyError {
53    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54        match self {
55            Self::ShapeMismatch {
56                path,
57                expected,
58                found,
59            } => {
60                write!(
61                    f,
62                    "Shape mismatch for '{}': expected {:?}, found {:?}",
63                    path, expected, found
64                )
65            }
66            Self::DTypeMismatch {
67                path,
68                expected,
69                found,
70            } => {
71                write!(
72                    f,
73                    "DType mismatch for '{}': expected {:?}, found {:?}",
74                    path, expected, found
75                )
76            }
77            Self::AdapterError { path, message } => {
78                write!(f, "Adapter error for '{}': {}", path, message)
79            }
80            Self::LoadError { path, message } => {
81                write!(f, "Load error for '{}': {}", path, message)
82            }
83        }
84    }
85}
86
87impl core::error::Error for ApplyError {}
88
89/// Result of applying tensor snapshots to a module
90#[derive(Debug, Clone)]
91pub struct ApplyResult {
92    /// Successfully applied tensor paths
93    pub applied: Vec<String>,
94    /// Skipped tensor paths (due to filter)
95    pub skipped: Vec<String>,
96    /// Missing tensor paths (in module but not in snapshots)
97    pub missing: Vec<String>,
98    /// Unused tensor paths (in snapshots but not in module)
99    pub unused: Vec<String>,
100    /// Errors encountered during application
101    pub errors: Vec<ApplyError>,
102}
103
104impl ApplyResult {
105    /// Check if the apply operation was successful (no errors)
106    /// Note: Missing tensors are not considered errors when allow_partial is true
107    pub fn is_success(&self) -> bool {
108        self.errors.is_empty()
109    }
110}
111
112/// Applier that applies tensor snapshots to module parameters
113/// with proper adapter support using container type information
114pub struct Applier<B: Backend> {
115    /// Map of tensor paths to their snapshots
116    snapshots: HashMap<String, TensorSnapshot>,
117    /// Current path in the module hierarchy
118    path_stack: Vec<String>,
119    /// Current container type stack in the module hierarchy
120    container_stack: Vec<String>,
121    /// Optional filter for selective application
122    filter: Option<PathFilter>,
123    /// Optional adapter to transform tensors based on container types
124    adapter: Option<Box<dyn ModuleAdapter>>,
125    /// Successfully applied tensor paths
126    applied: Vec<String>,
127    /// Skipped tensor paths
128    skipped: HashSet<String>,
129    /// Errors encountered during application
130    errors: Vec<ApplyError>,
131    /// Track visited paths to find missing tensors
132    visited_paths: HashSet<String>,
133    /// Phantom data for backend type
134    _backend: core::marker::PhantomData<B>,
135}
136
137impl<B: Backend> Applier<B> {
138    /// Create a new applier with snapshots, optional filter, and optional adapter
139    ///
140    /// # Arguments
141    ///
142    /// * `views` - A vector of TensorSnapshot objects to apply
143    /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
144    ///   When `None`, all available tensors are applied.
145    /// * `adapter` - Optional adapter to transform tensors based on container types
146    pub fn new(
147        views: Vec<TensorSnapshot>,
148        filter: Option<PathFilter>,
149        adapter: Option<Box<dyn ModuleAdapter>>,
150    ) -> Self {
151        let views_map: HashMap<String, TensorSnapshot> = views
152            .into_iter()
153            .map(|view| (view.full_path(), view))
154            .collect();
155
156        Self {
157            snapshots: views_map,
158            path_stack: Vec::new(),
159            container_stack: Vec::new(),
160            filter,
161            adapter,
162            applied: Vec::new(),
163            skipped: HashSet::new(),
164            errors: Vec::new(),
165            visited_paths: HashSet::new(),
166            _backend: core::marker::PhantomData,
167        }
168    }
169
170    /// Get the current path in the module hierarchy
171    fn current_path(&self) -> String {
172        self.path_stack.join(".")
173    }
174
175    /// Check if a tensor should be applied based on filter
176    fn should_apply(&self) -> bool {
177        match &self.filter {
178            None => true,
179            Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),
180        }
181    }
182
183    /// Apply adapter to a snapshot using current container information
184    fn adapt_snapshot(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
185        if let Some(ref adapter) = self.adapter {
186            // Create a snapshot with proper container information from module traversal
187            let snapshot_with_context = TensorSnapshot::from_closure(
188                snapshot.clone_data_fn(),
189                snapshot.dtype,
190                snapshot.shape.clone(),
191                self.path_stack.clone(), // Use current path from traversal
192                self.container_stack.clone(), // Use current container types!
193                snapshot.tensor_id.unwrap_or_default(),
194            );
195
196            // Apply adapter with full context
197            return adapter.adapt(&snapshot_with_context);
198        }
199        snapshot.clone()
200    }
201
202    /// Convert the applier into a result
203    pub fn into_result(self) -> ApplyResult {
204        let unused: Vec<String> = self
205            .snapshots
206            .keys()
207            .filter(|path| !self.visited_paths.contains(*path) && !self.skipped.contains(*path))
208            .cloned()
209            .collect();
210
211        let missing: Vec<String> = self
212            .visited_paths
213            .into_iter()
214            .filter(|p| !self.snapshots.contains_key(p) && !self.skipped.contains(p))
215            .collect();
216
217        ApplyResult {
218            applied: self.applied,
219            skipped: self.skipped.into_iter().collect(),
220            missing,
221            unused,
222            errors: self.errors,
223        }
224    }
225
226    /// Apply a tensor snapshot with shape validation
227    /// Returns None if snapshot not found, filtered, or validation fails
228    fn apply_tensor<const D: usize, K>(
229        &mut self,
230        target_device: &B::Device,
231        target_shape: Shape,
232    ) -> Option<Tensor<B, D, K>>
233    where
234        K: burn_tensor::TensorKind<B>,
235        K: burn_tensor::BasicOps<B>,
236    {
237        let path = self.current_path();
238        self.visited_paths.insert(path.clone());
239
240        // Check if we have a snapshot for this path
241        let snapshot = match self.snapshots.get(&path) {
242            Some(s) => s,
243            None => {
244                // No snapshot available - signal caller not to apply
245                return None;
246            }
247        };
248
249        // Check if we should apply based on filter
250        if !self.should_apply() {
251            self.skipped.insert(path.clone());
252            return None;
253        }
254
255        // Apply adapter with current container context
256        let adapted_snapshot = self.adapt_snapshot(snapshot);
257        let data = match adapted_snapshot.to_data() {
258            Ok(data) => data,
259            Err(e) => {
260                self.errors.push(ApplyError::LoadError {
261                    path: path.clone(),
262                    message: format!("Failed to load tensor data: {:?}", e),
263                });
264                return None; // Signal caller to fall back to initialization
265            }
266        };
267
268        // Validate shape
269        if data.shape != target_shape.dims {
270            self.errors.push(ApplyError::ShapeMismatch {
271                path: path.clone(),
272                expected: target_shape.dims,
273                found: data.shape.clone(),
274            });
275            return None; // Signal caller to fall back to initialization
276        }
277
278        self.applied.push(path);
279        Some(Tensor::from_data(data, target_device))
280    }
281}
282
283impl<B: Backend> ModuleMapper<B> for Applier<B> {
284    fn enter_module(&mut self, name: &str, container_type: &str) {
285        self.path_stack.push(name.to_string());
286        self.container_stack.push(container_type.to_string());
287    }
288
289    fn exit_module(&mut self, _name: &str, _container_type: &str) {
290        self.path_stack.pop();
291        self.container_stack.pop();
292    }
293
294    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
295        let param_id = param.id;
296        let target_device = param.lazy_device();
297        let target_shape = param.lazy_shape();
298
299        // Try to apply snapshot with shape validation
300        match self.apply_tensor(&target_device, target_shape) {
301            Some(tensor) => {
302                // We have a tensor to apply - load it
303                param.transform_for_load(tensor, param_id)
304            }
305            None => {
306                // No snapshot, filtered, or validation failed - return param unchanged
307                param
308            }
309        }
310    }
311
312    fn map_int<const D: usize>(
313        &mut self,
314        param: Param<Tensor<B, D, Int>>,
315    ) -> Param<Tensor<B, D, Int>> {
316        let param_id = param.id;
317        let target_device = param.lazy_device();
318        let target_shape = param.lazy_shape();
319
320        // Try to apply snapshot with shape validation
321        match self.apply_tensor(&target_device, target_shape) {
322            Some(tensor) => {
323                // We have a tensor to apply - load it
324                param.transform_for_load(tensor, param_id)
325            }
326            None => {
327                // No snapshot, filtered, or validation failed - return param unchanged
328                param
329            }
330        }
331    }
332
333    fn map_bool<const D: usize>(
334        &mut self,
335        param: Param<Tensor<B, D, Bool>>,
336    ) -> Param<Tensor<B, D, Bool>> {
337        let param_id = param.id;
338        let target_device = param.lazy_device();
339        let target_shape = param.lazy_shape();
340
341        // Try to apply snapshot with shape validation
342        match self.apply_tensor(&target_device, target_shape) {
343            Some(tensor) => {
344                // We have a tensor to apply - load it
345                param.transform_for_load(tensor, param_id)
346            }
347            None => {
348                // No snapshot, filtered, or validation failed - return param unchanged
349                param
350            }
351        }
352    }
353}
354
355#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))]
356mod tests {
357    use super::*;
358    use burn_core::module::{ModuleMapper, Param, ParamId};
359    use burn_tensor::Tensor;
360
361    type TestBackend = burn_ndarray::NdArray;
362
363    #[test]
364    fn root_level_parameters() {
365        let device = Default::default();
366
367        // Create root-level parameters (not inside any module)
368        let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
369        let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
370
371        // Create snapshots with root-level paths (single-element path, no nested modules)
372        let weight_snapshot = crate::TensorSnapshot::from_data(
373            weight.val().to_data(),
374            vec!["weight".to_string()], // root-level parameter name
375            vec![],                     // no container
376            ParamId::new(),
377        );
378
379        let bias_snapshot = crate::TensorSnapshot::from_data(
380            bias.val().to_data(),
381            vec!["bias".to_string()], // root-level parameter name
382            vec![],                   // no container
383            ParamId::new(),
384        );
385
386        // Create applier with root-level snapshots
387        let mut applier =
388            Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None);
389
390        // Create new params to load into
391        let weight_target = Param::initialized(
392            ParamId::new(),
393            Tensor::<TestBackend, 2>::zeros([2, 2], &device),
394        );
395        let bias_target = Param::initialized(
396            ParamId::new(),
397            Tensor::<TestBackend, 1>::zeros([2], &device),
398        );
399
400        // Apply using the ModuleMapper interface - simulate module traversal
401        // Enter "weight" path (as if we're visiting a field named "weight")
402        applier.enter_module("weight", "");
403        let weight_loaded = applier.map_float(weight_target);
404        applier.exit_module("weight", "");
405
406        // Enter "bias" path (as if we're visiting a field named "bias")
407        applier.enter_module("bias", "");
408        let bias_loaded = applier.map_float(bias_target);
409        applier.exit_module("bias", "");
410
411        // Verify values were loaded
412        let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();
413        let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();
414
415        assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
416        assert_eq!(bias_data, vec![5.0, 6.0]);
417
418        // Verify applier result
419        let result = applier.into_result();
420        assert_eq!(result.applied.len(), 2);
421        assert_eq!(result.errors.len(), 0);
422    }
423}