Skip to main content

entrenar/quality/failure/
types.rs

1//! Failure types and context structures.
2
3use serde::{Deserialize, Serialize};
4
5/// Categories of training failures
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum FailureCategory {
8    /// Data quality issues (corrupt files, missing features, invalid formats)
9    DataQuality,
10
11    /// Model convergence failures (NaN loss, exploding gradients, divergence)
12    ModelConvergence,
13
14    /// Resource exhaustion (OOM, disk full, timeout)
15    ResourceExhaustion,
16
17    /// Dependency failures (missing crates, version conflicts, build errors)
18    DependencyFailure,
19
20    /// Configuration errors (invalid hyperparameters, missing required fields)
21    ConfigurationError,
22
23    /// Unknown or uncategorized failure
24    Unknown,
25}
26
27impl FailureCategory {
28    /// Get a human-readable description of the category
29    pub fn description(&self) -> &'static str {
30        match self {
31            Self::DataQuality => "Data quality issue",
32            Self::ModelConvergence => "Model convergence failure",
33            Self::ResourceExhaustion => "Resource exhaustion",
34            Self::DependencyFailure => "Dependency failure",
35            Self::ConfigurationError => "Configuration error",
36            Self::Unknown => "Unknown failure",
37        }
38    }
39
40    /// Pattern table: each entry maps keywords to a failure category.
41    /// Checked in priority order (first match wins).
42    const CATEGORY_PATTERNS: &'static [(&'static [&'static str], FailureCategory)] = &[
43        (&["nan", "inf", "exploding", "diverge", "gradient"], FailureCategory::ModelConvergence),
44        (
45            &["out of memory", "oom", "memory", "timeout", "disk full", "no space"],
46            FailureCategory::ResourceExhaustion,
47        ),
48        (
49            &[
50                "corrupt",
51                "invalid data",
52                "missing feature",
53                "data format",
54                "parse error",
55                "invalid shape",
56            ],
57            FailureCategory::DataQuality,
58        ),
59        (
60            &["dependency", "crate", "version", "build error", "compile"],
61            FailureCategory::DependencyFailure,
62        ),
63        (
64            &["config", "parameter", "invalid value", "missing field", "required"],
65            FailureCategory::ConfigurationError,
66        ),
67    ];
68
69    /// Attempt to categorize from error message patterns
70    pub fn from_error_message(message: &str) -> Self {
71        let lower = message.to_lowercase();
72
73        for (patterns, category) in Self::CATEGORY_PATTERNS {
74            if patterns.iter().any(|p| lower.contains(p)) {
75                return *category;
76            }
77        }
78
79        Self::Unknown
80    }
81}
82
83impl std::fmt::Display for FailureCategory {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "{}", self.description())
86    }
87}
88
89/// Structured failure context for a training run
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct FailureContext {
92    /// Error code (e.g., "E001", "NAN_LOSS")
93    pub error_code: String,
94
95    /// Human-readable error message
96    pub message: String,
97
98    /// Failure category for aggregation
99    pub category: FailureCategory,
100
101    /// Optional stack trace
102    pub stack_trace: Option<String>,
103
104    /// Suggested fix or remediation
105    pub suggested_fix: Option<String>,
106
107    /// Related run IDs that may have similar issues
108    pub related_runs: Vec<String>,
109}
110
111impl FailureContext {
112    /// Create a new failure context
113    pub fn new(error_code: impl Into<String>, message: impl Into<String>) -> Self {
114        let message_str = message.into();
115        let category = FailureCategory::from_error_message(&message_str);
116
117        Self {
118            error_code: error_code.into(),
119            message: message_str,
120            category,
121            stack_trace: None,
122            suggested_fix: None,
123            related_runs: Vec::new(),
124        }
125    }
126
127    /// Create with explicit category
128    pub fn with_category(
129        error_code: impl Into<String>,
130        message: impl Into<String>,
131        category: FailureCategory,
132    ) -> Self {
133        Self {
134            error_code: error_code.into(),
135            message: message.into(),
136            category,
137            stack_trace: None,
138            suggested_fix: None,
139            related_runs: Vec::new(),
140        }
141    }
142
143    /// Add a stack trace
144    pub fn with_stack_trace(mut self, trace: impl Into<String>) -> Self {
145        self.stack_trace = Some(trace.into());
146        self
147    }
148
149    /// Add a suggested fix
150    pub fn with_suggested_fix(mut self, fix: impl Into<String>) -> Self {
151        self.suggested_fix = Some(fix.into());
152        self
153    }
154
155    /// Add related run IDs
156    pub fn with_related_runs(mut self, runs: Vec<String>) -> Self {
157        self.related_runs = runs;
158        self
159    }
160
161    /// Generate a suggested fix based on the category
162    pub fn generate_suggested_fix(&self) -> String {
163        match self.category {
164            FailureCategory::ModelConvergence => {
165                "Try reducing the learning rate, enabling gradient clipping, \
166                 or checking for NaN values in input data."
167                    .to_string()
168            }
169            FailureCategory::ResourceExhaustion => {
170                "Try reducing batch size, using gradient checkpointing, \
171                 or enabling mixed-precision training."
172                    .to_string()
173            }
174            FailureCategory::DataQuality => {
175                "Validate input data format, check for missing values, \
176                 and verify data preprocessing pipeline."
177                    .to_string()
178            }
179            FailureCategory::DependencyFailure => {
180                "Run `cargo update`, check Cargo.lock for version conflicts, \
181                 and verify all required features are enabled."
182                    .to_string()
183            }
184            FailureCategory::ConfigurationError => {
185                "Review configuration file for typos, missing required fields, \
186                 and invalid parameter values."
187                    .to_string()
188            }
189            FailureCategory::Unknown => {
190                "Review the error message and stack trace for more details. \
191                 Consider enabling debug logging."
192                    .to_string()
193            }
194        }
195    }
196}
197
198impl<E: std::error::Error> From<&E> for FailureContext {
199    fn from(error: &E) -> Self {
200        let message = error.to_string();
201        let category = FailureCategory::from_error_message(&message);
202
203        let mut context = Self::new("ERR_GENERIC", message);
204        context.category = category;
205
206        // Try to get source chain for stack trace
207        let mut trace = String::new();
208        let mut source = error.source();
209        while let Some(s) = source {
210            trace.push_str(&format!("Caused by: {s}\n"));
211            source = s.source();
212        }
213        if !trace.is_empty() {
214            context.stack_trace = Some(trace);
215        }
216
217        context
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_failure_category_description() {
227        assert_eq!(FailureCategory::DataQuality.description(), "Data quality issue");
228        assert_eq!(FailureCategory::ModelConvergence.description(), "Model convergence failure");
229        assert_eq!(FailureCategory::ResourceExhaustion.description(), "Resource exhaustion");
230        assert_eq!(FailureCategory::DependencyFailure.description(), "Dependency failure");
231        assert_eq!(FailureCategory::ConfigurationError.description(), "Configuration error");
232        assert_eq!(FailureCategory::Unknown.description(), "Unknown failure");
233    }
234
235    #[test]
236    fn test_failure_category_display() {
237        assert_eq!(format!("{}", FailureCategory::DataQuality), "Data quality issue");
238    }
239
240    #[test]
241    fn test_from_error_message_model_convergence() {
242        assert_eq!(
243            FailureCategory::from_error_message("NaN loss detected"),
244            FailureCategory::ModelConvergence
245        );
246        assert_eq!(
247            FailureCategory::from_error_message("inf value in tensor"),
248            FailureCategory::ModelConvergence
249        );
250        assert_eq!(
251            FailureCategory::from_error_message("exploding gradients"),
252            FailureCategory::ModelConvergence
253        );
254        assert_eq!(
255            FailureCategory::from_error_message("model diverged"),
256            FailureCategory::ModelConvergence
257        );
258        assert_eq!(
259            FailureCategory::from_error_message("gradient overflow"),
260            FailureCategory::ModelConvergence
261        );
262    }
263
264    #[test]
265    fn test_from_error_message_resource_exhaustion() {
266        assert_eq!(
267            FailureCategory::from_error_message("out of memory"),
268            FailureCategory::ResourceExhaustion
269        );
270        assert_eq!(
271            FailureCategory::from_error_message("OOM killed"),
272            FailureCategory::ResourceExhaustion
273        );
274        assert_eq!(
275            FailureCategory::from_error_message("memory allocation failed"),
276            FailureCategory::ResourceExhaustion
277        );
278        assert_eq!(
279            FailureCategory::from_error_message("timeout exceeded"),
280            FailureCategory::ResourceExhaustion
281        );
282        assert_eq!(
283            FailureCategory::from_error_message("disk full"),
284            FailureCategory::ResourceExhaustion
285        );
286        assert_eq!(
287            FailureCategory::from_error_message("no space left"),
288            FailureCategory::ResourceExhaustion
289        );
290    }
291
292    #[test]
293    fn test_from_error_message_data_quality() {
294        assert_eq!(
295            FailureCategory::from_error_message("corrupt file"),
296            FailureCategory::DataQuality
297        );
298        assert_eq!(
299            FailureCategory::from_error_message("invalid data format"),
300            FailureCategory::DataQuality
301        );
302        assert_eq!(
303            FailureCategory::from_error_message("missing feature: X"),
304            FailureCategory::DataQuality
305        );
306        assert_eq!(
307            FailureCategory::from_error_message("data format error"),
308            FailureCategory::DataQuality
309        );
310        assert_eq!(
311            FailureCategory::from_error_message("parse error"),
312            FailureCategory::DataQuality
313        );
314        assert_eq!(
315            FailureCategory::from_error_message("invalid shape"),
316            FailureCategory::DataQuality
317        );
318    }
319
320    #[test]
321    fn test_from_error_message_dependency() {
322        assert_eq!(
323            FailureCategory::from_error_message("dependency not found"),
324            FailureCategory::DependencyFailure
325        );
326        assert_eq!(
327            FailureCategory::from_error_message("crate version conflict"),
328            FailureCategory::DependencyFailure
329        );
330        assert_eq!(
331            FailureCategory::from_error_message("version mismatch"),
332            FailureCategory::DependencyFailure
333        );
334        assert_eq!(
335            FailureCategory::from_error_message("build error"),
336            FailureCategory::DependencyFailure
337        );
338        assert_eq!(
339            FailureCategory::from_error_message("compile failed"),
340            FailureCategory::DependencyFailure
341        );
342    }
343
344    #[test]
345    fn test_from_error_message_configuration() {
346        assert_eq!(
347            FailureCategory::from_error_message("config error"),
348            FailureCategory::ConfigurationError
349        );
350        assert_eq!(
351            FailureCategory::from_error_message("invalid parameter"),
352            FailureCategory::ConfigurationError
353        );
354        assert_eq!(
355            FailureCategory::from_error_message("invalid value for field"),
356            FailureCategory::ConfigurationError
357        );
358        assert_eq!(
359            FailureCategory::from_error_message("missing field: name"),
360            FailureCategory::ConfigurationError
361        );
362        assert_eq!(
363            FailureCategory::from_error_message("required field missing"),
364            FailureCategory::ConfigurationError
365        );
366    }
367
368    #[test]
369    fn test_from_error_message_unknown() {
370        assert_eq!(
371            FailureCategory::from_error_message("something weird happened"),
372            FailureCategory::Unknown
373        );
374        assert_eq!(FailureCategory::from_error_message(""), FailureCategory::Unknown);
375    }
376
377    #[test]
378    fn test_failure_context_new() {
379        let ctx = FailureContext::new("E001", "NaN loss detected");
380        assert_eq!(ctx.error_code, "E001");
381        assert_eq!(ctx.message, "NaN loss detected");
382        assert_eq!(ctx.category, FailureCategory::ModelConvergence);
383        assert!(ctx.stack_trace.is_none());
384        assert!(ctx.suggested_fix.is_none());
385        assert!(ctx.related_runs.is_empty());
386    }
387
388    #[test]
389    fn test_failure_context_with_category() {
390        let ctx =
391            FailureContext::with_category("E002", "Custom error", FailureCategory::DataQuality);
392        assert_eq!(ctx.error_code, "E002");
393        assert_eq!(ctx.category, FailureCategory::DataQuality);
394    }
395
396    #[test]
397    fn test_failure_context_with_stack_trace() {
398        let ctx = FailureContext::new("E001", "error").with_stack_trace("at line 42");
399        assert_eq!(ctx.stack_trace, Some("at line 42".to_string()));
400    }
401
402    #[test]
403    fn test_failure_context_with_suggested_fix() {
404        let ctx = FailureContext::new("E001", "error").with_suggested_fix("Try rebooting");
405        assert_eq!(ctx.suggested_fix, Some("Try rebooting".to_string()));
406    }
407
408    #[test]
409    fn test_failure_context_with_related_runs() {
410        let ctx = FailureContext::new("E001", "error")
411            .with_related_runs(vec!["run1".to_string(), "run2".to_string()]);
412        assert_eq!(ctx.related_runs.len(), 2);
413    }
414
415    #[test]
416    fn test_generate_suggested_fix_all_categories() {
417        let categories = [
418            FailureCategory::ModelConvergence,
419            FailureCategory::ResourceExhaustion,
420            FailureCategory::DataQuality,
421            FailureCategory::DependencyFailure,
422            FailureCategory::ConfigurationError,
423            FailureCategory::Unknown,
424        ];
425        for category in categories {
426            let ctx = FailureContext::with_category("E001", "error", category);
427            let fix = ctx.generate_suggested_fix();
428            assert!(!fix.is_empty());
429        }
430    }
431
432    #[test]
433    fn test_failure_context_from_error() {
434        use std::io;
435        let err = io::Error::new(io::ErrorKind::OutOfMemory, "out of memory");
436        let ctx = FailureContext::from(&err);
437        assert_eq!(ctx.error_code, "ERR_GENERIC");
438        assert!(ctx.message.contains("memory"));
439        assert_eq!(ctx.category, FailureCategory::ResourceExhaustion);
440    }
441
442    #[test]
443    fn test_failure_category_serialization() {
444        let cat = FailureCategory::DataQuality;
445        let json = serde_json::to_string(&cat).expect("JSON serialization should succeed");
446        let deserialized: FailureCategory =
447            serde_json::from_str(&json).expect("JSON deserialization should succeed");
448        assert_eq!(cat, deserialized);
449    }
450
451    #[test]
452    fn test_failure_context_serialization() {
453        let ctx = FailureContext::new("E001", "test error")
454            .with_stack_trace("trace")
455            .with_suggested_fix("fix it");
456        let json = serde_json::to_string(&ctx).expect("JSON serialization should succeed");
457        let deserialized: FailureContext =
458            serde_json::from_str(&json).expect("JSON deserialization should succeed");
459        assert_eq!(ctx.error_code, deserialized.error_code);
460        assert_eq!(ctx.stack_trace, deserialized.stack_trace);
461    }
462
463    #[test]
464    fn test_failure_category_clone_copy() {
465        let cat = FailureCategory::ModelConvergence;
466        let cloned = cat;
467        let copied = cat;
468        assert_eq!(cat, cloned);
469        assert_eq!(cat, copied);
470    }
471
472    #[test]
473    fn test_failure_category_hash() {
474        use std::collections::HashSet;
475        let mut set = HashSet::new();
476        set.insert(FailureCategory::DataQuality);
477        set.insert(FailureCategory::ModelConvergence);
478        assert_eq!(set.len(), 2);
479    }
480
481    #[test]
482    fn test_failure_context_builder_chain() {
483        let ctx = FailureContext::new("E001", "error")
484            .with_stack_trace("trace")
485            .with_suggested_fix("fix")
486            .with_related_runs(vec!["run1".to_string()]);
487        assert!(ctx.stack_trace.is_some());
488        assert!(ctx.suggested_fix.is_some());
489        assert_eq!(ctx.related_runs.len(), 1);
490    }
491}