Skip to main content

lean_ctx/core/
intent_engine.rs

1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
2pub enum TaskType {
3    Generate,
4    FixBug,
5    Refactor,
6    Explore,
7    Test,
8    Debug,
9    Config,
10    Deploy,
11    Review,
12}
13
14impl TaskType {
15    pub fn as_str(&self) -> &'static str {
16        match self {
17            Self::Generate => "generate",
18            Self::FixBug => "fix_bug",
19            Self::Refactor => "refactor",
20            Self::Explore => "explore",
21            Self::Test => "test",
22            Self::Debug => "debug",
23            Self::Config => "config",
24            Self::Deploy => "deploy",
25            Self::Review => "review",
26        }
27    }
28
29    pub fn thinking_budget(&self) -> ThinkingBudget {
30        match self {
31            Self::Generate | Self::FixBug | Self::Test | Self::Config | Self::Deploy => {
32                ThinkingBudget::Minimal
33            }
34            Self::Refactor | Self::Explore | Self::Debug | Self::Review => ThinkingBudget::Medium,
35        }
36    }
37
38    pub fn output_format(&self) -> OutputFormat {
39        match self {
40            Self::Generate | Self::Test | Self::Config => OutputFormat::CodeOnly,
41            Self::FixBug | Self::Refactor => OutputFormat::DiffOnly,
42            Self::Explore | Self::Review => OutputFormat::ExplainConcise,
43            Self::Debug => OutputFormat::Trace,
44            Self::Deploy => OutputFormat::StepList,
45        }
46    }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ThinkingBudget {
51    Minimal,
52    Medium,
53    Trace,
54    Deep,
55}
56
57impl ThinkingBudget {
58    pub fn instruction(&self) -> &'static str {
59        match self {
60            Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
61            Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
62            Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
63            Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
64        }
65    }
66
67    pub fn suppresses_thinking(&self) -> bool {
68        matches!(self, Self::Minimal)
69    }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum OutputFormat {
74    CodeOnly,
75    DiffOnly,
76    ExplainConcise,
77    Trace,
78    StepList,
79}
80
81impl OutputFormat {
82    pub fn instruction(&self) -> &'static str {
83        match self {
84            Self::CodeOnly => {
85                "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
86            }
87            Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
88            Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
89            Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
90            Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
91        }
92    }
93}
94
95#[derive(Debug)]
96pub struct TaskClassification {
97    pub task_type: TaskType,
98    pub confidence: f64,
99    pub targets: Vec<String>,
100    pub keywords: Vec<String>,
101}
102
103const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
104    (
105        &[
106            "add",
107            "create",
108            "implement",
109            "build",
110            "write",
111            "generate",
112            "make",
113            "new feature",
114            "new",
115        ],
116        TaskType::Generate,
117        0.9,
118    ),
119    (
120        &[
121            "fix",
122            "bug",
123            "broken",
124            "crash",
125            "error in",
126            "not working",
127            "fails",
128            "wrong output",
129        ],
130        TaskType::FixBug,
131        0.95,
132    ),
133    (
134        &[
135            "refactor",
136            "clean up",
137            "restructure",
138            "rename",
139            "move",
140            "extract",
141            "simplify",
142            "split",
143        ],
144        TaskType::Refactor,
145        0.9,
146    ),
147    (
148        &[
149            "how",
150            "what",
151            "where",
152            "explain",
153            "understand",
154            "show me",
155            "describe",
156            "why does",
157        ],
158        TaskType::Explore,
159        0.85,
160    ),
161    (
162        &[
163            "test",
164            "spec",
165            "coverage",
166            "assert",
167            "unit test",
168            "integration test",
169            "mock",
170        ],
171        TaskType::Test,
172        0.9,
173    ),
174    (
175        &[
176            "debug",
177            "trace",
178            "inspect",
179            "log",
180            "breakpoint",
181            "step through",
182            "stack trace",
183        ],
184        TaskType::Debug,
185        0.9,
186    ),
187    (
188        &[
189            "config",
190            "setup",
191            "install",
192            "env",
193            "configure",
194            "settings",
195            "dotenv",
196        ],
197        TaskType::Config,
198        0.85,
199    ),
200    (
201        &[
202            "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
203        ],
204        TaskType::Deploy,
205        0.85,
206    ),
207    (
208        &[
209            "review",
210            "check",
211            "audit",
212            "look at",
213            "evaluate",
214            "assess",
215            "pr review",
216        ],
217        TaskType::Review,
218        0.8,
219    ),
220];
221
222pub fn classify(query: &str) -> TaskClassification {
223    let q = query.to_lowercase();
224    let words: Vec<&str> = q.split_whitespace().collect();
225
226    let mut best_type = TaskType::Explore;
227    let mut best_score = 0.0_f64;
228
229    for &(phrases, task_type, base_confidence) in PHRASE_RULES {
230        let mut match_count = 0usize;
231        for phrase in phrases {
232            if phrase.contains(' ') {
233                if q.contains(phrase) {
234                    match_count += 2;
235                }
236            } else if words.contains(phrase) {
237                match_count += 1;
238            }
239        }
240        if match_count > 0 {
241            let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
242            if score > best_score {
243                best_score = score;
244                best_type = task_type;
245            }
246        }
247    }
248
249    let targets = extract_targets(query);
250    let keywords = extract_keywords(&q);
251
252    if best_score < 0.1 {
253        best_type = TaskType::Explore;
254        best_score = 0.3;
255    }
256
257    TaskClassification {
258        task_type: best_type,
259        confidence: best_score,
260        targets,
261        keywords,
262    }
263}
264
265fn extract_targets(query: &str) -> Vec<String> {
266    let mut targets = Vec::new();
267
268    for word in query.split_whitespace() {
269        if word.contains('.') && !word.starts_with('.') {
270            let clean = word.trim_matches(|c: char| {
271                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
272            });
273            if looks_like_path(clean) {
274                targets.push(clean.to_string());
275            }
276        }
277        if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
278            let clean = word.trim_matches(|c: char| {
279                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
280            });
281            if clean.len() > 2 {
282                targets.push(clean.to_string());
283            }
284        }
285    }
286
287    for word in query.split_whitespace() {
288        let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
289        if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
290            targets.push(w.to_string());
291        }
292        if w.chars().any(char::is_uppercase)
293            && w.len() > 2
294            && !is_stop_word(w)
295            && !targets.contains(&w.to_string())
296        {
297            targets.push(w.to_string());
298        }
299    }
300
301    targets.truncate(5);
302    targets
303}
304
305fn looks_like_path(s: &str) -> bool {
306    let exts = [
307        ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
308    ];
309    exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
310}
311
312fn is_stop_word(w: &str) -> bool {
313    matches!(
314        w.to_lowercase().as_str(),
315        "the"
316            | "this"
317            | "that"
318            | "with"
319            | "from"
320            | "into"
321            | "have"
322            | "please"
323            | "could"
324            | "would"
325            | "should"
326            | "also"
327            | "just"
328            | "then"
329            | "when"
330            | "what"
331            | "where"
332            | "which"
333            | "there"
334            | "here"
335            | "these"
336            | "those"
337            | "does"
338            | "will"
339            | "shall"
340            | "can"
341            | "may"
342            | "must"
343            | "need"
344            | "want"
345            | "like"
346            | "make"
347            | "take"
348    )
349}
350
351fn extract_keywords(query: &str) -> Vec<String> {
352    query
353        .split_whitespace()
354        .filter(|w| w.len() > 3)
355        .filter(|w| !is_stop_word(w))
356        .map(|w| {
357            w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
358                .to_lowercase()
359        })
360        .filter(|w| !w.is_empty())
361        .take(8)
362        .collect()
363}
364
365pub fn classify_complexity(
366    query: &str,
367    classification: &TaskClassification,
368) -> super::adaptive::TaskComplexity {
369    use super::adaptive::TaskComplexity;
370
371    let q = query.to_lowercase();
372    let word_count = q.split_whitespace().count();
373    let target_count = classification.targets.len();
374
375    let has_multi_file = target_count >= 3;
376    let has_cross_cutting = q.contains("all files")
377        || q.contains("across")
378        || q.contains("everywhere")
379        || q.contains("every")
380        || q.contains("migration")
381        || q.contains("architecture");
382
383    let is_simple = word_count < 8
384        && target_count <= 1
385        && matches!(
386            classification.task_type,
387            TaskType::Generate | TaskType::Config
388        );
389
390    if is_simple {
391        TaskComplexity::Mechanical
392    } else if has_multi_file || has_cross_cutting {
393        TaskComplexity::Architectural
394    } else {
395        TaskComplexity::Standard
396    }
397}
398
399pub fn detect_multi_intent(query: &str) -> Vec<TaskClassification> {
400    let delimiters = [" and then ", " then ", " also ", " + ", ". "];
401
402    let mut parts: Vec<&str> = vec![query];
403    for delim in &delimiters {
404        let mut new_parts = Vec::new();
405        for part in &parts {
406            for sub in part.split(delim) {
407                let trimmed = sub.trim();
408                if !trimmed.is_empty() {
409                    new_parts.push(trimmed);
410                }
411            }
412        }
413        parts = new_parts;
414    }
415
416    if parts.len() <= 1 {
417        return vec![classify(query)];
418    }
419
420    parts.iter().map(|part| classify(part)).collect()
421}
422
423pub fn format_briefing_header(classification: &TaskClassification) -> String {
424    format!(
425        "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
426        classification.task_type.as_str(),
427        classification.confidence * 100.0,
428        if classification.targets.is_empty() {
429            "-".to_string()
430        } else {
431            classification.targets.join(",")
432        },
433        classification.keywords.join(","),
434    )
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
438pub enum IntentScope {
439    SingleFile,
440    MultiFile,
441    CrossModule,
442    ProjectWide,
443}
444
445#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
446pub struct StructuredIntent {
447    pub task_type: TaskType,
448    pub confidence: f64,
449    pub targets: Vec<String>,
450    pub keywords: Vec<String>,
451    pub scope: IntentScope,
452    pub language_hint: Option<String>,
453    pub urgency: f64,
454    pub action_verb: Option<String>,
455}
456
457impl StructuredIntent {
458    pub fn from_query(query: &str) -> Self {
459        let classification = classify(query);
460        let complexity = classify_complexity(query, &classification);
461        let file_targets = classification
462            .targets
463            .iter()
464            .filter(|t| t.contains('.') || t.contains('/'))
465            .count();
466        let scope = match complexity {
467            super::adaptive::TaskComplexity::Mechanical => IntentScope::SingleFile,
468            super::adaptive::TaskComplexity::Standard => {
469                if file_targets > 1 {
470                    IntentScope::MultiFile
471                } else {
472                    IntentScope::SingleFile
473                }
474            }
475            super::adaptive::TaskComplexity::Architectural => {
476                let q = query.to_lowercase();
477                if q.contains("all files") || q.contains("everywhere") || q.contains("migration") {
478                    IntentScope::ProjectWide
479                } else {
480                    IntentScope::CrossModule
481                }
482            }
483        };
484
485        let language_hint = detect_language_hint(query, &classification.targets);
486        let urgency = detect_urgency(query);
487        let action_verb = extract_action_verb(query);
488
489        StructuredIntent {
490            task_type: classification.task_type,
491            confidence: classification.confidence,
492            targets: classification.targets,
493            keywords: classification.keywords,
494            scope,
495            language_hint,
496            urgency,
497            action_verb,
498        }
499    }
500
501    pub fn from_file_patterns(touched_files: &[String]) -> Self {
502        if touched_files.is_empty() {
503            return Self {
504                task_type: TaskType::Explore,
505                confidence: 0.3,
506                targets: Vec::new(),
507                keywords: Vec::new(),
508                scope: IntentScope::SingleFile,
509                language_hint: None,
510                urgency: 0.0,
511                action_verb: None,
512            };
513        }
514
515        let has_tests = touched_files
516            .iter()
517            .any(|f| f.contains("test") || f.contains("spec"));
518        let has_config = touched_files.iter().any(|f| {
519            let p = std::path::Path::new(f.as_str());
520            let is_config_ext = p.extension().is_some_and(|e| {
521                e.eq_ignore_ascii_case("toml")
522                    || e.eq_ignore_ascii_case("yaml")
523                    || e.eq_ignore_ascii_case("yml")
524                    || e.eq_ignore_ascii_case("json")
525            });
526            is_config_ext || f.contains("config") || f.contains(".env")
527        });
528
529        let dirs: std::collections::HashSet<&str> = touched_files
530            .iter()
531            .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
532            .collect();
533
534        let task_type = if has_tests && touched_files.len() <= 3 {
535            TaskType::Test
536        } else if has_config && touched_files.len() <= 2 {
537            TaskType::Config
538        } else if dirs.len() > 3 {
539            TaskType::Refactor
540        } else {
541            TaskType::Explore
542        };
543
544        let scope = match touched_files.len() {
545            1 => IntentScope::SingleFile,
546            2..=4 => IntentScope::MultiFile,
547            _ => IntentScope::CrossModule,
548        };
549
550        let language_hint = detect_language_from_files(touched_files);
551
552        Self {
553            task_type,
554            confidence: 0.5,
555            targets: touched_files.to_vec(),
556            keywords: Vec::new(),
557            scope,
558            language_hint,
559            urgency: 0.0,
560            action_verb: None,
561        }
562    }
563
564    pub fn from_query_with_session(query: &str, touched_files: &[String]) -> Self {
565        let mut intent = Self::from_query(query);
566
567        if intent.language_hint.is_none() && !touched_files.is_empty() {
568            intent.language_hint = detect_language_from_files(touched_files);
569        }
570
571        if intent.scope == IntentScope::SingleFile && touched_files.len() > 3 {
572            let dirs: std::collections::HashSet<&str> = touched_files
573                .iter()
574                .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
575                .collect();
576            if dirs.len() > 2 {
577                intent.scope = IntentScope::MultiFile;
578            }
579        }
580
581        intent
582    }
583
584    pub fn format_header(&self) -> String {
585        format!(
586            "[TASK:{} SCOPE:{} CONF:{:.0}%{}{}]",
587            self.task_type.as_str(),
588            match self.scope {
589                IntentScope::SingleFile => "single",
590                IntentScope::MultiFile => "multi",
591                IntentScope::CrossModule => "cross",
592                IntentScope::ProjectWide => "project",
593            },
594            self.confidence * 100.0,
595            self.language_hint
596                .as_ref()
597                .map(|l| format!(" LANG:{l}"))
598                .unwrap_or_default(),
599            if self.urgency > 0.5 { " URGENT" } else { "" },
600        )
601    }
602}
603
604fn detect_language_hint(query: &str, targets: &[String]) -> Option<String> {
605    for t in targets {
606        let ext = std::path::Path::new(t).extension().and_then(|e| e.to_str());
607        match ext {
608            Some("rs") => return Some("rust".into()),
609            Some("ts" | "tsx") => return Some("typescript".into()),
610            Some("js" | "jsx") => return Some("javascript".into()),
611            Some("py") => return Some("python".into()),
612            Some("go") => return Some("go".into()),
613            Some("rb") => return Some("ruby".into()),
614            Some("java") => return Some("java".into()),
615            Some("swift") => return Some("swift".into()),
616            Some("zig") => return Some("zig".into()),
617            _ => {}
618        }
619    }
620
621    let q = query.to_lowercase();
622    let lang_keywords: &[(&str, &str)] = &[
623        ("rust", "rust"),
624        ("python", "python"),
625        ("typescript", "typescript"),
626        ("javascript", "javascript"),
627        ("golang", "go"),
628        (" go ", "go"),
629        ("ruby", "ruby"),
630        ("java ", "java"),
631        ("swift", "swift"),
632    ];
633    for &(kw, lang) in lang_keywords {
634        if q.contains(kw) {
635            return Some(lang.into());
636        }
637    }
638
639    None
640}
641
642fn detect_language_from_files(files: &[String]) -> Option<String> {
643    let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
644    for f in files {
645        let ext = std::path::Path::new(f)
646            .extension()
647            .and_then(|e| e.to_str())
648            .unwrap_or("");
649        let lang = match ext {
650            "rs" => "rust",
651            "ts" | "tsx" => "typescript",
652            "js" | "jsx" => "javascript",
653            "py" => "python",
654            "go" => "go",
655            "rb" => "ruby",
656            "java" => "java",
657            _ => continue,
658        };
659        *counts.entry(lang).or_insert(0) += 1;
660    }
661    counts
662        .into_iter()
663        .max_by_key(|(_, c)| *c)
664        .map(|(l, _)| l.to_string())
665}
666
667fn detect_urgency(query: &str) -> f64 {
668    let q = query.to_lowercase();
669    let urgent_words = [
670        "urgent",
671        "asap",
672        "immediately",
673        "critical",
674        "hotfix",
675        "emergency",
676        "blocker",
677        "breaking",
678    ];
679    let hits = urgent_words.iter().filter(|w| q.contains(*w)).count();
680    (hits as f64 * 0.4).min(1.0)
681}
682
683fn extract_action_verb(query: &str) -> Option<String> {
684    let verbs = [
685        "fix",
686        "add",
687        "create",
688        "implement",
689        "refactor",
690        "debug",
691        "test",
692        "write",
693        "update",
694        "remove",
695        "delete",
696        "rename",
697        "move",
698        "extract",
699        "split",
700        "merge",
701        "deploy",
702        "review",
703        "check",
704        "build",
705        "generate",
706        "optimize",
707        "clean",
708    ];
709    let q = query.to_lowercase();
710    let words: Vec<&str> = q.split_whitespace().collect();
711    for v in &verbs {
712        if words.first() == Some(v) || words.get(1) == Some(v) {
713            return Some(v.to_string());
714        }
715    }
716    for v in &verbs {
717        if words.contains(v) {
718            return Some(v.to_string());
719        }
720    }
721    None
722}
723
724#[cfg(test)]
725mod tests {
726    use super::*;
727
728    #[test]
729    fn classify_fix_bug() {
730        let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
731        assert_eq!(r.task_type, TaskType::FixBug);
732        assert!(r.confidence > 0.5);
733        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
734    }
735
736    #[test]
737    fn classify_generate() {
738        let r = classify("add a new function normalized_token_entropy to entropy.rs");
739        assert_eq!(r.task_type, TaskType::Generate);
740        assert!(r.confidence > 0.5);
741    }
742
743    #[test]
744    fn classify_refactor() {
745        let r = classify("refactor the compression pipeline to split into smaller modules");
746        assert_eq!(r.task_type, TaskType::Refactor);
747    }
748
749    #[test]
750    fn classify_explore() {
751        let r = classify("how does the session cache work?");
752        assert_eq!(r.task_type, TaskType::Explore);
753    }
754
755    #[test]
756    fn classify_debug() {
757        let r = classify("debug why the compression ratio drops for large files");
758        assert_eq!(r.task_type, TaskType::Debug);
759    }
760
761    #[test]
762    fn classify_test() {
763        let r = classify("write unit tests for the token_optimizer module");
764        assert_eq!(r.task_type, TaskType::Test);
765    }
766
767    #[test]
768    fn targets_extract_paths() {
769        let r = classify("fix entropy.rs and update core/mod.rs");
770        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
771        assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
772    }
773
774    #[test]
775    fn targets_extract_identifiers() {
776        let r = classify("refactor SessionCache to use LRU eviction");
777        assert!(r.targets.iter().any(|t| t == "SessionCache"));
778    }
779
780    #[test]
781    fn fallback_to_explore() {
782        let r = classify("xyz qqq bbb");
783        assert_eq!(r.task_type, TaskType::Explore);
784        assert!(r.confidence < 0.5);
785    }
786
787    #[test]
788    fn multi_intent_detection() {
789        let results = detect_multi_intent("fix the bug in auth.rs and then write unit tests");
790        assert!(results.len() >= 2);
791        assert_eq!(results[0].task_type, TaskType::FixBug);
792        assert_eq!(results[1].task_type, TaskType::Test);
793    }
794
795    #[test]
796    fn single_intent_no_split() {
797        let results = detect_multi_intent("fix the bug in auth.rs");
798        assert_eq!(results.len(), 1);
799        assert_eq!(results[0].task_type, TaskType::FixBug);
800    }
801
802    #[test]
803    fn complexity_mechanical() {
804        let r = classify("add a comment");
805        let c = classify_complexity("add a comment", &r);
806        assert_eq!(c, super::super::adaptive::TaskComplexity::Mechanical);
807    }
808
809    #[test]
810    fn complexity_architectural() {
811        let r = classify("refactor auth across all files and update the migration");
812        let c = classify_complexity(
813            "refactor auth across all files and update the migration",
814            &r,
815        );
816        assert_eq!(c, super::super::adaptive::TaskComplexity::Architectural);
817    }
818}