1use crate::TensorSnapshot;
9
10use alloc::boxed::Box;
11use alloc::format;
12use alloc::rc::Rc;
13use alloc::string::String;
14use alloc::string::ToString;
15use alloc::vec;
16
17use burn_tensor::shape;
18use burn_tensor::{DType, TensorData};
19use hashbrown::HashSet;
20
21mod module_names {
25 pub const LINEAR: &str = "Struct:Linear";
27 pub const BATCH_NORM: &str = "Struct:BatchNorm";
28 pub const LAYER_NORM: &str = "Struct:LayerNorm";
29 pub const GROUP_NORM: &str = "Struct:GroupNorm";
30 pub const EMBEDDING: &str = "Struct:Embedding";
31 pub const CONV1D: &str = "Struct:Conv1d";
32 pub const CONV2D: &str = "Struct:Conv2d";
33 pub const CONV3D: &str = "Struct:Conv3d";
34 pub const CONV_TRANSPOSE1D: &str = "Struct:ConvTranspose1d";
35 pub const CONV_TRANSPOSE2D: &str = "Struct:ConvTranspose2d";
36 pub const CONV_TRANSPOSE3D: &str = "Struct:ConvTranspose3d";
37 pub const DEFORM_CONV2D: &str = "Struct:DeformConv2d";
38 pub const INSTANCE_NORM: &str = "Struct:InstanceNorm";
39 pub const RMS_NORM: &str = "Struct:RmsNorm";
40 pub const PRELU: &str = "Struct:PRelu";
41}
42
43pub trait ModuleAdapter: Send + Sync {
45 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
47
48 fn get_alternative_param_name(
61 &self,
62 _param_name: &str,
63 _container_type: &str,
64 ) -> Option<String> {
65 None
66 }
67
68 fn clone_box(&self) -> Box<dyn ModuleAdapter>;
70
71 fn chain<A>(self, next: A) -> ChainAdapter
81 where
82 Self: Sized + 'static,
83 A: ModuleAdapter + 'static,
84 {
85 ChainAdapter::new(self, next)
86 }
87}
88
89impl Clone for Box<dyn ModuleAdapter> {
90 fn clone(&self) -> Self {
91 self.clone_box()
92 }
93}
94
95#[derive(Clone)]
99pub struct ChainAdapter {
100 first: Box<dyn ModuleAdapter>,
101 second: Box<dyn ModuleAdapter>,
102}
103
104impl ChainAdapter {
105 pub fn new<A, B>(first: A, second: B) -> Self
107 where
108 A: ModuleAdapter + 'static,
109 B: ModuleAdapter + 'static,
110 {
111 Self {
112 first: Box::new(first),
113 second: Box::new(second),
114 }
115 }
116}
117
118impl ModuleAdapter for ChainAdapter {
119 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
120 let snapshot = self.first.adapt(snapshot);
121 self.second.adapt(&snapshot)
122 }
123
124 fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
125 if let Some(name) = self
126 .first
127 .get_alternative_param_name(param_name, container_type)
128 {
129 self.second
130 .get_alternative_param_name(&name, container_type)
131 .or(Some(name))
132 } else {
133 self.second
134 .get_alternative_param_name(param_name, container_type)
135 }
136 }
137
138 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
139 Box::new(self.clone())
140 }
141}
142
143#[derive(Debug, Clone, Default)]
145pub struct IdentityAdapter;
146
147impl ModuleAdapter for IdentityAdapter {
148 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
149 snapshot.clone()
150 }
151
152 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
153 Box::new(self.clone())
154 }
155}
156
157fn default_half_precision_modules() -> HashSet<String> {
164 let modules = [
165 module_names::LINEAR,
166 module_names::EMBEDDING,
167 module_names::CONV1D,
168 module_names::CONV2D,
169 module_names::CONV3D,
170 module_names::CONV_TRANSPOSE1D,
171 module_names::CONV_TRANSPOSE2D,
172 module_names::CONV_TRANSPOSE3D,
173 module_names::DEFORM_CONV2D,
174 module_names::LAYER_NORM,
175 module_names::GROUP_NORM,
176 module_names::INSTANCE_NORM,
177 module_names::RMS_NORM,
178 module_names::PRELU,
179 ];
180 modules.iter().map(|s| s.to_string()).collect()
181}
182
183#[derive(Debug, Clone)]
219pub struct HalfPrecisionAdapter {
220 modules: HashSet<String>,
221}
222
223impl HalfPrecisionAdapter {
224 pub fn new() -> Self {
226 Self {
227 modules: default_half_precision_modules(),
228 }
229 }
230
231 pub fn with_module(mut self, module_type: impl Into<String>) -> Self {
237 let name = module_type.into();
238 if name.contains(':') {
239 self.modules.insert(name);
240 } else {
241 self.modules.insert(format!("Struct:{}", name));
242 }
243 self
244 }
245
246 pub fn without_module(mut self, module_type: impl Into<String>) -> Self {
248 let name = module_type.into();
249 let key = if name.contains(':') {
250 name
251 } else {
252 format!("Struct:{}", name)
253 };
254 assert!(
255 self.modules.contains(&key),
256 "without_module called with '{}' which is not in the module set",
257 key
258 );
259 self.modules.remove(&key);
260 self
261 }
262
263 fn should_convert(&self, snapshot: &TensorSnapshot) -> bool {
265 snapshot
266 .module_type()
267 .is_some_and(|mt| self.modules.contains(&mt))
268 }
269}
270
271impl Default for HalfPrecisionAdapter {
272 fn default() -> Self {
273 Self::new()
274 }
275}
276
277impl ModuleAdapter for HalfPrecisionAdapter {
278 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
279 let target_dtype = match snapshot.dtype {
281 DType::F32 => DType::F16,
282 DType::F16 => DType::F32,
283 _ => return snapshot.clone(),
284 };
285
286 if !self.should_convert(snapshot) {
287 return snapshot.clone();
288 }
289
290 let original_data_fn = snapshot.clone_data_fn();
291
292 let cast_data_fn = Rc::new(move || {
293 let data = original_data_fn()?;
294 Ok(data.convert_dtype(target_dtype))
295 });
296
297 TensorSnapshot::from_closure(
298 cast_data_fn,
299 target_dtype,
300 snapshot.shape.clone(),
301 snapshot.path_stack.clone().unwrap_or_default(),
302 snapshot.container_stack.clone().unwrap_or_default(),
303 snapshot.tensor_id.unwrap_or_default(),
304 )
305 }
306
307 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
308 Box::new(self.clone())
309 }
310}
311
312#[derive(Debug, Clone, Default)]
318pub struct PyTorchToBurnAdapter;
319
320impl ModuleAdapter for PyTorchToBurnAdapter {
321 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
322 adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
323 }
324
325 fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
326 if is_normalization_layer(container_type) {
328 burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())
329 } else {
330 None
331 }
332 }
333
334 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
335 Box::new(self.clone())
336 }
337}
338
339#[derive(Debug, Clone, Default)]
345pub struct BurnToPyTorchAdapter;
346
347impl ModuleAdapter for BurnToPyTorchAdapter {
348 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
349 adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
350 }
351
352 fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
353 if is_normalization_layer(container_type) {
355 pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())
356 } else {
357 None
358 }
359 }
360
361 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
362 Box::new(self.clone())
363 }
364}
365
366#[derive(Debug, Clone, Copy)]
368enum PyTorchConversionDirection {
369 PyTorchToBurn,
370 BurnToPyTorch,
371}
372
373fn is_normalization_layer(container_type: &str) -> bool {
375 matches!(
376 container_type,
377 module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM
378 )
379}
380
381fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {
383 match param_name {
384 "weight" => Some("gamma"),
385 "bias" => Some("beta"),
386 _ => None,
387 }
388}
389
390fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {
392 match param_name {
393 "gamma" => Some("weight"),
394 "beta" => Some("bias"),
395 _ => None,
396 }
397}
398
399fn adapt_pytorch_tensor(
401 snapshot: &TensorSnapshot,
402 direction: PyTorchConversionDirection,
403) -> TensorSnapshot {
404 let (path_stack, param_name) = match get_path_and_param(snapshot) {
406 Some(result) => result,
407 None => return snapshot.clone(),
408 };
409
410 let module_type = match snapshot.module_type() {
412 Some(mt) => mt,
413 None => return snapshot.clone(), };
415
416 if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 {
418 return transpose_2d_tensor(snapshot);
419 }
420
421 if is_normalization_layer(&module_type) {
423 let new_name = match direction {
424 PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),
425 PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),
426 };
427
428 if let Some(new_name) = new_name {
429 return rename_parameter(snapshot, path_stack, new_name);
430 }
431 }
432
433 snapshot.clone()
434}
435
436fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
438 let path_stack = snapshot.path_stack.as_ref()?;
439 let param_name = path_stack.last()?.as_str();
440 Some((path_stack.as_slice(), param_name))
441}
442
443fn rename_parameter(
445 snapshot: &TensorSnapshot,
446 path_stack: &[String],
447 new_name: &str,
448) -> TensorSnapshot {
449 let mut new_path = path_stack.to_vec();
450 *new_path.last_mut().unwrap() = new_name.to_string();
451
452 TensorSnapshot::from_closure(
453 snapshot.clone_data_fn(),
454 snapshot.dtype,
455 snapshot.shape.clone(),
456 new_path,
457 snapshot.container_stack.clone().unwrap_or_default(),
458 snapshot.tensor_id.unwrap_or_default(),
459 )
460}
461
462fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
464 if snapshot.shape.len() != 2 {
465 return snapshot.clone();
466 }
467
468 let original_data_fn = snapshot.clone_data_fn();
469 let dtype = snapshot.dtype;
470 let transposed_shape = shape![snapshot.shape[1], snapshot.shape[0]];
471
472 let transposed_data_fn = Rc::new(move || {
474 let data = original_data_fn()?;
475 Ok(transpose_tensor_data(data))
476 });
477
478 TensorSnapshot::from_closure(
479 transposed_data_fn,
480 dtype,
481 transposed_shape,
482 snapshot.path_stack.clone().unwrap_or_default(),
483 snapshot.container_stack.clone().unwrap_or_default(),
484 snapshot.tensor_id.unwrap_or_default(),
485 )
486}
487
488fn transpose_tensor_data(data: TensorData) -> TensorData {
490 let shape = &data.shape;
491 let rows = shape[0];
492 let cols = shape[1];
493 let transposed_shape = vec![cols, rows];
494
495 let bytes = data.as_bytes();
497 let element_size = data.dtype.size();
498
499 let mut transposed_bytes = vec![0u8; bytes.len()];
501
502 for i in 0..rows {
504 for j in 0..cols {
505 let src_idx = (i * cols + j) * element_size;
506 let dst_idx = (j * rows + i) * element_size;
507
508 transposed_bytes[dst_idx..dst_idx + element_size]
510 .copy_from_slice(&bytes[src_idx..src_idx + element_size]);
511 }
512 }
513
514 TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use alloc::rc::Rc;
522 use alloc::sync::Arc;
523 use burn_tensor::{DType, Shape, TensorData};
524 use core::sync::atomic::{AtomicUsize, Ordering};
525
526 #[test]
527 fn test_module_names_match_burn_nn() {
528 #[allow(unused_imports)]
530 use burn_nn::{
531 BatchNorm, Embedding, GroupNorm, InstanceNorm, LayerNorm, Linear, PRelu, RmsNorm,
532 conv::{
533 Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d,
534 DeformConv2d,
535 },
536 };
537
538 assert_eq!(module_names::LINEAR, "Struct:Linear");
539 assert_eq!(module_names::BATCH_NORM, "Struct:BatchNorm");
540 assert_eq!(module_names::LAYER_NORM, "Struct:LayerNorm");
541 assert_eq!(module_names::GROUP_NORM, "Struct:GroupNorm");
542 assert_eq!(module_names::EMBEDDING, "Struct:Embedding");
543 assert_eq!(module_names::CONV1D, "Struct:Conv1d");
544 assert_eq!(module_names::CONV2D, "Struct:Conv2d");
545 assert_eq!(module_names::CONV3D, "Struct:Conv3d");
546 assert_eq!(module_names::CONV_TRANSPOSE1D, "Struct:ConvTranspose1d");
547 assert_eq!(module_names::CONV_TRANSPOSE2D, "Struct:ConvTranspose2d");
548 assert_eq!(module_names::CONV_TRANSPOSE3D, "Struct:ConvTranspose3d");
549 assert_eq!(module_names::DEFORM_CONV2D, "Struct:DeformConv2d");
550 assert_eq!(module_names::INSTANCE_NORM, "Struct:InstanceNorm");
551 assert_eq!(module_names::RMS_NORM, "Struct:RmsNorm");
552 assert_eq!(module_names::PRELU, "Struct:PRelu");
553 }
554
555 fn create_test_snapshot(path: &str, shape: Shape, container_type: &str) -> TensorSnapshot {
556 let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
557 let values = vec![1.0f32; shape.iter().product()];
558 let data = TensorData::new(values, shape.clone());
559
560 TensorSnapshot::from_closure(
561 Rc::new(move || Ok(data.clone())),
562 DType::F32,
563 shape,
564 path_parts,
565 vec![container_type.to_string()],
566 burn_core::module::ParamId::new(),
567 )
568 }
569
570 #[test]
571 fn test_pytorch_to_burn_linear_weight() {
572 let adapter = PyTorchToBurnAdapter;
573
574 let snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR);
576 let adapted = adapter.adapt(&snapshot);
577 assert_eq!(adapted.shape, shape![5, 10]);
578
579 let snapshot = create_test_snapshot("fc.bias", shape![10], module_names::LINEAR);
581 let adapted = adapter.adapt(&snapshot);
582 assert_eq!(adapted.shape, shape![10]);
583 }
584
585 #[test]
586 fn test_pytorch_to_burn_norm_params() {
587 let adapter = PyTorchToBurnAdapter;
588
589 let snapshot = create_test_snapshot("norm.weight", shape![10], module_names::BATCH_NORM);
591 let adapted = adapter.adapt(&snapshot);
592 assert_eq!(adapted.full_path(), "norm.gamma");
593
594 let snapshot = create_test_snapshot("norm.bias", shape![10], module_names::BATCH_NORM);
596 let adapted = adapter.adapt(&snapshot);
597 assert_eq!(adapted.full_path(), "norm.beta");
598 }
599
600 #[test]
601 fn test_burn_to_pytorch_linear_weight() {
602 let adapter = BurnToPyTorchAdapter;
603
604 let snapshot = create_test_snapshot("fc.weight", shape![5, 10], module_names::LINEAR);
606 let adapted = adapter.adapt(&snapshot);
607 assert_eq!(adapted.shape, shape![10, 5]);
608 }
609
610 #[test]
611 fn test_burn_to_pytorch_norm_params() {
612 let adapter = BurnToPyTorchAdapter;
613
614 let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::BATCH_NORM);
616 let adapted = adapter.adapt(&snapshot);
617 assert_eq!(adapted.full_path(), "norm.weight");
618
619 let snapshot = create_test_snapshot("norm.beta", shape![10], module_names::BATCH_NORM);
621 let adapted = adapter.adapt(&snapshot);
622 assert_eq!(adapted.full_path(), "norm.bias");
623 }
624
625 #[test]
626 fn test_transpose_different_dtypes() {
627 let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);
631 let transposed = transpose_tensor_data(f32_data);
632 assert_eq!(transposed.shape, shape![3, 2]);
633 let values = transposed.to_vec::<f32>().unwrap();
634 assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
635
636 let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], [2, 3]);
638 let transposed = transpose_tensor_data(i32_data);
639 assert_eq!(transposed.shape, shape![3, 2]);
640 let values = transposed.to_vec::<i32>().unwrap();
641 assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
642
643 let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);
645 let transposed = transpose_tensor_data(f64_data);
646 assert_eq!(transposed.shape, shape![2, 2]);
647 let values = transposed.to_vec::<f64>().unwrap();
648 assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
649 }
650
651 #[test]
652 fn test_no_container_info() {
653 let adapter = PyTorchToBurnAdapter;
654
655 let mut snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR);
657 snapshot.container_stack = None;
658
659 let adapted = adapter.adapt(&snapshot);
661 assert_eq!(adapted.shape, shape![10, 5]); let mut snapshot2 = create_test_snapshot("other.weight", shape![10, 5], "Struct:Other");
665 snapshot2.container_stack = None;
666 let adapted2 = adapter.adapt(&snapshot2);
667 assert_eq!(adapted2.shape, shape![10, 5]); }
669
670 #[derive(Clone)]
671 struct RenameParamAdapter {
672 from: &'static str,
673 to: &'static str,
674 called: Arc<AtomicUsize>,
675 }
676
677 impl ModuleAdapter for RenameParamAdapter {
678 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
679 self.called.fetch_add(1, Ordering::Relaxed);
680
681 let path_stack = match snapshot.path_stack.as_ref() {
682 Some(stack) => stack,
683 None => return snapshot.clone(),
684 };
685 let param = match path_stack.last() {
686 Some(p) => p.as_str(),
687 None => return snapshot.clone(),
688 };
689 if param != self.from {
690 return snapshot.clone();
691 }
692
693 let mut new_path = path_stack.to_vec();
694 *new_path.last_mut().unwrap() = self.to.to_string();
695
696 TensorSnapshot::from_closure(
697 snapshot.clone_data_fn(),
698 snapshot.dtype,
699 snapshot.shape.clone(),
700 new_path,
701 snapshot.container_stack.clone().unwrap_or_default(),
702 snapshot.tensor_id.unwrap_or_default(),
703 )
704 }
705
706 fn get_alternative_param_name(
707 &self,
708 _param_name: &str,
709 _container_type: &str,
710 ) -> Option<String> {
711 None
712 }
713
714 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
715 Box::new(self.clone())
716 }
717 }
718
719 #[derive(Clone)]
720 struct AltNameAdapter {
721 from: &'static str,
722 to: &'static str,
723 called: Arc<AtomicUsize>,
724 }
725
726 impl ModuleAdapter for AltNameAdapter {
727 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
728 TensorSnapshot::from_closure(
729 snapshot.clone_data_fn(),
730 snapshot.dtype,
731 snapshot.shape.clone(),
732 snapshot.path_stack.clone().unwrap_or_default(),
733 snapshot.container_stack.clone().unwrap_or_default(),
734 snapshot.tensor_id.unwrap_or_default(),
735 )
736 }
737
738 fn get_alternative_param_name(
739 &self,
740 param_name: &str,
741 _container_type: &str,
742 ) -> Option<String> {
743 self.called.fetch_add(1, Ordering::Relaxed);
744 if param_name == self.from {
745 Some(self.to.to_string())
746 } else {
747 None
748 }
749 }
750
751 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
752 Box::new(self.clone())
753 }
754 }
755
756 #[test]
757 fn test_chain_adapter_pipes_adapt() {
758 let called1 = Arc::new(AtomicUsize::new(0));
759 let called2 = Arc::new(AtomicUsize::new(0));
760
761 let a = RenameParamAdapter {
762 from: "weight",
763 to: "a",
764 called: called1.clone(),
765 };
766 let b = RenameParamAdapter {
767 from: "a",
768 to: "b",
769 called: called2.clone(),
770 };
771
772 let chain = a.chain(b);
773 let snapshot = create_test_snapshot("fc.weight", shape![2, 2], module_names::LINEAR);
774 let adapted = chain.adapt(&snapshot);
775
776 assert_eq!(adapted.full_path(), "fc.b");
777 assert_eq!(called1.load(Ordering::Relaxed), 1);
778 assert_eq!(called2.load(Ordering::Relaxed), 1);
779 }
780
781 #[test]
782 fn test_chain_adapter_alternative_name_pipes_and_fallbacks() {
783 let called1 = Arc::new(AtomicUsize::new(0));
784 let called2 = Arc::new(AtomicUsize::new(0));
785
786 let a = AltNameAdapter {
787 from: "gamma",
788 to: "weight",
789 called: called1.clone(),
790 };
791 let b = AltNameAdapter {
792 from: "weight",
793 to: "scale",
794 called: called2.clone(),
795 };
796
797 let chain = a.chain(b);
798 let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
799 assert_eq!(alt.as_deref(), Some("scale"));
800 assert_eq!(called1.load(Ordering::Relaxed), 1);
801 assert_eq!(called2.load(Ordering::Relaxed), 1);
802
803 let called1 = Arc::new(AtomicUsize::new(0));
806 let called2 = Arc::new(AtomicUsize::new(0));
807 let a = AltNameAdapter {
808 from: "gamma",
809 to: "weight",
810 called: called1.clone(),
811 };
812 let b = AltNameAdapter {
813 from: "something-else",
814 to: "unused",
815 called: called2.clone(),
816 };
817 let chain = a.chain(b);
818 let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
819 assert_eq!(alt.as_deref(), Some("weight"));
820 assert_eq!(called1.load(Ordering::Relaxed), 1);
821 assert_eq!(called2.load(Ordering::Relaxed), 1);
822
823 let called1 = Arc::new(AtomicUsize::new(0));
825 let called2 = Arc::new(AtomicUsize::new(0));
826 let a = AltNameAdapter {
827 from: "something-else",
828 to: "unused",
829 called: called1.clone(),
830 };
831 let b = AltNameAdapter {
832 from: "gamma",
833 to: "weight",
834 called: called2.clone(),
835 };
836 let chain = a.chain(b);
837 let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
838 assert_eq!(alt.as_deref(), Some("weight"));
839 assert_eq!(called1.load(Ordering::Relaxed), 1);
840 assert_eq!(called2.load(Ordering::Relaxed), 1);
841
842 let boxed = chain.clone_box();
844 let alt = boxed.get_alternative_param_name("gamma", module_names::LAYER_NORM);
845 assert_eq!(alt.as_deref(), Some("weight"));
846 }
847
848 #[test]
849 fn test_half_precision_f32_to_f16() {
850 let adapter = HalfPrecisionAdapter::new();
851 let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
852
853 let adapted = adapter.adapt(&snapshot);
854 assert_eq!(adapted.dtype, DType::F16);
855 assert_eq!(adapted.shape, shape![2, 3]);
856
857 let data = adapted.to_data().unwrap();
858 assert_eq!(data.dtype, DType::F16);
859 }
860
861 #[test]
862 fn test_half_precision_f16_to_f32() {
863 let adapter = HalfPrecisionAdapter::new();
864
865 let values = vec![1.0f32; 6];
867 let data = TensorData::new(values, shape![2, 3]).convert_dtype(DType::F16);
868 let path_parts = vec!["fc".to_string(), "weight".to_string()];
869 let snapshot = TensorSnapshot::from_closure(
870 Rc::new(move || Ok(data.clone())),
871 DType::F16,
872 shape![2, 3],
873 path_parts,
874 vec![module_names::LINEAR.to_string()],
875 burn_core::module::ParamId::new(),
876 );
877
878 let adapted = adapter.adapt(&snapshot);
879 assert_eq!(adapted.dtype, DType::F32);
880 }
881
882 #[test]
883 fn test_half_precision_skips_batch_norm() {
884 let adapter = HalfPrecisionAdapter::new();
885
886 let snapshot = create_test_snapshot("norm.weight", shape![10], module_names::BATCH_NORM);
888 let adapted = adapter.adapt(&snapshot);
889 assert_eq!(adapted.dtype, DType::F32); }
891
892 #[test]
893 fn test_half_precision_converts_default_modules() {
894 let adapter = HalfPrecisionAdapter::new();
895
896 let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
898 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
899
900 let snapshot = create_test_snapshot("emb.weight", shape![100, 64], module_names::EMBEDDING);
902 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
903
904 let snapshot =
906 create_test_snapshot("conv.weight", shape![3, 3, 3, 3], module_names::CONV2D);
907 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
908
909 let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM);
911 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
912
913 let snapshot = create_test_snapshot("gn.gamma", shape![10], module_names::GROUP_NORM);
915 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
916
917 let snapshot = create_test_snapshot("rms.weight", shape![10], module_names::RMS_NORM);
919 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
920 }
921
922 #[test]
923 fn test_half_precision_without_module() {
924 let adapter = HalfPrecisionAdapter::new().without_module("LayerNorm");
925
926 let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM);
928 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32);
929
930 let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
932 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
933 }
934
935 #[test]
936 fn test_half_precision_with_module() {
937 let adapter = HalfPrecisionAdapter::new().with_module("CustomLayer");
938
939 let snapshot = create_test_snapshot("custom.weight", shape![5], "Struct:CustomLayer");
941 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
942 }
943
944 #[test]
945 fn test_half_precision_with_qualified_name() {
946 let adapter = HalfPrecisionAdapter::new().with_module("Struct:CustomLayer");
947
948 let snapshot = create_test_snapshot("custom.weight", shape![5], "Struct:CustomLayer");
949 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
950 }
951
952 #[test]
953 fn test_half_precision_chain() {
954 let adapter = PyTorchToBurnAdapter.chain(HalfPrecisionAdapter::new());
955
956 let snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR);
957 let adapted = adapter.adapt(&snapshot);
958
959 assert_eq!(adapted.shape, shape![5, 10]);
961 assert_eq!(adapted.dtype, DType::F16);
962 }
963
964 #[test]
965 fn test_half_precision_skips_no_container() {
966 let adapter = HalfPrecisionAdapter::new();
967 let mut snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
968 snapshot.container_stack = None;
969
970 let adapted = adapter.adapt(&snapshot);
972 assert_eq!(adapted.dtype, DType::F32);
973 }
974
975 #[test]
976 fn test_half_precision_skips_non_float() {
977 use burn_tensor::quantization::QuantScheme;
978
979 let adapter = HalfPrecisionAdapter::new();
980
981 let qfloat_dtype = DType::QFloat(QuantScheme::default());
983 let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
984 let qfloat_snapshot = TensorSnapshot::from_closure(
985 snapshot.clone_data_fn(),
986 qfloat_dtype,
987 snapshot.shape.clone(),
988 snapshot.path_stack.clone().unwrap_or_default(),
989 snapshot.container_stack.clone().unwrap_or_default(),
990 snapshot.tensor_id.unwrap_or_default(),
991 );
992 let adapted = adapter.adapt(&qfloat_snapshot);
993 assert_eq!(adapted.dtype, qfloat_dtype);
994 }
995
996 #[test]
997 fn test_half_precision_default_module_count() {
998 let adapter = HalfPrecisionAdapter::new();
999 assert_eq!(adapter.modules.len(), 14);
1002 }
1003
1004 #[test]
1005 fn test_half_precision_without_module_qualified() {
1006 let adapter = HalfPrecisionAdapter::new().without_module("Struct:LayerNorm");
1007
1008 let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM);
1009 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32);
1010 }
1011
1012 #[test]
1013 fn test_half_precision_with_module_batch_norm_opt_in() {
1014 let adapter = HalfPrecisionAdapter::new().with_module("BatchNorm");
1015
1016 let snapshot = create_test_snapshot("bn.weight", shape![10], module_names::BATCH_NORM);
1017 assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
1018 }
1019}