Skip to main content

oris_evolution/
task_class.rs

1//! Task-class abstraction for semantic-equivalent task grouping.
2//!
3//! A `TaskClass` represents a category of semantically equivalent tasks that
4//! can reuse the same learned `Gene` even when the exact signal strings differ.
5//!
6//! # Example classes
7//!
8//! | ID | Name | Signal keywords |
9//! |----|------|-----------------|
10//! | `missing-import` | Missing import / undefined symbol | `E0425`, `E0433`, `unresolved`, `undefined`, `import`, `use` |
11//! | `type-mismatch` | Type mismatch | `E0308`, `mismatched`, `expected`, `found`, `type` |
12//! | `borrow-conflict` | Borrow checker conflict | `E0502`, `E0505`, `borrow`, `lifetime`, `moved` |
13//!
14//! # How matching works
15//!
16//! 1. Each signal string is tokenised into lowercase words.
17//! 2. A signal **matches** a `TaskClass` if the intersection of its word-set
18//!    with the class's `signal_keywords` is non-empty.
19//! 3. The `TaskClassMatcher::classify` method returns the class whose keywords
20//!    produce the highest overlap score with the combined signal list.
21//!
22//! Cross-class false positives are prevented because each class uses disjoint
23//! keyword sets; overlap scoring breaks ties by choosing the highest count, so
24//! a signal that partially matches two classes still maps to the one with
25//! more matching keywords.
26
27use serde::{Deserialize, Serialize};
28
29// ─── TaskClass ────────────────────────────────────────────────────────────────
30
31/// A named category of semantically equivalent tasks.
32///
33/// Genes are tagged with a `task_class_id` during their Solidify phase.
34/// When the Select stage cannot find an exact signal match, it falls back to
35/// `TaskClassMatcher` to surface candidates that share the same class.
36#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
37pub struct TaskClass {
38    /// Opaque, stable identifier. Genes reference this via `Gene::task_class_id`.
39    pub id: String,
40    /// Human-readable label.
41    pub name: String,
42    /// Lowercase keywords used for signal classification.
43    ///
44    /// A signal string matches this class when any of these keywords appears as
45    /// a word token (after lowercasing and splitting on non-alphanumeric chars).
46    pub signal_keywords: Vec<String>,
47}
48
49impl TaskClass {
50    /// Create a new `TaskClass`.
51    pub fn new(
52        id: impl Into<String>,
53        name: impl Into<String>,
54        signal_keywords: impl IntoIterator<Item = impl Into<String>>,
55    ) -> Self {
56        Self {
57            id: id.into(),
58            name: name.into(),
59            signal_keywords: signal_keywords
60                .into_iter()
61                .map(|k| k.into().to_lowercase())
62                .collect(),
63        }
64    }
65
66    /// Count how many keyword tokens overlap with `signal`.
67    ///
68    /// The signal is tokenised (split on non-alphanumeric characters) and each
69    /// token is compared against `signal_keywords`. Returns the overlap count.
70    pub(crate) fn overlap_score(&self, signal: &str) -> usize {
71        let tokens = tokenise(signal);
72        self.signal_keywords
73            .iter()
74            .filter(|kw| tokens.contains(*kw))
75            .count()
76    }
77}
78
79// ─── Built-in task classes ────────────────────────────────────────────────────
80
81/// Return the canonical built-in set of task classes.
82///
83/// Callers may extend this list with domain-specific classes before passing it
84/// to `TaskClassMatcher::new`.
85pub fn builtin_task_classes() -> Vec<TaskClass> {
86    vec![
87        TaskClass::new(
88            "missing-import",
89            "Missing import / undefined symbol",
90            [
91                "e0425",
92                "e0433",
93                "unresolved",
94                "undefined",
95                "import",
96                "missing",
97                "cannot",
98                "find",
99                "symbol",
100            ],
101        ),
102        TaskClass::new(
103            "type-mismatch",
104            "Type mismatch",
105            [
106                "e0308",
107                "mismatched",
108                "expected",
109                "found",
110                "type",
111                "mismatch",
112            ],
113        ),
114        TaskClass::new(
115            "borrow-conflict",
116            "Borrow checker conflict",
117            [
118                "e0502", "e0505", "borrow", "lifetime", "moved", "cannot", "conflict",
119            ],
120        ),
121        TaskClass::new(
122            "test-failure",
123            "Test failure",
124            ["test", "failed", "panic", "assert", "assertion", "failure"],
125        ),
126        TaskClass::new(
127            "performance",
128            "Performance issue",
129            ["slow", "latency", "timeout", "perf", "performance", "hot"],
130        ),
131    ]
132}
133
134// ─── TaskClassMatcher ─────────────────────────────────────────────────────────
135
136/// Classifies a list of signal strings to the best-matching `TaskClass`.
137pub struct TaskClassMatcher {
138    classes: Vec<TaskClass>,
139}
140
141impl TaskClassMatcher {
142    /// Create a matcher with the provided task-class registry.
143    pub fn new(classes: Vec<TaskClass>) -> Self {
144        Self { classes }
145    }
146
147    /// Create a matcher pre-loaded with `builtin_task_classes()`.
148    pub fn with_builtins() -> Self {
149        Self::new(builtin_task_classes())
150    }
151
152    /// Classify `signals` to the best-matching task class.
153    ///
154    /// Returns `None` when no class achieves a positive overlap score.
155    pub fn classify<'a>(&'a self, signals: &[String]) -> Option<&'a TaskClass> {
156        let mut best: Option<(&TaskClass, usize)> = None;
157
158        for class in &self.classes {
159            let total_score: usize = signals.iter().map(|s| class.overlap_score(s)).sum();
160            if total_score > 0 {
161                match best {
162                    None => best = Some((class, total_score)),
163                    Some((_, prev_score)) if total_score > prev_score => {
164                        best = Some((class, total_score));
165                    }
166                    _ => {}
167                }
168            }
169        }
170
171        best.map(|(c, _)| c)
172    }
173
174    /// Return a reference to the underlying class registry.
175    pub fn classes(&self) -> &[TaskClass] {
176        &self.classes
177    }
178}
179
180// ─── Helpers ──────────────────────────────────────────────────────────────────
181
182/// Tokenise a string into lowercase alphanumeric words.
183fn tokenise(s: &str) -> Vec<String> {
184    s.split(|c: char| !c.is_alphanumeric())
185        .filter(|t| !t.is_empty())
186        .map(|t| t.to_lowercase())
187        .collect()
188}
189
190/// Check whether `signals` match the given task-class ID in `registry`.
191///
192/// A convenience wrapper around `TaskClassMatcher::classify`.
193pub fn signals_match_class(signals: &[String], class_id: &str, registry: &[TaskClass]) -> bool {
194    let matcher = TaskClassMatcher::new(registry.to_vec());
195    matcher
196        .classify(signals)
197        .map_or(false, |c| c.id == class_id)
198}
199
200// ─── TaskClassDefinition ──────────────────────────────────────────────────────
201
202/// Extended task-class definition that adds a natural-language `description`
203/// field used by `TaskClassInferencer` for semantic matching and TOML persistence.
204#[derive(Clone, Debug, Serialize, Deserialize)]
205pub struct TaskClassDefinition {
206    /// Opaque, stable identifier.
207    pub id: String,
208    /// Human-readable label.
209    pub name: String,
210    /// Natural-language description used when scoring signal similarity.
211    pub description: String,
212    /// Lowercase keywords used for overlap-based classification.
213    pub signal_keywords: Vec<String>,
214}
215
216impl TaskClassDefinition {
217    /// Convert into a lightweight `TaskClass` (drops the description field).
218    pub fn into_task_class(self) -> TaskClass {
219        TaskClass::new(self.id, self.name, self.signal_keywords)
220    }
221}
222
223// ─── Built-in task class definitions ─────────────────────────────────────────
224
225/// Return the canonical built-in task class definitions including descriptions.
226pub fn builtin_task_class_definitions() -> Vec<TaskClassDefinition> {
227    vec![
228        TaskClassDefinition {
229            id: "missing-import".to_string(),
230            name: "Missing import / undefined symbol".to_string(),
231            description: "Compiler cannot find symbol unresolved import undefined reference \
232                          missing use declaration cannot find value in scope"
233                .to_string(),
234            signal_keywords: vec![
235                "e0425",
236                "e0433",
237                "unresolved",
238                "undefined",
239                "import",
240                "missing",
241                "cannot",
242                "find",
243                "symbol",
244            ]
245            .into_iter()
246            .map(String::from)
247            .collect(),
248        },
249        TaskClassDefinition {
250            id: "type-mismatch".to_string(),
251            name: "Type mismatch".to_string(),
252            description: "Type mismatch mismatched types expected one type found another \
253                          type annotation required"
254                .to_string(),
255            signal_keywords: vec![
256                "e0308",
257                "mismatched",
258                "expected",
259                "found",
260                "type",
261                "mismatch",
262            ]
263            .into_iter()
264            .map(String::from)
265            .collect(),
266        },
267        TaskClassDefinition {
268            id: "borrow-conflict".to_string(),
269            name: "Borrow checker conflict".to_string(),
270            description: "Borrow checker conflict cannot borrow as mutable lifetime error \
271                          value moved cannot use after move"
272                .to_string(),
273            signal_keywords: vec![
274                "e0502", "e0505", "borrow", "lifetime", "moved", "cannot", "conflict",
275            ]
276            .into_iter()
277            .map(String::from)
278            .collect(),
279        },
280        TaskClassDefinition {
281            id: "test-failure".to_string(),
282            name: "Test failure".to_string(),
283            description: "Test failure panicked assertion failed test did not pass".to_string(),
284            signal_keywords: vec!["test", "failed", "panic", "assert", "assertion", "failure"]
285                .into_iter()
286                .map(String::from)
287                .collect(),
288        },
289        TaskClassDefinition {
290            id: "performance".to_string(),
291            name: "Performance issue".to_string(),
292            description: "Performance issue slow response high latency operation timeout \
293                          hot path resource contention"
294                .to_string(),
295            signal_keywords: vec!["slow", "latency", "timeout", "perf", "performance", "hot"]
296                .into_iter()
297                .map(String::from)
298                .collect(),
299        },
300    ]
301}
302
303// ─── TOML persistence ─────────────────────────────────────────────────────────
304
305#[cfg(feature = "evolution-experimental")]
306#[derive(Deserialize)]
307struct TaskClassesToml {
308    task_classes: Vec<TaskClassDefinition>,
309}
310
311/// Load task class definitions from a TOML file.
312///
313/// The file must contain a top-level `[[task_classes]]` array whose entries
314/// each have `id`, `name`, `description`, and `signal_keywords` fields.
315///
316/// Only available with the `evolution-experimental` feature.
317#[cfg(feature = "evolution-experimental")]
318pub fn load_task_classes_from_toml(
319    path: &std::path::Path,
320) -> Result<Vec<TaskClassDefinition>, String> {
321    let content = std::fs::read_to_string(path).map_err(|e| e.to_string())?;
322    let parsed: TaskClassesToml = toml::from_str(&content).map_err(|e| e.to_string())?;
323    Ok(parsed.task_classes)
324}
325
326/// Load task class definitions.
327///
328/// When the `evolution-experimental` feature is enabled, attempts to load from
329/// `~/.oris/oris-task-classes.toml` if it exists; otherwise falls back to
330/// `builtin_task_class_definitions()`.
331pub fn load_task_classes() -> Vec<TaskClassDefinition> {
332    #[cfg(feature = "evolution-experimental")]
333    {
334        if let Some(home) = std::env::var_os("HOME") {
335            let path = std::path::Path::new(&home)
336                .join(".oris")
337                .join("oris-task-classes.toml");
338            if path.exists() {
339                if let Ok(classes) = load_task_classes_from_toml(&path) {
340                    return classes;
341                }
342            }
343        }
344    }
345    builtin_task_class_definitions()
346}
347
348// ─── TaskClassInferencer ──────────────────────────────────────────────────────
349
350/// Infers the task class for a signal description using keyword recall scoring.
351///
352/// # Scoring
353///
354/// For each registered class, the score is:
355///
356/// ```text
357/// score = |signal_tokens ∩ class_keywords| / |class_keywords|
358/// ```
359///
360/// The class with the highest score is returned when the score meets
361/// `threshold` (default `0.75`).  When no class reaches the threshold the
362/// fallback `"generic_fix"` ID is returned.
363pub struct TaskClassInferencer {
364    classes: Vec<TaskClassDefinition>,
365    threshold: f32,
366}
367
368impl TaskClassInferencer {
369    /// Create an inferencer from a custom set of definitions.
370    pub fn new(classes: Vec<TaskClassDefinition>) -> Self {
371        Self {
372            classes,
373            threshold: 0.75,
374        }
375    }
376
377    /// Create an inferencer pre-loaded with `builtin_task_class_definitions()`.
378    pub fn with_builtins() -> Self {
379        Self::new(builtin_task_class_definitions())
380    }
381
382    /// Override the similarity threshold (default `0.75`).
383    pub fn with_threshold(mut self, threshold: f32) -> Self {
384        self.threshold = threshold;
385        self
386    }
387
388    /// Infer the task class ID for the given signal description.
389    ///
390    /// Returns the ID of the best matching class when it achieves a score
391    /// ≥ `threshold`, or `"generic_fix"` otherwise.
392    pub fn infer(&self, signal_description: &str) -> String {
393        let signal_tokens = tokenise(signal_description);
394        if signal_tokens.is_empty() {
395            return "generic_fix".to_string();
396        }
397
398        let mut best_id = "generic_fix";
399        let mut best_score = 0.0f32;
400
401        for class in &self.classes {
402            let score = recall_score(&signal_tokens, &class.signal_keywords);
403            if score > best_score {
404                best_score = score;
405                best_id = &class.id;
406            }
407        }
408
409        if best_score >= self.threshold {
410            best_id.to_string()
411        } else {
412            "generic_fix".to_string()
413        }
414    }
415
416    /// Return a reference to the underlying class definitions.
417    pub fn class_definitions(&self) -> &[TaskClassDefinition] {
418        &self.classes
419    }
420}
421
422// ─── Internal similarity helper ───────────────────────────────────────────────
423
424/// Keyword recall: fraction of class keywords that appear in the signal tokens.
425fn recall_score(signal_tokens: &[String], keywords: &[String]) -> f32 {
426    if keywords.is_empty() {
427        return 0.0;
428    }
429    let intersection = keywords
430        .iter()
431        .filter(|kw| signal_tokens.contains(kw))
432        .count();
433    intersection as f32 / keywords.len() as f32
434}
435
436// ─── Tests ────────────────────────────────────────────────────────────────────
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    fn matcher() -> TaskClassMatcher {
443        TaskClassMatcher::with_builtins()
444    }
445
446    // ── Positive: same task-class, different signal variants ─────────────────
447
448    #[test]
449    fn test_missing_import_via_error_code() {
450        let m = matcher();
451        let signals = vec!["error[E0425]: cannot find value `foo` in scope".to_string()];
452        let cls = m.classify(&signals).expect("should classify");
453        assert_eq!(cls.id, "missing-import");
454    }
455
456    #[test]
457    fn test_missing_import_via_natural_language() {
458        let m = matcher();
459        // Different phrasing — no Rust error code, but "undefined symbol" keywords
460        let signals = vec!["undefined symbol: use_missing_fn".to_string()];
461        let cls = m.classify(&signals).expect("should classify");
462        assert_eq!(cls.id, "missing-import");
463    }
464
465    #[test]
466    fn test_missing_import_via_unresolved_import() {
467        let m = matcher();
468        let signals = vec!["unresolved import `std::collections::Missing`".to_string()];
469        let cls = m.classify(&signals).expect("should classify");
470        assert_eq!(cls.id, "missing-import");
471    }
472
473    #[test]
474    fn test_type_mismatch_classification() {
475        let m = matcher();
476        let signals =
477            vec!["error[E0308]: mismatched types: expected `u32` found `String`".to_string()];
478        let cls = m.classify(&signals).expect("should classify");
479        assert_eq!(cls.id, "type-mismatch");
480    }
481
482    #[test]
483    fn test_borrow_conflict_classification() {
484        let m = matcher();
485        let signals = vec![
486            "error[E0502]: cannot borrow `x` as mutable because it is also borrowed as immutable"
487                .to_string(),
488        ];
489        let cls = m.classify(&signals).expect("should classify");
490        assert_eq!(cls.id, "borrow-conflict");
491    }
492
493    #[test]
494    fn test_test_failure_classification() {
495        let m = matcher();
496        let signals = vec!["test panicked: assertion failed: x == y".to_string()];
497        let cls = m.classify(&signals).expect("should classify");
498        assert_eq!(cls.id, "test-failure");
499    }
500
501    #[test]
502    fn test_multiple_signals_accumulate_score() {
503        let m = matcher();
504        // Two signals both pointing at type-mismatch → still resolves correctly
505        let signals = vec![
506            "expected type `u32`".to_string(),
507            "found type `String` — type mismatch".to_string(),
508        ];
509        let cls = m.classify(&signals).expect("should classify");
510        assert_eq!(cls.id, "type-mismatch");
511    }
512
513    // ── Negative: cross-class no false positives ──────────────────────────────
514
515    #[test]
516    fn test_no_false_positive_type_vs_borrow() {
517        let m = matcher();
518        // "E0308" → type-mismatch only, not borrow-conflict
519        let signals = vec!["error[E0308]: mismatched type".to_string()];
520        let cls = m.classify(&signals).unwrap();
521        assert_ne!(
522            cls.id, "borrow-conflict",
523            "must not cross-match borrow-conflict"
524        );
525    }
526
527    #[test]
528    fn test_no_false_positive_borrow_vs_import() {
529        let m = matcher();
530        let signals = vec!["error[E0502]: cannot borrow as mutable".to_string()];
531        let cls = m.classify(&signals).unwrap();
532        assert_ne!(cls.id, "missing-import");
533    }
534
535    #[test]
536    fn test_no_match_returns_none() {
537        let m = matcher();
538        // Completely unrelated signal with no keyword overlap
539        let signals = vec!["network timeout connecting to database server".to_string()];
540        // This might match "performance/timeout" — but if it doesn't, None is fine.
541        // The key invariant is it doesn't match an unrelated class like "missing-import".
542        if let Some(cls) = m.classify(&signals) {
543            assert_ne!(cls.id, "missing-import");
544            assert_ne!(cls.id, "type-mismatch");
545            assert_ne!(cls.id, "borrow-conflict");
546        }
547        // None is also acceptable
548    }
549
550    #[test]
551    fn test_empty_signals_returns_none() {
552        let m = matcher();
553        assert!(m.classify(&[]).is_none());
554    }
555
556    // ── Boundary: custom classes ──────────────────────────────────────────────
557
558    #[test]
559    fn test_custom_class_wins_over_builtin() {
560        // A domain-specific class with high keyword density should beat builtins
561        let mut classes = builtin_task_classes();
562        classes.push(TaskClass::new(
563            "db-timeout",
564            "Database timeout",
565            ["database", "timeout", "connection", "pool", "exhausted"],
566        ));
567        let m = TaskClassMatcher::new(classes);
568        let signals = vec!["database connection pool exhausted — timeout".to_string()];
569        let cls = m.classify(&signals).expect("should classify");
570        assert_eq!(cls.id, "db-timeout");
571    }
572
573    #[test]
574    fn test_signals_match_class_helper() {
575        let registry = builtin_task_classes();
576        let signals = vec!["error[E0425]: cannot find value".to_string()];
577        assert!(signals_match_class(&signals, "missing-import", &registry));
578        assert!(!signals_match_class(&signals, "type-mismatch", &registry));
579    }
580
581    #[test]
582    fn test_overlap_score_case_insensitive() {
583        let class = TaskClass::new("tc", "Test", ["e0425", "unresolved"]);
584        let m = TaskClassMatcher::new(vec![class]);
585        // Signal contains uppercase E0425 — tokenise lowercases all tokens
586        // so the match is case-insensitive.
587        let signals = vec!["E0425 unresolved import".to_string()];
588        let cls = m
589            .classify(&signals)
590            .expect("case-insensitive classify should work");
591        assert_eq!(cls.id, "tc");
592    }
593
594    // ── TaskClassInferencer tests ─────────────────────────────────────────────
595
596    #[test]
597    fn inferencer_canonical_compiler_error_missing_import() {
598        let inferencer = TaskClassInferencer::with_builtins();
599        // Canonical signal: contains the majority of missing-import keywords.
600        let signal = "error[E0425]: cannot find value `foo`: \
601                      unresolved import symbol is undefined missing";
602        let class_id = inferencer.infer(signal);
603        assert_eq!(
604            class_id, "missing-import",
605            "canonical missing-import signal should infer correct class"
606        );
607    }
608
609    #[test]
610    fn inferencer_canonical_compiler_error_type_mismatch() {
611        let inferencer = TaskClassInferencer::with_builtins();
612        // Canonical signal: contains most type-mismatch keywords.
613        let signal = "error[E0308]: mismatched type expected u32 found String type mismatch";
614        let class_id = inferencer.infer(signal);
615        assert_eq!(class_id, "type-mismatch");
616    }
617
618    #[test]
619    fn inferencer_score_below_threshold_falls_back_to_generic_fix() {
620        let inferencer = TaskClassInferencer::with_builtins();
621        // Signal with only one matching keyword — far below 0.75 threshold.
622        let signal = "e0308";
623        let class_id = inferencer.infer(signal);
624        assert_eq!(
625            class_id, "generic_fix",
626            "low-match signal must fall back to generic_fix"
627        );
628    }
629
630    #[test]
631    fn inferencer_empty_signal_falls_back_to_generic_fix() {
632        let inferencer = TaskClassInferencer::with_builtins();
633        assert_eq!(inferencer.infer(""), "generic_fix");
634    }
635
636    #[test]
637    fn inferencer_custom_threshold_lower_accepts_partial_match() {
638        // With a lower threshold partial matches should succeed.
639        let inferencer = TaskClassInferencer::with_builtins().with_threshold(0.3);
640        // "E0308 mismatched" — 2/6 = 0.333, which is ≥ 0.30 threshold.
641        let class_id = inferencer.infer("E0308 mismatched");
642        assert_eq!(class_id, "type-mismatch");
643    }
644
645    #[test]
646    fn inferencer_builtin_definitions_are_configurable_via_load() {
647        // load_task_classes() must return at least the builtin definitions
648        // (no TOML file exists in CI — falls back to builtins).
649        let defs = load_task_classes();
650        assert!(
651            !defs.is_empty(),
652            "load_task_classes must return at least builtins"
653        );
654        let has_missing_import = defs.iter().any(|d| d.id == "missing-import");
655        assert!(
656            has_missing_import,
657            "builtin missing-import class must be present"
658        );
659    }
660}