1use alloc::collections::BTreeMap;
2use alloc::string::{String, ToString};
3use alloc::vec::Vec;
4
5use regex::{self, Regex};
6
7use crate::TensorSnapshot;
8
9#[derive(Debug, Clone, Default)]
27pub struct KeyRemapper {
28 pub patterns: Vec<(Regex, String)>,
30}
31
32impl KeyRemapper {
33 pub fn new() -> Self {
35 Self::default()
36 }
37
38 pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>
50 where
51 S1: AsRef<str>,
52 S2: Into<String>,
53 {
54 let regex = Regex::new(from.as_ref())?;
55 self.patterns.push((regex, to.into()));
56 Ok(self)
57 }
58
59 pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {
61 Self { patterns }
62 }
63
64 pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>
75 where
76 S1: AsRef<str>,
77 S2: Into<String>,
78 {
79 let mut compiled_patterns = Vec::new();
80 for (pattern, replacement) in patterns {
81 let regex = Regex::new(pattern.as_ref())?;
82 compiled_patterns.push((regex, replacement.into()));
83 }
84 Ok(Self {
85 patterns: compiled_patterns,
86 })
87 }
88
89 pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>
100 where
101 I: IntoIterator<Item = (S1, S2)>,
102 S1: AsRef<str>,
103 S2: Into<String>,
104 {
105 let patterns: Result<Vec<_>, _> = iter
106 .into_iter()
107 .map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))
108 .collect();
109 Ok(Self {
110 patterns: patterns?,
111 })
112 }
113
114 pub fn is_empty(&self) -> bool {
116 self.patterns.is_empty()
117 }
118
119 pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {
121 self.patterns.clone()
122 }
123
124 pub fn remap(
136 &self,
137 mut tensors: Vec<TensorSnapshot>,
138 ) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
139 if self.patterns.is_empty() {
140 let remapped_names = tensors
141 .iter()
142 .map(|v| {
143 let path = v.full_path();
144 (path.clone(), path)
145 })
146 .collect();
147 return (tensors, remapped_names);
148 }
149
150 let mut remapped_snapshots = Vec::new();
151 let mut remapped_names = Vec::new();
152
153 for mut snapshot in tensors.drain(..) {
154 let original_path = snapshot.full_path();
155 let mut new_path = original_path.clone();
156
157 for (pattern, replacement) in &self.patterns {
159 if pattern.is_match(&new_path) {
160 new_path = pattern
161 .replace_all(&new_path, replacement.as_str())
162 .to_string();
163 }
164 }
165
166 if new_path != original_path
168 && let Some(ref mut path_stack) = snapshot.path_stack
169 {
170 *path_stack = new_path.split('.').map(|s| s.to_string()).collect();
171 }
172
173 remapped_names.push((new_path.clone(), original_path));
174 remapped_snapshots.push(snapshot);
175 }
176
177 (remapped_snapshots, remapped_names)
178 }
179}
180
181pub fn map_indices_contiguous(
226 mut tensors: Vec<TensorSnapshot>,
227) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
228 if tensors.is_empty() {
229 return (tensors, Vec::new());
230 }
231
232 let mut index_maps: BTreeMap<String, BTreeMap<usize, usize>> = BTreeMap::new();
239
240 for snapshot in &tensors {
242 let path = snapshot.full_path();
243 let parts: Vec<&str> = path.split('.').collect();
244
245 for (i, part) in parts.iter().enumerate() {
247 if let Ok(index) = part.parse::<usize>() {
248 let prefix = if i > 0 {
250 format!("{}.", parts[..i].join("."))
251 } else {
252 String::new()
253 };
254
255 index_maps
256 .entry(prefix)
257 .or_default()
258 .entry(index)
259 .or_insert(usize::MAX); }
261 }
262 }
263
264 for indices in index_maps.values_mut() {
266 let mut sorted_indices: Vec<usize> = indices.keys().cloned().collect();
267 sorted_indices.sort();
268
269 for (new_idx, old_idx) in sorted_indices.into_iter().enumerate() {
270 indices.insert(old_idx, new_idx);
271 }
272 }
273
274 let mut mapped_snapshots = Vec::new();
277 let mut transformations = Vec::new();
278
279 for mut snapshot in tensors.drain(..) {
280 let original_path = snapshot.full_path();
281 let new_path = remap_all_indices_with_original_prefix(&original_path, &index_maps);
282
283 if new_path != original_path
285 && let Some(ref mut path_stack) = snapshot.path_stack
286 {
287 *path_stack = new_path.split('.').map(|s| s.to_string()).collect();
288 }
289
290 transformations.push((new_path, original_path));
291 mapped_snapshots.push(snapshot);
292 }
293
294 (mapped_snapshots, transformations)
295}
296
297fn remap_all_indices_with_original_prefix(
300 path: &str,
301 index_maps: &BTreeMap<String, BTreeMap<usize, usize>>,
302) -> String {
303 let parts: Vec<&str> = path.split('.').collect();
304 let mut result_parts: Vec<String> = Vec::with_capacity(parts.len());
305
306 for (i, part) in parts.iter().enumerate() {
307 if let Ok(index) = part.parse::<usize>() {
308 let prefix = if i > 0 {
310 format!("{}.", parts[..i].join("."))
311 } else {
312 String::new()
313 };
314
315 if let Some(index_map) = index_maps.get(&prefix)
317 && let Some(&new_index) = index_map.get(&index)
318 {
319 result_parts.push(new_index.to_string());
320 continue;
321 }
322 }
323 result_parts.push((*part).to_string());
325 }
326
327 result_parts.join(".")
328}
329
330#[cfg(all(test, feature = "std"))]
331mod tests {
332 use super::*;
333 use burn_core::module::ParamId;
334 use burn_tensor::TensorData;
335
336 fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {
337 let data = TensorData {
338 bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),
339 shape: vec![2, 2],
340 dtype: burn_tensor::DType::F32,
341 };
342 let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();
343 TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new())
344 }
345
346 #[test]
347 fn test_key_remapper_basic() {
348 let remapper = KeyRemapper::new()
349 .add_pattern(r"^encoder\.", "transformer.encoder.")
350 .expect("valid regex");
351
352 let tensors = vec![
353 create_test_tensor_snapshot("encoder.layer1.weight"),
354 create_test_tensor_snapshot("decoder.layer1.weight"),
355 ];
356
357 let (remapped, transformations) = remapper.remap(tensors);
358
359 assert!(
361 remapped
362 .iter()
363 .any(|v| v.full_path() == "transformer.encoder.layer1.weight")
364 );
365 assert!(
366 remapped
367 .iter()
368 .any(|v| v.full_path() == "decoder.layer1.weight")
369 );
370 assert_eq!(remapped.len(), 2);
371
372 let encoder_transform = transformations
374 .iter()
375 .find(|(_new, old)| old == "encoder.layer1.weight")
376 .expect("should find encoder transformation");
377 assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight");
378 }
379
380 #[test]
381 fn test_key_remapper_multiple_patterns() {
382 let remapper = KeyRemapper::new()
383 .add_pattern(r"^encoder\.", "transformer.encoder.")
384 .expect("valid regex")
385 .add_pattern(r"\.gamma$", ".weight")
386 .expect("valid regex");
387
388 let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")];
389
390 let (remapped, _) = remapper.remap(tensors);
391
392 assert!(
393 remapped
394 .iter()
395 .any(|v| v.full_path() == "transformer.encoder.layer1.weight")
396 );
397 assert_eq!(remapped.len(), 1);
398 }
399
400 #[test]
401 fn test_key_remapper_from_patterns() {
402 let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")];
403 let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns");
404
405 let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")];
406
407 let (remapped, _) = remapper.remap(tensors);
408
409 assert!(
410 remapped
411 .iter()
412 .any(|v| v.full_path() == "burn.linear.bias_param")
413 );
414 }
415
416 #[test]
417 fn test_key_remapper_empty() {
418 let remapper = KeyRemapper::new();
419 assert!(remapper.is_empty());
420
421 let tensors = vec![create_test_tensor_snapshot("test.weight")];
422
423 let (remapped, transformations) = remapper.remap(tensors);
424
425 assert!(remapped.iter().any(|v| v.full_path() == "test.weight"));
426 assert_eq!(remapped.len(), 1);
427 assert_eq!(transformations.len(), 1);
428 assert_eq!(
429 transformations[0],
430 ("test.weight".to_string(), "test.weight".to_string())
431 );
432 }
433
434 #[test]
435 fn test_map_indices_contiguous_basic() {
436 let tensors = vec![
439 create_test_tensor_snapshot("fc.0.weight"),
440 create_test_tensor_snapshot("fc.0.bias"),
441 create_test_tensor_snapshot("fc.2.weight"),
442 create_test_tensor_snapshot("fc.2.bias"),
443 create_test_tensor_snapshot("fc.4.weight"),
444 create_test_tensor_snapshot("fc.4.bias"),
445 ];
446
447 let (reindexed, transformations) = map_indices_contiguous(tensors);
448
449 assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
451 assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.bias"));
452 assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
453 assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.bias"));
454 assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
455 assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.bias"));
456 assert_eq!(reindexed.len(), 6);
457
458 let transform_2_to_1 = transformations
460 .iter()
461 .find(|(_, old)| old == "fc.2.weight")
462 .expect("should find fc.2.weight transformation");
463 assert_eq!(transform_2_to_1.0, "fc.1.weight");
464
465 let transform_4_to_2 = transformations
466 .iter()
467 .find(|(_, old)| old == "fc.4.weight")
468 .expect("should find fc.4.weight transformation");
469 assert_eq!(transform_4_to_2.0, "fc.2.weight");
470 }
471
472 #[test]
473 fn test_map_indices_contiguous_already_contiguous() {
474 let tensors = vec![
476 create_test_tensor_snapshot("fc.0.weight"),
477 create_test_tensor_snapshot("fc.1.weight"),
478 create_test_tensor_snapshot("fc.2.weight"),
479 ];
480
481 let (reindexed, transformations) = map_indices_contiguous(tensors);
482
483 assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
484 assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
485 assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
486 assert_eq!(reindexed.len(), 3);
487
488 for (new, old) in &transformations {
490 assert_eq!(new, old);
491 }
492 }
493
494 #[test]
495 fn test_map_indices_contiguous_multiple_prefixes() {
496 let tensors = vec![
498 create_test_tensor_snapshot("encoder.0.weight"),
499 create_test_tensor_snapshot("encoder.2.weight"),
500 create_test_tensor_snapshot("decoder.1.weight"),
501 create_test_tensor_snapshot("decoder.5.weight"),
502 ];
503
504 let (reindexed, _) = map_indices_contiguous(tensors);
505
506 assert!(
508 reindexed
509 .iter()
510 .any(|v| v.full_path() == "encoder.0.weight")
511 );
512 assert!(
513 reindexed
514 .iter()
515 .any(|v| v.full_path() == "encoder.1.weight")
516 );
517
518 assert!(
520 reindexed
521 .iter()
522 .any(|v| v.full_path() == "decoder.0.weight")
523 );
524 assert!(
525 reindexed
526 .iter()
527 .any(|v| v.full_path() == "decoder.1.weight")
528 );
529 }
530
531 #[test]
532 fn test_map_indices_contiguous_no_indices() {
533 let tensors = vec![
535 create_test_tensor_snapshot("encoder.weight"),
536 create_test_tensor_snapshot("decoder.bias"),
537 ];
538
539 let (reindexed, transformations) = map_indices_contiguous(tensors);
540
541 assert!(reindexed.iter().any(|v| v.full_path() == "encoder.weight"));
542 assert!(reindexed.iter().any(|v| v.full_path() == "decoder.bias"));
543
544 for (new, old) in &transformations {
545 assert_eq!(new, old);
546 }
547 }
548
549 #[test]
550 fn test_map_indices_contiguous_empty() {
551 let tensors: Vec<TensorSnapshot> = vec![];
552 let (reindexed, transformations) = map_indices_contiguous(tensors);
553
554 assert!(reindexed.is_empty());
555 assert!(transformations.is_empty());
556 }
557
558 #[test]
559 fn test_map_indices_contiguous_mixed_indexed_and_non_indexed() {
560 let tensors = vec![
562 create_test_tensor_snapshot("fc.0.weight"),
563 create_test_tensor_snapshot("fc.2.weight"),
564 create_test_tensor_snapshot("output.weight"), ];
566
567 let (reindexed, _) = map_indices_contiguous(tensors);
568
569 assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
570 assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "output.weight")); }
573
574 #[test]
575 fn test_map_indices_contiguous_nested_sequential() {
576 let tensors = vec![
586 create_test_tensor_snapshot("feature.layers.0.conv_block.0.weight"),
587 create_test_tensor_snapshot("feature.layers.0.conv_block.2.weight"),
588 create_test_tensor_snapshot("feature.layers.2.conv_block.0.weight"),
589 create_test_tensor_snapshot("feature.layers.2.conv_block.2.weight"),
590 ];
591
592 let (mapped, transformations) = map_indices_contiguous(tensors);
593
594 assert!(
606 mapped
607 .iter()
608 .any(|v| v.full_path() == "feature.layers.0.conv_block.0.weight"),
609 "0.0 should stay as 0.0"
610 );
611 assert!(
612 mapped
613 .iter()
614 .any(|v| v.full_path() == "feature.layers.0.conv_block.1.weight"),
615 "0.2 should become 0.1"
616 );
617 assert!(
618 mapped
619 .iter()
620 .any(|v| v.full_path() == "feature.layers.1.conv_block.0.weight"),
621 "2.0 should become 1.0"
622 );
623 assert!(
624 mapped
625 .iter()
626 .any(|v| v.full_path() == "feature.layers.1.conv_block.1.weight"),
627 "2.2 should become 1.1"
628 );
629
630 let t1 = transformations
632 .iter()
633 .find(|(_, old)| old == "feature.layers.2.conv_block.2.weight");
634 assert_eq!(
635 t1.map(|(new, _)| new.as_str()),
636 Some("feature.layers.1.conv_block.1.weight"),
637 "2.2 should map to 1.1"
638 );
639 }
640
641 #[test]
642 fn test_map_indices_contiguous_deeply_nested() {
643 let tensors = vec![
645 create_test_tensor_snapshot("a.0.b.0.c.0.weight"),
646 create_test_tensor_snapshot("a.0.b.0.c.2.weight"),
647 create_test_tensor_snapshot("a.0.b.2.c.0.weight"),
648 create_test_tensor_snapshot("a.2.b.0.c.0.weight"),
649 ];
650
651 let (mapped, _) = map_indices_contiguous(tensors);
652
653 assert!(mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.0.weight"));
661 assert!(
662 mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.1.weight"),
663 "a.0.b.0.c.2 should become a.0.b.0.c.1"
664 );
665 assert!(
666 mapped.iter().any(|v| v.full_path() == "a.0.b.1.c.0.weight"),
667 "a.0.b.2.c.0 should become a.0.b.1.c.0"
668 );
669 assert!(
670 mapped.iter().any(|v| v.full_path() == "a.1.b.0.c.0.weight"),
671 "a.2.b.0.c.0 should become a.1.b.0.c.0"
672 );
673 }
674}