burn_store/
adapter.rs

1//! Module adapters for transforming tensors between different formats
2//!
3//! This module provides adapters that handle differences between PyTorch and Burn:
4//! - Linear layer weight transposition
5//! - Normalization parameter naming (weight/bias vs gamma/beta)
6
7use crate::TensorSnapshot;
8
9use alloc::boxed::Box;
10use alloc::rc::Rc;
11use alloc::string::String;
12use alloc::string::ToString;
13use alloc::vec;
14
15use burn_tensor::TensorData;
16
17// Module type names as they appear in the container_type field
18// These come from the Module derive macro which uses stringify! on the struct name
19mod module_names {
20    // Import the types to ensure they exist at compile time
21    // If these types are renamed or moved, we'll get a compile error
22    #[allow(unused_imports)]
23    use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear};
24
25    // The actual string constants that match what the Module derive macro produces
26    // The imports above ensure these types exist at compile-time
27    pub const LINEAR: &str = "Linear";
28    pub const BATCH_NORM: &str = "BatchNorm";
29    pub const LAYER_NORM: &str = "LayerNorm";
30    pub const GROUP_NORM: &str = "GroupNorm";
31}
32
33/// Trait for adapting tensor snapshots between different module formats
34pub trait ModuleAdapter: Send + Sync {
35    /// Adapt a tensor snapshot based on its container type and parameter name
36    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
37
38    /// Clone the adapter into a boxed trait object
39    fn clone_box(&self) -> Box<dyn ModuleAdapter>;
40}
41
42impl Clone for Box<dyn ModuleAdapter> {
43    fn clone(&self) -> Self {
44        self.clone_box()
45    }
46}
47
48/// Identity adapter that passes tensors through unchanged
49#[derive(Debug, Clone, Default)]
50pub struct IdentityAdapter;
51
52impl ModuleAdapter for IdentityAdapter {
53    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
54        snapshot.clone()
55    }
56
57    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
58        Box::new(self.clone())
59    }
60}
61
62/// Adapter for converting from PyTorch format to Burn format
63///
64/// Handles:
65/// - Linear layer weight transposition (PyTorch: [out, in] → Burn: [in, out])
66/// - Normalization parameter renaming (weight → gamma, bias → beta)
67#[derive(Debug, Clone, Default)]
68pub struct PyTorchToBurnAdapter;
69
70impl ModuleAdapter for PyTorchToBurnAdapter {
71    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
72        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
73    }
74
75    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
76        Box::new(self.clone())
77    }
78}
79
80/// Adapter for converting from Burn format to PyTorch format
81///
82/// Handles:
83/// - Linear layer weight transposition (Burn: [in, out] → PyTorch: [out, in])
84/// - Normalization parameter renaming (gamma → weight, beta → bias)
85#[derive(Debug, Clone, Default)]
86pub struct BurnToPyTorchAdapter;
87
88impl ModuleAdapter for BurnToPyTorchAdapter {
89    fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
90        adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
91    }
92
93    fn clone_box(&self) -> Box<dyn ModuleAdapter> {
94        Box::new(self.clone())
95    }
96}
97
98/// Direction of PyTorch conversion for parameter naming
99#[derive(Debug, Clone, Copy)]
100enum PyTorchConversionDirection {
101    PyTorchToBurn,
102    BurnToPyTorch,
103}
104
105/// Core tensor adaptation logic for PyTorch format conversions
106fn adapt_pytorch_tensor(
107    snapshot: &TensorSnapshot,
108    direction: PyTorchConversionDirection,
109) -> TensorSnapshot {
110    // Extract path and parameter name
111    let (path_stack, param_name) = match get_path_and_param(snapshot) {
112        Some(result) => result,
113        None => return snapshot.clone(),
114    };
115
116    // Get container type
117    let container_type = match snapshot.container_stack.as_ref().and_then(|s| s.last()) {
118        Some(ct) => ct,
119        None => return snapshot.clone(),
120    };
121
122    match container_type.as_str() {
123        // Linear: transpose weight (bidirectional - same operation both ways)
124        module_names::LINEAR if param_name == "weight" && snapshot.shape.len() == 2 => {
125            transpose_2d_tensor(snapshot)
126        }
127        // Normalization layers: rename parameters based on direction
128        module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM => {
129            let new_name = match direction {
130                PyTorchConversionDirection::PyTorchToBurn => match param_name {
131                    "weight" => "gamma",
132                    "bias" => "beta",
133                    _ => return snapshot.clone(),
134                },
135                PyTorchConversionDirection::BurnToPyTorch => match param_name {
136                    "gamma" => "weight",
137                    "beta" => "bias",
138                    _ => return snapshot.clone(),
139                },
140            };
141            rename_parameter(snapshot, path_stack, new_name)
142        }
143        _ => snapshot.clone(),
144    }
145}
146
147/// Extract path stack and parameter name from snapshot
148fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
149    let path_stack = snapshot.path_stack.as_ref()?;
150    let param_name = path_stack.last()?.as_str();
151    Some((path_stack.as_slice(), param_name))
152}
153
154/// Rename a parameter in the snapshot
155fn rename_parameter(
156    snapshot: &TensorSnapshot,
157    path_stack: &[String],
158    new_name: &str,
159) -> TensorSnapshot {
160    let mut new_path = path_stack.to_vec();
161    *new_path.last_mut().unwrap() = new_name.to_string();
162
163    TensorSnapshot::from_closure(
164        snapshot.clone_data_fn(),
165        snapshot.dtype,
166        snapshot.shape.clone(),
167        new_path,
168        snapshot.container_stack.clone().unwrap_or_default(),
169        snapshot.tensor_id.unwrap_or_default(),
170    )
171}
172
173/// Transpose a 2D tensor
174fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
175    if snapshot.shape.len() != 2 {
176        return snapshot.clone();
177    }
178
179    let original_data_fn = snapshot.clone_data_fn();
180    let dtype = snapshot.dtype;
181    let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]];
182
183    // Create a lazy closure that transposes when called
184    let transposed_data_fn = Rc::new(move || {
185        let data = original_data_fn()?;
186        Ok(transpose_tensor_data(data))
187    });
188
189    TensorSnapshot::from_closure(
190        transposed_data_fn,
191        dtype,
192        transposed_shape,
193        snapshot.path_stack.clone().unwrap_or_default(),
194        snapshot.container_stack.clone().unwrap_or_default(),
195        snapshot.tensor_id.unwrap_or_default(),
196    )
197}
198
199/// Transpose tensor data (assumes 2D shape is already validated)
200fn transpose_tensor_data(data: TensorData) -> TensorData {
201    let shape = &data.shape;
202    let rows = shape[0];
203    let cols = shape[1];
204    let transposed_shape = vec![cols, rows];
205
206    // Get the raw bytes and element size
207    let bytes = data.as_bytes();
208    let element_size = data.dtype.size();
209
210    // Create a new buffer for transposed data
211    let mut transposed_bytes = vec![0u8; bytes.len()];
212
213    // Transpose at the byte level - works for any data type
214    for i in 0..rows {
215        for j in 0..cols {
216            let src_idx = (i * cols + j) * element_size;
217            let dst_idx = (j * rows + i) * element_size;
218
219            // Copy the bytes for this element
220            transposed_bytes[dst_idx..dst_idx + element_size]
221                .copy_from_slice(&bytes[src_idx..src_idx + element_size]);
222        }
223    }
224
225    // Create new TensorData from transposed bytes
226    TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use alloc::rc::Rc;
233    use burn_tensor::{DType, TensorData};
234
235    fn create_test_snapshot(path: &str, shape: Vec<usize>, container_type: &str) -> TensorSnapshot {
236        let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
237        let values = vec![1.0f32; shape.iter().product()];
238        let data = TensorData::new(values, shape.clone());
239
240        TensorSnapshot::from_closure(
241            Rc::new(move || Ok(data.clone())),
242            DType::F32,
243            shape,
244            path_parts,
245            vec![container_type.to_string()],
246            burn_core::module::ParamId::new(),
247        )
248    }
249
250    #[test]
251    fn test_pytorch_to_burn_linear_weight() {
252        let adapter = PyTorchToBurnAdapter;
253
254        // Linear layer weight should be transposed
255        let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
256        let adapted = adapter.adapt(&snapshot);
257        assert_eq!(adapted.shape, vec![5, 10]);
258
259        // Linear layer bias should not be transposed
260        let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR);
261        let adapted = adapter.adapt(&snapshot);
262        assert_eq!(adapted.shape, vec![10]);
263    }
264
265    #[test]
266    fn test_pytorch_to_burn_norm_params() {
267        let adapter = PyTorchToBurnAdapter;
268
269        // BatchNorm weight -> gamma
270        let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM);
271        let adapted = adapter.adapt(&snapshot);
272        assert_eq!(adapted.full_path(), "norm.gamma");
273
274        // BatchNorm bias -> beta
275        let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM);
276        let adapted = adapter.adapt(&snapshot);
277        assert_eq!(adapted.full_path(), "norm.beta");
278    }
279
280    #[test]
281    fn test_burn_to_pytorch_linear_weight() {
282        let adapter = BurnToPyTorchAdapter;
283
284        // Linear layer weight should be transposed
285        let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR);
286        let adapted = adapter.adapt(&snapshot);
287        assert_eq!(adapted.shape, vec![10, 5]);
288    }
289
290    #[test]
291    fn test_burn_to_pytorch_norm_params() {
292        let adapter = BurnToPyTorchAdapter;
293
294        // BatchNorm gamma -> weight
295        let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM);
296        let adapted = adapter.adapt(&snapshot);
297        assert_eq!(adapted.full_path(), "norm.weight");
298
299        // BatchNorm beta -> bias
300        let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM);
301        let adapted = adapter.adapt(&snapshot);
302        assert_eq!(adapted.full_path(), "norm.bias");
303    }
304
305    #[test]
306    fn test_transpose_different_dtypes() {
307        // Test that transpose works for different data types
308
309        // Test with F32
310        let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
311        let transposed = transpose_tensor_data(f32_data);
312        assert_eq!(transposed.shape, vec![3, 2]);
313        let values = transposed.to_vec::<f32>().unwrap();
314        assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
315
316        // Test with I32
317        let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
318        let transposed = transpose_tensor_data(i32_data);
319        assert_eq!(transposed.shape, vec![3, 2]);
320        let values = transposed.to_vec::<i32>().unwrap();
321        assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
322
323        // Test with F64
324        let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
325        let transposed = transpose_tensor_data(f64_data);
326        assert_eq!(transposed.shape, vec![2, 2]);
327        let values = transposed.to_vec::<f64>().unwrap();
328        assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
329    }
330
331    #[test]
332    fn test_no_container_info() {
333        let adapter = PyTorchToBurnAdapter;
334
335        // Without container info, adapter returns unchanged
336        let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
337        snapshot.container_stack = None;
338
339        // Without container info, no transformation occurs
340        let adapted = adapter.adapt(&snapshot);
341        assert_eq!(adapted.shape, vec![10, 5]); // No transposition without container info
342
343        // Test a non-linear, non-norm parameter - should pass through unchanged
344        let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Other");
345        snapshot2.container_stack = None;
346        let adapted2 = adapter.adapt(&snapshot2);
347        assert_eq!(adapted2.shape, vec![10, 5]); // No transposition
348    }
349}