1use serde::{Deserialize, Serialize};
28
29#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
37pub struct TaskClass {
38 pub id: String,
40 pub name: String,
42 pub signal_keywords: Vec<String>,
47}
48
49impl TaskClass {
50 pub fn new(
52 id: impl Into<String>,
53 name: impl Into<String>,
54 signal_keywords: impl IntoIterator<Item = impl Into<String>>,
55 ) -> Self {
56 Self {
57 id: id.into(),
58 name: name.into(),
59 signal_keywords: signal_keywords
60 .into_iter()
61 .map(|k| k.into().to_lowercase())
62 .collect(),
63 }
64 }
65
66 pub(crate) fn overlap_score(&self, signal: &str) -> usize {
71 let tokens = tokenise(signal);
72 self.signal_keywords
73 .iter()
74 .filter(|kw| tokens.contains(*kw))
75 .count()
76 }
77}
78
79pub fn builtin_task_classes() -> Vec<TaskClass> {
86 vec![
87 TaskClass::new(
88 "missing-import",
89 "Missing import / undefined symbol",
90 [
91 "e0425",
92 "e0433",
93 "unresolved",
94 "undefined",
95 "import",
96 "missing",
97 "cannot",
98 "find",
99 "symbol",
100 ],
101 ),
102 TaskClass::new(
103 "type-mismatch",
104 "Type mismatch",
105 [
106 "e0308",
107 "mismatched",
108 "expected",
109 "found",
110 "type",
111 "mismatch",
112 ],
113 ),
114 TaskClass::new(
115 "borrow-conflict",
116 "Borrow checker conflict",
117 [
118 "e0502", "e0505", "borrow", "lifetime", "moved", "cannot", "conflict",
119 ],
120 ),
121 TaskClass::new(
122 "test-failure",
123 "Test failure",
124 ["test", "failed", "panic", "assert", "assertion", "failure"],
125 ),
126 TaskClass::new(
127 "performance",
128 "Performance issue",
129 ["slow", "latency", "timeout", "perf", "performance", "hot"],
130 ),
131 ]
132}
133
134pub struct TaskClassMatcher {
138 classes: Vec<TaskClass>,
139}
140
141impl TaskClassMatcher {
142 pub fn new(classes: Vec<TaskClass>) -> Self {
144 Self { classes }
145 }
146
147 pub fn with_builtins() -> Self {
149 Self::new(builtin_task_classes())
150 }
151
152 pub fn classify<'a>(&'a self, signals: &[String]) -> Option<&'a TaskClass> {
156 let mut best: Option<(&TaskClass, usize)> = None;
157
158 for class in &self.classes {
159 let total_score: usize = signals.iter().map(|s| class.overlap_score(s)).sum();
160 if total_score > 0 {
161 match best {
162 None => best = Some((class, total_score)),
163 Some((_, prev_score)) if total_score > prev_score => {
164 best = Some((class, total_score));
165 }
166 _ => {}
167 }
168 }
169 }
170
171 best.map(|(c, _)| c)
172 }
173
174 pub fn classes(&self) -> &[TaskClass] {
176 &self.classes
177 }
178}
179
180fn tokenise(s: &str) -> Vec<String> {
184 s.split(|c: char| !c.is_alphanumeric())
185 .filter(|t| !t.is_empty())
186 .map(|t| t.to_lowercase())
187 .collect()
188}
189
190pub fn signals_match_class(signals: &[String], class_id: &str, registry: &[TaskClass]) -> bool {
194 let matcher = TaskClassMatcher::new(registry.to_vec());
195 matcher
196 .classify(signals)
197 .map_or(false, |c| c.id == class_id)
198}
199
200#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct TaskClassDefinition {
206 pub id: String,
208 pub name: String,
210 pub description: String,
212 pub signal_keywords: Vec<String>,
214}
215
216impl TaskClassDefinition {
217 pub fn into_task_class(self) -> TaskClass {
219 TaskClass::new(self.id, self.name, self.signal_keywords)
220 }
221}
222
223pub fn builtin_task_class_definitions() -> Vec<TaskClassDefinition> {
227 vec![
228 TaskClassDefinition {
229 id: "missing-import".to_string(),
230 name: "Missing import / undefined symbol".to_string(),
231 description: "Compiler cannot find symbol unresolved import undefined reference \
232 missing use declaration cannot find value in scope"
233 .to_string(),
234 signal_keywords: vec![
235 "e0425",
236 "e0433",
237 "unresolved",
238 "undefined",
239 "import",
240 "missing",
241 "cannot",
242 "find",
243 "symbol",
244 ]
245 .into_iter()
246 .map(String::from)
247 .collect(),
248 },
249 TaskClassDefinition {
250 id: "type-mismatch".to_string(),
251 name: "Type mismatch".to_string(),
252 description: "Type mismatch mismatched types expected one type found another \
253 type annotation required"
254 .to_string(),
255 signal_keywords: vec![
256 "e0308",
257 "mismatched",
258 "expected",
259 "found",
260 "type",
261 "mismatch",
262 ]
263 .into_iter()
264 .map(String::from)
265 .collect(),
266 },
267 TaskClassDefinition {
268 id: "borrow-conflict".to_string(),
269 name: "Borrow checker conflict".to_string(),
270 description: "Borrow checker conflict cannot borrow as mutable lifetime error \
271 value moved cannot use after move"
272 .to_string(),
273 signal_keywords: vec![
274 "e0502", "e0505", "borrow", "lifetime", "moved", "cannot", "conflict",
275 ]
276 .into_iter()
277 .map(String::from)
278 .collect(),
279 },
280 TaskClassDefinition {
281 id: "test-failure".to_string(),
282 name: "Test failure".to_string(),
283 description: "Test failure panicked assertion failed test did not pass".to_string(),
284 signal_keywords: vec!["test", "failed", "panic", "assert", "assertion", "failure"]
285 .into_iter()
286 .map(String::from)
287 .collect(),
288 },
289 TaskClassDefinition {
290 id: "performance".to_string(),
291 name: "Performance issue".to_string(),
292 description: "Performance issue slow response high latency operation timeout \
293 hot path resource contention"
294 .to_string(),
295 signal_keywords: vec!["slow", "latency", "timeout", "perf", "performance", "hot"]
296 .into_iter()
297 .map(String::from)
298 .collect(),
299 },
300 ]
301}
302
303#[cfg(feature = "evolution-experimental")]
306#[derive(Deserialize)]
307struct TaskClassesToml {
308 task_classes: Vec<TaskClassDefinition>,
309}
310
311#[cfg(feature = "evolution-experimental")]
318pub fn load_task_classes_from_toml(
319 path: &std::path::Path,
320) -> Result<Vec<TaskClassDefinition>, String> {
321 let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
322 let parsed: TaskClassesToml = toml::from_str(&content).map_err(|e| e.to_string())?;
323 Ok(parsed.task_classes)
324}
325
326pub fn load_task_classes() -> Vec<TaskClassDefinition> {
332 #[cfg(feature = "evolution-experimental")]
333 {
334 if let Some(home) = std::env::var_os("HOME") {
335 let path = std::path::Path::new(&home)
336 .join(".oris")
337 .join("oris-task-classes.toml");
338 if path.exists() {
339 if let Ok(classes) = load_task_classes_from_toml(&path) {
340 return classes;
341 }
342 }
343 }
344 }
345 builtin_task_class_definitions()
346}
347
348pub struct TaskClassInferencer {
364 classes: Vec<TaskClassDefinition>,
365 threshold: f32,
366}
367
368impl TaskClassInferencer {
369 pub fn new(classes: Vec<TaskClassDefinition>) -> Self {
371 Self {
372 classes,
373 threshold: 0.75,
374 }
375 }
376
377 pub fn with_builtins() -> Self {
379 Self::new(builtin_task_class_definitions())
380 }
381
382 pub fn with_threshold(mut self, threshold: f32) -> Self {
384 self.threshold = threshold;
385 self
386 }
387
388 pub fn infer(&self, signal_description: &str) -> String {
393 let signal_tokens = tokenise(signal_description);
394 if signal_tokens.is_empty() {
395 return "generic_fix".to_string();
396 }
397
398 let mut best_id = "generic_fix";
399 let mut best_score = 0.0f32;
400
401 for class in &self.classes {
402 let score = recall_score(&signal_tokens, &class.signal_keywords);
403 if score > best_score {
404 best_score = score;
405 best_id = &class.id;
406 }
407 }
408
409 if best_score >= self.threshold {
410 best_id.to_string()
411 } else {
412 "generic_fix".to_string()
413 }
414 }
415
416 pub fn class_definitions(&self) -> &[TaskClassDefinition] {
418 &self.classes
419 }
420}
421
422fn recall_score(signal_tokens: &[String], keywords: &[String]) -> f32 {
426 if keywords.is_empty() {
427 return 0.0;
428 }
429 let intersection = keywords
430 .iter()
431 .filter(|kw| signal_tokens.contains(kw))
432 .count();
433 intersection as f32 / keywords.len() as f32
434}
435
436#[cfg(test)]
439mod tests {
440 use super::*;
441
442 fn matcher() -> TaskClassMatcher {
443 TaskClassMatcher::with_builtins()
444 }
445
446 #[test]
449 fn test_missing_import_via_error_code() {
450 let m = matcher();
451 let signals = vec!["error[E0425]: cannot find value `foo` in scope".to_string()];
452 let cls = m.classify(&signals).expect("should classify");
453 assert_eq!(cls.id, "missing-import");
454 }
455
456 #[test]
457 fn test_missing_import_via_natural_language() {
458 let m = matcher();
459 let signals = vec!["undefined symbol: use_missing_fn".to_string()];
461 let cls = m.classify(&signals).expect("should classify");
462 assert_eq!(cls.id, "missing-import");
463 }
464
465 #[test]
466 fn test_missing_import_via_unresolved_import() {
467 let m = matcher();
468 let signals = vec!["unresolved import `std::collections::Missing`".to_string()];
469 let cls = m.classify(&signals).expect("should classify");
470 assert_eq!(cls.id, "missing-import");
471 }
472
473 #[test]
474 fn test_type_mismatch_classification() {
475 let m = matcher();
476 let signals =
477 vec!["error[E0308]: mismatched types: expected `u32` found `String`".to_string()];
478 let cls = m.classify(&signals).expect("should classify");
479 assert_eq!(cls.id, "type-mismatch");
480 }
481
482 #[test]
483 fn test_borrow_conflict_classification() {
484 let m = matcher();
485 let signals = vec![
486 "error[E0502]: cannot borrow `x` as mutable because it is also borrowed as immutable"
487 .to_string(),
488 ];
489 let cls = m.classify(&signals).expect("should classify");
490 assert_eq!(cls.id, "borrow-conflict");
491 }
492
493 #[test]
494 fn test_test_failure_classification() {
495 let m = matcher();
496 let signals = vec!["test panicked: assertion failed: x == y".to_string()];
497 let cls = m.classify(&signals).expect("should classify");
498 assert_eq!(cls.id, "test-failure");
499 }
500
501 #[test]
502 fn test_multiple_signals_accumulate_score() {
503 let m = matcher();
504 let signals = vec![
506 "expected type `u32`".to_string(),
507 "found type `String` — type mismatch".to_string(),
508 ];
509 let cls = m.classify(&signals).expect("should classify");
510 assert_eq!(cls.id, "type-mismatch");
511 }
512
513 #[test]
516 fn test_no_false_positive_type_vs_borrow() {
517 let m = matcher();
518 let signals = vec!["error[E0308]: mismatched type".to_string()];
520 let cls = m.classify(&signals).unwrap();
521 assert_ne!(
522 cls.id, "borrow-conflict",
523 "must not cross-match borrow-conflict"
524 );
525 }
526
527 #[test]
528 fn test_no_false_positive_borrow_vs_import() {
529 let m = matcher();
530 let signals = vec!["error[E0502]: cannot borrow as mutable".to_string()];
531 let cls = m.classify(&signals).unwrap();
532 assert_ne!(cls.id, "missing-import");
533 }
534
535 #[test]
536 fn test_no_match_returns_none() {
537 let m = matcher();
538 let signals = vec!["network timeout connecting to database server".to_string()];
540 if let Some(cls) = m.classify(&signals) {
543 assert_ne!(cls.id, "missing-import");
544 assert_ne!(cls.id, "type-mismatch");
545 assert_ne!(cls.id, "borrow-conflict");
546 }
547 }
549
550 #[test]
551 fn test_empty_signals_returns_none() {
552 let m = matcher();
553 assert!(m.classify(&[]).is_none());
554 }
555
556 #[test]
559 fn test_custom_class_wins_over_builtin() {
560 let mut classes = builtin_task_classes();
562 classes.push(TaskClass::new(
563 "db-timeout",
564 "Database timeout",
565 ["database", "timeout", "connection", "pool", "exhausted"],
566 ));
567 let m = TaskClassMatcher::new(classes);
568 let signals = vec!["database connection pool exhausted — timeout".to_string()];
569 let cls = m.classify(&signals).expect("should classify");
570 assert_eq!(cls.id, "db-timeout");
571 }
572
573 #[test]
574 fn test_signals_match_class_helper() {
575 let registry = builtin_task_classes();
576 let signals = vec!["error[E0425]: cannot find value".to_string()];
577 assert!(signals_match_class(&signals, "missing-import", ®istry));
578 assert!(!signals_match_class(&signals, "type-mismatch", ®istry));
579 }
580
581 #[test]
582 fn test_overlap_score_case_insensitive() {
583 let class = TaskClass::new("tc", "Test", ["e0425", "unresolved"]);
584 let m = TaskClassMatcher::new(vec![class]);
585 let signals = vec!["E0425 unresolved import".to_string()];
588 let cls = m
589 .classify(&signals)
590 .expect("case-insensitive classify should work");
591 assert_eq!(cls.id, "tc");
592 }
593
594 #[test]
597 fn inferencer_canonical_compiler_error_missing_import() {
598 let inferencer = TaskClassInferencer::with_builtins();
599 let signal = "error[E0425]: cannot find value `foo`: \
601 unresolved import symbol is undefined missing";
602 let class_id = inferencer.infer(signal);
603 assert_eq!(
604 class_id, "missing-import",
605 "canonical missing-import signal should infer correct class"
606 );
607 }
608
609 #[test]
610 fn inferencer_canonical_compiler_error_type_mismatch() {
611 let inferencer = TaskClassInferencer::with_builtins();
612 let signal = "error[E0308]: mismatched type expected u32 found String type mismatch";
614 let class_id = inferencer.infer(signal);
615 assert_eq!(class_id, "type-mismatch");
616 }
617
618 #[test]
619 fn inferencer_score_below_threshold_falls_back_to_generic_fix() {
620 let inferencer = TaskClassInferencer::with_builtins();
621 let signal = "e0308";
623 let class_id = inferencer.infer(signal);
624 assert_eq!(
625 class_id, "generic_fix",
626 "low-match signal must fall back to generic_fix"
627 );
628 }
629
630 #[test]
631 fn inferencer_empty_signal_falls_back_to_generic_fix() {
632 let inferencer = TaskClassInferencer::with_builtins();
633 assert_eq!(inferencer.infer(""), "generic_fix");
634 }
635
636 #[test]
637 fn inferencer_custom_threshold_lower_accepts_partial_match() {
638 let inferencer = TaskClassInferencer::with_builtins().with_threshold(0.3);
640 let class_id = inferencer.infer("E0308 mismatched");
642 assert_eq!(class_id, "type-mismatch");
643 }
644
645 #[test]
646 fn inferencer_builtin_definitions_are_configurable_via_load() {
647 let defs = load_task_classes();
650 assert!(
651 !defs.is_empty(),
652 "load_task_classes must return at least builtins"
653 );
654 let has_missing_import = defs.iter().any(|d| d.id == "missing-import");
655 assert!(
656 has_missing_import,
657 "builtin missing-import class must be present"
658 );
659 }
660}