burn_store/
applier.rs

1//! Applier that correctly applies tensor snapshots with adapter support
2
3use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::{String, ToString};
6use alloc::vec::Vec;
7
8use hashbrown::{HashMap, HashSet};
9
10use burn_core::module::{ModuleMapper, Param};
11use burn_tensor::{Bool, Int, Shape, Tensor, backend::Backend};
12
13use crate::apply_result::{ApplyError, ApplyResult};
14use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
15
16/// Applier that applies tensor snapshots to module parameters
17/// with proper adapter support using container type information
18pub struct Applier<B: Backend> {
19    /// Map of tensor paths to their snapshots
20    snapshots: HashMap<String, TensorSnapshot>,
21    /// Current path in the module hierarchy
22    path_stack: Vec<String>,
23    /// Current container type stack in the module hierarchy
24    container_stack: Vec<String>,
25    /// Optional filter for selective application
26    filter: Option<PathFilter>,
27    /// Optional adapter to transform tensors based on container types
28    adapter: Option<Box<dyn ModuleAdapter>>,
29    /// Successfully applied tensor paths
30    applied: Vec<String>,
31    /// Skipped tensor paths
32    skipped: HashSet<String>,
33    /// Errors encountered during application
34    errors: Vec<ApplyError>,
35    /// Track visited paths with their container stacks (in dot notation) to find missing tensors
36    visited_paths: HashMap<String, String>,
37    /// Skip enum variant names when matching paths
38    /// When true, "feature.BaseConv.weight" will also try to match "feature.weight"
39    skip_enum_variants: bool,
40    /// Phantom data for backend type
41    _backend: core::marker::PhantomData<B>,
42}
43
44impl<B: Backend> Applier<B> {
45    /// Create a new applier with snapshots, optional filter, and optional adapter
46    ///
47    /// # Arguments
48    ///
49    /// * `views` - A vector of TensorSnapshot objects to apply
50    /// * `filter` - An optional [`PathFilter`] to determine which tensors to apply.
51    ///   When `None`, all available tensors are applied.
52    /// * `adapter` - Optional adapter to transform tensors based on container types
53    /// * `skip_enum_variants` - Skip enum variant names when matching paths
54    pub fn new(
55        views: Vec<TensorSnapshot>,
56        filter: Option<PathFilter>,
57        adapter: Option<Box<dyn ModuleAdapter>>,
58        skip_enum_variants: bool,
59    ) -> Self {
60        let views_map: HashMap<String, TensorSnapshot> = views
61            .into_iter()
62            .map(|view| (view.full_path(), view))
63            .collect();
64
65        Self {
66            snapshots: views_map,
67            path_stack: Vec::new(),
68            container_stack: Vec::new(),
69            filter,
70            adapter,
71            applied: Vec::new(),
72            skipped: HashSet::new(),
73            errors: Vec::new(),
74            visited_paths: HashMap::new(),
75            skip_enum_variants,
76            _backend: core::marker::PhantomData,
77        }
78    }
79
80    /// Get the current path in the module hierarchy
81    fn current_path(&self) -> String {
82        self.path_stack.join(".")
83    }
84
85    /// Get the current module type (last Struct/Enum in container stack)
86    fn current_module_type(&self) -> Option<&str> {
87        self.container_stack
88            .iter()
89            .rev()
90            .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
91            .map(|s| s.as_str())
92    }
93
94    /// Check if a tensor should be applied based on filter
95    fn should_apply(&self) -> bool {
96        match &self.filter {
97            None => true,
98            Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),
99        }
100    }
101
102    /// Convert the applier into a result
103    pub fn into_result(self) -> ApplyResult {
104        let mut unused: Vec<String> = self
105            .snapshots
106            .keys()
107            .filter(|path| !self.visited_paths.contains_key(*path) && !self.skipped.contains(*path))
108            .cloned()
109            .collect();
110        // Sort for stable output order
111        unused.sort();
112
113        // Create a set of successfully applied paths for efficient lookup
114        let applied_set: HashSet<String> = self.applied.iter().cloned().collect();
115
116        // Extract paths that have errors - these are not "missing", they were found but had issues
117        let errored_paths: HashSet<String> = self
118            .errors
119            .iter()
120            .map(|e| match e {
121                ApplyError::ShapeMismatch { path, .. } => path.clone(),
122                ApplyError::DTypeMismatch { path, .. } => path.clone(),
123                ApplyError::AdapterError { path, .. } => path.clone(),
124                ApplyError::LoadError { path, .. } => path.clone(),
125            })
126            .collect();
127
128        // A path is missing if it was visited but not successfully applied, not skipped, and didn't have an error
129        // Store both the path and its container stack (in dot notation)
130        let mut missing: Vec<(String, String)> = self
131            .visited_paths
132            .into_iter()
133            .filter(|(p, _)| {
134                !applied_set.contains(p) && !self.skipped.contains(p) && !errored_paths.contains(p)
135            })
136            .collect();
137        // Sort for stable output order (by path)
138        missing.sort_by(|a, b| a.0.cmp(&b.0));
139
140        // Convert skipped HashSet to sorted Vec for stable output
141        let mut skipped: Vec<String> = self.skipped.into_iter().collect();
142        skipped.sort();
143
144        ApplyResult {
145            applied: self.applied,
146            skipped,
147            missing,
148            unused,
149            errors: self.errors,
150        }
151    }
152
153    /// Apply a tensor snapshot with shape validation and optional adapter transformation
154    /// Returns None if snapshot not found, filtered, or validation fails
155    fn apply_tensor<const D: usize, K>(
156        &mut self,
157        target_device: &B::Device,
158        target_shape: Shape,
159    ) -> Option<Tensor<B, D, K>>
160    where
161        K: burn_tensor::TensorKind<B>,
162        K: burn_tensor::BasicOps<B>,
163    {
164        let path = self.current_path();
165        let container_stack_str = self.container_stack.join(".");
166        self.visited_paths.insert(path.clone(), container_stack_str);
167
168        // Try to get snapshot with original path first
169        let mut snapshot = self.snapshots.get(&path).cloned();
170
171        // If not found and we have an adapter, try alternative parameter names
172        if snapshot.is_none()
173            && let Some(ref adapter) = self.adapter
174            && let Some(module_type) = self.current_module_type()
175        {
176            // Get alternative name based on current module type (user-defined module only)
177            let param_name = self.path_stack.last()?;
178
179            if let Some(alt_name) = adapter.get_alternative_param_name(param_name, module_type) {
180                // Build alternative path with parameter name substitution
181                let mut alt_path_stack = self.path_stack.clone();
182                *alt_path_stack.last_mut().unwrap() = alt_name.clone();
183                let alt_path = alt_path_stack.join(".");
184
185                // Try to get snapshot with alternative name
186                snapshot = self.snapshots.get(&alt_path).cloned();
187
188                // Don't mark the alternative path as visited - only the original Burn path
189                // should be tracked. The alternative path is just for lookup.
190            }
191        }
192
193        let mut snapshot = snapshot?;
194
195        // Apply adapter transformation using current container_stack context (for data transformation like transpose)
196        if let Some(ref adapter) = self.adapter {
197            // Create a temporary snapshot with current context for adaptation
198            let snapshot_with_context = TensorSnapshot::from_closure(
199                snapshot.clone_data_fn(),
200                snapshot.dtype,
201                snapshot.shape.clone(),
202                self.path_stack.clone(),
203                self.container_stack.clone(),
204                snapshot.tensor_id.unwrap_or_default(),
205            );
206
207            // Transform using adapter (handles transpose)
208            snapshot = adapter.adapt(&snapshot_with_context);
209        }
210
211        // Check if we should apply based on filter
212        if !self.should_apply() {
213            self.skipped.insert(path.clone());
214            return None;
215        }
216
217        // Load tensor data
218        let data = match snapshot.to_data() {
219            Ok(data) => data,
220            Err(e) => {
221                self.errors.push(ApplyError::LoadError {
222                    path: path.clone(),
223                    message: format!("Failed to load tensor data: {:?}", e),
224                });
225                return None; // Signal caller to fall back to initialization
226            }
227        };
228
229        // Validate shape
230        if data.shape != target_shape.dims {
231            self.errors.push(ApplyError::ShapeMismatch {
232                path: path.clone(),
233                expected: target_shape.dims,
234                found: data.shape.clone(),
235            });
236            return None; // Signal caller to fall back to initialization
237        }
238
239        self.applied.push(path);
240        Some(Tensor::from_data_dtype(data, target_device, snapshot.dtype))
241    }
242}
243
244impl<B: Backend> ModuleMapper<B> for Applier<B> {
245    fn enter_module(&mut self, name: &str, container_type: &str) {
246        // Always track the container type for proper module type detection
247        self.container_stack.push(container_type.to_string());
248
249        // Only add to path if it's not an enum variant (when skip_enum_variants is enabled)
250        // This ensures paths are built without enum variant names from the start
251        if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
252            self.path_stack.push(name.to_string());
253        }
254    }
255
256    fn exit_module(&mut self, _name: &str, container_type: &str) {
257        self.container_stack.pop();
258
259        // Only pop from path if we added it (not an enum variant when skip_enum_variants is enabled)
260        if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
261            self.path_stack.pop();
262        }
263    }
264
265    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
266        let param_id = param.id;
267        let target_device = param.lazy_device();
268        let target_shape = param.lazy_shape();
269
270        // Try to apply snapshot with shape validation
271        match self.apply_tensor(&target_device, target_shape) {
272            Some(tensor) => {
273                // We have a tensor to apply - load it
274                param.transform_for_load(tensor, param_id)
275            }
276            None => {
277                // No snapshot, filtered, or validation failed - return param unchanged
278                param
279            }
280        }
281    }
282
283    fn map_int<const D: usize>(
284        &mut self,
285        param: Param<Tensor<B, D, Int>>,
286    ) -> Param<Tensor<B, D, Int>> {
287        let param_id = param.id;
288        let target_device = param.lazy_device();
289        let target_shape = param.lazy_shape();
290
291        // Try to apply snapshot with shape validation
292        match self.apply_tensor(&target_device, target_shape) {
293            Some(tensor) => {
294                // We have a tensor to apply - load it
295                param.transform_for_load(tensor, param_id)
296            }
297            None => {
298                // No snapshot, filtered, or validation failed - return param unchanged
299                param
300            }
301        }
302    }
303
304    fn map_bool<const D: usize>(
305        &mut self,
306        param: Param<Tensor<B, D, Bool>>,
307    ) -> Param<Tensor<B, D, Bool>> {
308        let param_id = param.id;
309        let target_device = param.lazy_device();
310        let target_shape = param.lazy_shape();
311
312        // Try to apply snapshot with shape validation
313        match self.apply_tensor(&target_device, target_shape) {
314            Some(tensor) => {
315                // We have a tensor to apply - load it
316                param.transform_for_load(tensor, param_id)
317            }
318            None => {
319                // No snapshot, filtered, or validation failed - return param unchanged
320                param
321            }
322        }
323    }
324}
325
326#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))]
327mod tests {
328    use super::*;
329    use burn_core::module::{ModuleMapper, Param, ParamId};
330    use burn_tensor::{DType, Tensor, TensorData};
331
332    type TestBackend = burn_ndarray::NdArray;
333
334    #[test]
335    fn root_level_parameters() {
336        let device = Default::default();
337
338        // Create root-level parameters (not inside any module)
339        let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
340        let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
341
342        // Create snapshots with root-level paths (single-element path, no nested modules)
343        let weight_snapshot = crate::TensorSnapshot::from_data(
344            weight.val().to_data(),
345            vec!["weight".to_string()], // root-level parameter name
346            vec![],                     // no container
347            ParamId::new(),
348        );
349
350        let bias_snapshot = crate::TensorSnapshot::from_data(
351            bias.val().to_data(),
352            vec!["bias".to_string()], // root-level parameter name
353            vec![],                   // no container
354            ParamId::new(),
355        );
356
357        // Create applier with root-level snapshots
358        let mut applier =
359            Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None, false);
360
361        // Create new params to load into
362        let weight_target = Param::initialized(
363            ParamId::new(),
364            Tensor::<TestBackend, 2>::zeros([2, 2], &device),
365        );
366        let bias_target = Param::initialized(
367            ParamId::new(),
368            Tensor::<TestBackend, 1>::zeros([2], &device),
369        );
370
371        // Apply using the ModuleMapper interface - simulate module traversal
372        // Enter "weight" path (as if we're visiting a field named "weight")
373        applier.enter_module("weight", "");
374        let weight_loaded = applier.map_float(weight_target);
375        applier.exit_module("weight", "");
376
377        // Enter "bias" path (as if we're visiting a field named "bias")
378        applier.enter_module("bias", "");
379        let bias_loaded = applier.map_float(bias_target);
380        applier.exit_module("bias", "");
381
382        // Verify values were loaded
383        let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();
384        let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();
385
386        assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
387        assert_eq!(bias_data, vec![5.0, 6.0]);
388
389        // Verify applier result
390        let result = applier.into_result();
391        assert_eq!(result.applied.len(), 2);
392        assert_eq!(result.errors.len(), 0);
393    }
394
395    /// Test that the applier preserves dtype when loading tensor data.
396    /// This is a regression test for the bug where F16 tensors were being
397    /// loaded as F32 because `Tensor::from_data` was used instead of
398    /// `Tensor::from_data_dtype`.
399    #[test]
400    fn dtype_preservation_f64() {
401        // Use NdArray<f64> backend to properly test F64 dtype preservation
402        type TestBackendF64 = burn_ndarray::NdArray<f64>;
403        let device = Default::default();
404
405        // Create TensorData with F64 dtype explicitly
406        let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);
407        assert_eq!(f64_data.dtype, DType::F64, "Test setup: data should be F64");
408
409        // Create a snapshot with F64 data
410        let snapshot = crate::TensorSnapshot::from_data(
411            f64_data.clone(),
412            vec!["weight".to_string()],
413            vec![],
414            ParamId::new(),
415        );
416        assert_eq!(
417            snapshot.dtype,
418            DType::F64,
419            "Snapshot should preserve F64 dtype"
420        );
421
422        // Create applier with the F64 snapshot
423        let mut applier = Applier::<TestBackendF64>::new(vec![snapshot], None, None, false);
424
425        // Create target parameter
426        let target = Param::initialized(
427            ParamId::new(),
428            Tensor::<TestBackendF64, 2>::zeros([2, 2], &device),
429        );
430
431        // Apply the snapshot
432        applier.enter_module("weight", "");
433        let loaded = applier.map_float(target);
434        applier.exit_module("weight", "");
435
436        // Verify dtype is preserved - this would fail before the fix
437        // because the data would be converted to the backend's default FloatElem
438        assert_eq!(
439            loaded.val().dtype(),
440            DType::F64,
441            "Loaded tensor should have F64 dtype"
442        );
443
444        // Verify data values are correct
445        let loaded_data = loaded.val().to_data().to_vec::<f64>().unwrap();
446        assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
447
448        // Verify applier result
449        let result = applier.into_result();
450        assert_eq!(result.applied.len(), 1);
451        assert_eq!(result.errors.len(), 0);
452    }
453
454    /// Test that F32 dtype is preserved when loading (verifies we didn't break F32 handling)
455    #[test]
456    fn dtype_preservation_f32() {
457        let device = Default::default();
458
459        // Create TensorData with F32 dtype
460        let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]);
461        assert_eq!(f32_data.dtype, DType::F32);
462
463        // Create a snapshot with F32 data
464        let snapshot = crate::TensorSnapshot::from_data(
465            f32_data.clone(),
466            vec!["weight".to_string()],
467            vec![],
468            ParamId::new(),
469        );
470        assert_eq!(snapshot.dtype, DType::F32);
471
472        // Create applier with the F32 snapshot
473        let mut applier = Applier::<TestBackend>::new(vec![snapshot], None, None, false);
474
475        // Create target parameter
476        let target = Param::initialized(
477            ParamId::new(),
478            Tensor::<TestBackend, 2>::zeros([2, 2], &device),
479        );
480
481        // Apply the snapshot
482        applier.enter_module("weight", "");
483        let loaded = applier.map_float(target);
484        applier.exit_module("weight", "");
485
486        // Verify dtype is F32
487        assert_eq!(loaded.val().dtype(), DType::F32);
488
489        // Verify data values
490        let loaded_data = loaded.val().to_data().to_vec::<f32>().unwrap();
491        assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
492    }
493
494    /// Test that F16 dtype is correctly preserved in TensorSnapshot.
495    ///
496    /// Note: Full F16 tensor loading requires a backend that supports F16
497    /// (e.g., CUDA, WebGPU). The NdArray backend does not support F16.
498    /// This test verifies that the snapshot correctly preserves F16 dtype,
499    /// which is the key part of the dtype preservation fix.
500    #[test]
501    fn dtype_preservation_f16_snapshot() {
502        use half::f16;
503
504        // Create TensorData with F16 dtype using the half crate
505        let f16_values: Vec<f16> = vec![
506            f16::from_f32(1.0),
507            f16::from_f32(2.0),
508            f16::from_f32(3.0),
509            f16::from_f32(4.0),
510        ];
511        let f16_data = TensorData::new(f16_values.clone(), [2, 2]);
512        assert_eq!(
513            f16_data.dtype,
514            DType::F16,
515            "TensorData should have F16 dtype"
516        );
517
518        // Create a snapshot with F16 data
519        let snapshot = crate::TensorSnapshot::from_data(
520            f16_data.clone(),
521            vec!["weight".to_string()],
522            vec![],
523            ParamId::new(),
524        );
525
526        // Verify snapshot preserves F16 dtype
527        assert_eq!(
528            snapshot.dtype,
529            DType::F16,
530            "TensorSnapshot should preserve F16 dtype"
531        );
532
533        // Verify the data can be retrieved with correct dtype
534        let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
535        assert_eq!(
536            retrieved_data.dtype,
537            DType::F16,
538            "Retrieved data should have F16 dtype"
539        );
540
541        // Verify the actual values are preserved
542        let retrieved_values: Vec<f16> = retrieved_data
543            .to_vec()
544            .expect("Should be able to convert to f16 vec");
545        assert_eq!(
546            retrieved_values, f16_values,
547            "F16 values should be preserved"
548        );
549
550        // Note: To fully test F16 tensor creation, you would need a backend
551        // that supports F16 (like CUDA or WebGPU). The applier fix ensures
552        // that `Tensor::from_data_dtype(data, device, snapshot.dtype)` is
553        // called with DType::F16, which will correctly create an F16 tensor
554        // on backends that support it.
555    }
556
557    /// Test that BF16 dtype is correctly preserved in TensorSnapshot.
558    #[test]
559    fn dtype_preservation_bf16_snapshot() {
560        use half::bf16;
561
562        // Create TensorData with BF16 dtype
563        let bf16_values: Vec<bf16> = vec![
564            bf16::from_f32(1.0),
565            bf16::from_f32(2.0),
566            bf16::from_f32(3.0),
567            bf16::from_f32(4.0),
568        ];
569        let bf16_data = TensorData::new(bf16_values.clone(), [2, 2]);
570        assert_eq!(
571            bf16_data.dtype,
572            DType::BF16,
573            "TensorData should have BF16 dtype"
574        );
575
576        // Create a snapshot with BF16 data
577        let snapshot = crate::TensorSnapshot::from_data(
578            bf16_data.clone(),
579            vec!["weight".to_string()],
580            vec![],
581            ParamId::new(),
582        );
583
584        // Verify snapshot preserves BF16 dtype
585        assert_eq!(
586            snapshot.dtype,
587            DType::BF16,
588            "TensorSnapshot should preserve BF16 dtype"
589        );
590
591        // Verify the data can be retrieved with correct dtype
592        let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
593        assert_eq!(
594            retrieved_data.dtype,
595            DType::BF16,
596            "Retrieved data should have BF16 dtype"
597        );
598
599        // Verify the actual values are preserved
600        let retrieved_values: Vec<bf16> = retrieved_data
601            .to_vec()
602            .expect("Should be able to convert to bf16 vec");
603        assert_eq!(
604            retrieved_values, bf16_values,
605            "BF16 values should be preserved"
606        );
607    }
608}