use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use regex::{self, Regex};
use crate::TensorSnapshot;
#[derive(Debug, Clone, Default)]
pub struct KeyRemapper {
pub patterns: Vec<(Regex, String)>,
}
impl KeyRemapper {
pub fn new() -> Self {
Self::default()
}
pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>
where
S1: AsRef<str>,
S2: Into<String>,
{
let regex = Regex::new(from.as_ref())?;
self.patterns.push((regex, to.into()));
Ok(self)
}
pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {
Self { patterns }
}
pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>
where
S1: AsRef<str>,
S2: Into<String>,
{
let mut compiled_patterns = Vec::new();
for (pattern, replacement) in patterns {
let regex = Regex::new(pattern.as_ref())?;
compiled_patterns.push((regex, replacement.into()));
}
Ok(Self {
patterns: compiled_patterns,
})
}
pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>
where
I: IntoIterator<Item = (S1, S2)>,
S1: AsRef<str>,
S2: Into<String>,
{
let patterns: Result<Vec<_>, _> = iter
.into_iter()
.map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))
.collect();
Ok(Self {
patterns: patterns?,
})
}
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {
self.patterns.clone()
}
pub fn remap(
&self,
mut tensors: Vec<TensorSnapshot>,
) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
if self.patterns.is_empty() {
let remapped_names = tensors
.iter()
.map(|v| {
let path = v.full_path();
(path.clone(), path)
})
.collect();
return (tensors, remapped_names);
}
let mut remapped_snapshots = Vec::new();
let mut remapped_names = Vec::new();
for mut snapshot in tensors.drain(..) {
let original_path = snapshot.full_path();
let mut new_path = original_path.clone();
for (pattern, replacement) in &self.patterns {
if pattern.is_match(&new_path) {
new_path = pattern
.replace_all(&new_path, replacement.as_str())
.to_string();
}
}
if new_path != original_path
&& let Some(ref mut path_stack) = snapshot.path_stack
{
*path_stack = new_path.split('.').map(|s| s.to_string()).collect();
}
remapped_names.push((new_path.clone(), original_path));
remapped_snapshots.push(snapshot);
}
(remapped_snapshots, remapped_names)
}
}
pub fn map_indices_contiguous(
mut tensors: Vec<TensorSnapshot>,
) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
if tensors.is_empty() {
return (tensors, Vec::new());
}
let mut index_maps: BTreeMap<String, BTreeMap<usize, usize>> = BTreeMap::new();
for snapshot in &tensors {
let path = snapshot.full_path();
let parts: Vec<&str> = path.split('.').collect();
for (i, part) in parts.iter().enumerate() {
if let Ok(index) = part.parse::<usize>() {
let prefix = if i > 0 {
format!("{}.", parts[..i].join("."))
} else {
String::new()
};
index_maps
.entry(prefix)
.or_default()
.entry(index)
.or_insert(usize::MAX); }
}
}
for indices in index_maps.values_mut() {
let mut sorted_indices: Vec<usize> = indices.keys().cloned().collect();
sorted_indices.sort();
for (new_idx, old_idx) in sorted_indices.into_iter().enumerate() {
indices.insert(old_idx, new_idx);
}
}
let mut mapped_snapshots = Vec::new();
let mut transformations = Vec::new();
for mut snapshot in tensors.drain(..) {
let original_path = snapshot.full_path();
let new_path = remap_all_indices_with_original_prefix(&original_path, &index_maps);
if new_path != original_path
&& let Some(ref mut path_stack) = snapshot.path_stack
{
*path_stack = new_path.split('.').map(|s| s.to_string()).collect();
}
transformations.push((new_path, original_path));
mapped_snapshots.push(snapshot);
}
(mapped_snapshots, transformations)
}
fn remap_all_indices_with_original_prefix(
path: &str,
index_maps: &BTreeMap<String, BTreeMap<usize, usize>>,
) -> String {
let parts: Vec<&str> = path.split('.').collect();
let mut result_parts: Vec<String> = Vec::with_capacity(parts.len());
for (i, part) in parts.iter().enumerate() {
if let Ok(index) = part.parse::<usize>() {
let prefix = if i > 0 {
format!("{}.", parts[..i].join("."))
} else {
String::new()
};
if let Some(index_map) = index_maps.get(&prefix)
&& let Some(&new_index) = index_map.get(&index)
{
result_parts.push(new_index.to_string());
continue;
}
}
result_parts.push((*part).to_string());
}
result_parts.join(".")
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use burn_core::module::ParamId;
use burn_tensor::{TensorData, shape};
fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {
let data = TensorData {
bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),
shape: shape![2, 2],
dtype: burn_tensor::DType::F32,
};
let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();
TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new())
}
#[test]
fn test_key_remapper_basic() {
let remapper = KeyRemapper::new()
.add_pattern(r"^encoder\.", "transformer.encoder.")
.expect("valid regex");
let tensors = vec![
create_test_tensor_snapshot("encoder.layer1.weight"),
create_test_tensor_snapshot("decoder.layer1.weight"),
];
let (remapped, transformations) = remapper.remap(tensors);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "transformer.encoder.layer1.weight")
);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "decoder.layer1.weight")
);
assert_eq!(remapped.len(), 2);
let encoder_transform = transformations
.iter()
.find(|(_new, old)| old == "encoder.layer1.weight")
.expect("should find encoder transformation");
assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight");
}
#[test]
fn test_key_remapper_multiple_patterns() {
let remapper = KeyRemapper::new()
.add_pattern(r"^encoder\.", "transformer.encoder.")
.expect("valid regex")
.add_pattern(r"\.gamma$", ".weight")
.expect("valid regex");
let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")];
let (remapped, _) = remapper.remap(tensors);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "transformer.encoder.layer1.weight")
);
assert_eq!(remapped.len(), 1);
}
#[test]
fn test_key_remapper_from_patterns() {
let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")];
let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns");
let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")];
let (remapped, _) = remapper.remap(tensors);
assert!(
remapped
.iter()
.any(|v| v.full_path() == "burn.linear.bias_param")
);
}
#[test]
fn test_key_remapper_empty() {
let remapper = KeyRemapper::new();
assert!(remapper.is_empty());
let tensors = vec![create_test_tensor_snapshot("test.weight")];
let (remapped, transformations) = remapper.remap(tensors);
assert!(remapped.iter().any(|v| v.full_path() == "test.weight"));
assert_eq!(remapped.len(), 1);
assert_eq!(transformations.len(), 1);
assert_eq!(
transformations[0],
("test.weight".to_string(), "test.weight".to_string())
);
}
#[test]
fn test_map_indices_contiguous_basic() {
let tensors = vec![
create_test_tensor_snapshot("fc.0.weight"),
create_test_tensor_snapshot("fc.0.bias"),
create_test_tensor_snapshot("fc.2.weight"),
create_test_tensor_snapshot("fc.2.bias"),
create_test_tensor_snapshot("fc.4.weight"),
create_test_tensor_snapshot("fc.4.bias"),
];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.bias"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.bias"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.bias"));
assert_eq!(reindexed.len(), 6);
let transform_2_to_1 = transformations
.iter()
.find(|(_, old)| old == "fc.2.weight")
.expect("should find fc.2.weight transformation");
assert_eq!(transform_2_to_1.0, "fc.1.weight");
let transform_4_to_2 = transformations
.iter()
.find(|(_, old)| old == "fc.4.weight")
.expect("should find fc.4.weight transformation");
assert_eq!(transform_4_to_2.0, "fc.2.weight");
}
#[test]
fn test_map_indices_contiguous_already_contiguous() {
let tensors = vec![
create_test_tensor_snapshot("fc.0.weight"),
create_test_tensor_snapshot("fc.1.weight"),
create_test_tensor_snapshot("fc.2.weight"),
];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.2.weight"));
assert_eq!(reindexed.len(), 3);
for (new, old) in &transformations {
assert_eq!(new, old);
}
}
#[test]
fn test_map_indices_contiguous_multiple_prefixes() {
let tensors = vec![
create_test_tensor_snapshot("encoder.0.weight"),
create_test_tensor_snapshot("encoder.2.weight"),
create_test_tensor_snapshot("decoder.1.weight"),
create_test_tensor_snapshot("decoder.5.weight"),
];
let (reindexed, _) = map_indices_contiguous(tensors);
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "encoder.0.weight")
);
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "encoder.1.weight")
);
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "decoder.0.weight")
);
assert!(
reindexed
.iter()
.any(|v| v.full_path() == "decoder.1.weight")
);
}
#[test]
fn test_map_indices_contiguous_no_indices() {
let tensors = vec![
create_test_tensor_snapshot("encoder.weight"),
create_test_tensor_snapshot("decoder.bias"),
];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "encoder.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "decoder.bias"));
for (new, old) in &transformations {
assert_eq!(new, old);
}
}
#[test]
fn test_map_indices_contiguous_empty() {
let tensors: Vec<TensorSnapshot> = vec![];
let (reindexed, transformations) = map_indices_contiguous(tensors);
assert!(reindexed.is_empty());
assert!(transformations.is_empty());
}
#[test]
fn test_map_indices_contiguous_mixed_indexed_and_non_indexed() {
let tensors = vec![
create_test_tensor_snapshot("fc.0.weight"),
create_test_tensor_snapshot("fc.2.weight"),
create_test_tensor_snapshot("output.weight"), ];
let (reindexed, _) = map_indices_contiguous(tensors);
assert!(reindexed.iter().any(|v| v.full_path() == "fc.0.weight"));
assert!(reindexed.iter().any(|v| v.full_path() == "fc.1.weight")); assert!(reindexed.iter().any(|v| v.full_path() == "output.weight")); }
#[test]
fn test_map_indices_contiguous_nested_sequential() {
let tensors = vec![
create_test_tensor_snapshot("feature.layers.0.conv_block.0.weight"),
create_test_tensor_snapshot("feature.layers.0.conv_block.2.weight"),
create_test_tensor_snapshot("feature.layers.2.conv_block.0.weight"),
create_test_tensor_snapshot("feature.layers.2.conv_block.2.weight"),
];
let (mapped, transformations) = map_indices_contiguous(tensors);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.0.conv_block.0.weight"),
"0.0 should stay as 0.0"
);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.0.conv_block.1.weight"),
"0.2 should become 0.1"
);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.1.conv_block.0.weight"),
"2.0 should become 1.0"
);
assert!(
mapped
.iter()
.any(|v| v.full_path() == "feature.layers.1.conv_block.1.weight"),
"2.2 should become 1.1"
);
let t1 = transformations
.iter()
.find(|(_, old)| old == "feature.layers.2.conv_block.2.weight");
assert_eq!(
t1.map(|(new, _)| new.as_str()),
Some("feature.layers.1.conv_block.1.weight"),
"2.2 should map to 1.1"
);
}
#[test]
fn test_map_indices_contiguous_deeply_nested() {
let tensors = vec![
create_test_tensor_snapshot("a.0.b.0.c.0.weight"),
create_test_tensor_snapshot("a.0.b.0.c.2.weight"),
create_test_tensor_snapshot("a.0.b.2.c.0.weight"),
create_test_tensor_snapshot("a.2.b.0.c.0.weight"),
];
let (mapped, _) = map_indices_contiguous(tensors);
assert!(mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.0.weight"));
assert!(
mapped.iter().any(|v| v.full_path() == "a.0.b.0.c.1.weight"),
"a.0.b.0.c.2 should become a.0.b.0.c.1"
);
assert!(
mapped.iter().any(|v| v.full_path() == "a.0.b.1.c.0.weight"),
"a.0.b.2.c.0 should become a.0.b.1.c.0"
);
assert!(
mapped.iter().any(|v| v.full_path() == "a.1.b.0.c.0.weight"),
"a.2.b.0.c.0 should become a.1.b.0.c.0"
);
}
}