burn_store/
apply_result.rs

1//! Result types and diagnostics for tensor application operations
2
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6
7use burn_tensor::DType;
8
9/// Error types that can occur during tensor application
10#[derive(Debug, Clone)]
11pub enum ApplyError {
12    /// Shape mismatch between expected and actual tensor
13    ShapeMismatch {
14        /// Path of the tensor
15        path: String,
16        /// Expected shape
17        expected: Vec<usize>,
18        /// Found shape
19        found: Vec<usize>,
20    },
21    /// Data type mismatch between expected and actual tensor
22    DTypeMismatch {
23        /// Path of the tensor
24        path: String,
25        /// Expected data type
26        expected: DType,
27        /// Found data type
28        found: DType,
29    },
30    /// Error from adapter transformation
31    AdapterError {
32        /// Path of the tensor
33        path: String,
34        /// Error message
35        message: String,
36    },
37    /// Error loading tensor data
38    LoadError {
39        /// Path of the tensor
40        path: String,
41        /// Error message
42        message: String,
43    },
44}
45
46impl core::fmt::Display for ApplyError {
47    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48        match self {
49            Self::ShapeMismatch {
50                path,
51                expected,
52                found,
53            } => {
54                write!(
55                    f,
56                    "Shape mismatch for '{}': expected {:?}, found {:?}",
57                    path, expected, found
58                )
59            }
60            Self::DTypeMismatch {
61                path,
62                expected,
63                found,
64            } => {
65                write!(
66                    f,
67                    "DType mismatch for '{}': expected {:?}, found {:?}",
68                    path, expected, found
69                )
70            }
71            Self::AdapterError { path, message } => {
72                write!(f, "Adapter error for '{}': {}", path, message)
73            }
74            Self::LoadError { path, message } => {
75                write!(f, "Load error for '{}': {}", path, message)
76            }
77        }
78    }
79}
80
81impl core::error::Error for ApplyError {}
82
83/// Result of applying tensor snapshots to a module
84#[derive(Clone)]
85pub struct ApplyResult {
86    /// Successfully applied tensor paths
87    pub applied: Vec<String>,
88    /// Skipped tensor paths (due to filter)
89    pub skipped: Vec<String>,
90    /// Missing tensor paths with their container stacks in dot notation (path, container_stack)
91    /// Container stack shows the hierarchy: "Struct:Model.Struct:Linear" or "Struct:Model.Enum:ConvType.Struct:Linear"
92    pub missing: Vec<(String, String)>,
93    /// Unused tensor paths (in snapshots but not in module)
94    pub unused: Vec<String>,
95    /// Errors encountered during application
96    pub errors: Vec<ApplyError>,
97}
98
99impl ApplyResult {
100    /// Try to strip enum variant from a path
101    /// e.g., "field.BaseConv.weight" -> "field.weight"
102    fn strip_enum_variant(path: &str) -> Option<String> {
103        let segments: Vec<&str> = path.split('.').collect();
104
105        // Find segments that look like enum variants (CamelCase in middle of path)
106        let variant_indices: Vec<usize> = segments
107            .iter()
108            .enumerate()
109            .filter(|(i, segment)| {
110                *i > 0 && *i < segments.len() - 1 // Not first or last
111                    && !segment.is_empty()
112                    && segment.chars().next().map(|c| c.is_uppercase()).unwrap_or(false)
113                    && segment.len() > 1
114                    && segment.chars().skip(1).any(|c| c.is_lowercase())
115            })
116            .map(|(i, _)| i)
117            .collect();
118
119        if variant_indices.is_empty() {
120            return None;
121        }
122
123        // Remove the first found variant and return the modified path
124        let mut result_segments = segments.clone();
125        result_segments.remove(variant_indices[0]);
126        Some(result_segments.join("."))
127    }
128
129    /// Find similar paths for a given missing path (for "Did you mean?" suggestions)
130    fn find_similar_paths(&self, missing_path: &str, max_suggestions: usize) -> Vec<String> {
131        // First, try exact match with enum variant stripped
132        if let Some(stripped) = Self::strip_enum_variant(missing_path)
133            && self.unused.contains(&stripped)
134        {
135            return vec![stripped];
136        }
137
138        // Fall back to Jaro similarity (used by Elixir for "did you mean?" suggestions)
139        // Jaro gives higher weight to matching prefixes, ideal for hierarchical tensor paths
140        let mut similarities: Vec<(String, f64)> = self
141            .unused
142            .iter()
143            .map(|available| {
144                let similarity = textdistance::nstr::jaro(missing_path, available);
145                (available.clone(), similarity)
146            })
147            .collect();
148
149        // Sort by similarity (higher = more similar)
150        similarities
151            .sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(core::cmp::Ordering::Equal));
152
153        // Only suggest paths with >= 70% similarity
154        const SIMILARITY_THRESHOLD: f64 = 0.7;
155        similarities
156            .into_iter()
157            .filter(|(_, sim)| *sim >= SIMILARITY_THRESHOLD)
158            .take(max_suggestions)
159            .map(|(path, _)| path)
160            .collect()
161    }
162}
163
164impl ApplyResult {
165    /// Check if the apply operation was successful (no errors)
166    /// Note: Missing tensors are not considered errors when allow_partial is true
167    pub fn is_success(&self) -> bool {
168        self.errors.is_empty()
169    }
170}
171
172impl core::fmt::Debug for ApplyResult {
173    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
174        // Delegate to Display for comprehensive output
175        core::fmt::Display::fmt(self, f)
176    }
177}
178
179impl core::fmt::Display for ApplyResult {
180    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
181        writeln!(f, "┌─ Tensor Loading Summary ─────────────────────────")?;
182        writeln!(f, "│")?;
183        writeln!(
184            f,
185            "│ ✓ Successfully applied: {} tensors",
186            self.applied.len()
187        )?;
188        writeln!(f, "│ ⊘ Skipped (filtered):  {} tensors", self.skipped.len())?;
189        writeln!(
190            f,
191            "│ ✗ Missing in source:    {} tensors",
192            self.missing.len()
193        )?;
194        writeln!(f, "│ ? Unused in target:     {} tensors", self.unused.len())?;
195        writeln!(f, "│ ! Errors:               {} errors", self.errors.len())?;
196
197        if !self.missing.is_empty() {
198            writeln!(f, "│")?;
199            writeln!(
200                f,
201                "├─ Missing Tensors (requested by model but not found in source)"
202            )?;
203            writeln!(f, "│")?;
204
205            // Use actual container stack data to detect enum variants
206            // Count how many missing paths have "Enum:" in their container stack
207            let enum_variant_missing: Vec<_> = self
208                .missing
209                .iter()
210                .filter(|(_, stack)| stack.contains("Enum:"))
211                .collect();
212
213            if !enum_variant_missing.is_empty() {
214                writeln!(
215                    f,
216                    "│  ⚠️  {} paths contain enum variants (detected from container stack)",
217                    enum_variant_missing.len()
218                )?;
219                writeln!(
220                    f,
221                    "│      Burn includes enum variant names in paths, but PyTorch doesn't."
222                )?;
223                writeln!(
224                    f,
225                    "│      Example: Burn has 'field.BaseConv.weight', PyTorch has 'field.weight'"
226                )?;
227                writeln!(f, "│")?;
228                writeln!(
229                    f,
230                    "│      💡 Solution 1: Enable skip_enum_variants flag (simplest):"
231                )?;
232                writeln!(f, "│")?;
233                writeln!(
234                    f,
235                    "│         let mut store = PytorchStore::from_file(\"model.pth\")"
236                )?;
237                writeln!(f, "│             .skip_enum_variants(true);  // ← Add this")?;
238                writeln!(f, "│")?;
239                writeln!(
240                    f,
241                    "│      💡 Solution 2: Remap enum keys in source (most precise):"
242                )?;
243                writeln!(f, "│")?;
244                writeln!(
245                    f,
246                    "│         let mut store = SafetensorsStore::from_file(\"model.safetensors\")"
247                )?;
248                writeln!(
249                    f,
250                    "│             .with_key_remapping(r\"field\\.(\\w+)\", \"field.BaseConv.$1\");"
251                )?;
252                writeln!(f, "│")?;
253            }
254
255            writeln!(f, "│  First 10 missing tensors:")?;
256            for (path, _) in self.missing.iter().take(10) {
257                writeln!(f, "│    • {}", path)?;
258
259                // Show "Did you mean?" suggestions for this path
260                let suggestions = self.find_similar_paths(path, 1);
261                if !suggestions.is_empty() {
262                    writeln!(f, "│        Did you mean: '{}'?", suggestions[0])?;
263                }
264            }
265            if self.missing.len() > 10 {
266                writeln!(f, "│    ... and {} more", self.missing.len() - 10)?;
267            }
268        }
269
270        if !self.unused.is_empty() {
271            writeln!(f, "│")?;
272            writeln!(f, "├─ Unused Tensors (in source but not used by model)")?;
273            writeln!(f, "│")?;
274            writeln!(f, "│  First 10 unused tensors:")?;
275            for path in self.unused.iter().take(10) {
276                writeln!(f, "│    • {}", path)?;
277            }
278            if self.unused.len() > 10 {
279                writeln!(f, "│    ... and {} more", self.unused.len() - 10)?;
280            }
281        }
282
283        if !self.errors.is_empty() {
284            writeln!(f, "│")?;
285            writeln!(f, "├─ Errors")?;
286            writeln!(f, "│")?;
287            for error in self.errors.iter().take(10) {
288                writeln!(f, "│  ⚠️  {}", error)?;
289            }
290            if self.errors.len() > 10 {
291                writeln!(f, "│    ... and {} more", self.errors.len() - 10)?;
292            }
293        }
294
295        writeln!(f, "│")?;
296        write!(f, "└───────────────────────────────────────────────────")?;
297
298        Ok(())
299    }
300}