1use crate::TensorSnapshot;
8
9use alloc::boxed::Box;
10use alloc::rc::Rc;
11use alloc::string::String;
12use alloc::string::ToString;
13use alloc::vec;
14
15use burn_tensor::TensorData;
16
17mod module_names {
21 #[allow(unused_imports)]
24 use burn_nn::{BatchNorm, GroupNorm, LayerNorm, Linear};
25
26 pub const LINEAR: &str = "Struct:Linear";
29 pub const BATCH_NORM: &str = "Struct:BatchNorm";
30 pub const LAYER_NORM: &str = "Struct:LayerNorm";
31 pub const GROUP_NORM: &str = "Struct:GroupNorm";
32}
33
34pub trait ModuleAdapter: Send + Sync {
36 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot;
38
39 fn get_alternative_param_name(
52 &self,
53 _param_name: &str,
54 _container_type: &str,
55 ) -> Option<String> {
56 None
57 }
58
59 fn clone_box(&self) -> Box<dyn ModuleAdapter>;
61}
62
63impl Clone for Box<dyn ModuleAdapter> {
64 fn clone(&self) -> Self {
65 self.clone_box()
66 }
67}
68
69#[derive(Debug, Clone, Default)]
71pub struct IdentityAdapter;
72
73impl ModuleAdapter for IdentityAdapter {
74 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
75 snapshot.clone()
76 }
77
78 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
79 Box::new(self.clone())
80 }
81}
82
83#[derive(Debug, Clone, Default)]
89pub struct PyTorchToBurnAdapter;
90
91impl ModuleAdapter for PyTorchToBurnAdapter {
92 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
93 adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::PyTorchToBurn)
94 }
95
96 fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
97 if is_normalization_layer(container_type) {
99 burn_norm_param_to_pytorch(param_name).map(|s| s.to_string())
100 } else {
101 None
102 }
103 }
104
105 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
106 Box::new(self.clone())
107 }
108}
109
110#[derive(Debug, Clone, Default)]
116pub struct BurnToPyTorchAdapter;
117
118impl ModuleAdapter for BurnToPyTorchAdapter {
119 fn adapt(&self, snapshot: &TensorSnapshot) -> TensorSnapshot {
120 adapt_pytorch_tensor(snapshot, PyTorchConversionDirection::BurnToPyTorch)
121 }
122
123 fn get_alternative_param_name(&self, param_name: &str, container_type: &str) -> Option<String> {
124 if is_normalization_layer(container_type) {
126 pytorch_norm_param_to_burn(param_name).map(|s| s.to_string())
127 } else {
128 None
129 }
130 }
131
132 fn clone_box(&self) -> Box<dyn ModuleAdapter> {
133 Box::new(self.clone())
134 }
135}
136
137#[derive(Debug, Clone, Copy)]
139enum PyTorchConversionDirection {
140 PyTorchToBurn,
141 BurnToPyTorch,
142}
143
144fn is_normalization_layer(container_type: &str) -> bool {
146 matches!(
147 container_type,
148 module_names::BATCH_NORM | module_names::LAYER_NORM | module_names::GROUP_NORM
149 )
150}
151
152fn pytorch_norm_param_to_burn(param_name: &str) -> Option<&'static str> {
154 match param_name {
155 "weight" => Some("gamma"),
156 "bias" => Some("beta"),
157 _ => None,
158 }
159}
160
161fn burn_norm_param_to_pytorch(param_name: &str) -> Option<&'static str> {
163 match param_name {
164 "gamma" => Some("weight"),
165 "beta" => Some("bias"),
166 _ => None,
167 }
168}
169
170fn adapt_pytorch_tensor(
172 snapshot: &TensorSnapshot,
173 direction: PyTorchConversionDirection,
174) -> TensorSnapshot {
175 let (path_stack, param_name) = match get_path_and_param(snapshot) {
177 Some(result) => result,
178 None => return snapshot.clone(),
179 };
180
181 let module_type = match snapshot.module_type() {
183 Some(mt) => mt,
184 None => return snapshot.clone(), };
186
187 if module_type == module_names::LINEAR && param_name == "weight" && snapshot.shape.len() == 2 {
189 return transpose_2d_tensor(snapshot);
190 }
191
192 if is_normalization_layer(&module_type) {
194 let new_name = match direction {
195 PyTorchConversionDirection::PyTorchToBurn => pytorch_norm_param_to_burn(param_name),
196 PyTorchConversionDirection::BurnToPyTorch => burn_norm_param_to_pytorch(param_name),
197 };
198
199 if let Some(new_name) = new_name {
200 return rename_parameter(snapshot, path_stack, new_name);
201 }
202 }
203
204 snapshot.clone()
205}
206
207fn get_path_and_param(snapshot: &TensorSnapshot) -> Option<(&[String], &str)> {
209 let path_stack = snapshot.path_stack.as_ref()?;
210 let param_name = path_stack.last()?.as_str();
211 Some((path_stack.as_slice(), param_name))
212}
213
214fn rename_parameter(
216 snapshot: &TensorSnapshot,
217 path_stack: &[String],
218 new_name: &str,
219) -> TensorSnapshot {
220 let mut new_path = path_stack.to_vec();
221 *new_path.last_mut().unwrap() = new_name.to_string();
222
223 TensorSnapshot::from_closure(
224 snapshot.clone_data_fn(),
225 snapshot.dtype,
226 snapshot.shape.clone(),
227 new_path,
228 snapshot.container_stack.clone().unwrap_or_default(),
229 snapshot.tensor_id.unwrap_or_default(),
230 )
231}
232
233fn transpose_2d_tensor(snapshot: &TensorSnapshot) -> TensorSnapshot {
235 if snapshot.shape.len() != 2 {
236 return snapshot.clone();
237 }
238
239 let original_data_fn = snapshot.clone_data_fn();
240 let dtype = snapshot.dtype;
241 let transposed_shape = vec![snapshot.shape[1], snapshot.shape[0]];
242
243 let transposed_data_fn = Rc::new(move || {
245 let data = original_data_fn()?;
246 Ok(transpose_tensor_data(data))
247 });
248
249 TensorSnapshot::from_closure(
250 transposed_data_fn,
251 dtype,
252 transposed_shape,
253 snapshot.path_stack.clone().unwrap_or_default(),
254 snapshot.container_stack.clone().unwrap_or_default(),
255 snapshot.tensor_id.unwrap_or_default(),
256 )
257}
258
259fn transpose_tensor_data(data: TensorData) -> TensorData {
261 let shape = &data.shape;
262 let rows = shape[0];
263 let cols = shape[1];
264 let transposed_shape = vec![cols, rows];
265
266 let bytes = data.as_bytes();
268 let element_size = data.dtype.size();
269
270 let mut transposed_bytes = vec![0u8; bytes.len()];
272
273 for i in 0..rows {
275 for j in 0..cols {
276 let src_idx = (i * cols + j) * element_size;
277 let dst_idx = (j * rows + i) * element_size;
278
279 transposed_bytes[dst_idx..dst_idx + element_size]
281 .copy_from_slice(&bytes[src_idx..src_idx + element_size]);
282 }
283 }
284
285 TensorData::from_bytes_vec(transposed_bytes, transposed_shape, data.dtype)
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use alloc::rc::Rc;
293 use burn_tensor::{DType, TensorData};
294
295 fn create_test_snapshot(path: &str, shape: Vec<usize>, container_type: &str) -> TensorSnapshot {
296 let path_parts: Vec<String> = path.split('.').map(|s| s.to_string()).collect();
297 let values = vec![1.0f32; shape.iter().product()];
298 let data = TensorData::new(values, shape.clone());
299
300 TensorSnapshot::from_closure(
301 Rc::new(move || Ok(data.clone())),
302 DType::F32,
303 shape,
304 path_parts,
305 vec![container_type.to_string()],
306 burn_core::module::ParamId::new(),
307 )
308 }
309
310 #[test]
311 fn test_pytorch_to_burn_linear_weight() {
312 let adapter = PyTorchToBurnAdapter;
313
314 let snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
316 let adapted = adapter.adapt(&snapshot);
317 assert_eq!(adapted.shape, vec![5, 10]);
318
319 let snapshot = create_test_snapshot("fc.bias", vec![10], module_names::LINEAR);
321 let adapted = adapter.adapt(&snapshot);
322 assert_eq!(adapted.shape, vec![10]);
323 }
324
325 #[test]
326 fn test_pytorch_to_burn_norm_params() {
327 let adapter = PyTorchToBurnAdapter;
328
329 let snapshot = create_test_snapshot("norm.weight", vec![10], module_names::BATCH_NORM);
331 let adapted = adapter.adapt(&snapshot);
332 assert_eq!(adapted.full_path(), "norm.gamma");
333
334 let snapshot = create_test_snapshot("norm.bias", vec![10], module_names::BATCH_NORM);
336 let adapted = adapter.adapt(&snapshot);
337 assert_eq!(adapted.full_path(), "norm.beta");
338 }
339
340 #[test]
341 fn test_burn_to_pytorch_linear_weight() {
342 let adapter = BurnToPyTorchAdapter;
343
344 let snapshot = create_test_snapshot("fc.weight", vec![5, 10], module_names::LINEAR);
346 let adapted = adapter.adapt(&snapshot);
347 assert_eq!(adapted.shape, vec![10, 5]);
348 }
349
350 #[test]
351 fn test_burn_to_pytorch_norm_params() {
352 let adapter = BurnToPyTorchAdapter;
353
354 let snapshot = create_test_snapshot("norm.gamma", vec![10], module_names::BATCH_NORM);
356 let adapted = adapter.adapt(&snapshot);
357 assert_eq!(adapted.full_path(), "norm.weight");
358
359 let snapshot = create_test_snapshot("norm.beta", vec![10], module_names::BATCH_NORM);
361 let adapted = adapter.adapt(&snapshot);
362 assert_eq!(adapted.full_path(), "norm.bias");
363 }
364
365 #[test]
366 fn test_transpose_different_dtypes() {
367 let f32_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
371 let transposed = transpose_tensor_data(f32_data);
372 assert_eq!(transposed.shape, vec![3, 2]);
373 let values = transposed.to_vec::<f32>().unwrap();
374 assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
375
376 let i32_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
378 let transposed = transpose_tensor_data(i32_data);
379 assert_eq!(transposed.shape, vec![3, 2]);
380 let values = transposed.to_vec::<i32>().unwrap();
381 assert_eq!(values, vec![1, 4, 2, 5, 3, 6]);
382
383 let f64_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
385 let transposed = transpose_tensor_data(f64_data);
386 assert_eq!(transposed.shape, vec![2, 2]);
387 let values = transposed.to_vec::<f64>().unwrap();
388 assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
389 }
390
391 #[test]
392 fn test_no_container_info() {
393 let adapter = PyTorchToBurnAdapter;
394
395 let mut snapshot = create_test_snapshot("fc.weight", vec![10, 5], module_names::LINEAR);
397 snapshot.container_stack = None;
398
399 let adapted = adapter.adapt(&snapshot);
401 assert_eq!(adapted.shape, vec![10, 5]); let mut snapshot2 = create_test_snapshot("other.weight", vec![10, 5], "Struct:Other");
405 snapshot2.container_stack = None;
406 let adapted2 = adapter.adapt(&snapshot2);
407 assert_eq!(adapted2.shape, vec![10, 5]); }
409}