burn_store/
keyremapper.rs1use alloc::string::{String, ToString};
2use alloc::vec::Vec;
3
4use regex::{self, Regex};
5
6use crate::TensorSnapshot;
7
8#[derive(Debug, Clone, Default)]
29pub struct KeyRemapper {
30 pub patterns: Vec<(Regex, String)>,
32}
33
34impl KeyRemapper {
35 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn add_pattern<S1, S2>(mut self, from: S1, to: S2) -> Result<Self, regex::Error>
52 where
53 S1: AsRef<str>,
54 S2: Into<String>,
55 {
56 let regex = Regex::new(from.as_ref())?;
57 self.patterns.push((regex, to.into()));
58 Ok(self)
59 }
60
61 pub fn from_compiled_patterns(patterns: Vec<(Regex, String)>) -> Self {
63 Self { patterns }
64 }
65
66 pub fn from_patterns<S1, S2>(patterns: Vec<(S1, S2)>) -> Result<Self, regex::Error>
77 where
78 S1: AsRef<str>,
79 S2: Into<String>,
80 {
81 let mut compiled_patterns = Vec::new();
82 for (pattern, replacement) in patterns {
83 let regex = Regex::new(pattern.as_ref())?;
84 compiled_patterns.push((regex, replacement.into()));
85 }
86 Ok(Self {
87 patterns: compiled_patterns,
88 })
89 }
90
91 pub fn from_pattern_iter<I, S1, S2>(iter: I) -> Result<Self, regex::Error>
102 where
103 I: IntoIterator<Item = (S1, S2)>,
104 S1: AsRef<str>,
105 S2: Into<String>,
106 {
107 let patterns: Result<Vec<_>, _> = iter
108 .into_iter()
109 .map(|(from, to)| Ok((Regex::new(from.as_ref())?, to.into())))
110 .collect();
111 Ok(Self {
112 patterns: patterns?,
113 })
114 }
115
116 pub fn is_empty(&self) -> bool {
118 self.patterns.is_empty()
119 }
120
121 pub fn to_regex_pairs(&self) -> Vec<(Regex, String)> {
123 self.patterns.clone()
124 }
125
126 pub fn remap(
138 &self,
139 mut tensors: Vec<TensorSnapshot>,
140 ) -> (Vec<TensorSnapshot>, Vec<(String, String)>) {
141 if self.patterns.is_empty() {
142 let remapped_names = tensors
143 .iter()
144 .map(|v| {
145 let path = v.full_path();
146 (path.clone(), path)
147 })
148 .collect();
149 return (tensors, remapped_names);
150 }
151
152 let mut remapped_snapshots = Vec::new();
153 let mut remapped_names = Vec::new();
154
155 for mut snapshot in tensors.drain(..) {
156 let original_path = snapshot.full_path();
157 let mut new_path = original_path.clone();
158
159 for (pattern, replacement) in &self.patterns {
161 if pattern.is_match(&new_path) {
162 new_path = pattern
163 .replace_all(&new_path, replacement.as_str())
164 .to_string();
165 }
166 }
167
168 if new_path != original_path
170 && let Some(ref mut path_stack) = snapshot.path_stack
171 {
172 *path_stack = new_path.split('.').map(|s| s.to_string()).collect();
173 }
174
175 remapped_names.push((new_path.clone(), original_path));
176 remapped_snapshots.push(snapshot);
177 }
178
179 (remapped_snapshots, remapped_names)
180 }
181}
182
183#[cfg(all(test, feature = "std"))]
184mod tests {
185 use super::*;
186 use burn_core::module::ParamId;
187 use burn_tensor::TensorData;
188
189 fn create_test_tensor_snapshot(name: &str) -> TensorSnapshot {
190 let data = TensorData {
191 bytes: burn_tensor::Bytes::from_bytes_vec(vec![1, 2, 3, 4]),
192 shape: vec![2, 2],
193 dtype: burn_tensor::DType::F32,
194 };
195 let path_parts: Vec<String> = name.split('.').map(|s| s.to_string()).collect();
196 TensorSnapshot::from_data(data, path_parts, vec!["Test".to_string()], ParamId::new())
197 }
198
199 #[test]
200 fn test_key_remapper_basic() {
201 let remapper = KeyRemapper::new()
202 .add_pattern(r"^encoder\.", "transformer.encoder.")
203 .expect("valid regex");
204
205 let tensors = vec![
206 create_test_tensor_snapshot("encoder.layer1.weight"),
207 create_test_tensor_snapshot("decoder.layer1.weight"),
208 ];
209
210 let (remapped, transformations) = remapper.remap(tensors);
211
212 assert!(
214 remapped
215 .iter()
216 .any(|v| v.full_path() == "transformer.encoder.layer1.weight")
217 );
218 assert!(
219 remapped
220 .iter()
221 .any(|v| v.full_path() == "decoder.layer1.weight")
222 );
223 assert_eq!(remapped.len(), 2);
224
225 let encoder_transform = transformations
227 .iter()
228 .find(|(_new, old)| old == "encoder.layer1.weight")
229 .expect("should find encoder transformation");
230 assert_eq!(encoder_transform.0, "transformer.encoder.layer1.weight");
231 }
232
233 #[test]
234 fn test_key_remapper_multiple_patterns() {
235 let remapper = KeyRemapper::new()
236 .add_pattern(r"^encoder\.", "transformer.encoder.")
237 .expect("valid regex")
238 .add_pattern(r"\.gamma$", ".weight")
239 .expect("valid regex");
240
241 let tensors = vec![create_test_tensor_snapshot("encoder.layer1.gamma")];
242
243 let (remapped, _) = remapper.remap(tensors);
244
245 assert!(
246 remapped
247 .iter()
248 .any(|v| v.full_path() == "transformer.encoder.layer1.weight")
249 );
250 assert_eq!(remapped.len(), 1);
251 }
252
253 #[test]
254 fn test_key_remapper_from_patterns() {
255 let patterns = vec![(r"^pytorch\.", "burn."), (r"\.bias$", ".bias_param")];
256 let remapper = KeyRemapper::from_patterns(patterns).expect("valid patterns");
257
258 let tensors = vec![create_test_tensor_snapshot("pytorch.linear.bias")];
259
260 let (remapped, _) = remapper.remap(tensors);
261
262 assert!(
263 remapped
264 .iter()
265 .any(|v| v.full_path() == "burn.linear.bias_param")
266 );
267 }
268
269 #[test]
270 fn test_key_remapper_empty() {
271 let remapper = KeyRemapper::new();
272 assert!(remapper.is_empty());
273
274 let tensors = vec![create_test_tensor_snapshot("test.weight")];
275
276 let (remapped, transformations) = remapper.remap(tensors);
277
278 assert!(remapped.iter().any(|v| v.full_path() == "test.weight"));
279 assert_eq!(remapped.len(), 1);
280 assert_eq!(transformations.len(), 1);
281 assert_eq!(
282 transformations[0],
283 ("test.weight".to_string(), "test.weight".to_string())
284 );
285 }
286}