1use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6
7use burn_tensor::DType;
8
9#[derive(Debug, Clone)]
11pub enum ApplyError {
12 ShapeMismatch {
14 path: String,
16 expected: Vec<usize>,
18 found: Vec<usize>,
20 },
21 DTypeMismatch {
23 path: String,
25 expected: DType,
27 found: DType,
29 },
30 AdapterError {
32 path: String,
34 message: String,
36 },
37 LoadError {
39 path: String,
41 message: String,
43 },
44}
45
46impl core::fmt::Display for ApplyError {
47 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48 match self {
49 Self::ShapeMismatch {
50 path,
51 expected,
52 found,
53 } => {
54 write!(
55 f,
56 "Shape mismatch for '{}': expected {:?}, found {:?}",
57 path, expected, found
58 )
59 }
60 Self::DTypeMismatch {
61 path,
62 expected,
63 found,
64 } => {
65 write!(
66 f,
67 "DType mismatch for '{}': expected {:?}, found {:?}",
68 path, expected, found
69 )
70 }
71 Self::AdapterError { path, message } => {
72 write!(f, "Adapter error for '{}': {}", path, message)
73 }
74 Self::LoadError { path, message } => {
75 write!(f, "Load error for '{}': {}", path, message)
76 }
77 }
78 }
79}
80
81impl core::error::Error for ApplyError {}
82
83#[derive(Clone)]
85pub struct ApplyResult {
86 pub applied: Vec<String>,
88 pub skipped: Vec<String>,
90 pub missing: Vec<(String, String)>,
93 pub unused: Vec<String>,
95 pub errors: Vec<ApplyError>,
97}
98
99impl ApplyResult {
100 fn strip_enum_variant(path: &str) -> Option<String> {
103 let segments: Vec<&str> = path.split('.').collect();
104
105 let variant_indices: Vec<usize> = segments
107 .iter()
108 .enumerate()
109 .filter(|(i, segment)| {
110 *i > 0 && *i < segments.len() - 1 && !segment.is_empty()
112 && segment.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
113 && segment.len() > 1
114 && segment.chars().skip(1).any(|c| c.is_lowercase())
115 })
116 .map(|(i, _)| i)
117 .collect();
118
119 if variant_indices.is_empty() {
120 return None;
121 }
122
123 let mut result_segments = segments.clone();
125 result_segments.remove(variant_indices[0]);
126 Some(result_segments.join("."))
127 }
128
129 fn find_similar_paths(&self, missing_path: &str, max_suggestions: usize) -> Vec<String> {
131 if let Some(stripped) = Self::strip_enum_variant(missing_path)
133 && self.unused.contains(&stripped)
134 {
135 return vec![stripped];
136 }
137
138 let mut similarities: Vec<(String, f64)> = self
141 .unused
142 .iter()
143 .map(|available| {
144 let similarity = textdistance::nstr::jaro(missing_path, available);
145 (available.clone(), similarity)
146 })
147 .collect();
148
149 similarities
151 .sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Equal));
152
153 const SIMILARITY_THRESHOLD: f64 = 0.7;
155 similarities
156 .into_iter()
157 .filter(|(_, sim)| *sim >= SIMILARITY_THRESHOLD)
158 .take(max_suggestions)
159 .map(|(path, _)| path)
160 .collect()
161 }
162}
163
164impl ApplyResult {
165 pub fn is_success(&self) -> bool {
168 self.errors.is_empty()
169 }
170}
171
172impl core::fmt::Debug for ApplyResult {
173 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
174 core::fmt::Display::fmt(self, f)
176 }
177}
178
179impl core::fmt::Display for ApplyResult {
180 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
181 writeln!(f, "┌─ Tensor Loading Summary ─────────────────────────")?;
182 writeln!(f, "│")?;
183 writeln!(
184 f,
185 "│ ✓ Successfully applied: {} tensors",
186 self.applied.len()
187 )?;
188 writeln!(f, "│ ⊘ Skipped (filtered): {} tensors", self.skipped.len())?;
189 writeln!(
190 f,
191 "│ ✗ Missing in source: {} tensors",
192 self.missing.len()
193 )?;
194 writeln!(f, "│ ? Unused in target: {} tensors", self.unused.len())?;
195 writeln!(f, "│ ! Errors: {} errors", self.errors.len())?;
196
197 if !self.missing.is_empty() {
198 writeln!(f, "│")?;
199 writeln!(
200 f,
201 "├─ Missing Tensors (requested by model but not found in source)"
202 )?;
203 writeln!(f, "│")?;
204
205 let enum_variant_missing: Vec<_> = self
208 .missing
209 .iter()
210 .filter(|(_, stack)| stack.contains("Enum:"))
211 .collect();
212
213 if !enum_variant_missing.is_empty() {
214 writeln!(
215 f,
216 "│ ⚠️ {} paths contain enum variants (detected from container stack)",
217 enum_variant_missing.len()
218 )?;
219 writeln!(
220 f,
221 "│ Burn includes enum variant names in paths, but PyTorch doesn't."
222 )?;
223 writeln!(
224 f,
225 "│ Example: Burn has 'field.BaseConv.weight', PyTorch has 'field.weight'"
226 )?;
227 writeln!(f, "│")?;
228 writeln!(
229 f,
230 "│ 💡 Solution 1: Enable skip_enum_variants flag (simplest):"
231 )?;
232 writeln!(f, "│")?;
233 writeln!(
234 f,
235 "│ let mut store = PytorchStore::from_file(\"model.pth\")"
236 )?;
237 writeln!(f, "│ .skip_enum_variants(true); // ← Add this")?;
238 writeln!(f, "│")?;
239 writeln!(
240 f,
241 "│ 💡 Solution 2: Remap enum keys in source (most precise):"
242 )?;
243 writeln!(f, "│")?;
244 writeln!(
245 f,
246 "│ let mut store = SafetensorsStore::from_file(\"model.safetensors\")"
247 )?;
248 writeln!(
249 f,
250 "│ .with_key_remapping(r\"field\\.(\\w+)\", \"field.BaseConv.$1\");"
251 )?;
252 writeln!(f, "│")?;
253 }
254
255 writeln!(f, "│ First 10 missing tensors:")?;
256 for (path, _) in self.missing.iter().take(10) {
257 writeln!(f, "│ • {}", path)?;
258
259 let suggestions = self.find_similar_paths(path, 1);
261 if !suggestions.is_empty() {
262 writeln!(f, "│ Did you mean: '{}'?", suggestions[0])?;
263 }
264 }
265 if self.missing.len() > 10 {
266 writeln!(f, "│ ... and {} more", self.missing.len() - 10)?;
267 }
268 }
269
270 if !self.unused.is_empty() {
271 writeln!(f, "│")?;
272 writeln!(f, "├─ Unused Tensors (in source but not used by model)")?;
273 writeln!(f, "│")?;
274 writeln!(f, "│ First 10 unused tensors:")?;
275 for path in self.unused.iter().take(10) {
276 writeln!(f, "│ • {}", path)?;
277 }
278 if self.unused.len() > 10 {
279 writeln!(f, "│ ... and {} more", self.unused.len() - 10)?;
280 }
281 }
282
283 if !self.errors.is_empty() {
284 writeln!(f, "│")?;
285 writeln!(f, "├─ Errors")?;
286 writeln!(f, "│")?;
287 for error in self.errors.iter().take(10) {
288 writeln!(f, "│ ⚠️ {}", error)?;
289 }
290 if self.errors.len() > 10 {
291 writeln!(f, "│ ... and {} more", self.errors.len() - 10)?;
292 }
293 }
294
295 writeln!(f, "│")?;
296 write!(f, "└───────────────────────────────────────────────────")?;
297
298 Ok(())
299 }
300}