burn_store/
keyremapper.rs

1use alloc::string::{String, ToString};
2use alloc::vec::Vec;
3
4use regex::{self, Regex};
5
6use crate::TensorSnapshot;
7
8/// Key remapper for transforming tensor names.
9///
10/// This allows mapping tensor names from one naming convention to another,
11/// which is useful for loading models from different frameworks or versions.
12///
13/// # Examples
14///
15/// ```rust,no_run
16/// # use burn_store::KeyRemapper;
17/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
18/// // Create a key remapper
19/// let remapper = KeyRemapper::new()
20///     .add_pattern(r"^pytorch\.(.*)", "burn.$1")?  // pytorch.layer -> burn.layer
21///     .add_pattern(r"\.gamma$", ".weight")?;       // layer.gamma -> layer.weight
22///
23/// // Use remapper with stores
24/// // store.remap(remapper)
25/// # Ok(())
26/// # }
27/// ```
28#[derive(Debug, Clone, Default)]
29pub struct KeyRemapper {
30    /// Pattern-based remapping rules (regex pattern, replacement string)
31    pub patterns: Vec<(Regex, String)>,
32}
33
34impl KeyRemapper {
35    /// Create a new empty key remapper
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Add a remapping pattern (compiles regex)
41    ///
42    /// # Arguments
43    ///
44    /// * `from` - Source pattern (regex string)
45    /// * `to` - Replacement string (can include capture groups like `$1`)
46    ///
47    /// # Returns
48    ///
49    /// * `Ok(Self)` - Updated remapping configuration
50    /// * `Err(regex::Error)` - If regex compilation fails
51    pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>
52    where
53        S1: AsRef<str>,
54        S2: Into<String>,
55    {
56        let regex = Regex::new(from.as_ref())?;
57        self.patterns.push((regex, to.into()));
58        Ok(self)
59    }
60
61    /// Create from a list of compiled regex patterns
62    pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {
63        Self { patterns }
64    }
65
66    /// Create from string patterns (will compile to regex)
67    ///
68    /// # Arguments
69    ///
70    /// * `patterns` - Vector of (pattern, replacement) tuples
71    ///
72    /// # Returns
73    ///
74    /// * `Ok(Self)` - New remapping configuration
75    /// * `Err(regex::Error)` - If any regex compilation fails
76    pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>
77    where
78        S1: AsRef<str>,
79        S2: Into<String>,
80    {
81        let mut compiled_patterns = Vec::new();
82        for (pattern, replacement) in patterns {
83            let regex = Regex::new(pattern.as_ref())?;
84            compiled_patterns.push((regex, replacement.into()));
85        }
86        Ok(Self {
87            patterns: compiled_patterns,
88        })
89    }
90
91    /// Create from an iterator of patterns
92    ///
93    /// # Arguments
94    ///
95    /// * `iter` - Iterator yielding (pattern, replacement) tuples
96    ///
97    /// # Returns
98    ///
99    /// * `Ok(Self)` - New remapping configuration
100    /// * `Err(regex::Error)` - If any regex compilation fails
101    pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>
102    where
103        I: IntoIterator<Item = (S1, S2)>,
104        S1: AsRef<str>,
105        S2: Into<String>,
106    {
107        let patterns: Result<Vec<_>, _> = iter
108            .into_iter()
109            .map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))
110            .collect();
111        Ok(Self {
112            patterns: patterns?,
113        })
114    }
115
116    /// Check if the remapping is empty
117    pub fn is_empty(&self) -> bool {
118        self.patterns.is_empty()
119    }
120
121    /// Convert to the format expected by remap_tensor_paths_with_patterns
122    pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {
123        self.patterns.clone()
124    }
125
126    /// Remap tensor paths using the configured patterns.
127    ///
128    /// # Arguments
129    ///
130    /// * `tensors` - Vec of TensorSnapshots to remap
131    ///
132    /// # Returns
133    ///
134    /// A tuple containing:
135    /// * The remapped Vec of TensorSnapshots with updated paths
136    /// * A vector of (new_path, original_path) showing the transformations
137    pub fn remap(
138        &self,
139        mut tensors: Vec<TensorSnapshot>,
140    ) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
141        if self.patterns.is_empty() {
142            let remapped_names = tensors
143                .iter()
144                .map(|v| {
145                    let path = v.full_path();
146                    (path.clone(), path)
147                })
148                .collect();
149            return (tensors, remapped_names);
150        }
151
152        let mut remapped_snapshots = Vec::new();
153        let mut remapped_names = Vec::new();
154
155        for mut snapshot in tensors.drain(..) {
156            let original_path = snapshot.full_path();
157            let mut new_path = original_path.clone();
158
159            // Apply all patterns to get the new path
160            for (pattern, replacement) in &self.patterns {
161                if pattern.is_match(&new_path) {
162                    new_path = pattern
163                        .replace_all(&new_path, replacement.as_str())
164                        .to_string();
165                }
166            }
167
168            // Update the snapshot's internal path_stack if the path changed
169            if new_path != original_path
170                && let Some(ref mut path_stack) = snapshot.path_stack
171            {
172                *path_stack = new_path.split('.').map(|s| s.to_string()).collect();
173            }
174
175            remapped_names.push((new_path.clone(), original_path));
176            remapped_snapshots.push(snapshot);
177        }
178
179        (remapped_snapshots, remapped_names)
180    }
181}
182
183#[cfg(all(test, feature = "std"))]
184mod tests {
185    use super::*;
186    use burn_core::module::ParamId;
187    use burn_tensor::TensorData;
188
189    fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {
190        let data = TensorData {
191            bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),
192            shape: vec![2, 2],
193            dtype: burn_tensor::DType::F32,
194        };
195        let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();
196        TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new())
197    }
198
199    #[test]
200    fn test_key_remapper_basic() {
201        let remapper = KeyRemapper::new()
202            .add_pattern(r"^encoder\.", "transformer.encoder.")
203            .expect("valid regex");
204
205        let tensors = vec![
206            create_test_tensor_snapshot("encoder.layer1.weight"),
207            create_test_tensor_snapshot("decoder.layer1.weight"),
208        ];
209
210        let (remapped, transformations) = remapper.remap(tensors);
211
212        // Check that remapped views exist with correct paths
213        assert!(
214            remapped
215                .iter()
216                .any(|v| v.full_path() == "transformer.encoder.layer1.weight")
217        );
218        assert!(
219            remapped
220                .iter()
221                .any(|v| v.full_path() == "decoder.layer1.weight")
222        );
223        assert_eq!(remapped.len(), 2);
224
225        // Check transformations
226        let encoder_transform = transformations
227            .iter()
228            .find(|(_new, old)| old == "encoder.layer1.weight")
229            .expect("should find encoder transformation");
230        assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight");
231    }
232
233    #[test]
234    fn test_key_remapper_multiple_patterns() {
235        let remapper = KeyRemapper::new()
236            .add_pattern(r"^encoder\.", "transformer.encoder.")
237            .expect("valid regex")
238            .add_pattern(r"\.gamma$", ".weight")
239            .expect("valid regex");
240
241        let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")];
242
243        let (remapped, _) = remapper.remap(tensors);
244
245        assert!(
246            remapped
247                .iter()
248                .any(|v| v.full_path() == "transformer.encoder.layer1.weight")
249        );
250        assert_eq!(remapped.len(), 1);
251    }
252
253    #[test]
254    fn test_key_remapper_from_patterns() {
255        let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")];
256        let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns");
257
258        let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")];
259
260        let (remapped, _) = remapper.remap(tensors);
261
262        assert!(
263            remapped
264                .iter()
265                .any(|v| v.full_path() == "burn.linear.bias_param")
266        );
267    }
268
269    #[test]
270    fn test_key_remapper_empty() {
271        let remapper = KeyRemapper::new();
272        assert!(remapper.is_empty());
273
274        let tensors = vec![create_test_tensor_snapshot("test.weight")];
275
276        let (remapped, transformations) = remapper.remap(tensors);
277
278        assert!(remapped.iter().any(|v| v.full_path() == "test.weight"));
279        assert_eq!(remapped.len(), 1);
280        assert_eq!(transformations.len(), 1);
281        assert_eq!(
282            transformations[0],
283            ("test.weight".to_string(), "test.weight".to_string())
284        );
285    }
286}