entrenar/quality/failure/
types.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum FailureCategory {
8 DataQuality,
10
11 ModelConvergence,
13
14 ResourceExhaustion,
16
17 DependencyFailure,
19
20 ConfigurationError,
22
23 Unknown,
25}
26
27impl FailureCategory {
28 pub fn description(&self) -> &'static str {
30 match self {
31 Self::DataQuality => "Data quality issue",
32 Self::ModelConvergence => "Model convergence failure",
33 Self::ResourceExhaustion => "Resource exhaustion",
34 Self::DependencyFailure => "Dependency failure",
35 Self::ConfigurationError => "Configuration error",
36 Self::Unknown => "Unknown failure",
37 }
38 }
39
40 const CATEGORY_PATTERNS: &'static [(&'static [&'static str], FailureCategory)] = &[
43 (&["nan", "inf", "exploding", "diverge", "gradient"], FailureCategory::ModelConvergence),
44 (
45 &["out of memory", "oom", "memory", "timeout", "disk full", "no space"],
46 FailureCategory::ResourceExhaustion,
47 ),
48 (
49 &[
50 "corrupt",
51 "invalid data",
52 "missing feature",
53 "data format",
54 "parse error",
55 "invalid shape",
56 ],
57 FailureCategory::DataQuality,
58 ),
59 (
60 &["dependency", "crate", "version", "build error", "compile"],
61 FailureCategory::DependencyFailure,
62 ),
63 (
64 &["config", "parameter", "invalid value", "missing field", "required"],
65 FailureCategory::ConfigurationError,
66 ),
67 ];
68
69 pub fn from_error_message(message: &str) -> Self {
71 let lower = message.to_lowercase();
72
73 for (patterns, category) in Self::CATEGORY_PATTERNS {
74 if patterns.iter().any(|p| lower.contains(p)) {
75 return *category;
76 }
77 }
78
79 Self::Unknown
80 }
81}
82
83impl std::fmt::Display for FailureCategory {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 write!(f, "{}", self.description())
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct FailureContext {
92 pub error_code: String,
94
95 pub message: String,
97
98 pub category: FailureCategory,
100
101 pub stack_trace: Option<String>,
103
104 pub suggested_fix: Option<String>,
106
107 pub related_runs: Vec<String>,
109}
110
111impl FailureContext {
112 pub fn new(error_code: impl Into<String>, message: impl Into<String>) -> Self {
114 let message_str = message.into();
115 let category = FailureCategory::from_error_message(&message_str);
116
117 Self {
118 error_code: error_code.into(),
119 message: message_str,
120 category,
121 stack_trace: None,
122 suggested_fix: None,
123 related_runs: Vec::new(),
124 }
125 }
126
127 pub fn with_category(
129 error_code: impl Into<String>,
130 message: impl Into<String>,
131 category: FailureCategory,
132 ) -> Self {
133 Self {
134 error_code: error_code.into(),
135 message: message.into(),
136 category,
137 stack_trace: None,
138 suggested_fix: None,
139 related_runs: Vec::new(),
140 }
141 }
142
143 pub fn with_stack_trace(mut self, trace: impl Into<String>) -> Self {
145 self.stack_trace = Some(trace.into());
146 self
147 }
148
149 pub fn with_suggested_fix(mut self, fix: impl Into<String>) -> Self {
151 self.suggested_fix = Some(fix.into());
152 self
153 }
154
155 pub fn with_related_runs(mut self, runs: Vec<String>) -> Self {
157 self.related_runs = runs;
158 self
159 }
160
161 pub fn generate_suggested_fix(&self) -> String {
163 match self.category {
164 FailureCategory::ModelConvergence => {
165 "Try reducing the learning rate, enabling gradient clipping, \
166 or checking for NaN values in input data."
167 .to_string()
168 }
169 FailureCategory::ResourceExhaustion => {
170 "Try reducing batch size, using gradient checkpointing, \
171 or enabling mixed-precision training."
172 .to_string()
173 }
174 FailureCategory::DataQuality => {
175 "Validate input data format, check for missing values, \
176 and verify data preprocessing pipeline."
177 .to_string()
178 }
179 FailureCategory::DependencyFailure => {
180 "Run `cargo update`, check Cargo.lock for version conflicts, \
181 and verify all required features are enabled."
182 .to_string()
183 }
184 FailureCategory::ConfigurationError => {
185 "Review configuration file for typos, missing required fields, \
186 and invalid parameter values."
187 .to_string()
188 }
189 FailureCategory::Unknown => {
190 "Review the error message and stack trace for more details. \
191 Consider enabling debug logging."
192 .to_string()
193 }
194 }
195 }
196}
197
198impl<E: std::error::Error> From<&E> for FailureContext {
199 fn from(error: &E) -> Self {
200 let message = error.to_string();
201 let category = FailureCategory::from_error_message(&message);
202
203 let mut context = Self::new("ERR_GENERIC", message);
204 context.category = category;
205
206 let mut trace = String::new();
208 let mut source = error.source();
209 while let Some(s) = source {
210 trace.push_str(&format!("Caused by: {s}\n"));
211 source = s.source();
212 }
213 if !trace.is_empty() {
214 context.stack_trace = Some(trace);
215 }
216
217 context
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_failure_category_description() {
227 assert_eq!(FailureCategory::DataQuality.description(), "Data quality issue");
228 assert_eq!(FailureCategory::ModelConvergence.description(), "Model convergence failure");
229 assert_eq!(FailureCategory::ResourceExhaustion.description(), "Resource exhaustion");
230 assert_eq!(FailureCategory::DependencyFailure.description(), "Dependency failure");
231 assert_eq!(FailureCategory::ConfigurationError.description(), "Configuration error");
232 assert_eq!(FailureCategory::Unknown.description(), "Unknown failure");
233 }
234
235 #[test]
236 fn test_failure_category_display() {
237 assert_eq!(format!("{}", FailureCategory::DataQuality), "Data quality issue");
238 }
239
240 #[test]
241 fn test_from_error_message_model_convergence() {
242 assert_eq!(
243 FailureCategory::from_error_message("NaN loss detected"),
244 FailureCategory::ModelConvergence
245 );
246 assert_eq!(
247 FailureCategory::from_error_message("inf value in tensor"),
248 FailureCategory::ModelConvergence
249 );
250 assert_eq!(
251 FailureCategory::from_error_message("exploding gradients"),
252 FailureCategory::ModelConvergence
253 );
254 assert_eq!(
255 FailureCategory::from_error_message("model diverged"),
256 FailureCategory::ModelConvergence
257 );
258 assert_eq!(
259 FailureCategory::from_error_message("gradient overflow"),
260 FailureCategory::ModelConvergence
261 );
262 }
263
264 #[test]
265 fn test_from_error_message_resource_exhaustion() {
266 assert_eq!(
267 FailureCategory::from_error_message("out of memory"),
268 FailureCategory::ResourceExhaustion
269 );
270 assert_eq!(
271 FailureCategory::from_error_message("OOM killed"),
272 FailureCategory::ResourceExhaustion
273 );
274 assert_eq!(
275 FailureCategory::from_error_message("memory allocation failed"),
276 FailureCategory::ResourceExhaustion
277 );
278 assert_eq!(
279 FailureCategory::from_error_message("timeout exceeded"),
280 FailureCategory::ResourceExhaustion
281 );
282 assert_eq!(
283 FailureCategory::from_error_message("disk full"),
284 FailureCategory::ResourceExhaustion
285 );
286 assert_eq!(
287 FailureCategory::from_error_message("no space left"),
288 FailureCategory::ResourceExhaustion
289 );
290 }
291
292 #[test]
293 fn test_from_error_message_data_quality() {
294 assert_eq!(
295 FailureCategory::from_error_message("corrupt file"),
296 FailureCategory::DataQuality
297 );
298 assert_eq!(
299 FailureCategory::from_error_message("invalid data format"),
300 FailureCategory::DataQuality
301 );
302 assert_eq!(
303 FailureCategory::from_error_message("missing feature: X"),
304 FailureCategory::DataQuality
305 );
306 assert_eq!(
307 FailureCategory::from_error_message("data format error"),
308 FailureCategory::DataQuality
309 );
310 assert_eq!(
311 FailureCategory::from_error_message("parse error"),
312 FailureCategory::DataQuality
313 );
314 assert_eq!(
315 FailureCategory::from_error_message("invalid shape"),
316 FailureCategory::DataQuality
317 );
318 }
319
320 #[test]
321 fn test_from_error_message_dependency() {
322 assert_eq!(
323 FailureCategory::from_error_message("dependency not found"),
324 FailureCategory::DependencyFailure
325 );
326 assert_eq!(
327 FailureCategory::from_error_message("crate version conflict"),
328 FailureCategory::DependencyFailure
329 );
330 assert_eq!(
331 FailureCategory::from_error_message("version mismatch"),
332 FailureCategory::DependencyFailure
333 );
334 assert_eq!(
335 FailureCategory::from_error_message("build error"),
336 FailureCategory::DependencyFailure
337 );
338 assert_eq!(
339 FailureCategory::from_error_message("compile failed"),
340 FailureCategory::DependencyFailure
341 );
342 }
343
344 #[test]
345 fn test_from_error_message_configuration() {
346 assert_eq!(
347 FailureCategory::from_error_message("config error"),
348 FailureCategory::ConfigurationError
349 );
350 assert_eq!(
351 FailureCategory::from_error_message("invalid parameter"),
352 FailureCategory::ConfigurationError
353 );
354 assert_eq!(
355 FailureCategory::from_error_message("invalid value for field"),
356 FailureCategory::ConfigurationError
357 );
358 assert_eq!(
359 FailureCategory::from_error_message("missing field: name"),
360 FailureCategory::ConfigurationError
361 );
362 assert_eq!(
363 FailureCategory::from_error_message("required field missing"),
364 FailureCategory::ConfigurationError
365 );
366 }
367
368 #[test]
369 fn test_from_error_message_unknown() {
370 assert_eq!(
371 FailureCategory::from_error_message("something weird happened"),
372 FailureCategory::Unknown
373 );
374 assert_eq!(FailureCategory::from_error_message(""), FailureCategory::Unknown);
375 }
376
377 #[test]
378 fn test_failure_context_new() {
379 let ctx = FailureContext::new("E001", "NaN loss detected");
380 assert_eq!(ctx.error_code, "E001");
381 assert_eq!(ctx.message, "NaN loss detected");
382 assert_eq!(ctx.category, FailureCategory::ModelConvergence);
383 assert!(ctx.stack_trace.is_none());
384 assert!(ctx.suggested_fix.is_none());
385 assert!(ctx.related_runs.is_empty());
386 }
387
388 #[test]
389 fn test_failure_context_with_category() {
390 let ctx =
391 FailureContext::with_category("E002", "Custom error", FailureCategory::DataQuality);
392 assert_eq!(ctx.error_code, "E002");
393 assert_eq!(ctx.category, FailureCategory::DataQuality);
394 }
395
396 #[test]
397 fn test_failure_context_with_stack_trace() {
398 let ctx = FailureContext::new("E001", "error").with_stack_trace("at line 42");
399 assert_eq!(ctx.stack_trace, Some("at line 42".to_string()));
400 }
401
402 #[test]
403 fn test_failure_context_with_suggested_fix() {
404 let ctx = FailureContext::new("E001", "error").with_suggested_fix("Try rebooting");
405 assert_eq!(ctx.suggested_fix, Some("Try rebooting".to_string()));
406 }
407
408 #[test]
409 fn test_failure_context_with_related_runs() {
410 let ctx = FailureContext::new("E001", "error")
411 .with_related_runs(vec!["run1".to_string(), "run2".to_string()]);
412 assert_eq!(ctx.related_runs.len(), 2);
413 }
414
415 #[test]
416 fn test_generate_suggested_fix_all_categories() {
417 let categories = [
418 FailureCategory::ModelConvergence,
419 FailureCategory::ResourceExhaustion,
420 FailureCategory::DataQuality,
421 FailureCategory::DependencyFailure,
422 FailureCategory::ConfigurationError,
423 FailureCategory::Unknown,
424 ];
425 for category in categories {
426 let ctx = FailureContext::with_category("E001", "error", category);
427 let fix = ctx.generate_suggested_fix();
428 assert!(!fix.is_empty());
429 }
430 }
431
432 #[test]
433 fn test_failure_context_from_error() {
434 use std::io;
435 let err = io::Error::new(io::ErrorKind::OutOfMemory, "out of memory");
436 let ctx = FailureContext::from(&err);
437 assert_eq!(ctx.error_code, "ERR_GENERIC");
438 assert!(ctx.message.contains("memory"));
439 assert_eq!(ctx.category, FailureCategory::ResourceExhaustion);
440 }
441
442 #[test]
443 fn test_failure_category_serialization() {
444 let cat = FailureCategory::DataQuality;
445 let json = serde_json::to_string(&cat).expect("JSON serialization should succeed");
446 let deserialized: FailureCategory =
447 serde_json::from_str(&json).expect("JSON deserialization should succeed");
448 assert_eq!(cat, deserialized);
449 }
450
451 #[test]
452 fn test_failure_context_serialization() {
453 let ctx = FailureContext::new("E001", "test error")
454 .with_stack_trace("trace")
455 .with_suggested_fix("fix it");
456 let json = serde_json::to_string(&ctx).expect("JSON serialization should succeed");
457 let deserialized: FailureContext =
458 serde_json::from_str(&json).expect("JSON deserialization should succeed");
459 assert_eq!(ctx.error_code, deserialized.error_code);
460 assert_eq!(ctx.stack_trace, deserialized.stack_trace);
461 }
462
463 #[test]
464 fn test_failure_category_clone_copy() {
465 let cat = FailureCategory::ModelConvergence;
466 let cloned = cat;
467 let copied = cat;
468 assert_eq!(cat, cloned);
469 assert_eq!(cat, copied);
470 }
471
472 #[test]
473 fn test_failure_category_hash() {
474 use std::collections::HashSet;
475 let mut set = HashSet::new();
476 set.insert(FailureCategory::DataQuality);
477 set.insert(FailureCategory::ModelConvergence);
478 assert_eq!(set.len(), 2);
479 }
480
481 #[test]
482 fn test_failure_context_builder_chain() {
483 let ctx = FailureContext::new("E001", "error")
484 .with_stack_trace("trace")
485 .with_suggested_fix("fix")
486 .with_related_runs(vec!["run1".to_string()]);
487 assert!(ctx.stack_trace.is_some());
488 assert!(ctx.suggested_fix.is_some());
489 assert_eq!(ctx.related_runs.len(), 1);
490 }
491}