burn_store/
adapter.rs

1//! Module adapters for transforming tensors between different formats
2//!
3//! This module provides adapters that handle differences between PyTorch and Burn:
4//! - Linear layer weight transposition
5//! - Normalization parameter naming (weight/bias vs gamma/beta)
6
7use crate::TensorSnapshot;
8
9use alloc::boxed::Box;
10use alloc::rc::Rc;
11use alloc::string::String;
12use alloc::string::ToString;
13use alloc::vec;
14
15use burn_tensor::TensorData;
16
17// Module type names as they appear in the container_type field
18// These come from the Module derive macro which uses stringify! on the struct name
19// Format: "Struct:TypeName" for user-defined structs
20mod module_names {
21    // Import the types to ensure they exist at compile time
22    // If these types are renamed or moved, we'll get a compile error
23    #[allow(unused_imports)]
24    use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear};
25
26    // The actual string constants that match what the Module derive macro produces
27    // The imports above ensure these types exist at compile-time
28    pub const LINEAR: &str = "Struct:Linear";
29    pub const BATCH_NORM: &str = "Struct:BatchNorm";
30    pub const LAYER_NORM: &str = "Struct:LayerNorm";
31    pub const GROUP_NORM: &str = "Struct:GroupNorm";
32}
33
34/// Trait for adapting tensor snapshots between different module formats
35pub trait ModuleAdapter: Send + Sync {
36    /// Adapt a tensor snapshot based on its container type and parameter name
37    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
38
39    /// Get alternative parameter name to try during matching
40    ///
41    /// When looking for a parameter in a module, this method provides an alternative
42    /// name to try if the direct name doesn't match. This enables matching parameters
43    /// with different naming conventions (e.g., PyTorch's "weight" vs Burn's "gamma").
44    ///
45    /// # Arguments
46    /// * `param_name` - The parameter name we're looking for
47    /// * `container_type` - The type of container module (e.g., "BatchNorm")
48    ///
49    /// # Returns
50    /// Alternative parameter name to try, or None if no alternative exists
51    fn get_alternative_param_name(
52        &self,
53        _param_name: &str,
54        _container_type: &str,
55    ) -> Option<String> {
56        None
57    }
58
59    /// Clone the adapter into a boxed trait object
60    fn clone_box(&self) -> Box<dyn ModuleAdapter>;
61}
62
63impl Clone for Box<dyn ModuleAdapter> {
64    fn clone(&self) -> Self {
65        self.clone_box()
66    }
67}
68
69/// Identity adapter that passes tensors through unchanged
70#[derive(Debug, Clone, Default)]
71pub struct IdentityAdapter;
72
73impl ModuleAdapter for IdentityAdapter {
74    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
75        snapshot.clone()
76    }
77
78    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
79        Box::new(self.clone())
80    }
81}
82
83/// Adapter for converting from PyTorch format to Burn format
84///
85/// Handles:
86/// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out])
87/// - Normalization parameter renaming (weight → gamma, bias → beta)
88#[derive(Debug, Clone, Default)]
89pub struct PyTorchToBurnAdapter;
90
91impl ModuleAdapter for PyTorchToBurnAdapter {
92    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
93        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
94    }
95
96    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
97        // For PyTorch->Burn: When looking for Burn names (gamma/beta), try PyTorch names (weight/bias)
98        if is_normalization_layer(container_type) {
99            burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())
100        } else {
101            None
102        }
103    }
104
105    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
106        Box::new(self.clone())
107    }
108}
109
110/// Adapter for converting from Burn format to PyTorch format
111///
112/// Handles:
113/// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in])
114/// - Normalization parameter renaming (gamma → weight, beta → bias)
115#[derive(Debug, Clone, Default)]
116pub struct BurnToPyTorchAdapter;
117
118impl ModuleAdapter for BurnToPyTorchAdapter {
119    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
120        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
121    }
122
123    fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
124        // For Burn->PyTorch: When looking for PyTorch names (weight/bias), try Burn names (gamma/beta)
125        if is_normalization_layer(container_type) {
126            pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())
127        } else {
128            None
129        }
130    }
131
132    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
133        Box::new(self.clone())
134    }
135}
136
137/// Direction of PyTorch conversion for parameter naming
138#[derive(Debug, Clone, Copy)]
139enum PyTorchConversionDirection {
140    PyTorchToBurn,
141    BurnToPyTorch,
142}
143
144/// Check if container type is a normalization layer
145fn is_normalization_layer(container_type: &str) -> bool {
146    matches!(
147        container_type,
148        module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM
149    )
150}
151
152/// Map PyTorch normalization parameter name to Burn
153fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {
154    match param_name {
155        "weight" => Some("gamma"),
156        "bias" => Some("beta"),
157        _ => None,
158    }
159}
160
161/// Map Burn normalization parameter name to PyTorch
162fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {
163    match param_name {
164        "gamma" => Some("weight"),
165        "beta" => Some("bias"),
166        _ => None,
167    }
168}
169
170/// Core tensor adaptation logic for PyTorch format conversions
171fn adapt_pytorch_tensor(
172    snapshot: &TensorSnapshot,
173    direction: PyTorchConversionDirection,
174) -> TensorSnapshot {
175    // Extract path and parameter name
176    let (path_stack, param_name) = match get_path_and_param(snapshot) {
177        Some(result) => result,
178        None => return snapshot.clone(),
179    };
180
181    // Get module type for matching (ignores Vec/Array wrappers)
182    let module_type = match snapshot.module_type() {
183        Some(mt) => mt,
184        None => return snapshot.clone(), // No user-defined module found
185    };
186
187    // Linear: transpose weight (bidirectional - same operation both ways)
188    if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 {
189        return transpose_2d_tensor(snapshot);
190    }
191
192    // Normalization layers: rename parameters based on direction
193    if is_normalization_layer(&module_type) {
194        let new_name = match direction {
195            PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),
196            PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),
197        };
198
199        if let Some(new_name) = new_name {
200            return rename_parameter(snapshot, path_stack, new_name);
201        }
202    }
203
204    snapshot.clone()
205}
206
207/// Extract path stack and parameter name from snapshot
208fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
209    let path_stack = snapshot.path_stack.as_ref()?;
210    let param_name = path_stack.last()?.as_str();
211    Some((path_stack.as_slice(), param_name))
212}
213
214/// Rename a parameter in the snapshot
215fn rename_parameter(
216    snapshot: &TensorSnapshot,
217    path_stack: &[String],
218    new_name: &str,
219) -> TensorSnapshot {
220    let mut new_path = path_stack.to_vec();
221    *new_path.last_mut().unwrap() = new_name.to_string();
222
223    TensorSnapshot::from_closure(
224        snapshot.clone_data_fn(),
225        snapshot.dtype,
226        snapshot.shape.clone(),
227        new_path,
228        snapshot.container_stack.clone().unwrap_or_default(),
229        snapshot.tensor_id.unwrap_or_default(),
230    )
231}
232
233/// Transpose a 2D tensor
234fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
235    if snapshot.shape.len() != 2 {
236        return snapshot.clone();
237    }
238
239    let original_data_fn = snapshot.clone_data_fn();
240    let dtype = snapshot.dtype;
241    let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]];
242
243    // Create a lazy closure that transposes when called
244    let transposed_data_fn = Rc::new(move || {
245        let data = original_data_fn()?;
246        Ok(transpose_tensor_data(data))
247    });
248
249    TensorSnapshot::from_closure(
250        transposed_data_fn,
251        dtype,
252        transposed_shape,
253        snapshot.path_stack.clone().unwrap_or_default(),
254        snapshot.container_stack.clone().unwrap_or_default(),
255        snapshot.tensor_id.unwrap_or_default(),
256    )
257}
258
259/// Transpose tensor data (assumes 2D shape is already validated)
260fn transpose_tensor_data(data: TensorData) -> TensorData {
261    let shape = &data.shape;
262    let rows = shape[0];
263    let cols = shape[1];
264    let transposed_shape = vec![cols, rows];
265
266    // Get the raw bytes and element size
267    let bytes = data.as_bytes();
268    let element_size = data.dtype.size();
269
270    // Create a new buffer for transposed data
271    let mut transposed_bytes = vec![0u8; bytes.len()];
272
273    // Transpose at the byte level - works for any data type
274    for i in 0..rows {
275        for j in 0..cols {
276            let src_idx = (i * cols + j) * element_size;
277            let dst_idx = (j * rows + i) * element_size;
278
279            // Copy the bytes for this element
280            transposed_bytes[dst_idx..dst_idx + element_size]
281                .copy_from_slice(&bytes[src_idx..src_idx + element_size]);
282        }
283    }
284
285    // Create new TensorData from transposed bytes
286    TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use alloc::rc::Rc;
293    use burn_tensor::{DType, TensorData};
294
295    fn create_test_snapshot(path: &str, shape: Vec<usize>, container_type: &str) -> TensorSnapshot {
296        let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
297        let values = vec![1.0f32; shape.iter().product()];
298        let data = TensorData::new(values, shape.clone());
299
300        TensorSnapshot::from_closure(
301            Rc::new(move || Ok(data.clone())),
302            DType::F32,
303            shape,
304            path_parts,
305            vec![container_type.to_string()],
306            burn_core::module::ParamId::new(),
307        )
308    }
309
310    #[test]
311    fn test_pytorch_to_burn_linear_weight() {
312        let adapter = PyTorchToBurnAdapter;
313
314        // Linear layer weight should be transposed
315        let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
316        let adapted = adapter.adapt(&snapshot);
317        assert_eq!(adapted.shape, vec![5, 10]);
318
319        // Linear layer bias should not be transposed
320        let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR);
321        let adapted = adapter.adapt(&snapshot);
322        assert_eq!(adapted.shape, vec![10]);
323    }
324
325    #[test]
326    fn test_pytorch_to_burn_norm_params() {
327        let adapter = PyTorchToBurnAdapter;
328
329        // BatchNorm weight -> gamma
330        let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM);
331        let adapted = adapter.adapt(&snapshot);
332        assert_eq!(adapted.full_path(), "norm.gamma");
333
334        // BatchNorm bias -> beta
335        let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM);
336        let adapted = adapter.adapt(&snapshot);
337        assert_eq!(adapted.full_path(), "norm.beta");
338    }
339
340    #[test]
341    fn test_burn_to_pytorch_linear_weight() {
342        let adapter = BurnToPyTorchAdapter;
343
344        // Linear layer weight should be transposed
345        let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR);
346        let adapted = adapter.adapt(&snapshot);
347        assert_eq!(adapted.shape, vec![10, 5]);
348    }
349
350    #[test]
351    fn test_burn_to_pytorch_norm_params() {
352        let adapter = BurnToPyTorchAdapter;
353
354        // BatchNorm gamma -> weight
355        let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM);
356        let adapted = adapter.adapt(&snapshot);
357        assert_eq!(adapted.full_path(), "norm.weight");
358
359        // BatchNorm beta -> bias
360        let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM);
361        let adapted = adapter.adapt(&snapshot);
362        assert_eq!(adapted.full_path(), "norm.bias");
363    }
364
365    #[test]
366    fn test_transpose_different_dtypes() {
367        // Test that transpose works for different data types
368
369        // Test with F32
370        let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
371        let transposed = transpose_tensor_data(f32_data);
372        assert_eq!(transposed.shape, vec![3, 2]);
373        let values = transposed.to_vec::<f32>().unwrap();
374        assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
375
376        // Test with I32
377        let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
378        let transposed = transpose_tensor_data(i32_data);
379        assert_eq!(transposed.shape, vec![3, 2]);
380        let values = transposed.to_vec::<i32>().unwrap();
381        assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
382
383        // Test with F64
384        let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
385        let transposed = transpose_tensor_data(f64_data);
386        assert_eq!(transposed.shape, vec![2, 2]);
387        let values = transposed.to_vec::<f64>().unwrap();
388        assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
389    }
390
391    #[test]
392    fn test_no_container_info() {
393        let adapter = PyTorchToBurnAdapter;
394
395        // Without container info, adapter returns unchanged for non-norm parameters
396        let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
397        snapshot.container_stack = None;
398
399        // Without container info, no transformation occurs for linear layers
400        let adapted = adapter.adapt(&snapshot);
401        assert_eq!(adapted.shape, vec![10, 5]); // No transposition without container info
402
403        // Test a non-linear, non-norm parameter - should pass through unchanged
404        let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Struct:Other");
405        snapshot2.container_stack = None;
406        let adapted2 = adapter.adapt(&snapshot2);
407        assert_eq!(adapted2.shape, vec![10, 5]); // No transposition
408    }
409}