1use aprender::primitives::Matrix;
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone, Debug, Default, Serialize, Deserialize)]
8pub struct ErrorFeatures {
9 pub message_length: f32,
11 pub type_keywords: f32,
13 pub borrow_keywords: f32,
15 pub import_keywords: f32,
17 pub lifetime_keywords: f32,
19 pub trait_keywords: f32,
21 pub has_line_number: f32,
23 pub has_column: f32,
25 pub has_code_snippets: f32,
27 pub has_arrows: f32,
29 pub has_error_code: f32,
31 pub suggestion_count: f32,
33}
34
35impl ErrorFeatures {
36 pub const DIM: usize = 12;
38
39 #[must_use]
41 pub fn from_error_message(message: &str) -> Self {
42 let lower = message.to_lowercase();
43
44 Self {
45 message_length: (message.len() as f32 / 500.0).min(1.0),
46
47 type_keywords: count_keywords(
48 &lower,
49 &[
50 "expected",
51 "found",
52 "mismatched",
53 "type",
54 "cannot coerce",
55 "incompatible",
56 ],
57 ),
58
59 borrow_keywords: count_keywords(
60 &lower,
61 &[
62 "borrow",
63 "borrowed",
64 "move",
65 "moved",
66 "ownership",
67 "cannot move",
68 ],
69 ),
70
71 import_keywords: count_keywords(
72 &lower,
73 &[
74 "not found",
75 "unresolved",
76 "cannot find",
77 "undefined",
78 "undeclared",
79 ],
80 ),
81
82 lifetime_keywords: count_keywords(
83 &lower,
84 &[
85 "lifetime",
86 "'a",
87 "'static",
88 "live long enough",
89 "dangling",
90 "borrowed value",
91 ],
92 ),
93
94 trait_keywords: count_keywords(
95 &lower,
96 &[
97 "trait",
98 "impl",
99 "not implemented",
100 "bound",
101 "doesn't implement",
102 ],
103 ),
104
105 has_line_number: if message.contains(':') && message.chars().any(|c| c.is_ascii_digit())
106 {
107 1.0
108 } else {
109 0.0
110 },
111
112 has_column: if message.matches(':').count() > 1 {
113 1.0
114 } else {
115 0.0
116 },
117
118 has_code_snippets: (message.matches('`').count() as f32 / 10.0).min(1.0),
119
120 has_arrows: if message.contains("-->") || message.contains("^^^") {
121 1.0
122 } else {
123 0.0
124 },
125
126 has_error_code: if message.contains("E0") || message.contains("[E") {
127 1.0
128 } else {
129 0.0
130 },
131
132 suggestion_count: count_keywords(
133 &lower,
134 &["help:", "suggestion:", "consider", "try", "perhaps"],
135 ),
136 }
137 }
138
139 #[must_use]
141 pub fn to_matrix(&self) -> Matrix<f32> {
142 Matrix::from_vec(1, Self::DIM, self.to_vec()).expect("Feature dimensions are correct")
143 }
144
145 #[must_use]
147 pub fn to_vec(&self) -> Vec<f32> {
148 vec![
149 self.message_length,
150 self.type_keywords,
151 self.borrow_keywords,
152 self.import_keywords,
153 self.lifetime_keywords,
154 self.trait_keywords,
155 self.has_line_number,
156 self.has_column,
157 self.has_code_snippets,
158 self.has_arrows,
159 self.has_error_code,
160 self.suggestion_count,
161 ]
162 }
163
164 #[must_use]
170 pub fn from_vec(v: &[f32]) -> Self {
171 assert_eq!(
172 v.len(),
173 Self::DIM,
174 "Feature vector must have {} elements",
175 Self::DIM
176 );
177
178 Self {
179 message_length: v[0],
180 type_keywords: v[1],
181 borrow_keywords: v[2],
182 import_keywords: v[3],
183 lifetime_keywords: v[4],
184 trait_keywords: v[5],
185 has_line_number: v[6],
186 has_column: v[7],
187 has_code_snippets: v[8],
188 has_arrows: v[9],
189 has_error_code: v[10],
190 suggestion_count: v[11],
191 }
192 }
193}
194
195fn count_keywords(text: &str, keywords: &[&str]) -> f32 {
197 let count = keywords.iter().filter(|k| text.contains(*k)).count();
198 (count as f32 / keywords.len() as f32).min(1.0)
199}
200
201pub const ERROR_CODES: [&str; 25] = [
207 "E0308", "E0425", "E0433", "E0277", "E0599", "E0382", "E0502", "E0503", "E0505", "E0506", "E0507", "E0106", "E0495", "E0621", "E0282", "E0283", "E0412", "E0432", "E0603", "E0609", "E0614", "E0615", "E0616", "E0618", "E0620", ];
233
234pub const KEYWORD_CATEGORIES: [(&str, &[&str]); 9] = [
236 ("type_coercion", &["as", "into", "from", "convert", "cast"]),
237 ("ownership", &["owned", "clone", "copy", "drop", "take"]),
238 ("reference", &["ref", "&", "deref", "borrow"]),
239 ("mutability", &["mut", "immutable", "mutable"]),
240 ("generic", &["generic", "parameter", "constraint", "where"]),
241 ("async", &["async", "await", "future", "poll"]),
242 ("closure", &["closure", "capture", "fn", "move"]),
243 ("derive", &["derive", "debug", "clone", "default"]),
244 (
245 "result_option",
246 &["result", "option", "some", "none", "ok", "err", "unwrap"],
247 ),
248];
249
250#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct EnhancedErrorFeatures {
254 pub base: ErrorFeatures,
256 pub error_code_onehot: Vec<f32>,
258 pub keyword_counts: Vec<f32>,
260}
261
262impl Default for EnhancedErrorFeatures {
263 fn default() -> Self {
264 Self {
265 base: ErrorFeatures::default(),
266 error_code_onehot: vec![0.0; 25],
267 keyword_counts: vec![0.0; 36],
268 }
269 }
270}
271
272impl EnhancedErrorFeatures {
273 pub const DIM: usize = 73;
275
276 #[must_use]
278 pub fn from_error_message(message: &str) -> Self {
279 let lower = message.to_lowercase();
280
281 let base = ErrorFeatures::from_error_message(message);
283
284 let mut error_code_onehot = vec![0.0f32; 25];
286 for (i, code) in ERROR_CODES.iter().enumerate() {
287 if message.contains(code) {
288 error_code_onehot[i] = 1.0;
289 break; }
291 }
292
293 let mut keyword_counts = vec![0.0f32; 36];
295 for (i, (_name, keywords)) in KEYWORD_CATEGORIES.iter().enumerate() {
296 let base_idx = i * 4;
297 let present = keywords.iter().any(|k| lower.contains(k));
299 keyword_counts[base_idx] = if present { 1.0 } else { 0.0 };
300
301 let count = keywords.iter().filter(|k| lower.contains(*k)).count();
303 keyword_counts[base_idx + 1] = (count as f32 / keywords.len() as f32).min(1.0);
304
305 let first_pos = keywords
307 .iter()
308 .filter_map(|k| lower.find(k))
309 .min()
310 .unwrap_or(lower.len());
311 keyword_counts[base_idx + 2] = 1.0 - (first_pos as f32 / lower.len().max(1) as f32);
312
313 let total_occurrences: usize = keywords.iter().map(|k| lower.matches(k).count()).sum();
315 keyword_counts[base_idx + 3] =
316 (total_occurrences as f32 * 100.0 / lower.len().max(1) as f32).min(1.0);
317 }
318
319 Self {
320 base,
321 error_code_onehot,
322 keyword_counts,
323 }
324 }
325
326 #[must_use]
328 pub fn to_vec(&self) -> Vec<f32> {
329 let mut vec = Vec::with_capacity(Self::DIM);
330 vec.extend(self.base.to_vec());
331 vec.extend(self.error_code_onehot.iter());
332 vec.extend(self.keyword_counts.iter());
333 vec
334 }
335
336 #[must_use]
338 pub fn to_matrix(&self) -> Matrix<f32> {
339 Matrix::from_vec(1, Self::DIM, self.to_vec()).expect("Feature dimensions are correct")
340 }
341}
342
343pub struct EnhancedFeatureExtractor;
345
346impl EnhancedFeatureExtractor {
347 #[must_use]
349 pub fn extract_batch(messages: &[&str]) -> Matrix<f32> {
350 let features: Vec<f32> = messages
351 .iter()
352 .flat_map(|msg| EnhancedErrorFeatures::from_error_message(msg).to_vec())
353 .collect();
354
355 Matrix::from_vec(messages.len(), EnhancedErrorFeatures::DIM, features)
356 .expect("Feature batch dimensions are correct")
357 }
358}
359
360pub struct FeatureExtractor;
362
363impl FeatureExtractor {
364 #[must_use]
366 pub fn extract_batch(messages: &[&str]) -> Matrix<f32> {
367 let features: Vec<f32> = messages
368 .iter()
369 .flat_map(|msg| ErrorFeatures::from_error_message(msg).to_vec())
370 .collect();
371
372 Matrix::from_vec(messages.len(), ErrorFeatures::DIM, features)
373 .expect("Feature batch dimensions are correct")
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_feature_extraction() {
383 let msg = "error[E0308]: mismatched types\n --> src/main.rs:10:5\n |\n10 | foo(bar)\n | ^^^ expected `i32`, found `&str`";
384
385 let features = ErrorFeatures::from_error_message(msg);
386
387 assert!(features.message_length > 0.0);
388 assert!(features.type_keywords > 0.0);
389 assert!(features.has_error_code > 0.0);
390 assert!(features.has_line_number > 0.0);
391 assert!(features.has_arrows > 0.0);
392 }
393
394 #[test]
395 fn test_borrow_features() {
396 let msg = "error: cannot move out of borrowed content";
397 let features = ErrorFeatures::from_error_message(msg);
398
399 assert!(features.borrow_keywords > 0.0);
400 assert!((features.type_keywords - 0.0).abs() < 0.1);
401 }
402
403 #[test]
404 fn test_to_matrix() {
405 let msg = "error: expected i32";
406 let features = ErrorFeatures::from_error_message(msg);
407 let matrix = features.to_matrix();
408
409 assert_eq!(matrix.n_rows(), 1);
410 assert_eq!(matrix.n_cols(), ErrorFeatures::DIM);
411 }
412
413 #[test]
414 fn test_vec_roundtrip() {
415 let msg = "error: mismatched types";
416 let features = ErrorFeatures::from_error_message(msg);
417 let vec = features.to_vec();
418 let restored = ErrorFeatures::from_vec(&vec);
419
420 assert!((features.type_keywords - restored.type_keywords).abs() < 1e-6);
421 }
422
423 #[test]
424 fn test_batch_extraction() {
425 let messages = vec![
426 "error: expected i32",
427 "error: cannot move",
428 "error: not found",
429 ];
430
431 let matrix = FeatureExtractor::extract_batch(&messages);
432
433 assert_eq!(matrix.n_rows(), 3);
434 assert_eq!(matrix.n_cols(), ErrorFeatures::DIM);
435 }
436
437 #[test]
438 fn test_lifetime_features() {
439 let msg = "error: `x` does not live long enough";
440 let features = ErrorFeatures::from_error_message(msg);
441
442 assert!(features.lifetime_keywords > 0.0);
443 }
444
445 #[test]
446 fn test_trait_features() {
447 let msg = "error: the trait bound `T: Clone` is not satisfied";
448 let features = ErrorFeatures::from_error_message(msg);
449
450 assert!(features.trait_keywords > 0.0);
451 }
452
453 #[test]
454 fn test_suggestion_count() {
455 let msg = "error: type mismatch\nhelp: try using `.into()`\nhelp: consider adding type annotation";
456 let features = ErrorFeatures::from_error_message(msg);
457
458 assert!(features.suggestion_count > 0.0);
459 }
460
461 #[test]
464 fn test_enhanced_feature_dimension() {
465 let msg = "error[E0308]: mismatched types";
466 let features = EnhancedErrorFeatures::from_error_message(msg);
467 let vec = features.to_vec();
468
469 assert_eq!(vec.len(), EnhancedErrorFeatures::DIM);
470 assert_eq!(vec.len(), 73);
471 }
472
473 #[test]
474 fn test_enhanced_error_code_onehot() {
475 let msg = "error[E0308]: mismatched types\n --> src/main.rs:10:5";
476 let features = EnhancedErrorFeatures::from_error_message(msg);
477
478 assert_eq!(features.error_code_onehot[0], 1.0);
480 assert_eq!(features.error_code_onehot[1..].iter().sum::<f32>(), 0.0);
482 }
483
484 #[test]
485 fn test_enhanced_e0425_onehot() {
486 let msg = "error[E0425]: cannot find value `foo` in this scope";
487 let features = EnhancedErrorFeatures::from_error_message(msg);
488
489 assert_eq!(features.error_code_onehot[1], 1.0);
491 }
492
493 #[test]
494 fn test_enhanced_keyword_categories() {
495 let msg = "error: cannot convert `&str` into `String`";
496 let features = EnhancedErrorFeatures::from_error_message(msg);
497
498 assert!(features.keyword_counts[0] > 0.0, "type_coercion presence");
501 assert!(
502 features.keyword_counts[1] > 0.0,
503 "type_coercion count ratio"
504 );
505 }
506
507 #[test]
508 fn test_enhanced_result_option_keywords() {
509 let msg = "error: cannot call `.unwrap()` on `Result<T, E>`";
510 let features = EnhancedErrorFeatures::from_error_message(msg);
511
512 assert!(features.keyword_counts[32] > 0.0, "result_option presence");
514 }
515
516 #[test]
517 fn test_enhanced_batch_extraction() {
518 let messages = vec![
519 "error[E0308]: expected i32, found &str",
520 "error[E0382]: use of moved value",
521 "error[E0277]: trait bound not satisfied",
522 ];
523
524 let matrix = EnhancedFeatureExtractor::extract_batch(&messages);
525
526 assert_eq!(matrix.n_rows(), 3);
527 assert_eq!(matrix.n_cols(), EnhancedErrorFeatures::DIM);
528 }
529
530 #[test]
531 fn test_enhanced_to_matrix() {
532 let msg = "error[E0599]: no method named `foo` found";
533 let features = EnhancedErrorFeatures::from_error_message(msg);
534 let matrix = features.to_matrix();
535
536 assert_eq!(matrix.n_rows(), 1);
537 assert_eq!(matrix.n_cols(), EnhancedErrorFeatures::DIM);
538 }
539}