Skip to main content

lean_ctx/core/
intent_engine.rs

1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
2#[serde(rename_all = "snake_case")]
3pub enum TaskType {
4    Generate,
5    FixBug,
6    Refactor,
7    Explore,
8    Test,
9    Debug,
10    Config,
11    Deploy,
12    Review,
13}
14
15impl TaskType {
16    pub fn as_str(&self) -> &'static str {
17        match self {
18            Self::Generate => "generate",
19            Self::FixBug => "fix_bug",
20            Self::Refactor => "refactor",
21            Self::Explore => "explore",
22            Self::Test => "test",
23            Self::Debug => "debug",
24            Self::Config => "config",
25            Self::Deploy => "deploy",
26            Self::Review => "review",
27        }
28    }
29
30    pub fn thinking_budget(&self) -> ThinkingBudget {
31        match self {
32            Self::Generate => ThinkingBudget::Minimal,
33            Self::FixBug => ThinkingBudget::Minimal,
34            Self::Refactor => ThinkingBudget::Medium,
35            Self::Explore => ThinkingBudget::Medium,
36            Self::Test => ThinkingBudget::Minimal,
37            Self::Debug => ThinkingBudget::Medium,
38            Self::Config => ThinkingBudget::Minimal,
39            Self::Deploy => ThinkingBudget::Minimal,
40            Self::Review => ThinkingBudget::Medium,
41        }
42    }
43
44    pub fn output_format(&self) -> OutputFormat {
45        match self {
46            Self::Generate => OutputFormat::CodeOnly,
47            Self::FixBug => OutputFormat::DiffOnly,
48            Self::Refactor => OutputFormat::DiffOnly,
49            Self::Explore => OutputFormat::ExplainConcise,
50            Self::Test => OutputFormat::CodeOnly,
51            Self::Debug => OutputFormat::Trace,
52            Self::Config => OutputFormat::CodeOnly,
53            Self::Deploy => OutputFormat::StepList,
54            Self::Review => OutputFormat::ExplainConcise,
55        }
56    }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum ThinkingBudget {
61    Minimal,
62    Medium,
63    Trace,
64    Deep,
65}
66
67impl ThinkingBudget {
68    pub fn instruction(&self) -> &'static str {
69        match self {
70            Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
71            Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
72            Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
73            Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
74        }
75    }
76
77    pub fn suppresses_thinking(&self) -> bool {
78        matches!(self, Self::Minimal)
79    }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum OutputFormat {
84    CodeOnly,
85    DiffOnly,
86    ExplainConcise,
87    Trace,
88    StepList,
89}
90
91impl OutputFormat {
92    pub fn instruction(&self) -> &'static str {
93        match self {
94            Self::CodeOnly => {
95                "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
96            }
97            Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
98            Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
99            Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
100            Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
101        }
102    }
103}
104
105#[derive(Debug)]
106pub struct TaskClassification {
107    pub task_type: TaskType,
108    pub confidence: f64,
109    pub targets: Vec<String>,
110    pub keywords: Vec<String>,
111}
112
113const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
114    (
115        &[
116            "add",
117            "create",
118            "implement",
119            "build",
120            "write",
121            "generate",
122            "make",
123            "new feature",
124            "new",
125        ],
126        TaskType::Generate,
127        0.9,
128    ),
129    (
130        &[
131            "fix",
132            "bug",
133            "broken",
134            "crash",
135            "error in",
136            "not working",
137            "fails",
138            "wrong output",
139        ],
140        TaskType::FixBug,
141        0.95,
142    ),
143    (
144        &[
145            "refactor",
146            "clean up",
147            "restructure",
148            "rename",
149            "move",
150            "extract",
151            "simplify",
152            "split",
153        ],
154        TaskType::Refactor,
155        0.9,
156    ),
157    (
158        &[
159            "how",
160            "what",
161            "where",
162            "explain",
163            "understand",
164            "show me",
165            "describe",
166            "why does",
167        ],
168        TaskType::Explore,
169        0.85,
170    ),
171    (
172        &[
173            "test",
174            "spec",
175            "coverage",
176            "assert",
177            "unit test",
178            "integration test",
179            "mock",
180        ],
181        TaskType::Test,
182        0.9,
183    ),
184    (
185        &[
186            "debug",
187            "trace",
188            "inspect",
189            "log",
190            "breakpoint",
191            "step through",
192            "stack trace",
193        ],
194        TaskType::Debug,
195        0.9,
196    ),
197    (
198        &[
199            "config",
200            "setup",
201            "install",
202            "env",
203            "configure",
204            "settings",
205            "dotenv",
206        ],
207        TaskType::Config,
208        0.85,
209    ),
210    (
211        &[
212            "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
213        ],
214        TaskType::Deploy,
215        0.85,
216    ),
217    (
218        &[
219            "review",
220            "check",
221            "audit",
222            "look at",
223            "evaluate",
224            "assess",
225            "pr review",
226        ],
227        TaskType::Review,
228        0.8,
229    ),
230];
231
232pub fn classify(query: &str) -> TaskClassification {
233    let q = query.to_lowercase();
234    let words: Vec<&str> = q.split_whitespace().collect();
235
236    let mut best_type = TaskType::Explore;
237    let mut best_score = 0.0_f64;
238
239    for &(phrases, task_type, base_confidence) in PHRASE_RULES {
240        let mut match_count = 0usize;
241        for phrase in phrases {
242            if phrase.contains(' ') {
243                if q.contains(phrase) {
244                    match_count += 2;
245                }
246            } else if words.contains(phrase) {
247                match_count += 1;
248            }
249        }
250        if match_count > 0 {
251            let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
252            if score > best_score {
253                best_score = score;
254                best_type = task_type;
255            }
256        }
257    }
258
259    let targets = extract_targets(query);
260    let keywords = extract_keywords(&q);
261
262    if best_score < 0.1 {
263        best_type = TaskType::Explore;
264        best_score = 0.3;
265    }
266
267    TaskClassification {
268        task_type: best_type,
269        confidence: best_score,
270        targets,
271        keywords,
272    }
273}
274
275fn extract_targets(query: &str) -> Vec<String> {
276    let mut targets = Vec::new();
277
278    for word in query.split_whitespace() {
279        if word.contains('.') && !word.starts_with('.') {
280            let clean = word.trim_matches(|c: char| {
281                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
282            });
283            if looks_like_path(clean) {
284                targets.push(clean.to_string());
285            }
286        }
287        if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
288            let clean = word.trim_matches(|c: char| {
289                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
290            });
291            if clean.len() > 2 {
292                targets.push(clean.to_string());
293            }
294        }
295    }
296
297    for word in query.split_whitespace() {
298        let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
299        if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
300            targets.push(w.to_string());
301        }
302        if w.chars().any(|c| c.is_uppercase())
303            && w.len() > 2
304            && !is_stop_word(w)
305            && !targets.contains(&w.to_string())
306        {
307            targets.push(w.to_string());
308        }
309    }
310
311    targets.truncate(5);
312    targets
313}
314
315fn looks_like_path(s: &str) -> bool {
316    let exts = [
317        ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
318    ];
319    exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
320}
321
322fn is_stop_word(w: &str) -> bool {
323    matches!(
324        w.to_lowercase().as_str(),
325        "the"
326            | "this"
327            | "that"
328            | "with"
329            | "from"
330            | "into"
331            | "have"
332            | "please"
333            | "could"
334            | "would"
335            | "should"
336            | "also"
337            | "just"
338            | "then"
339            | "when"
340            | "what"
341            | "where"
342            | "which"
343            | "there"
344            | "here"
345            | "these"
346            | "those"
347            | "does"
348            | "will"
349            | "shall"
350            | "can"
351            | "may"
352            | "must"
353            | "need"
354            | "want"
355            | "like"
356            | "make"
357            | "take"
358    )
359}
360
361fn extract_keywords(query: &str) -> Vec<String> {
362    query
363        .split_whitespace()
364        .filter(|w| w.len() > 3)
365        .filter(|w| !is_stop_word(w))
366        .map(|w| {
367            w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
368                .to_lowercase()
369        })
370        .filter(|w| !w.is_empty())
371        .take(8)
372        .collect()
373}
374
375pub fn classify_complexity(
376    query: &str,
377    classification: &TaskClassification,
378) -> super::adaptive::TaskComplexity {
379    use super::adaptive::TaskComplexity;
380
381    let q = query.to_lowercase();
382    let word_count = q.split_whitespace().count();
383    let target_count = classification.targets.len();
384
385    let has_multi_file = target_count >= 3;
386    let has_cross_cutting = q.contains("all files")
387        || q.contains("across")
388        || q.contains("everywhere")
389        || q.contains("every")
390        || q.contains("migration")
391        || q.contains("architecture");
392
393    let is_simple = word_count < 8
394        && target_count <= 1
395        && matches!(
396            classification.task_type,
397            TaskType::Generate | TaskType::Config
398        );
399
400    if is_simple {
401        TaskComplexity::Mechanical
402    } else if has_multi_file || has_cross_cutting {
403        TaskComplexity::Architectural
404    } else {
405        TaskComplexity::Standard
406    }
407}
408
409pub fn detect_multi_intent(query: &str) -> Vec<TaskClassification> {
410    let delimiters = [" and then ", " then ", " also ", " + ", ". "];
411
412    let mut parts: Vec<&str> = vec![query];
413    for delim in &delimiters {
414        let mut new_parts = Vec::new();
415        for part in &parts {
416            for sub in part.split(delim) {
417                let trimmed = sub.trim();
418                if !trimmed.is_empty() {
419                    new_parts.push(trimmed);
420                }
421            }
422        }
423        parts = new_parts;
424    }
425
426    if parts.len() <= 1 {
427        return vec![classify(query)];
428    }
429
430    parts.iter().map(|part| classify(part)).collect()
431}
432
433pub fn format_briefing_header(classification: &TaskClassification) -> String {
434    format!(
435        "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
436        classification.task_type.as_str(),
437        classification.confidence * 100.0,
438        if classification.targets.is_empty() {
439            "-".to_string()
440        } else {
441            classification.targets.join(",")
442        },
443        classification.keywords.join(","),
444    )
445}
446
447#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
448#[serde(rename_all = "snake_case")]
449pub enum IntentScope {
450    SingleFile,
451    MultiFile,
452    CrossModule,
453    ProjectWide,
454}
455
456#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
457pub struct StructuredIntent {
458    pub task_type: TaskType,
459    pub confidence: f64,
460    pub targets: Vec<String>,
461    pub keywords: Vec<String>,
462    pub scope: IntentScope,
463    pub language_hint: Option<String>,
464    pub urgency: f64,
465    pub action_verb: Option<String>,
466}
467
468impl StructuredIntent {
469    pub fn from_query(query: &str) -> Self {
470        let classification = classify(query);
471        let complexity = classify_complexity(query, &classification);
472        let file_targets = classification
473            .targets
474            .iter()
475            .filter(|t| t.contains('.') || t.contains('/'))
476            .count();
477        let scope = match complexity {
478            super::adaptive::TaskComplexity::Mechanical => IntentScope::SingleFile,
479            super::adaptive::TaskComplexity::Standard => {
480                if file_targets > 1 {
481                    IntentScope::MultiFile
482                } else {
483                    IntentScope::SingleFile
484                }
485            }
486            super::adaptive::TaskComplexity::Architectural => {
487                let q = query.to_lowercase();
488                if q.contains("all files") || q.contains("everywhere") || q.contains("migration") {
489                    IntentScope::ProjectWide
490                } else {
491                    IntentScope::CrossModule
492                }
493            }
494        };
495
496        let language_hint = detect_language_hint(query, &classification.targets);
497        let urgency = detect_urgency(query);
498        let action_verb = extract_action_verb(query);
499
500        StructuredIntent {
501            task_type: classification.task_type,
502            confidence: classification.confidence,
503            targets: classification.targets,
504            keywords: classification.keywords,
505            scope,
506            language_hint,
507            urgency,
508            action_verb,
509        }
510    }
511
512    pub fn from_query_with_session(query: &str, touched_files: &[String]) -> Self {
513        let mut intent = Self::from_query(query);
514
515        if intent.language_hint.is_none() && !touched_files.is_empty() {
516            intent.language_hint = detect_language_from_files(touched_files);
517        }
518
519        if intent.scope == IntentScope::SingleFile && touched_files.len() > 3 {
520            let dirs: std::collections::HashSet<&str> = touched_files
521                .iter()
522                .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
523                .collect();
524            if dirs.len() > 2 {
525                intent.scope = IntentScope::MultiFile;
526            }
527        }
528
529        intent
530    }
531
532    pub fn format_header(&self) -> String {
533        format!(
534            "[TASK:{} SCOPE:{} CONF:{:.0}%{}{}]",
535            self.task_type.as_str(),
536            match self.scope {
537                IntentScope::SingleFile => "single",
538                IntentScope::MultiFile => "multi",
539                IntentScope::CrossModule => "cross",
540                IntentScope::ProjectWide => "project",
541            },
542            self.confidence * 100.0,
543            self.language_hint
544                .as_ref()
545                .map(|l| format!(" LANG:{l}"))
546                .unwrap_or_default(),
547            if self.urgency > 0.5 { " URGENT" } else { "" },
548        )
549    }
550}
551
552fn detect_language_hint(query: &str, targets: &[String]) -> Option<String> {
553    for t in targets {
554        let ext = std::path::Path::new(t).extension().and_then(|e| e.to_str());
555        match ext {
556            Some("rs") => return Some("rust".into()),
557            Some("ts" | "tsx") => return Some("typescript".into()),
558            Some("js" | "jsx") => return Some("javascript".into()),
559            Some("py") => return Some("python".into()),
560            Some("go") => return Some("go".into()),
561            Some("rb") => return Some("ruby".into()),
562            Some("java") => return Some("java".into()),
563            Some("swift") => return Some("swift".into()),
564            Some("zig") => return Some("zig".into()),
565            _ => {}
566        }
567    }
568
569    let q = query.to_lowercase();
570    let lang_keywords: &[(&str, &str)] = &[
571        ("rust", "rust"),
572        ("python", "python"),
573        ("typescript", "typescript"),
574        ("javascript", "javascript"),
575        ("golang", "go"),
576        (" go ", "go"),
577        ("ruby", "ruby"),
578        ("java ", "java"),
579        ("swift", "swift"),
580    ];
581    for &(kw, lang) in lang_keywords {
582        if q.contains(kw) {
583            return Some(lang.into());
584        }
585    }
586
587    None
588}
589
590fn detect_language_from_files(files: &[String]) -> Option<String> {
591    let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
592    for f in files {
593        let ext = std::path::Path::new(f)
594            .extension()
595            .and_then(|e| e.to_str())
596            .unwrap_or("");
597        let lang = match ext {
598            "rs" => "rust",
599            "ts" | "tsx" => "typescript",
600            "js" | "jsx" => "javascript",
601            "py" => "python",
602            "go" => "go",
603            "rb" => "ruby",
604            "java" => "java",
605            _ => continue,
606        };
607        *counts.entry(lang).or_insert(0) += 1;
608    }
609    counts
610        .into_iter()
611        .max_by_key(|(_, c)| *c)
612        .map(|(l, _)| l.to_string())
613}
614
615fn detect_urgency(query: &str) -> f64 {
616    let q = query.to_lowercase();
617    let urgent_words = [
618        "urgent",
619        "asap",
620        "immediately",
621        "critical",
622        "hotfix",
623        "emergency",
624        "blocker",
625        "breaking",
626    ];
627    let hits = urgent_words.iter().filter(|w| q.contains(*w)).count();
628    (hits as f64 * 0.4).min(1.0)
629}
630
631fn extract_action_verb(query: &str) -> Option<String> {
632    let verbs = [
633        "fix",
634        "add",
635        "create",
636        "implement",
637        "refactor",
638        "debug",
639        "test",
640        "write",
641        "update",
642        "remove",
643        "delete",
644        "rename",
645        "move",
646        "extract",
647        "split",
648        "merge",
649        "deploy",
650        "review",
651        "check",
652        "build",
653        "generate",
654        "optimize",
655        "clean",
656    ];
657    let q = query.to_lowercase();
658    let words: Vec<&str> = q.split_whitespace().collect();
659    for v in &verbs {
660        if words.first() == Some(v) || words.get(1) == Some(v) {
661            return Some(v.to_string());
662        }
663    }
664    for v in &verbs {
665        if words.contains(v) {
666            return Some(v.to_string());
667        }
668    }
669    None
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn classify_fix_bug() {
678        let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
679        assert_eq!(r.task_type, TaskType::FixBug);
680        assert!(r.confidence > 0.5);
681        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
682    }
683
684    #[test]
685    fn classify_generate() {
686        let r = classify("add a new function normalized_token_entropy to entropy.rs");
687        assert_eq!(r.task_type, TaskType::Generate);
688        assert!(r.confidence > 0.5);
689    }
690
691    #[test]
692    fn classify_refactor() {
693        let r = classify("refactor the compression pipeline to split into smaller modules");
694        assert_eq!(r.task_type, TaskType::Refactor);
695    }
696
697    #[test]
698    fn classify_explore() {
699        let r = classify("how does the session cache work?");
700        assert_eq!(r.task_type, TaskType::Explore);
701    }
702
703    #[test]
704    fn classify_debug() {
705        let r = classify("debug why the compression ratio drops for large files");
706        assert_eq!(r.task_type, TaskType::Debug);
707    }
708
709    #[test]
710    fn classify_test() {
711        let r = classify("write unit tests for the token_optimizer module");
712        assert_eq!(r.task_type, TaskType::Test);
713    }
714
715    #[test]
716    fn targets_extract_paths() {
717        let r = classify("fix entropy.rs and update core/mod.rs");
718        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
719        assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
720    }
721
722    #[test]
723    fn targets_extract_identifiers() {
724        let r = classify("refactor SessionCache to use LRU eviction");
725        assert!(r.targets.iter().any(|t| t == "SessionCache"));
726    }
727
728    #[test]
729    fn fallback_to_explore() {
730        let r = classify("xyz qqq bbb");
731        assert_eq!(r.task_type, TaskType::Explore);
732        assert!(r.confidence < 0.5);
733    }
734
735    #[test]
736    fn multi_intent_detection() {
737        let results = detect_multi_intent("fix the bug in auth.rs and then write unit tests");
738        assert!(results.len() >= 2);
739        assert_eq!(results[0].task_type, TaskType::FixBug);
740        assert_eq!(results[1].task_type, TaskType::Test);
741    }
742
743    #[test]
744    fn single_intent_no_split() {
745        let results = detect_multi_intent("fix the bug in auth.rs");
746        assert_eq!(results.len(), 1);
747        assert_eq!(results[0].task_type, TaskType::FixBug);
748    }
749
750    #[test]
751    fn complexity_mechanical() {
752        let r = classify("add a comment");
753        let c = classify_complexity("add a comment", &r);
754        assert_eq!(c, super::super::adaptive::TaskComplexity::Mechanical);
755    }
756
757    #[test]
758    fn complexity_architectural() {
759        let r = classify("refactor auth across all files and update the migration");
760        let c = classify_complexity(
761            "refactor auth across all files and update the migration",
762            &r,
763        );
764        assert_eq!(c, super::super::adaptive::TaskComplexity::Architectural);
765    }
766
767    #[test]
768    fn structured_intent_from_fixbug_query() {
769        let intent = StructuredIntent::from_query("fix the NaN bug in entropy.rs");
770        assert_eq!(intent.task_type, TaskType::FixBug);
771        assert!(intent.targets.iter().any(|t| t.contains("entropy.rs")));
772        assert_eq!(intent.language_hint.as_deref(), Some("rust"));
773        assert_eq!(intent.action_verb.as_deref(), Some("fix"));
774        assert_eq!(intent.scope, IntentScope::SingleFile);
775    }
776
777    #[test]
778    fn structured_intent_project_wide() {
779        let intent =
780            StructuredIntent::from_query("refactor auth across all files and update migration");
781        assert_eq!(intent.task_type, TaskType::Refactor);
782        assert_eq!(intent.scope, IntentScope::ProjectWide);
783    }
784
785    #[test]
786    fn structured_intent_urgency() {
787        let normal = StructuredIntent::from_query("add a new function");
788        let urgent = StructuredIntent::from_query("urgent hotfix for critical auth bug");
789        assert!(urgent.urgency > normal.urgency);
790        assert!(urgent.urgency >= 0.8);
791    }
792
793    #[test]
794    fn structured_intent_language_from_targets() {
795        let intent = StructuredIntent::from_query("fix main.py auth handler");
796        assert_eq!(intent.language_hint.as_deref(), Some("python"));
797    }
798
799    #[test]
800    fn structured_intent_with_session() {
801        let files = vec![
802            "src/core/session.rs".to_string(),
803            "src/core/litm.rs".to_string(),
804            "src/tools/ctx_read.rs".to_string(),
805        ];
806        let intent = StructuredIntent::from_query_with_session("how does this work?", &files);
807        assert_eq!(intent.language_hint.as_deref(), Some("rust"));
808    }
809
810    #[test]
811    fn structured_intent_header_format() {
812        let intent = StructuredIntent::from_query("fix bug in entropy.rs");
813        let header = intent.format_header();
814        assert!(header.contains("TASK:fix_bug"));
815        assert!(header.contains("SCOPE:"));
816        assert!(header.contains("LANG:rust"));
817    }
818
819    #[test]
820    fn detect_urgency_none() {
821        assert_eq!(detect_urgency("add a function"), 0.0);
822    }
823
824    #[test]
825    fn detect_urgency_multiple() {
826        let u = detect_urgency("urgent critical blocker");
827        assert!(u > 0.9);
828    }
829
830    #[test]
831    fn action_verb_extraction() {
832        assert_eq!(extract_action_verb("fix the bug"), Some("fix".to_string()));
833        assert_eq!(
834            extract_action_verb("please add logging"),
835            Some("add".to_string())
836        );
837        assert_eq!(extract_action_verb("xyz qqq"), None);
838    }
839}