1use crate::TensorSnapshot;
8
9use alloc::boxed::Box;
10use alloc::rc::Rc;
11use alloc::string::String;
12use alloc::string::ToString;
13use alloc::vec;
14
15use burn_tensor::TensorData;
16
17mod module_names {
20 #[allow(unused_imports)]
23 use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear};
24
25 pub const LINEAR: &str = "Linear";
28 pub const BATCH_NORM: &str = "BatchNorm";
29 pub const LAYER_NORM: &str = "LayerNorm";
30 pub const GROUP_NORM: &str = "GroupNorm";
31}
32
33pub trait ModuleAdapter: Send + Sync {
35 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
37
38 fn clone_box(&self) -> Box<dyn ModuleAdapter>;
40}
41
42impl Clone for Box<dyn ModuleAdapter> {
43 fn clone(&self) -> Self {
44 self.clone_box()
45 }
46}
47
48#[derive(Debug, Clone, Default)]
50pub struct IdentityAdapter;
51
52impl ModuleAdapter for IdentityAdapter {
53 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
54 snapshot.clone()
55 }
56
57 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
58 Box::new(self.clone())
59 }
60}
61
62#[derive(Debug, Clone, Default)]
68pub struct PyTorchToBurnAdapter;
69
70impl ModuleAdapter for PyTorchToBurnAdapter {
71 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
72 adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
73 }
74
75 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
76 Box::new(self.clone())
77 }
78}
79
80#[derive(Debug, Clone, Default)]
86pub struct BurnToPyTorchAdapter;
87
88impl ModuleAdapter for BurnToPyTorchAdapter {
89 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
90 adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
91 }
92
93 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
94 Box::new(self.clone())
95 }
96}
97
98#[derive(Debug, Clone, Copy)]
100enum PyTorchConversionDirection {
101 PyTorchToBurn,
102 BurnToPyTorch,
103}
104
105fn adapt_pytorch_tensor(
107 snapshot: &TensorSnapshot,
108 direction: PyTorchConversionDirection,
109) -> TensorSnapshot {
110 let (path_stack, param_name) = match get_path_and_param(snapshot) {
112 Some(result) => result,
113 None => return snapshot.clone(),
114 };
115
116 let container_type = match snapshot.container_stack.as_ref().and_then(|s| s.last()) {
118 Some(ct) => ct,
119 None => return snapshot.clone(),
120 };
121
122 match container_type.as_str() {
123 module_names::LINEAR if param_name == "weight" && snapshot.shape.len() == 2 => {
125 transpose_2d_tensor(snapshot)
126 }
127 module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM => {
129 let new_name = match direction {
130 PyTorchConversionDirection::PyTorchToBurn => match param_name {
131 "weight" => "gamma",
132 "bias" => "beta",
133 _ => return snapshot.clone(),
134 },
135 PyTorchConversionDirection::BurnToPyTorch => match param_name {
136 "gamma" => "weight",
137 "beta" => "bias",
138 _ => return snapshot.clone(),
139 },
140 };
141 rename_parameter(snapshot, path_stack, new_name)
142 }
143 _ => snapshot.clone(),
144 }
145}
146
147fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
149 let path_stack = snapshot.path_stack.as_ref()?;
150 let param_name = path_stack.last()?.as_str();
151 Some((path_stack.as_slice(), param_name))
152}
153
154fn rename_parameter(
156 snapshot: &TensorSnapshot,
157 path_stack: &[String],
158 new_name: &str,
159) -> TensorSnapshot {
160 let mut new_path = path_stack.to_vec();
161 *new_path.last_mut().unwrap() = new_name.to_string();
162
163 TensorSnapshot::from_closure(
164 snapshot.clone_data_fn(),
165 snapshot.dtype,
166 snapshot.shape.clone(),
167 new_path,
168 snapshot.container_stack.clone().unwrap_or_default(),
169 snapshot.tensor_id.unwrap_or_default(),
170 )
171}
172
173fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
175 if snapshot.shape.len() != 2 {
176 return snapshot.clone();
177 }
178
179 let original_data_fn = snapshot.clone_data_fn();
180 let dtype = snapshot.dtype;
181 let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]];
182
183 let transposed_data_fn = Rc::new(move || {
185 let data = original_data_fn()?;
186 Ok(transpose_tensor_data(data))
187 });
188
189 TensorSnapshot::from_closure(
190 transposed_data_fn,
191 dtype,
192 transposed_shape,
193 snapshot.path_stack.clone().unwrap_or_default(),
194 snapshot.container_stack.clone().unwrap_or_default(),
195 snapshot.tensor_id.unwrap_or_default(),
196 )
197}
198
199fn transpose_tensor_data(data: TensorData) -> TensorData {
201 let shape = &data.shape;
202 let rows = shape[0];
203 let cols = shape[1];
204 let transposed_shape = vec![cols, rows];
205
206 let bytes = data.as_bytes();
208 let element_size = data.dtype.size();
209
210 let mut transposed_bytes = vec![0u8; bytes.len()];
212
213 for i in 0..rows {
215 for j in 0..cols {
216 let src_idx = (i * cols + j) * element_size;
217 let dst_idx = (j * rows + i) * element_size;
218
219 transposed_bytes[dst_idx..dst_idx + element_size]
221 .copy_from_slice(&bytes[src_idx..src_idx + element_size]);
222 }
223 }
224
225 TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use alloc::rc::Rc;
233 use burn_tensor::{DType, TensorData};
234
235 fn create_test_snapshot(path: &str, shape: Vec<usize>, container_type: &str) -> TensorSnapshot {
236 let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
237 let values = vec![1.0f32; shape.iter().product()];
238 let data = TensorData::new(values, shape.clone());
239
240 TensorSnapshot::from_closure(
241 Rc::new(move || Ok(data.clone())),
242 DType::F32,
243 shape,
244 path_parts,
245 vec![container_type.to_string()],
246 burn_core::module::ParamId::new(),
247 )
248 }
249
250 #[test]
251 fn test_pytorch_to_burn_linear_weight() {
252 let adapter = PyTorchToBurnAdapter;
253
254 let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
256 let adapted = adapter.adapt(&snapshot);
257 assert_eq!(adapted.shape, vec![5, 10]);
258
259 let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR);
261 let adapted = adapter.adapt(&snapshot);
262 assert_eq!(adapted.shape, vec![10]);
263 }
264
265 #[test]
266 fn test_pytorch_to_burn_norm_params() {
267 let adapter = PyTorchToBurnAdapter;
268
269 let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM);
271 let adapted = adapter.adapt(&snapshot);
272 assert_eq!(adapted.full_path(), "norm.gamma");
273
274 let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM);
276 let adapted = adapter.adapt(&snapshot);
277 assert_eq!(adapted.full_path(), "norm.beta");
278 }
279
280 #[test]
281 fn test_burn_to_pytorch_linear_weight() {
282 let adapter = BurnToPyTorchAdapter;
283
284 let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR);
286 let adapted = adapter.adapt(&snapshot);
287 assert_eq!(adapted.shape, vec![10, 5]);
288 }
289
290 #[test]
291 fn test_burn_to_pytorch_norm_params() {
292 let adapter = BurnToPyTorchAdapter;
293
294 let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM);
296 let adapted = adapter.adapt(&snapshot);
297 assert_eq!(adapted.full_path(), "norm.weight");
298
299 let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM);
301 let adapted = adapter.adapt(&snapshot);
302 assert_eq!(adapted.full_path(), "norm.bias");
303 }
304
305 #[test]
306 fn test_transpose_different_dtypes() {
307 let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
311 let transposed = transpose_tensor_data(f32_data);
312 assert_eq!(transposed.shape, vec![3, 2]);
313 let values = transposed.to_vec::<f32>().unwrap();
314 assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
315
316 let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
318 let transposed = transpose_tensor_data(i32_data);
319 assert_eq!(transposed.shape, vec![3, 2]);
320 let values = transposed.to_vec::<i32>().unwrap();
321 assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
322
323 let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
325 let transposed = transpose_tensor_data(f64_data);
326 assert_eq!(transposed.shape, vec![2, 2]);
327 let values = transposed.to_vec::<f64>().unwrap();
328 assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
329 }
330
331 #[test]
332 fn test_no_container_info() {
333 let adapter = PyTorchToBurnAdapter;
334
335 let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
337 snapshot.container_stack = None;
338
339 let adapted = adapter.adapt(&snapshot);
341 assert_eq!(adapted.shape, vec![10, 5]); let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Other");
345 snapshot2.container_stack = None;
346 let adapted2 = adapter.adapt(&snapshot2);
347 assert_eq!(adapted2.shape, vec![10, 5]); }
349}