Skip to main content

lean_ctx/core/
intent_engine.rs

1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
2pub enum TaskType {
3    Generate,
4    FixBug,
5    Refactor,
6    Explore,
7    Test,
8    Debug,
9    Config,
10    Deploy,
11    Review,
12}
13
14impl TaskType {
15    pub fn as_str(&self) -> &'static str {
16        match self {
17            Self::Generate => "generate",
18            Self::FixBug => "fix_bug",
19            Self::Refactor => "refactor",
20            Self::Explore => "explore",
21            Self::Test => "test",
22            Self::Debug => "debug",
23            Self::Config => "config",
24            Self::Deploy => "deploy",
25            Self::Review => "review",
26        }
27    }
28
29    pub fn thinking_budget(&self) -> ThinkingBudget {
30        match self {
31            Self::Generate => ThinkingBudget::Minimal,
32            Self::FixBug => ThinkingBudget::Minimal,
33            Self::Refactor => ThinkingBudget::Medium,
34            Self::Explore => ThinkingBudget::Medium,
35            Self::Test => ThinkingBudget::Minimal,
36            Self::Debug => ThinkingBudget::Medium,
37            Self::Config => ThinkingBudget::Minimal,
38            Self::Deploy => ThinkingBudget::Minimal,
39            Self::Review => ThinkingBudget::Medium,
40        }
41    }
42
43    pub fn output_format(&self) -> OutputFormat {
44        match self {
45            Self::Generate => OutputFormat::CodeOnly,
46            Self::FixBug => OutputFormat::DiffOnly,
47            Self::Refactor => OutputFormat::DiffOnly,
48            Self::Explore => OutputFormat::ExplainConcise,
49            Self::Test => OutputFormat::CodeOnly,
50            Self::Debug => OutputFormat::Trace,
51            Self::Config => OutputFormat::CodeOnly,
52            Self::Deploy => OutputFormat::StepList,
53            Self::Review => OutputFormat::ExplainConcise,
54        }
55    }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ThinkingBudget {
60    Minimal,
61    Medium,
62    Trace,
63    Deep,
64}
65
66impl ThinkingBudget {
67    pub fn instruction(&self) -> &'static str {
68        match self {
69            Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
70            Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
71            Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
72            Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
73        }
74    }
75
76    pub fn suppresses_thinking(&self) -> bool {
77        matches!(self, Self::Minimal)
78    }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum OutputFormat {
83    CodeOnly,
84    DiffOnly,
85    ExplainConcise,
86    Trace,
87    StepList,
88}
89
90impl OutputFormat {
91    pub fn instruction(&self) -> &'static str {
92        match self {
93            Self::CodeOnly => {
94                "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
95            }
96            Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
97            Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
98            Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
99            Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
100        }
101    }
102}
103
104#[derive(Debug)]
105pub struct TaskClassification {
106    pub task_type: TaskType,
107    pub confidence: f64,
108    pub targets: Vec<String>,
109    pub keywords: Vec<String>,
110}
111
112const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
113    (
114        &[
115            "add",
116            "create",
117            "implement",
118            "build",
119            "write",
120            "generate",
121            "make",
122            "new feature",
123            "new",
124        ],
125        TaskType::Generate,
126        0.9,
127    ),
128    (
129        &[
130            "fix",
131            "bug",
132            "broken",
133            "crash",
134            "error in",
135            "not working",
136            "fails",
137            "wrong output",
138        ],
139        TaskType::FixBug,
140        0.95,
141    ),
142    (
143        &[
144            "refactor",
145            "clean up",
146            "restructure",
147            "rename",
148            "move",
149            "extract",
150            "simplify",
151            "split",
152        ],
153        TaskType::Refactor,
154        0.9,
155    ),
156    (
157        &[
158            "how",
159            "what",
160            "where",
161            "explain",
162            "understand",
163            "show me",
164            "describe",
165            "why does",
166        ],
167        TaskType::Explore,
168        0.85,
169    ),
170    (
171        &[
172            "test",
173            "spec",
174            "coverage",
175            "assert",
176            "unit test",
177            "integration test",
178            "mock",
179        ],
180        TaskType::Test,
181        0.9,
182    ),
183    (
184        &[
185            "debug",
186            "trace",
187            "inspect",
188            "log",
189            "breakpoint",
190            "step through",
191            "stack trace",
192        ],
193        TaskType::Debug,
194        0.9,
195    ),
196    (
197        &[
198            "config",
199            "setup",
200            "install",
201            "env",
202            "configure",
203            "settings",
204            "dotenv",
205        ],
206        TaskType::Config,
207        0.85,
208    ),
209    (
210        &[
211            "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
212        ],
213        TaskType::Deploy,
214        0.85,
215    ),
216    (
217        &[
218            "review",
219            "check",
220            "audit",
221            "look at",
222            "evaluate",
223            "assess",
224            "pr review",
225        ],
226        TaskType::Review,
227        0.8,
228    ),
229];
230
231pub fn classify(query: &str) -> TaskClassification {
232    let q = query.to_lowercase();
233    let words: Vec<&str> = q.split_whitespace().collect();
234
235    let mut best_type = TaskType::Explore;
236    let mut best_score = 0.0_f64;
237
238    for &(phrases, task_type, base_confidence) in PHRASE_RULES {
239        let mut match_count = 0usize;
240        for phrase in phrases {
241            if phrase.contains(' ') {
242                if q.contains(phrase) {
243                    match_count += 2;
244                }
245            } else if words.contains(phrase) {
246                match_count += 1;
247            }
248        }
249        if match_count > 0 {
250            let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
251            if score > best_score {
252                best_score = score;
253                best_type = task_type;
254            }
255        }
256    }
257
258    let targets = extract_targets(query);
259    let keywords = extract_keywords(&q);
260
261    if best_score < 0.1 {
262        best_type = TaskType::Explore;
263        best_score = 0.3;
264    }
265
266    TaskClassification {
267        task_type: best_type,
268        confidence: best_score,
269        targets,
270        keywords,
271    }
272}
273
274fn extract_targets(query: &str) -> Vec<String> {
275    let mut targets = Vec::new();
276
277    for word in query.split_whitespace() {
278        if word.contains('.') && !word.starts_with('.') {
279            let clean = word.trim_matches(|c: char| {
280                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
281            });
282            if looks_like_path(clean) {
283                targets.push(clean.to_string());
284            }
285        }
286        if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
287            let clean = word.trim_matches(|c: char| {
288                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
289            });
290            if clean.len() > 2 {
291                targets.push(clean.to_string());
292            }
293        }
294    }
295
296    for word in query.split_whitespace() {
297        let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
298        if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
299            targets.push(w.to_string());
300        }
301        if w.chars().any(|c| c.is_uppercase())
302            && w.len() > 2
303            && !is_stop_word(w)
304            && !targets.contains(&w.to_string())
305        {
306            targets.push(w.to_string());
307        }
308    }
309
310    targets.truncate(5);
311    targets
312}
313
314fn looks_like_path(s: &str) -> bool {
315    let exts = [
316        ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
317    ];
318    exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
319}
320
321fn is_stop_word(w: &str) -> bool {
322    matches!(
323        w.to_lowercase().as_str(),
324        "the"
325            | "this"
326            | "that"
327            | "with"
328            | "from"
329            | "into"
330            | "have"
331            | "please"
332            | "could"
333            | "would"
334            | "should"
335            | "also"
336            | "just"
337            | "then"
338            | "when"
339            | "what"
340            | "where"
341            | "which"
342            | "there"
343            | "here"
344            | "these"
345            | "those"
346            | "does"
347            | "will"
348            | "shall"
349            | "can"
350            | "may"
351            | "must"
352            | "need"
353            | "want"
354            | "like"
355            | "make"
356            | "take"
357    )
358}
359
360fn extract_keywords(query: &str) -> Vec<String> {
361    query
362        .split_whitespace()
363        .filter(|w| w.len() > 3)
364        .filter(|w| !is_stop_word(w))
365        .map(|w| {
366            w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
367                .to_lowercase()
368        })
369        .filter(|w| !w.is_empty())
370        .take(8)
371        .collect()
372}
373
374pub fn classify_complexity(
375    query: &str,
376    classification: &TaskClassification,
377) -> super::adaptive::TaskComplexity {
378    use super::adaptive::TaskComplexity;
379
380    let q = query.to_lowercase();
381    let word_count = q.split_whitespace().count();
382    let target_count = classification.targets.len();
383
384    let has_multi_file = target_count >= 3;
385    let has_cross_cutting = q.contains("all files")
386        || q.contains("across")
387        || q.contains("everywhere")
388        || q.contains("every")
389        || q.contains("migration")
390        || q.contains("architecture");
391
392    let is_simple = word_count < 8
393        && target_count <= 1
394        && matches!(
395            classification.task_type,
396            TaskType::Generate | TaskType::Config
397        );
398
399    if is_simple {
400        TaskComplexity::Mechanical
401    } else if has_multi_file || has_cross_cutting {
402        TaskComplexity::Architectural
403    } else {
404        TaskComplexity::Standard
405    }
406}
407
408pub fn detect_multi_intent(query: &str) -> Vec<TaskClassification> {
409    let delimiters = [" and then ", " then ", " also ", " + ", ". "];
410
411    let mut parts: Vec<&str> = vec![query];
412    for delim in &delimiters {
413        let mut new_parts = Vec::new();
414        for part in &parts {
415            for sub in part.split(delim) {
416                let trimmed = sub.trim();
417                if !trimmed.is_empty() {
418                    new_parts.push(trimmed);
419                }
420            }
421        }
422        parts = new_parts;
423    }
424
425    if parts.len() <= 1 {
426        return vec![classify(query)];
427    }
428
429    parts.iter().map(|part| classify(part)).collect()
430}
431
432pub fn format_briefing_header(classification: &TaskClassification) -> String {
433    format!(
434        "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
435        classification.task_type.as_str(),
436        classification.confidence * 100.0,
437        if classification.targets.is_empty() {
438            "-".to_string()
439        } else {
440            classification.targets.join(",")
441        },
442        classification.keywords.join(","),
443    )
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
447pub enum IntentScope {
448    SingleFile,
449    MultiFile,
450    CrossModule,
451    ProjectWide,
452}
453
454#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
455pub struct StructuredIntent {
456    pub task_type: TaskType,
457    pub confidence: f64,
458    pub targets: Vec<String>,
459    pub keywords: Vec<String>,
460    pub scope: IntentScope,
461    pub language_hint: Option<String>,
462    pub urgency: f64,
463    pub action_verb: Option<String>,
464}
465
466impl StructuredIntent {
467    pub fn from_query(query: &str) -> Self {
468        let classification = classify(query);
469        let complexity = classify_complexity(query, &classification);
470        let file_targets = classification
471            .targets
472            .iter()
473            .filter(|t| t.contains('.') || t.contains('/'))
474            .count();
475        let scope = match complexity {
476            super::adaptive::TaskComplexity::Mechanical => IntentScope::SingleFile,
477            super::adaptive::TaskComplexity::Standard => {
478                if file_targets > 1 {
479                    IntentScope::MultiFile
480                } else {
481                    IntentScope::SingleFile
482                }
483            }
484            super::adaptive::TaskComplexity::Architectural => {
485                let q = query.to_lowercase();
486                if q.contains("all files") || q.contains("everywhere") || q.contains("migration") {
487                    IntentScope::ProjectWide
488                } else {
489                    IntentScope::CrossModule
490                }
491            }
492        };
493
494        let language_hint = detect_language_hint(query, &classification.targets);
495        let urgency = detect_urgency(query);
496        let action_verb = extract_action_verb(query);
497
498        StructuredIntent {
499            task_type: classification.task_type,
500            confidence: classification.confidence,
501            targets: classification.targets,
502            keywords: classification.keywords,
503            scope,
504            language_hint,
505            urgency,
506            action_verb,
507        }
508    }
509
510    pub fn from_file_patterns(touched_files: &[String]) -> Self {
511        if touched_files.is_empty() {
512            return Self {
513                task_type: TaskType::Explore,
514                confidence: 0.3,
515                targets: Vec::new(),
516                keywords: Vec::new(),
517                scope: IntentScope::SingleFile,
518                language_hint: None,
519                urgency: 0.0,
520                action_verb: None,
521            };
522        }
523
524        let has_tests = touched_files
525            .iter()
526            .any(|f| f.contains("test") || f.contains("spec"));
527        let has_config = touched_files.iter().any(|f| {
528            let lower = f.to_lowercase();
529            lower.ends_with(".toml")
530                || lower.ends_with(".yaml")
531                || lower.ends_with(".yml")
532                || lower.ends_with(".json")
533                || lower.contains("config")
534                || lower.contains(".env")
535        });
536
537        let dirs: std::collections::HashSet<&str> = touched_files
538            .iter()
539            .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
540            .collect();
541
542        let task_type = if has_tests && touched_files.len() <= 3 {
543            TaskType::Test
544        } else if has_config && touched_files.len() <= 2 {
545            TaskType::Config
546        } else if dirs.len() > 3 {
547            TaskType::Refactor
548        } else {
549            TaskType::Explore
550        };
551
552        let scope = match touched_files.len() {
553            1 => IntentScope::SingleFile,
554            2..=4 => IntentScope::MultiFile,
555            _ => IntentScope::CrossModule,
556        };
557
558        let language_hint = detect_language_from_files(touched_files);
559
560        Self {
561            task_type,
562            confidence: 0.5,
563            targets: touched_files.to_vec(),
564            keywords: Vec::new(),
565            scope,
566            language_hint,
567            urgency: 0.0,
568            action_verb: None,
569        }
570    }
571
572    pub fn from_query_with_session(query: &str, touched_files: &[String]) -> Self {
573        let mut intent = Self::from_query(query);
574
575        if intent.language_hint.is_none() && !touched_files.is_empty() {
576            intent.language_hint = detect_language_from_files(touched_files);
577        }
578
579        if intent.scope == IntentScope::SingleFile && touched_files.len() > 3 {
580            let dirs: std::collections::HashSet<&str> = touched_files
581                .iter()
582                .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
583                .collect();
584            if dirs.len() > 2 {
585                intent.scope = IntentScope::MultiFile;
586            }
587        }
588
589        intent
590    }
591
592    pub fn format_header(&self) -> String {
593        format!(
594            "[TASK:{} SCOPE:{} CONF:{:.0}%{}{}]",
595            self.task_type.as_str(),
596            match self.scope {
597                IntentScope::SingleFile => "single",
598                IntentScope::MultiFile => "multi",
599                IntentScope::CrossModule => "cross",
600                IntentScope::ProjectWide => "project",
601            },
602            self.confidence * 100.0,
603            self.language_hint
604                .as_ref()
605                .map(|l| format!(" LANG:{l}"))
606                .unwrap_or_default(),
607            if self.urgency > 0.5 { " URGENT" } else { "" },
608        )
609    }
610}
611
612fn detect_language_hint(query: &str, targets: &[String]) -> Option<String> {
613    for t in targets {
614        let ext = std::path::Path::new(t).extension().and_then(|e| e.to_str());
615        match ext {
616            Some("rs") => return Some("rust".into()),
617            Some("ts" | "tsx") => return Some("typescript".into()),
618            Some("js" | "jsx") => return Some("javascript".into()),
619            Some("py") => return Some("python".into()),
620            Some("go") => return Some("go".into()),
621            Some("rb") => return Some("ruby".into()),
622            Some("java") => return Some("java".into()),
623            Some("swift") => return Some("swift".into()),
624            Some("zig") => return Some("zig".into()),
625            _ => {}
626        }
627    }
628
629    let q = query.to_lowercase();
630    let lang_keywords: &[(&str, &str)] = &[
631        ("rust", "rust"),
632        ("python", "python"),
633        ("typescript", "typescript"),
634        ("javascript", "javascript"),
635        ("golang", "go"),
636        (" go ", "go"),
637        ("ruby", "ruby"),
638        ("java ", "java"),
639        ("swift", "swift"),
640    ];
641    for &(kw, lang) in lang_keywords {
642        if q.contains(kw) {
643            return Some(lang.into());
644        }
645    }
646
647    None
648}
649
650fn detect_language_from_files(files: &[String]) -> Option<String> {
651    let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
652    for f in files {
653        let ext = std::path::Path::new(f)
654            .extension()
655            .and_then(|e| e.to_str())
656            .unwrap_or("");
657        let lang = match ext {
658            "rs" => "rust",
659            "ts" | "tsx" => "typescript",
660            "js" | "jsx" => "javascript",
661            "py" => "python",
662            "go" => "go",
663            "rb" => "ruby",
664            "java" => "java",
665            _ => continue,
666        };
667        *counts.entry(lang).or_insert(0) += 1;
668    }
669    counts
670        .into_iter()
671        .max_by_key(|(_, c)| *c)
672        .map(|(l, _)| l.to_string())
673}
674
675fn detect_urgency(query: &str) -> f64 {
676    let q = query.to_lowercase();
677    let urgent_words = [
678        "urgent",
679        "asap",
680        "immediately",
681        "critical",
682        "hotfix",
683        "emergency",
684        "blocker",
685        "breaking",
686    ];
687    let hits = urgent_words.iter().filter(|w| q.contains(*w)).count();
688    (hits as f64 * 0.4).min(1.0)
689}
690
691fn extract_action_verb(query: &str) -> Option<String> {
692    let verbs = [
693        "fix",
694        "add",
695        "create",
696        "implement",
697        "refactor",
698        "debug",
699        "test",
700        "write",
701        "update",
702        "remove",
703        "delete",
704        "rename",
705        "move",
706        "extract",
707        "split",
708        "merge",
709        "deploy",
710        "review",
711        "check",
712        "build",
713        "generate",
714        "optimize",
715        "clean",
716    ];
717    let q = query.to_lowercase();
718    let words: Vec<&str> = q.split_whitespace().collect();
719    for v in &verbs {
720        if words.first() == Some(v) || words.get(1) == Some(v) {
721            return Some(v.to_string());
722        }
723    }
724    for v in &verbs {
725        if words.contains(v) {
726            return Some(v.to_string());
727        }
728    }
729    None
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    #[test]
737    fn classify_fix_bug() {
738        let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
739        assert_eq!(r.task_type, TaskType::FixBug);
740        assert!(r.confidence > 0.5);
741        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
742    }
743
744    #[test]
745    fn classify_generate() {
746        let r = classify("add a new function normalized_token_entropy to entropy.rs");
747        assert_eq!(r.task_type, TaskType::Generate);
748        assert!(r.confidence > 0.5);
749    }
750
751    #[test]
752    fn classify_refactor() {
753        let r = classify("refactor the compression pipeline to split into smaller modules");
754        assert_eq!(r.task_type, TaskType::Refactor);
755    }
756
757    #[test]
758    fn classify_explore() {
759        let r = classify("how does the session cache work?");
760        assert_eq!(r.task_type, TaskType::Explore);
761    }
762
763    #[test]
764    fn classify_debug() {
765        let r = classify("debug why the compression ratio drops for large files");
766        assert_eq!(r.task_type, TaskType::Debug);
767    }
768
769    #[test]
770    fn classify_test() {
771        let r = classify("write unit tests for the token_optimizer module");
772        assert_eq!(r.task_type, TaskType::Test);
773    }
774
775    #[test]
776    fn targets_extract_paths() {
777        let r = classify("fix entropy.rs and update core/mod.rs");
778        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
779        assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
780    }
781
782    #[test]
783    fn targets_extract_identifiers() {
784        let r = classify("refactor SessionCache to use LRU eviction");
785        assert!(r.targets.iter().any(|t| t == "SessionCache"));
786    }
787
788    #[test]
789    fn fallback_to_explore() {
790        let r = classify("xyz qqq bbb");
791        assert_eq!(r.task_type, TaskType::Explore);
792        assert!(r.confidence < 0.5);
793    }
794
795    #[test]
796    fn multi_intent_detection() {
797        let results = detect_multi_intent("fix the bug in auth.rs and then write unit tests");
798        assert!(results.len() >= 2);
799        assert_eq!(results[0].task_type, TaskType::FixBug);
800        assert_eq!(results[1].task_type, TaskType::Test);
801    }
802
803    #[test]
804    fn single_intent_no_split() {
805        let results = detect_multi_intent("fix the bug in auth.rs");
806        assert_eq!(results.len(), 1);
807        assert_eq!(results[0].task_type, TaskType::FixBug);
808    }
809
810    #[test]
811    fn complexity_mechanical() {
812        let r = classify("add a comment");
813        let c = classify_complexity("add a comment", &r);
814        assert_eq!(c, super::super::adaptive::TaskComplexity::Mechanical);
815    }
816
817    #[test]
818    fn complexity_architectural() {
819        let r = classify("refactor auth across all files and update the migration");
820        let c = classify_complexity(
821            "refactor auth across all files and update the migration",
822            &r,
823        );
824        assert_eq!(c, super::super::adaptive::TaskComplexity::Architectural);
825    }
826}