1use alloc::boxed::Box;
4use alloc::format;
5use alloc::string::{String, ToString};
6use alloc::vec::Vec;
7
8use hashbrown::{HashMap, HashSet};
9
10use burn_core::module::{ModuleMapper, Param};
11use burn_tensor::{Bool, Int, Shape, Tensor, backend::Backend};
12
13use crate::apply_result::{ApplyError, ApplyResult};
14use crate::{ModuleAdapter, PathFilter, TensorSnapshot};
15
16pub struct Applier<B: Backend> {
19 snapshots: HashMap<String, TensorSnapshot>,
21 path_stack: Vec<String>,
23 container_stack: Vec<String>,
25 filter: Option<PathFilter>,
27 adapter: Option<Box<dyn ModuleAdapter>>,
29 applied: Vec<String>,
31 skipped: HashSet<String>,
33 errors: Vec<ApplyError>,
35 visited_paths: HashMap<String, String>,
37 skip_enum_variants: bool,
40 _backend: core::marker::PhantomData<B>,
42}
43
44impl<B: Backend> Applier<B> {
45 pub fn new(
55 views: Vec<TensorSnapshot>,
56 filter: Option<PathFilter>,
57 adapter: Option<Box<dyn ModuleAdapter>>,
58 skip_enum_variants: bool,
59 ) -> Self {
60 let views_map: HashMap<String, TensorSnapshot> = views
61 .into_iter()
62 .map(|view| (view.full_path(), view))
63 .collect();
64
65 Self {
66 snapshots: views_map,
67 path_stack: Vec::new(),
68 container_stack: Vec::new(),
69 filter,
70 adapter,
71 applied: Vec::new(),
72 skipped: HashSet::new(),
73 errors: Vec::new(),
74 visited_paths: HashMap::new(),
75 skip_enum_variants,
76 _backend: core::marker::PhantomData,
77 }
78 }
79
80 fn current_path(&self) -> String {
82 self.path_stack.join(".")
83 }
84
85 fn current_module_type(&self) -> Option<&str> {
87 self.container_stack
88 .iter()
89 .rev()
90 .find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
91 .map(|s| s.as_str())
92 }
93
94 fn should_apply(&self) -> bool {
96 match &self.filter {
97 None => true,
98 Some(f) => f.matches_with_container_path(&self.path_stack, &self.container_stack),
99 }
100 }
101
102 pub fn into_result(self) -> ApplyResult {
104 let mut unused: Vec<String> = self
105 .snapshots
106 .keys()
107 .filter(|path| !self.visited_paths.contains_key(*path) && !self.skipped.contains(*path))
108 .cloned()
109 .collect();
110 unused.sort();
112
113 let applied_set: HashSet<String> = self.applied.iter().cloned().collect();
115
116 let errored_paths: HashSet<String> = self
118 .errors
119 .iter()
120 .map(|e| match e {
121 ApplyError::ShapeMismatch { path, .. } => path.clone(),
122 ApplyError::DTypeMismatch { path, .. } => path.clone(),
123 ApplyError::AdapterError { path, .. } => path.clone(),
124 ApplyError::LoadError { path, .. } => path.clone(),
125 })
126 .collect();
127
128 let mut missing: Vec<(String, String)> = self
131 .visited_paths
132 .into_iter()
133 .filter(|(p, _)| {
134 !applied_set.contains(p) && !self.skipped.contains(p) && !errored_paths.contains(p)
135 })
136 .collect();
137 missing.sort_by(|a, b| a.0.cmp(&b.0));
139
140 let mut skipped: Vec<String> = self.skipped.into_iter().collect();
142 skipped.sort();
143
144 ApplyResult {
145 applied: self.applied,
146 skipped,
147 missing,
148 unused,
149 errors: self.errors,
150 }
151 }
152
153 fn apply_tensor<const D: usize, K>(
156 &mut self,
157 target_device: &B::Device,
158 target_shape: Shape,
159 ) -> Option<Tensor<B, D, K>>
160 where
161 K: burn_tensor::TensorKind<B>,
162 K: burn_tensor::BasicOps<B>,
163 {
164 let path = self.current_path();
165 let container_stack_str = self.container_stack.join(".");
166 self.visited_paths.insert(path.clone(), container_stack_str);
167
168 let mut snapshot = self.snapshots.get(&path).cloned();
170
171 if snapshot.is_none()
173 && let Some(ref adapter) = self.adapter
174 && let Some(module_type) = self.current_module_type()
175 {
176 let param_name = self.path_stack.last()?;
178
179 if let Some(alt_name) = adapter.get_alternative_param_name(param_name, module_type) {
180 let mut alt_path_stack = self.path_stack.clone();
182 *alt_path_stack.last_mut().unwrap() = alt_name.clone();
183 let alt_path = alt_path_stack.join(".");
184
185 snapshot = self.snapshots.get(&alt_path).cloned();
187
188 }
191 }
192
193 let mut snapshot = snapshot?;
194
195 if let Some(ref adapter) = self.adapter {
197 let snapshot_with_context = TensorSnapshot::from_closure(
199 snapshot.clone_data_fn(),
200 snapshot.dtype,
201 snapshot.shape.clone(),
202 self.path_stack.clone(),
203 self.container_stack.clone(),
204 snapshot.tensor_id.unwrap_or_default(),
205 );
206
207 snapshot = adapter.adapt(&snapshot_with_context);
209 }
210
211 if !self.should_apply() {
213 self.skipped.insert(path.clone());
214 return None;
215 }
216
217 let data = match snapshot.to_data() {
219 Ok(data) => data,
220 Err(e) => {
221 self.errors.push(ApplyError::LoadError {
222 path: path.clone(),
223 message: format!("Failed to load tensor data: {:?}", e),
224 });
225 return None; }
227 };
228
229 if data.shape != target_shape.dims {
231 self.errors.push(ApplyError::ShapeMismatch {
232 path: path.clone(),
233 expected: target_shape.dims,
234 found: data.shape.clone(),
235 });
236 return None; }
238
239 self.applied.push(path);
240 Some(Tensor::from_data_dtype(data, target_device, snapshot.dtype))
241 }
242}
243
244impl<B: Backend> ModuleMapper<B> for Applier<B> {
245 fn enter_module(&mut self, name: &str, container_type: &str) {
246 self.container_stack.push(container_type.to_string());
248
249 if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
252 self.path_stack.push(name.to_string());
253 }
254 }
255
256 fn exit_module(&mut self, _name: &str, container_type: &str) {
257 self.container_stack.pop();
258
259 if !self.skip_enum_variants || !container_type.starts_with("Enum:") {
261 self.path_stack.pop();
262 }
263 }
264
265 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
266 let param_id = param.id;
267 let target_device = param.lazy_device();
268 let target_shape = param.lazy_shape();
269
270 match self.apply_tensor(&target_device, target_shape) {
272 Some(tensor) => {
273 param.transform_for_load(tensor, param_id)
275 }
276 None => {
277 param
279 }
280 }
281 }
282
283 fn map_int<const D: usize>(
284 &mut self,
285 param: Param<Tensor<B, D, Int>>,
286 ) -> Param<Tensor<B, D, Int>> {
287 let param_id = param.id;
288 let target_device = param.lazy_device();
289 let target_shape = param.lazy_shape();
290
291 match self.apply_tensor(&target_device, target_shape) {
293 Some(tensor) => {
294 param.transform_for_load(tensor, param_id)
296 }
297 None => {
298 param
300 }
301 }
302 }
303
304 fn map_bool<const D: usize>(
305 &mut self,
306 param: Param<Tensor<B, D, Bool>>,
307 ) -> Param<Tensor<B, D, Bool>> {
308 let param_id = param.id;
309 let target_device = param.lazy_device();
310 let target_shape = param.lazy_shape();
311
312 match self.apply_tensor(&target_device, target_shape) {
314 Some(tensor) => {
315 param.transform_for_load(tensor, param_id)
317 }
318 None => {
319 param
321 }
322 }
323 }
324}
325
326#[cfg(all(test, feature = "std", target_has_atomic = "ptr"))]
327mod tests {
328 use super::*;
329 use burn_core::module::{ModuleMapper, Param, ParamId};
330 use burn_tensor::{DType, Tensor, TensorData};
331
332 type TestBackend = burn_ndarray::NdArray;
333
334 #[test]
335 fn root_level_parameters() {
336 let device = Default::default();
337
338 let weight = Param::<Tensor<TestBackend, 2>>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
340 let bias = Param::<Tensor<TestBackend, 1>>::from_data([5.0, 6.0], &device);
341
342 let weight_snapshot = crate::TensorSnapshot::from_data(
344 weight.val().to_data(),
345 vec!["weight".to_string()], vec![], ParamId::new(),
348 );
349
350 let bias_snapshot = crate::TensorSnapshot::from_data(
351 bias.val().to_data(),
352 vec!["bias".to_string()], vec![], ParamId::new(),
355 );
356
357 let mut applier =
359 Applier::<TestBackend>::new(vec![weight_snapshot, bias_snapshot], None, None, false);
360
361 let weight_target = Param::initialized(
363 ParamId::new(),
364 Tensor::<TestBackend, 2>::zeros([2, 2], &device),
365 );
366 let bias_target = Param::initialized(
367 ParamId::new(),
368 Tensor::<TestBackend, 1>::zeros([2], &device),
369 );
370
371 applier.enter_module("weight", "");
374 let weight_loaded = applier.map_float(weight_target);
375 applier.exit_module("weight", "");
376
377 applier.enter_module("bias", "");
379 let bias_loaded = applier.map_float(bias_target);
380 applier.exit_module("bias", "");
381
382 let weight_data = weight_loaded.val().to_data().to_vec::<f32>().unwrap();
384 let bias_data = bias_loaded.val().to_data().to_vec::<f32>().unwrap();
385
386 assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
387 assert_eq!(bias_data, vec![5.0, 6.0]);
388
389 let result = applier.into_result();
391 assert_eq!(result.applied.len(), 2);
392 assert_eq!(result.errors.len(), 0);
393 }
394
395 #[test]
400 fn dtype_preservation_f64() {
401 type TestBackendF64 = burn_ndarray::NdArray<f64>;
403 let device = Default::default();
404
405 let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);
407 assert_eq!(f64_data.dtype, DType::F64, "Test setup: data should be F64");
408
409 let snapshot = crate::TensorSnapshot::from_data(
411 f64_data.clone(),
412 vec!["weight".to_string()],
413 vec![],
414 ParamId::new(),
415 );
416 assert_eq!(
417 snapshot.dtype,
418 DType::F64,
419 "Snapshot should preserve F64 dtype"
420 );
421
422 let mut applier = Applier::<TestBackendF64>::new(vec![snapshot], None, None, false);
424
425 let target = Param::initialized(
427 ParamId::new(),
428 Tensor::<TestBackendF64, 2>::zeros([2, 2], &device),
429 );
430
431 applier.enter_module("weight", "");
433 let loaded = applier.map_float(target);
434 applier.exit_module("weight", "");
435
436 assert_eq!(
439 loaded.val().dtype(),
440 DType::F64,
441 "Loaded tensor should have F64 dtype"
442 );
443
444 let loaded_data = loaded.val().to_data().to_vec::<f64>().unwrap();
446 assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
447
448 let result = applier.into_result();
450 assert_eq!(result.applied.len(), 1);
451 assert_eq!(result.errors.len(), 0);
452 }
453
454 #[test]
456 fn dtype_preservation_f32() {
457 let device = Default::default();
458
459 let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]);
461 assert_eq!(f32_data.dtype, DType::F32);
462
463 let snapshot = crate::TensorSnapshot::from_data(
465 f32_data.clone(),
466 vec!["weight".to_string()],
467 vec![],
468 ParamId::new(),
469 );
470 assert_eq!(snapshot.dtype, DType::F32);
471
472 let mut applier = Applier::<TestBackend>::new(vec![snapshot], None, None, false);
474
475 let target = Param::initialized(
477 ParamId::new(),
478 Tensor::<TestBackend, 2>::zeros([2, 2], &device),
479 );
480
481 applier.enter_module("weight", "");
483 let loaded = applier.map_float(target);
484 applier.exit_module("weight", "");
485
486 assert_eq!(loaded.val().dtype(), DType::F32);
488
489 let loaded_data = loaded.val().to_data().to_vec::<f32>().unwrap();
491 assert_eq!(loaded_data, vec![1.0, 2.0, 3.0, 4.0]);
492 }
493
494 #[test]
501 fn dtype_preservation_f16_snapshot() {
502 use half::f16;
503
504 let f16_values: Vec<f16> = vec![
506 f16::from_f32(1.0),
507 f16::from_f32(2.0),
508 f16::from_f32(3.0),
509 f16::from_f32(4.0),
510 ];
511 let f16_data = TensorData::new(f16_values.clone(), [2, 2]);
512 assert_eq!(
513 f16_data.dtype,
514 DType::F16,
515 "TensorData should have F16 dtype"
516 );
517
518 let snapshot = crate::TensorSnapshot::from_data(
520 f16_data.clone(),
521 vec!["weight".to_string()],
522 vec![],
523 ParamId::new(),
524 );
525
526 assert_eq!(
528 snapshot.dtype,
529 DType::F16,
530 "TensorSnapshot should preserve F16 dtype"
531 );
532
533 let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
535 assert_eq!(
536 retrieved_data.dtype,
537 DType::F16,
538 "Retrieved data should have F16 dtype"
539 );
540
541 let retrieved_values: Vec<f16> = retrieved_data
543 .to_vec()
544 .expect("Should be able to convert to f16 vec");
545 assert_eq!(
546 retrieved_values, f16_values,
547 "F16 values should be preserved"
548 );
549
550 }
556
557 #[test]
559 fn dtype_preservation_bf16_snapshot() {
560 use half::bf16;
561
562 let bf16_values: Vec<bf16> = vec![
564 bf16::from_f32(1.0),
565 bf16::from_f32(2.0),
566 bf16::from_f32(3.0),
567 bf16::from_f32(4.0),
568 ];
569 let bf16_data = TensorData::new(bf16_values.clone(), [2, 2]);
570 assert_eq!(
571 bf16_data.dtype,
572 DType::BF16,
573 "TensorData should have BF16 dtype"
574 );
575
576 let snapshot = crate::TensorSnapshot::from_data(
578 bf16_data.clone(),
579 vec!["weight".to_string()],
580 vec![],
581 ParamId::new(),
582 );
583
584 assert_eq!(
586 snapshot.dtype,
587 DType::BF16,
588 "TensorSnapshot should preserve BF16 dtype"
589 );
590
591 let retrieved_data = snapshot.to_data().expect("Should be able to retrieve data");
593 assert_eq!(
594 retrieved_data.dtype,
595 DType::BF16,
596 "Retrieved data should have BF16 dtype"
597 );
598
599 let retrieved_values: Vec<bf16> = retrieved_data
601 .to_vec()
602 .expect("Should be able to convert to bf16 vec");
603 assert_eq!(
604 retrieved_values, bf16_values,
605 "BF16 values should be preserved"
606 );
607 }
608}