Skip to main content

burn_store/
adapter.rs

1//! Module adapters for transforming tensor snapshots during save/load
2//!
3//! This module provides adapters for:
4//! - PyTorch/Burn format conversion (weight transposition, parameter renaming)
5//! - Mixed-precision storage (F32/F16 dtype casting via [`HalfPrecisionAdapter`])
6//! - Adapter chaining for composing multiple transformations
7
8use crate::TensorSnapshot;
9
10use alloc::boxed::Box;
11use alloc::format;
12use alloc::rc::Rc;
13use alloc::string::String;
14use alloc::string::ToString;
15use alloc::vec;
16
17use burn_tensor::shape;
18use burn_tensor::{DType, TensorData};
19use hashbrown::HashSet;
20
21// Module type names as they appear in the container_type field
22// These come from the Module derive macro which uses stringify! on the struct name
23// Format: "Struct:TypeName" for user-defined structs
24mod module_names {
25    // The actual string constants that match what the Module derive macro produces
26    pub const LINEAR: &str = "Struct:Linear";
27    pub const BATCH_NORM: &str = "Struct:BatchNorm";
28    pub const LAYER_NORM: &str = "Struct:LayerNorm";
29    pub const GROUP_NORM: &str = "Struct:GroupNorm";
30    pub const EMBEDDING: &str = "Struct:Embedding";
31    pub const CONV1D: &str = "Struct:Conv1d";
32    pub const CONV2D: &str = "Struct:Conv2d";
33    pub const CONV3D: &str = "Struct:Conv3d";
34    pub const CONV_TRANSPOSE1D: &str = "Struct:ConvTranspose1d";
35    pub const CONV_TRANSPOSE2D: &str = "Struct:ConvTranspose2d";
36    pub const CONV_TRANSPOSE3D: &str = "Struct:ConvTranspose3d";
37    pub const DEFORM_CONV2D: &str = "Struct:DeformConv2d";
38    pub const INSTANCE_NORM: &str = "Struct:InstanceNorm";
39    pub const RMS_NORM: &str = "Struct:RmsNorm";
40    pub const PRELU: &str = "Struct:PRelu";
41}
42
43/// Trait for adapting tensor snapshots between different module formats
44pub trait ModuleAdapter: Send + Sync {
45    /// Adapt a tensor snapshot based on its container type and parameter name
46    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
47
48    /// Get alternative parameter name to try during matching
49    ///
50    /// When looking for a parameter in a module, this method provides an alternative
51    /// name to try if the direct name doesn't match. This enables matching parameters
52    /// with different naming conventions (e.g., PyTorch's "weight" vs Burn's "gamma").
53    ///
54    /// # Arguments
55    /// * `param_name` - The parameter name we're looking for
56    /// * `container_type` - The type of container module (e.g., "BatchNorm")
57    ///
58    /// # Returns
59    /// Alternative parameter name to try, or None if no alternative exists
60    fn get_alternative_param_name(
61        &self,
62        _param_name: &str,
63        _container_type: &str,
64    ) -> Option<String> {
65        None
66    }
67
68    /// Clone the adapter into a boxed trait object
69    fn clone_box(&self) -> Box<dyn ModuleAdapter>;
70
71    /// Chain adapters together, applying `self` first and then `next`.
72    ///
73    /// This is useful when multiple transformations are required when importing model weights
74    /// (e.g. PyTorch -> Burn layout conversion, then dtype casting, then custom remapping).
75    ///
76    /// The semantics follow a simple pipeline:
77    /// - `adapt`: `next.adapt(&self.adapt(snapshot))`
78    /// - `get_alternative_param_name`: try `self` first; if it returns an alternative name,
79    ///   try `next` with that name, otherwise return the first alternative name.
80    fn chain<A>(self, next: A) -> ChainAdapter
81    where
82        Self: Sized + 'static,
83        A: ModuleAdapter + 'static,
84    {
85        ChainAdapter::new(self, next)
86    }
87}
88
89impl Clone for Box<dyn ModuleAdapter> {
90    fn clone(&self) -> Self {
91        self.clone_box()
92    }
93}
94
95/// Adapter that applies two adapters in sequence.
96///
97/// This allows composing smaller adapters instead of creating one large monolithic adapter.
98#[derive(Clone)]
99pub struct ChainAdapter {
100    first: Box<dyn ModuleAdapter>,
101    second: Box<dyn ModuleAdapter>,
102}
103
104impl ChainAdapter {
105    /// Create a new adapter chain.
106    pub fn new<A, B>(first: A, second: B) -> Self
107    where
108        A: ModuleAdapter + 'static,
109        B: ModuleAdapter + 'static,
110    {
111        Self {
112            first: Box::new(first),
113            second: Box::new(second),
114        }
115    }
116}
117
118impl ModuleAdapter for ChainAdapter {
119    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
120        let snapshot = self.first.adapt(snapshot);
121        self.second.adapt(&snapshot)
122    }
123
124    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
125        if let Some(name) = self
126            .first
127            .get_alternative_param_name(param_name, container_type)
128        {
129            self.second
130                .get_alternative_param_name(&name, container_type)
131                .or(Some(name))
132        } else {
133            self.second
134                .get_alternative_param_name(param_name, container_type)
135        }
136    }
137
138    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
139        Box::new(self.clone())
140    }
141}
142
143/// Identity adapter that passes tensors through unchanged
144#[derive(Debug, Clone, Default)]
145pub struct IdentityAdapter;
146
147impl ModuleAdapter for IdentityAdapter {
148    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
149        snapshot.clone()
150    }
151
152    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
153        Box::new(self.clone())
154    }
155}
156
157/// Returns the default set of module types that `HalfPrecisionAdapter` converts.
158///
159/// Includes: Linear, Embedding, all Conv variants, LayerNorm, GroupNorm,
160/// InstanceNorm, RmsNorm, PRelu.
161///
162/// Excludes BatchNorm by default because `running_var` underflows in F16.
163fn default_half_precision_modules() -> HashSet<String> {
164    let modules = [
165        module_names::LINEAR,
166        module_names::EMBEDDING,
167        module_names::CONV1D,
168        module_names::CONV2D,
169        module_names::CONV3D,
170        module_names::CONV_TRANSPOSE1D,
171        module_names::CONV_TRANSPOSE2D,
172        module_names::CONV_TRANSPOSE3D,
173        module_names::DEFORM_CONV2D,
174        module_names::LAYER_NORM,
175        module_names::GROUP_NORM,
176        module_names::INSTANCE_NORM,
177        module_names::RMS_NORM,
178        module_names::PRELU,
179    ];
180    modules.iter().map(|s| s.to_string()).collect()
181}
182
183/// Adapter for mixed-precision (F32/F16) model storage.
184///
185/// Auto-detects conversion direction from the snapshot's dtype:
186/// - F32 source -> cast to F16 (typical for saving)
187/// - F16 source -> cast to F32 (typical for loading)
188/// - Other dtypes -> passed through unchanged
189///
190/// The same instance works for both `with_to_adapter` (save) and `with_from_adapter` (load).
191///
192/// By default, converts weights in: Linear, Embedding, Conv*, LayerNorm, GroupNorm,
193/// InstanceNorm, RmsNorm, PRelu. BatchNorm is excluded because `running_var` underflows in F16.
194///
195/// # Examples
196///
197/// Default usage (same adapter for save and load):
198/// ```rust
199/// # use burn_store::HalfPrecisionAdapter;
200/// let adapter = HalfPrecisionAdapter::new();
201/// // store.with_to_adapter(adapter.clone());  // F32 -> F16 on save
202/// // store.with_from_adapter(adapter);        // F16 -> F32 on load
203/// ```
204///
205/// Exclude a module type:
206/// ```rust
207/// # use burn_store::HalfPrecisionAdapter;
208/// let adapter = HalfPrecisionAdapter::new()
209///     .without_module("LayerNorm");
210/// ```
211///
212/// Add a custom module type:
213/// ```rust
214/// # use burn_store::HalfPrecisionAdapter;
215/// let adapter = HalfPrecisionAdapter::new()
216///     .with_module("CustomLayer");
217/// ```
218#[derive(Debug, Clone)]
219pub struct HalfPrecisionAdapter {
220    modules: HashSet<String>,
221}
222
223impl HalfPrecisionAdapter {
224    /// Create a new adapter with the default set of modules.
225    pub fn new() -> Self {
226        Self {
227            modules: default_half_precision_modules(),
228        }
229    }
230
231    /// Add a module type to convert. Accepts both short (`"MyLayer"`) and
232    /// qualified (`"Struct:MyLayer"`) forms.
233    ///
234    /// Note: short names are mapped to `"Struct:Name"`. If you have an Enum-based
235    /// module, use the qualified form `"Enum:MyModule"` explicitly.
236    pub fn with_module(mut self, module_type: impl Into<String>) -> Self {
237        let name = module_type.into();
238        if name.contains(':') {
239            self.modules.insert(name);
240        } else {
241            self.modules.insert(format!("Struct:{}", name));
242        }
243        self
244    }
245
246    /// Remove a module type from conversion. Accepts both short and qualified forms.
247    pub fn without_module(mut self, module_type: impl Into<String>) -> Self {
248        let name = module_type.into();
249        let key = if name.contains(':') {
250            name
251        } else {
252            format!("Struct:{}", name)
253        };
254        assert!(
255            self.modules.contains(&key),
256            "without_module called with '{}' which is not in the module set",
257            key
258        );
259        self.modules.remove(&key);
260        self
261    }
262
263    /// Check whether the tensor belongs to a module that should be converted.
264    fn should_convert(&self, snapshot: &TensorSnapshot) -> bool {
265        snapshot
266            .module_type()
267            .is_some_and(|mt| self.modules.contains(&mt))
268    }
269}
270
271impl Default for HalfPrecisionAdapter {
272    fn default() -> Self {
273        Self::new()
274    }
275}
276
277impl ModuleAdapter for HalfPrecisionAdapter {
278    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
279        // Determine target dtype from source: F32 -> F16, F16 -> F32, anything else -> skip
280        let target_dtype = match snapshot.dtype {
281            DType::F32 => DType::F16,
282            DType::F16 => DType::F32,
283            _ => return snapshot.clone(),
284        };
285
286        if !self.should_convert(snapshot) {
287            return snapshot.clone();
288        }
289
290        let original_data_fn = snapshot.clone_data_fn();
291
292        let cast_data_fn = Rc::new(move || {
293            let data = original_data_fn()?;
294            Ok(data.convert_dtype(target_dtype))
295        });
296
297        TensorSnapshot::from_closure(
298            cast_data_fn,
299            target_dtype,
300            snapshot.shape.clone(),
301            snapshot.path_stack.clone().unwrap_or_default(),
302            snapshot.container_stack.clone().unwrap_or_default(),
303            snapshot.tensor_id.unwrap_or_default(),
304        )
305    }
306
307    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
308        Box::new(self.clone())
309    }
310}
311
312/// Adapter for converting from PyTorch format to Burn format
313///
314/// Handles:
315/// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out])
316/// - Normalization parameter renaming (weight → gamma, bias → beta)
317#[derive(Debug, Clone, Default)]
318pub struct PyTorchToBurnAdapter;
319
320impl ModuleAdapter for PyTorchToBurnAdapter {
321    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
322        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
323    }
324
325    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
326        // For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias)
327        if is_normalization_layer(container_type) {
328            burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())
329        } else {
330            None
331        }
332    }
333
334    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
335        Box::new(self.clone())
336    }
337}
338
339/// Adapter for converting from Burn format to PyTorch format
340///
341/// Handles:
342/// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in])
343/// - Normalization parameter renaming (gamma → weight, beta → bias)
344#[derive(Debug, Clone, Default)]
345pub struct BurnToPyTorchAdapter;
346
347impl ModuleAdapter for BurnToPyTorchAdapter {
348    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
349        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
350    }
351
352    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
353        // For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta)
354        if is_normalization_layer(container_type) {
355            pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())
356        } else {
357            None
358        }
359    }
360
361    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
362        Box::new(self.clone())
363    }
364}
365
366/// Direction of PyTorch conversion for parameter naming
367#[derive(Debug, Clone, Copy)]
368enum PyTorchConversionDirection {
369    PyTorchToBurn,
370    BurnToPyTorch,
371}
372
373/// Check if container type is a normalization layer
374fn is_normalization_layer(container_type: &str) -> bool {
375    matches!(
376        container_type,
377        module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM
378    )
379}
380
381/// Map PyTorch normalization parameter name to Burn
382fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {
383    match param_name {
384        "weight" => Some("gamma"),
385        "bias" => Some("beta"),
386        _ => None,
387    }
388}
389
390/// Map Burn normalization parameter name to PyTorch
391fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {
392    match param_name {
393        "gamma" => Some("weight"),
394        "beta" => Some("bias"),
395        _ => None,
396    }
397}
398
399/// Core tensor adaptation logic for PyTorch format conversions
400fn adapt_pytorch_tensor(
401    snapshot: &TensorSnapshot,
402    direction: PyTorchConversionDirection,
403) -> TensorSnapshot {
404    // Extract path and parameter name
405    let (path_stack, param_name) = match get_path_and_param(snapshot) {
406        Some(result) => result,
407        None => return snapshot.clone(),
408    };
409
410    // Get module type for matching (ignores Vec/Array wrappers)
411    let module_type = match snapshot.module_type() {
412        Some(mt) => mt,
413        None => return snapshot.clone(), // No user-defined module found
414    };
415
416    // Linear: transpose weight (bidirectional - same operation both ways)
417    if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 {
418        return transpose_2d_tensor(snapshot);
419    }
420
421    // Normalization layers: rename parameters based on direction
422    if is_normalization_layer(&module_type) {
423        let new_name = match direction {
424            PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),
425            PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),
426        };
427
428        if let Some(new_name) = new_name {
429            return rename_parameter(snapshot, path_stack, new_name);
430        }
431    }
432
433    snapshot.clone()
434}
435
436/// Extract path stack and parameter name from snapshot
437fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
438    let path_stack = snapshot.path_stack.as_ref()?;
439    let param_name = path_stack.last()?.as_str();
440    Some((path_stack.as_slice(), param_name))
441}
442
443/// Rename a parameter in the snapshot
444fn rename_parameter(
445    snapshot: &TensorSnapshot,
446    path_stack: &[String],
447    new_name: &str,
448) -> TensorSnapshot {
449    let mut new_path = path_stack.to_vec();
450    *new_path.last_mut().unwrap() = new_name.to_string();
451
452    TensorSnapshot::from_closure(
453        snapshot.clone_data_fn(),
454        snapshot.dtype,
455        snapshot.shape.clone(),
456        new_path,
457        snapshot.container_stack.clone().unwrap_or_default(),
458        snapshot.tensor_id.unwrap_or_default(),
459    )
460}
461
462/// Transpose a 2D tensor
463fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
464    if snapshot.shape.len() != 2 {
465        return snapshot.clone();
466    }
467
468    let original_data_fn = snapshot.clone_data_fn();
469    let dtype = snapshot.dtype;
470    let transposed_shape = shape![snapshot.shape[1], snapshot.shape[0]];
471
472    // Create a lazy closure that transposes when called
473    let transposed_data_fn = Rc::new(move || {
474        let data = original_data_fn()?;
475        Ok(transpose_tensor_data(data))
476    });
477
478    TensorSnapshot::from_closure(
479        transposed_data_fn,
480        dtype,
481        transposed_shape,
482        snapshot.path_stack.clone().unwrap_or_default(),
483        snapshot.container_stack.clone().unwrap_or_default(),
484        snapshot.tensor_id.unwrap_or_default(),
485    )
486}
487
488/// Transpose tensor data (assumes 2D shape is already validated)
489fn transpose_tensor_data(data: TensorData) -> TensorData {
490    let shape = &data.shape;
491    let rows = shape[0];
492    let cols = shape[1];
493    let transposed_shape = vec![cols, rows];
494
495    // Get the raw bytes and element size
496    let bytes = data.as_bytes();
497    let element_size = data.dtype.size();
498
499    // Create a new buffer for transposed data
500    let mut transposed_bytes = vec![0u8; bytes.len()];
501
502    // Transpose at the byte level - works for any data type
503    for i in 0..rows {
504        for j in 0..cols {
505            let src_idx = (i * cols + j) * element_size;
506            let dst_idx = (j * rows + i) * element_size;
507
508            // Copy the bytes for this element
509            transposed_bytes[dst_idx..dst_idx + element_size]
510                .copy_from_slice(&bytes[src_idx..src_idx + element_size]);
511        }
512    }
513
514    // Create new TensorData from transposed bytes
515    TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use alloc::rc::Rc;
522    use alloc::sync::Arc;
523    use burn_tensor::{DType, Shape, TensorData};
524    use core::sync::atomic::{AtomicUsize, Ordering};
525
526    #[test]
527    fn test_module_names_match_burn_nn() {
528        // If these types are renamed or moved in `burn-nn`, this test will fail to compile.
529        #[allow(unused_imports)]
530        use burn_nn::{
531            BatchNorm, Embedding, GroupNorm, InstanceNorm, LayerNorm, Linear, PRelu, RmsNorm,
532            conv::{
533                Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d,
534                DeformConv2d,
535            },
536        };
537
538        assert_eq!(module_names::LINEAR, "Struct:Linear");
539        assert_eq!(module_names::BATCH_NORM, "Struct:BatchNorm");
540        assert_eq!(module_names::LAYER_NORM, "Struct:LayerNorm");
541        assert_eq!(module_names::GROUP_NORM, "Struct:GroupNorm");
542        assert_eq!(module_names::EMBEDDING, "Struct:Embedding");
543        assert_eq!(module_names::CONV1D, "Struct:Conv1d");
544        assert_eq!(module_names::CONV2D, "Struct:Conv2d");
545        assert_eq!(module_names::CONV3D, "Struct:Conv3d");
546        assert_eq!(module_names::CONV_TRANSPOSE1D, "Struct:ConvTranspose1d");
547        assert_eq!(module_names::CONV_TRANSPOSE2D, "Struct:ConvTranspose2d");
548        assert_eq!(module_names::CONV_TRANSPOSE3D, "Struct:ConvTranspose3d");
549        assert_eq!(module_names::DEFORM_CONV2D, "Struct:DeformConv2d");
550        assert_eq!(module_names::INSTANCE_NORM, "Struct:InstanceNorm");
551        assert_eq!(module_names::RMS_NORM, "Struct:RmsNorm");
552        assert_eq!(module_names::PRELU, "Struct:PRelu");
553    }
554
555    fn create_test_snapshot(path: &str, shape: Shape, container_type: &str) -> TensorSnapshot {
556        let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
557        let values = vec![1.0f32; shape.iter().product()];
558        let data = TensorData::new(values, shape.clone());
559
560        TensorSnapshot::from_closure(
561            Rc::new(move || Ok(data.clone())),
562            DType::F32,
563            shape,
564            path_parts,
565            vec![container_type.to_string()],
566            burn_core::module::ParamId::new(),
567        )
568    }
569
570    #[test]
571    fn test_pytorch_to_burn_linear_weight() {
572        let adapter = PyTorchToBurnAdapter;
573
574        // Linear layer weight should be transposed
575        let snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR);
576        let adapted = adapter.adapt(&snapshot);
577        assert_eq!(adapted.shape, shape![5, 10]);
578
579        // Linear layer bias should not be transposed
580        let snapshot = create_test_snapshot("fc.bias", shape![10], module_names::LINEAR);
581        let adapted = adapter.adapt(&snapshot);
582        assert_eq!(adapted.shape, shape![10]);
583    }
584
585    #[test]
586    fn test_pytorch_to_burn_norm_params() {
587        let adapter = PyTorchToBurnAdapter;
588
589        // BatchNorm weight -> gamma
590        let snapshot = create_test_snapshot("norm.weight", shape![10], module_names::BATCH_NORM);
591        let adapted = adapter.adapt(&snapshot);
592        assert_eq!(adapted.full_path(), "norm.gamma");
593
594        // BatchNorm bias -> beta
595        let snapshot = create_test_snapshot("norm.bias", shape![10], module_names::BATCH_NORM);
596        let adapted = adapter.adapt(&snapshot);
597        assert_eq!(adapted.full_path(), "norm.beta");
598    }
599
600    #[test]
601    fn test_burn_to_pytorch_linear_weight() {
602        let adapter = BurnToPyTorchAdapter;
603
604        // Linear layer weight should be transposed
605        let snapshot = create_test_snapshot("fc.weight", shape![5, 10], module_names::LINEAR);
606        let adapted = adapter.adapt(&snapshot);
607        assert_eq!(adapted.shape, shape![10, 5]);
608    }
609
610    #[test]
611    fn test_burn_to_pytorch_norm_params() {
612        let adapter = BurnToPyTorchAdapter;
613
614        // BatchNorm gamma -> weight
615        let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::BATCH_NORM);
616        let adapted = adapter.adapt(&snapshot);
617        assert_eq!(adapted.full_path(), "norm.weight");
618
619        // BatchNorm beta -> bias
620        let snapshot = create_test_snapshot("norm.beta", shape![10], module_names::BATCH_NORM);
621        let adapted = adapter.adapt(&snapshot);
622        assert_eq!(adapted.full_path(), "norm.bias");
623    }
624
625    #[test]
626    fn test_transpose_different_dtypes() {
627        // Test that transpose works for different data types
628
629        // Test with F32
630        let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], [2, 3]);
631        let transposed = transpose_tensor_data(f32_data);
632        assert_eq!(transposed.shape, shape![3, 2]);
633        let values = transposed.to_vec::<f32>().unwrap();
634        assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
635
636        // Test with I32
637        let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], [2, 3]);
638        let transposed = transpose_tensor_data(i32_data);
639        assert_eq!(transposed.shape, shape![3, 2]);
640        let values = transposed.to_vec::<i32>().unwrap();
641        assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
642
643        // Test with F64
644        let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]);
645        let transposed = transpose_tensor_data(f64_data);
646        assert_eq!(transposed.shape, shape![2, 2]);
647        let values = transposed.to_vec::<f64>().unwrap();
648        assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
649    }
650
651    #[test]
652    fn test_no_container_info() {
653        let adapter = PyTorchToBurnAdapter;
654
655        // Without container info, adapter returns unchanged for non-norm parameters
656        let mut snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR);
657        snapshot.container_stack = None;
658
659        // Without container info, no transformation occurs for linear layers
660        let adapted = adapter.adapt(&snapshot);
661        assert_eq!(adapted.shape, shape![10, 5]); // No transposition without container info
662
663        // Test a non-linear, non-norm parameter - should pass through unchanged
664        let mut snapshot2 = create_test_snapshot("other.weight", shape![10, 5], "Struct:Other");
665        snapshot2.container_stack = None;
666        let adapted2 = adapter.adapt(&snapshot2);
667        assert_eq!(adapted2.shape, shape![10, 5]); // No transposition
668    }
669
670    #[derive(Clone)]
671    struct RenameParamAdapter {
672        from: &'static str,
673        to: &'static str,
674        called: Arc<AtomicUsize>,
675    }
676
677    impl ModuleAdapter for RenameParamAdapter {
678        fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
679            self.called.fetch_add(1, Ordering::Relaxed);
680
681            let path_stack = match snapshot.path_stack.as_ref() {
682                Some(stack) => stack,
683                None => return snapshot.clone(),
684            };
685            let param = match path_stack.last() {
686                Some(p) => p.as_str(),
687                None => return snapshot.clone(),
688            };
689            if param != self.from {
690                return snapshot.clone();
691            }
692
693            let mut new_path = path_stack.to_vec();
694            *new_path.last_mut().unwrap() = self.to.to_string();
695
696            TensorSnapshot::from_closure(
697                snapshot.clone_data_fn(),
698                snapshot.dtype,
699                snapshot.shape.clone(),
700                new_path,
701                snapshot.container_stack.clone().unwrap_or_default(),
702                snapshot.tensor_id.unwrap_or_default(),
703            )
704        }
705
706        fn get_alternative_param_name(
707            &self,
708            _param_name: &str,
709            _container_type: &str,
710        ) -> Option<String> {
711            None
712        }
713
714        fn clone_box(&self) -> Box<dyn ModuleAdapter> {
715            Box::new(self.clone())
716        }
717    }
718
719    #[derive(Clone)]
720    struct AltNameAdapter {
721        from: &'static str,
722        to: &'static str,
723        called: Arc<AtomicUsize>,
724    }
725
726    impl ModuleAdapter for AltNameAdapter {
727        fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
728            TensorSnapshot::from_closure(
729                snapshot.clone_data_fn(),
730                snapshot.dtype,
731                snapshot.shape.clone(),
732                snapshot.path_stack.clone().unwrap_or_default(),
733                snapshot.container_stack.clone().unwrap_or_default(),
734                snapshot.tensor_id.unwrap_or_default(),
735            )
736        }
737
738        fn get_alternative_param_name(
739            &self,
740            param_name: &str,
741            _container_type: &str,
742        ) -> Option<String> {
743            self.called.fetch_add(1, Ordering::Relaxed);
744            if param_name == self.from {
745                Some(self.to.to_string())
746            } else {
747                None
748            }
749        }
750
751        fn clone_box(&self) -> Box<dyn ModuleAdapter> {
752            Box::new(self.clone())
753        }
754    }
755
756    #[test]
757    fn test_chain_adapter_pipes_adapt() {
758        let called1 = Arc::new(AtomicUsize::new(0));
759        let called2 = Arc::new(AtomicUsize::new(0));
760
761        let a = RenameParamAdapter {
762            from: "weight",
763            to: "a",
764            called: called1.clone(),
765        };
766        let b = RenameParamAdapter {
767            from: "a",
768            to: "b",
769            called: called2.clone(),
770        };
771
772        let chain = a.chain(b);
773        let snapshot = create_test_snapshot("fc.weight", shape![2, 2], module_names::LINEAR);
774        let adapted = chain.adapt(&snapshot);
775
776        assert_eq!(adapted.full_path(), "fc.b");
777        assert_eq!(called1.load(Ordering::Relaxed), 1);
778        assert_eq!(called2.load(Ordering::Relaxed), 1);
779    }
780
781    #[test]
782    fn test_chain_adapter_alternative_name_pipes_and_fallbacks() {
783        let called1 = Arc::new(AtomicUsize::new(0));
784        let called2 = Arc::new(AtomicUsize::new(0));
785
786        let a = AltNameAdapter {
787            from: "gamma",
788            to: "weight",
789            called: called1.clone(),
790        };
791        let b = AltNameAdapter {
792            from: "weight",
793            to: "scale",
794            called: called2.clone(),
795        };
796
797        let chain = a.chain(b);
798        let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
799        assert_eq!(alt.as_deref(), Some("scale"));
800        assert_eq!(called1.load(Ordering::Relaxed), 1);
801        assert_eq!(called2.load(Ordering::Relaxed), 1);
802
803        // If the second adapter doesn't have a mapping for the first alternative,
804        // fall back to the first alternative name.
805        let called1 = Arc::new(AtomicUsize::new(0));
806        let called2 = Arc::new(AtomicUsize::new(0));
807        let a = AltNameAdapter {
808            from: "gamma",
809            to: "weight",
810            called: called1.clone(),
811        };
812        let b = AltNameAdapter {
813            from: "something-else",
814            to: "unused",
815            called: called2.clone(),
816        };
817        let chain = a.chain(b);
818        let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
819        assert_eq!(alt.as_deref(), Some("weight"));
820        assert_eq!(called1.load(Ordering::Relaxed), 1);
821        assert_eq!(called2.load(Ordering::Relaxed), 1);
822
823        // If the first adapter doesn't provide an alternative, try the second with the original name.
824        let called1 = Arc::new(AtomicUsize::new(0));
825        let called2 = Arc::new(AtomicUsize::new(0));
826        let a = AltNameAdapter {
827            from: "something-else",
828            to: "unused",
829            called: called1.clone(),
830        };
831        let b = AltNameAdapter {
832            from: "gamma",
833            to: "weight",
834            called: called2.clone(),
835        };
836        let chain = a.chain(b);
837        let alt = chain.get_alternative_param_name("gamma", module_names::LAYER_NORM);
838        assert_eq!(alt.as_deref(), Some("weight"));
839        assert_eq!(called1.load(Ordering::Relaxed), 1);
840        assert_eq!(called2.load(Ordering::Relaxed), 1);
841
842        // clone_box must preserve behavior.
843        let boxed = chain.clone_box();
844        let alt = boxed.get_alternative_param_name("gamma", module_names::LAYER_NORM);
845        assert_eq!(alt.as_deref(), Some("weight"));
846    }
847
848    #[test]
849    fn test_half_precision_f32_to_f16() {
850        let adapter = HalfPrecisionAdapter::new();
851        let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
852
853        let adapted = adapter.adapt(&snapshot);
854        assert_eq!(adapted.dtype, DType::F16);
855        assert_eq!(adapted.shape, shape![2, 3]);
856
857        let data = adapted.to_data().unwrap();
858        assert_eq!(data.dtype, DType::F16);
859    }
860
861    #[test]
862    fn test_half_precision_f16_to_f32() {
863        let adapter = HalfPrecisionAdapter::new();
864
865        // Create an F16 snapshot
866        let values = vec![1.0f32; 6];
867        let data = TensorData::new(values, shape![2, 3]).convert_dtype(DType::F16);
868        let path_parts = vec!["fc".to_string(), "weight".to_string()];
869        let snapshot = TensorSnapshot::from_closure(
870            Rc::new(move || Ok(data.clone())),
871            DType::F16,
872            shape![2, 3],
873            path_parts,
874            vec![module_names::LINEAR.to_string()],
875            burn_core::module::ParamId::new(),
876        );
877
878        let adapted = adapter.adapt(&snapshot);
879        assert_eq!(adapted.dtype, DType::F32);
880    }
881
882    #[test]
883    fn test_half_precision_skips_batch_norm() {
884        let adapter = HalfPrecisionAdapter::new();
885
886        // BatchNorm is excluded by default
887        let snapshot = create_test_snapshot("norm.weight", shape![10], module_names::BATCH_NORM);
888        let adapted = adapter.adapt(&snapshot);
889        assert_eq!(adapted.dtype, DType::F32); // unchanged
890    }
891
892    #[test]
893    fn test_half_precision_converts_default_modules() {
894        let adapter = HalfPrecisionAdapter::new();
895
896        // Linear
897        let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
898        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
899
900        // Embedding
901        let snapshot = create_test_snapshot("emb.weight", shape![100, 64], module_names::EMBEDDING);
902        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
903
904        // Conv2d
905        let snapshot =
906            create_test_snapshot("conv.weight", shape![3, 3, 3, 3], module_names::CONV2D);
907        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
908
909        // LayerNorm (included by default)
910        let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM);
911        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
912
913        // GroupNorm
914        let snapshot = create_test_snapshot("gn.gamma", shape![10], module_names::GROUP_NORM);
915        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
916
917        // RmsNorm
918        let snapshot = create_test_snapshot("rms.weight", shape![10], module_names::RMS_NORM);
919        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
920    }
921
922    #[test]
923    fn test_half_precision_without_module() {
924        let adapter = HalfPrecisionAdapter::new().without_module("LayerNorm");
925
926        // LayerNorm removed from conversion set
927        let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM);
928        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32);
929
930        // Linear still converted
931        let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
932        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
933    }
934
935    #[test]
936    fn test_half_precision_with_module() {
937        let adapter = HalfPrecisionAdapter::new().with_module("CustomLayer");
938
939        // Custom module should now be converted
940        let snapshot = create_test_snapshot("custom.weight", shape![5], "Struct:CustomLayer");
941        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
942    }
943
944    #[test]
945    fn test_half_precision_with_qualified_name() {
946        let adapter = HalfPrecisionAdapter::new().with_module("Struct:CustomLayer");
947
948        let snapshot = create_test_snapshot("custom.weight", shape![5], "Struct:CustomLayer");
949        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
950    }
951
952    #[test]
953    fn test_half_precision_chain() {
954        let adapter = PyTorchToBurnAdapter.chain(HalfPrecisionAdapter::new());
955
956        let snapshot = create_test_snapshot("fc.weight", shape![10, 5], module_names::LINEAR);
957        let adapted = adapter.adapt(&snapshot);
958
959        // Should be both transposed and cast
960        assert_eq!(adapted.shape, shape![5, 10]);
961        assert_eq!(adapted.dtype, DType::F16);
962    }
963
964    #[test]
965    fn test_half_precision_skips_no_container() {
966        let adapter = HalfPrecisionAdapter::new();
967        let mut snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
968        snapshot.container_stack = None;
969
970        // No module type info: skip
971        let adapted = adapter.adapt(&snapshot);
972        assert_eq!(adapted.dtype, DType::F32);
973    }
974
975    #[test]
976    fn test_half_precision_skips_non_float() {
977        use burn_tensor::quantization::QuantScheme;
978
979        let adapter = HalfPrecisionAdapter::new();
980
981        // QFloat source: skip
982        let qfloat_dtype = DType::QFloat(QuantScheme::default());
983        let snapshot = create_test_snapshot("fc.weight", shape![2, 3], module_names::LINEAR);
984        let qfloat_snapshot = TensorSnapshot::from_closure(
985            snapshot.clone_data_fn(),
986            qfloat_dtype,
987            snapshot.shape.clone(),
988            snapshot.path_stack.clone().unwrap_or_default(),
989            snapshot.container_stack.clone().unwrap_or_default(),
990            snapshot.tensor_id.unwrap_or_default(),
991        );
992        let adapted = adapter.adapt(&qfloat_snapshot);
993        assert_eq!(adapted.dtype, qfloat_dtype);
994    }
995
996    #[test]
997    fn test_half_precision_default_module_count() {
998        let adapter = HalfPrecisionAdapter::new();
999        // 14 modules: Linear, Embedding, Conv1d-3d, ConvTranspose1d-3d,
1000        // DeformConv2d, LayerNorm, GroupNorm, InstanceNorm, RmsNorm, PRelu
1001        assert_eq!(adapter.modules.len(), 14);
1002    }
1003
1004    #[test]
1005    fn test_half_precision_without_module_qualified() {
1006        let adapter = HalfPrecisionAdapter::new().without_module("Struct:LayerNorm");
1007
1008        let snapshot = create_test_snapshot("norm.gamma", shape![10], module_names::LAYER_NORM);
1009        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F32);
1010    }
1011
1012    #[test]
1013    fn test_half_precision_with_module_batch_norm_opt_in() {
1014        let adapter = HalfPrecisionAdapter::new().with_module("BatchNorm");
1015
1016        let snapshot = create_test_snapshot("bn.weight", shape![10], module_names::BATCH_NORM);
1017        assert_eq!(adapter.adapt(&snapshot).dtype, DType::F16);
1018    }
1019}