Skip to main content

depyler_oracle/
features.rs

1//! Feature extraction from error messages.
2
3use aprender::primitives::Matrix;
4use serde::{Deserialize, Serialize};
5
6/// Features extracted from an error message for ML classification.
7#[derive(Clone, Debug, Default, Serialize, Deserialize)]
8pub struct ErrorFeatures {
9    /// Message length (normalized)
10    pub message_length: f32,
11    /// Number of type-related keywords
12    pub type_keywords: f32,
13    /// Number of borrow-related keywords
14    pub borrow_keywords: f32,
15    /// Number of import-related keywords
16    pub import_keywords: f32,
17    /// Number of lifetime-related keywords
18    pub lifetime_keywords: f32,
19    /// Number of trait-related keywords
20    pub trait_keywords: f32,
21    /// Contains line number
22    pub has_line_number: f32,
23    /// Contains column number
24    pub has_column: f32,
25    /// Contains backticks (code snippets)
26    pub has_code_snippets: f32,
27    /// Contains arrow indicators
28    pub has_arrows: f32,
29    /// Error code present (e.g., E0308)
30    pub has_error_code: f32,
31    /// Number of suggestions in message
32    pub suggestion_count: f32,
33}
34
35impl ErrorFeatures {
36    /// Feature dimension.
37    pub const DIM: usize = 12;
38
39    /// Extract features from an error message.
40    #[must_use]
41    pub fn from_error_message(message: &str) -> Self {
42        let lower = message.to_lowercase();
43
44        Self {
45            message_length: (message.len() as f32 / 500.0).min(1.0),
46
47            type_keywords: count_keywords(
48                &lower,
49                &[
50                    "expected",
51                    "found",
52                    "mismatched",
53                    "type",
54                    "cannot coerce",
55                    "incompatible",
56                ],
57            ),
58
59            borrow_keywords: count_keywords(
60                &lower,
61                &[
62                    "borrow",
63                    "borrowed",
64                    "move",
65                    "moved",
66                    "ownership",
67                    "cannot move",
68                ],
69            ),
70
71            import_keywords: count_keywords(
72                &lower,
73                &[
74                    "not found",
75                    "unresolved",
76                    "cannot find",
77                    "undefined",
78                    "undeclared",
79                ],
80            ),
81
82            lifetime_keywords: count_keywords(
83                &lower,
84                &[
85                    "lifetime",
86                    "'a",
87                    "'static",
88                    "live long enough",
89                    "dangling",
90                    "borrowed value",
91                ],
92            ),
93
94            trait_keywords: count_keywords(
95                &lower,
96                &[
97                    "trait",
98                    "impl",
99                    "not implemented",
100                    "bound",
101                    "doesn't implement",
102                ],
103            ),
104
105            has_line_number: if message.contains(':') && message.chars().any(|c| c.is_ascii_digit())
106            {
107                1.0
108            } else {
109                0.0
110            },
111
112            has_column: if message.matches(':').count() > 1 {
113                1.0
114            } else {
115                0.0
116            },
117
118            has_code_snippets: (message.matches('`').count() as f32 / 10.0).min(1.0),
119
120            has_arrows: if message.contains("-->") || message.contains("^^^") {
121                1.0
122            } else {
123                0.0
124            },
125
126            has_error_code: if message.contains("E0") || message.contains("[E") {
127                1.0
128            } else {
129                0.0
130            },
131
132            suggestion_count: count_keywords(
133                &lower,
134                &["help:", "suggestion:", "consider", "try", "perhaps"],
135            ),
136        }
137    }
138
139    /// Convert features to a row matrix for ML model.
140    #[must_use]
141    pub fn to_matrix(&self) -> Matrix<f32> {
142        Matrix::from_vec(1, Self::DIM, self.to_vec()).expect("Feature dimensions are correct")
143    }
144
145    /// Convert features to a vector.
146    #[must_use]
147    pub fn to_vec(&self) -> Vec<f32> {
148        vec![
149            self.message_length,
150            self.type_keywords,
151            self.borrow_keywords,
152            self.import_keywords,
153            self.lifetime_keywords,
154            self.trait_keywords,
155            self.has_line_number,
156            self.has_column,
157            self.has_code_snippets,
158            self.has_arrows,
159            self.has_error_code,
160            self.suggestion_count,
161        ]
162    }
163
164    /// Create features from a vector.
165    ///
166    /// # Panics
167    ///
168    /// Panics if vector length doesn't match DIM.
169    #[must_use]
170    pub fn from_vec(v: &[f32]) -> Self {
171        assert_eq!(
172            v.len(),
173            Self::DIM,
174            "Feature vector must have {} elements",
175            Self::DIM
176        );
177
178        Self {
179            message_length: v[0],
180            type_keywords: v[1],
181            borrow_keywords: v[2],
182            import_keywords: v[3],
183            lifetime_keywords: v[4],
184            trait_keywords: v[5],
185            has_line_number: v[6],
186            has_column: v[7],
187            has_code_snippets: v[8],
188            has_arrows: v[9],
189            has_error_code: v[10],
190            suggestion_count: v[11],
191        }
192    }
193}
194
195/// Count keyword occurrences (normalized).
196fn count_keywords(text: &str, keywords: &[&str]) -> f32 {
197    let count = keywords.iter().filter(|k| text.contains(*k)).count();
198    (count as f32 / keywords.len() as f32).min(1.0)
199}
200
201// GH-210: Enhanced feature extraction for Code2Vec & GNN upgrade
202// Expands from 12 to 73 dimensions for better ML classification
203
204/// Top 25 Rust error codes for one-hot encoding
205/// Based on frequency analysis from reprorusted-python-cli corpus
206pub const ERROR_CODES: [&str; 25] = [
207    "E0308", // mismatched types (most common)
208    "E0425", // cannot find value
209    "E0433", // failed to resolve
210    "E0277", // trait bound not satisfied
211    "E0599", // no method named
212    "E0382", // use of moved value
213    "E0502", // cannot borrow as mutable
214    "E0503", // cannot use while mutably borrowed
215    "E0505", // cannot move out of borrowed
216    "E0506", // cannot assign to borrowed
217    "E0507", // cannot move out of
218    "E0106", // missing lifetime specifier
219    "E0495", // cannot infer lifetime
220    "E0621", // explicit lifetime required
221    "E0282", // type annotations needed
222    "E0283", // type annotations required
223    "E0412", // cannot find type
224    "E0432", // unresolved import
225    "E0603", // private item
226    "E0609", // no field
227    "E0614", // cannot be dereferenced
228    "E0615", // attempted to take value
229    "E0616", // field is private
230    "E0618", // expected function
231    "E0620", // cast to unsized type
232];
233
234/// Extended keyword categories for detailed feature extraction
235pub const KEYWORD_CATEGORIES: [(&str, &[&str]); 9] = [
236    ("type_coercion", &["as", "into", "from", "convert", "cast"]),
237    ("ownership", &["owned", "clone", "copy", "drop", "take"]),
238    ("reference", &["ref", "&", "deref", "borrow"]),
239    ("mutability", &["mut", "immutable", "mutable"]),
240    ("generic", &["generic", "parameter", "constraint", "where"]),
241    ("async", &["async", "await", "future", "poll"]),
242    ("closure", &["closure", "capture", "fn", "move"]),
243    ("derive", &["derive", "debug", "clone", "default"]),
244    (
245        "result_option",
246        &["result", "option", "some", "none", "ok", "err", "unwrap"],
247    ),
248];
249
250/// GH-210: Enhanced error features with 73 dimensions
251/// Combines base features (12) + error code one-hot (25) + keyword counts (36)
252#[derive(Clone, Debug, Serialize, Deserialize)]
253pub struct EnhancedErrorFeatures {
254    /// Base 12 features from ErrorFeatures
255    pub base: ErrorFeatures,
256    /// One-hot encoding for top 25 error codes
257    pub error_code_onehot: Vec<f32>,
258    /// Detailed keyword occurrence counts (9 categories × 4 normalized features)
259    pub keyword_counts: Vec<f32>,
260}
261
262impl Default for EnhancedErrorFeatures {
263    fn default() -> Self {
264        Self {
265            base: ErrorFeatures::default(),
266            error_code_onehot: vec![0.0; 25],
267            keyword_counts: vec![0.0; 36],
268        }
269    }
270}
271
272impl EnhancedErrorFeatures {
273    /// Enhanced feature dimension: 12 + 25 + 36 = 73
274    pub const DIM: usize = 73;
275
276    /// Extract enhanced features from error message
277    #[must_use]
278    pub fn from_error_message(message: &str) -> Self {
279        let lower = message.to_lowercase();
280
281        // Base features
282        let base = ErrorFeatures::from_error_message(message);
283
284        // Error code one-hot encoding
285        let mut error_code_onehot = vec![0.0f32; 25];
286        for (i, code) in ERROR_CODES.iter().enumerate() {
287            if message.contains(code) {
288                error_code_onehot[i] = 1.0;
289                break; // Only one error code per message
290            }
291        }
292
293        // Extended keyword counts (9 categories × 4 features each)
294        let mut keyword_counts = vec![0.0f32; 36];
295        for (i, (_name, keywords)) in KEYWORD_CATEGORIES.iter().enumerate() {
296            let base_idx = i * 4;
297            // Feature 1: presence (0 or 1)
298            let present = keywords.iter().any(|k| lower.contains(k));
299            keyword_counts[base_idx] = if present { 1.0 } else { 0.0 };
300
301            // Feature 2: count ratio (normalized)
302            let count = keywords.iter().filter(|k| lower.contains(*k)).count();
303            keyword_counts[base_idx + 1] = (count as f32 / keywords.len() as f32).min(1.0);
304
305            // Feature 3: first occurrence position (normalized by message length)
306            let first_pos = keywords
307                .iter()
308                .filter_map(|k| lower.find(k))
309                .min()
310                .unwrap_or(lower.len());
311            keyword_counts[base_idx + 2] = 1.0 - (first_pos as f32 / lower.len().max(1) as f32);
312
313            // Feature 4: keyword density (occurrences per 100 chars)
314            let total_occurrences: usize = keywords.iter().map(|k| lower.matches(k).count()).sum();
315            keyword_counts[base_idx + 3] =
316                (total_occurrences as f32 * 100.0 / lower.len().max(1) as f32).min(1.0);
317        }
318
319        Self {
320            base,
321            error_code_onehot,
322            keyword_counts,
323        }
324    }
325
326    /// Convert to feature vector for ML model
327    #[must_use]
328    pub fn to_vec(&self) -> Vec<f32> {
329        let mut vec = Vec::with_capacity(Self::DIM);
330        vec.extend(self.base.to_vec());
331        vec.extend(self.error_code_onehot.iter());
332        vec.extend(self.keyword_counts.iter());
333        vec
334    }
335
336    /// Convert to matrix for ML model
337    #[must_use]
338    pub fn to_matrix(&self) -> Matrix<f32> {
339        Matrix::from_vec(1, Self::DIM, self.to_vec()).expect("Feature dimensions are correct")
340    }
341}
342
343/// Batch extraction for enhanced features
344pub struct EnhancedFeatureExtractor;
345
346impl EnhancedFeatureExtractor {
347    /// Extract enhanced features from multiple error messages
348    #[must_use]
349    pub fn extract_batch(messages: &[&str]) -> Matrix<f32> {
350        let features: Vec<f32> = messages
351            .iter()
352            .flat_map(|msg| EnhancedErrorFeatures::from_error_message(msg).to_vec())
353            .collect();
354
355        Matrix::from_vec(messages.len(), EnhancedErrorFeatures::DIM, features)
356            .expect("Feature batch dimensions are correct")
357    }
358}
359
360/// Batch feature extraction for training data.
361pub struct FeatureExtractor;
362
363impl FeatureExtractor {
364    /// Extract features from multiple error messages.
365    #[must_use]
366    pub fn extract_batch(messages: &[&str]) -> Matrix<f32> {
367        let features: Vec<f32> = messages
368            .iter()
369            .flat_map(|msg| ErrorFeatures::from_error_message(msg).to_vec())
370            .collect();
371
372        Matrix::from_vec(messages.len(), ErrorFeatures::DIM, features)
373            .expect("Feature batch dimensions are correct")
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_feature_extraction() {
383        let msg = "error[E0308]: mismatched types\n  --> src/main.rs:10:5\n   |\n10 |     foo(bar)\n   |         ^^^ expected `i32`, found `&str`";
384
385        let features = ErrorFeatures::from_error_message(msg);
386
387        assert!(features.message_length > 0.0);
388        assert!(features.type_keywords > 0.0);
389        assert!(features.has_error_code > 0.0);
390        assert!(features.has_line_number > 0.0);
391        assert!(features.has_arrows > 0.0);
392    }
393
394    #[test]
395    fn test_borrow_features() {
396        let msg = "error: cannot move out of borrowed content";
397        let features = ErrorFeatures::from_error_message(msg);
398
399        assert!(features.borrow_keywords > 0.0);
400        assert!((features.type_keywords - 0.0).abs() < 0.1);
401    }
402
403    #[test]
404    fn test_to_matrix() {
405        let msg = "error: expected i32";
406        let features = ErrorFeatures::from_error_message(msg);
407        let matrix = features.to_matrix();
408
409        assert_eq!(matrix.n_rows(), 1);
410        assert_eq!(matrix.n_cols(), ErrorFeatures::DIM);
411    }
412
413    #[test]
414    fn test_vec_roundtrip() {
415        let msg = "error: mismatched types";
416        let features = ErrorFeatures::from_error_message(msg);
417        let vec = features.to_vec();
418        let restored = ErrorFeatures::from_vec(&vec);
419
420        assert!((features.type_keywords - restored.type_keywords).abs() < 1e-6);
421    }
422
423    #[test]
424    fn test_batch_extraction() {
425        let messages = vec![
426            "error: expected i32",
427            "error: cannot move",
428            "error: not found",
429        ];
430
431        let matrix = FeatureExtractor::extract_batch(&messages);
432
433        assert_eq!(matrix.n_rows(), 3);
434        assert_eq!(matrix.n_cols(), ErrorFeatures::DIM);
435    }
436
437    #[test]
438    fn test_lifetime_features() {
439        let msg = "error: `x` does not live long enough";
440        let features = ErrorFeatures::from_error_message(msg);
441
442        assert!(features.lifetime_keywords > 0.0);
443    }
444
445    #[test]
446    fn test_trait_features() {
447        let msg = "error: the trait bound `T: Clone` is not satisfied";
448        let features = ErrorFeatures::from_error_message(msg);
449
450        assert!(features.trait_keywords > 0.0);
451    }
452
453    #[test]
454    fn test_suggestion_count() {
455        let msg = "error: type mismatch\nhelp: try using `.into()`\nhelp: consider adding type annotation";
456        let features = ErrorFeatures::from_error_message(msg);
457
458        assert!(features.suggestion_count > 0.0);
459    }
460
461    // GH-210: Tests for enhanced features
462
463    #[test]
464    fn test_enhanced_feature_dimension() {
465        let msg = "error[E0308]: mismatched types";
466        let features = EnhancedErrorFeatures::from_error_message(msg);
467        let vec = features.to_vec();
468
469        assert_eq!(vec.len(), EnhancedErrorFeatures::DIM);
470        assert_eq!(vec.len(), 73);
471    }
472
473    #[test]
474    fn test_enhanced_error_code_onehot() {
475        let msg = "error[E0308]: mismatched types\n  --> src/main.rs:10:5";
476        let features = EnhancedErrorFeatures::from_error_message(msg);
477
478        // E0308 is at index 0
479        assert_eq!(features.error_code_onehot[0], 1.0);
480        // All others should be 0
481        assert_eq!(features.error_code_onehot[1..].iter().sum::<f32>(), 0.0);
482    }
483
484    #[test]
485    fn test_enhanced_e0425_onehot() {
486        let msg = "error[E0425]: cannot find value `foo` in this scope";
487        let features = EnhancedErrorFeatures::from_error_message(msg);
488
489        // E0425 is at index 1
490        assert_eq!(features.error_code_onehot[1], 1.0);
491    }
492
493    #[test]
494    fn test_enhanced_keyword_categories() {
495        let msg = "error: cannot convert `&str` into `String`";
496        let features = EnhancedErrorFeatures::from_error_message(msg);
497
498        // type_coercion category (index 0-3) should have hits
499        // "into" and "convert" are both present
500        assert!(features.keyword_counts[0] > 0.0, "type_coercion presence");
501        assert!(
502            features.keyword_counts[1] > 0.0,
503            "type_coercion count ratio"
504        );
505    }
506
507    #[test]
508    fn test_enhanced_result_option_keywords() {
509        let msg = "error: cannot call `.unwrap()` on `Result<T, E>`";
510        let features = EnhancedErrorFeatures::from_error_message(msg);
511
512        // result_option category is index 8 (8 * 4 = 32-35)
513        assert!(features.keyword_counts[32] > 0.0, "result_option presence");
514    }
515
516    #[test]
517    fn test_enhanced_batch_extraction() {
518        let messages = vec![
519            "error[E0308]: expected i32, found &str",
520            "error[E0382]: use of moved value",
521            "error[E0277]: trait bound not satisfied",
522        ];
523
524        let matrix = EnhancedFeatureExtractor::extract_batch(&messages);
525
526        assert_eq!(matrix.n_rows(), 3);
527        assert_eq!(matrix.n_cols(), EnhancedErrorFeatures::DIM);
528    }
529
530    #[test]
531    fn test_enhanced_to_matrix() {
532        let msg = "error[E0599]: no method named `foo` found";
533        let features = EnhancedErrorFeatures::from_error_message(msg);
534        let matrix = features.to_matrix();
535
536        assert_eq!(matrix.n_rows(), 1);
537        assert_eq!(matrix.n_cols(), EnhancedErrorFeatures::DIM);
538    }
539}