Skip to main content

lean_ctx/core/
intent_engine.rs

1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2pub enum TaskType {
3    Generate,
4    FixBug,
5    Refactor,
6    Explore,
7    Test,
8    Debug,
9    Config,
10    Deploy,
11    Review,
12}
13
14impl TaskType {
15    pub fn as_str(&self) -> &'static str {
16        match self {
17            Self::Generate => "generate",
18            Self::FixBug => "fix_bug",
19            Self::Refactor => "refactor",
20            Self::Explore => "explore",
21            Self::Test => "test",
22            Self::Debug => "debug",
23            Self::Config => "config",
24            Self::Deploy => "deploy",
25            Self::Review => "review",
26        }
27    }
28
29    pub fn thinking_budget(&self) -> ThinkingBudget {
30        match self {
31            Self::Generate => ThinkingBudget::Minimal,
32            Self::FixBug => ThinkingBudget::Minimal,
33            Self::Refactor => ThinkingBudget::Medium,
34            Self::Explore => ThinkingBudget::Medium,
35            Self::Test => ThinkingBudget::Minimal,
36            Self::Debug => ThinkingBudget::Medium,
37            Self::Config => ThinkingBudget::Minimal,
38            Self::Deploy => ThinkingBudget::Minimal,
39            Self::Review => ThinkingBudget::Medium,
40        }
41    }
42
43    pub fn output_format(&self) -> OutputFormat {
44        match self {
45            Self::Generate => OutputFormat::CodeOnly,
46            Self::FixBug => OutputFormat::DiffOnly,
47            Self::Refactor => OutputFormat::DiffOnly,
48            Self::Explore => OutputFormat::ExplainConcise,
49            Self::Test => OutputFormat::CodeOnly,
50            Self::Debug => OutputFormat::Trace,
51            Self::Config => OutputFormat::CodeOnly,
52            Self::Deploy => OutputFormat::StepList,
53            Self::Review => OutputFormat::ExplainConcise,
54        }
55    }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ThinkingBudget {
60    Minimal,
61    Medium,
62    Trace,
63    Deep,
64}
65
66impl ThinkingBudget {
67    pub fn instruction(&self) -> &'static str {
68        match self {
69            Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
70            Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
71            Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
72            Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
73        }
74    }
75
76    pub fn suppresses_thinking(&self) -> bool {
77        matches!(self, Self::Minimal)
78    }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum OutputFormat {
83    CodeOnly,
84    DiffOnly,
85    ExplainConcise,
86    Trace,
87    StepList,
88}
89
90impl OutputFormat {
91    pub fn instruction(&self) -> &'static str {
92        match self {
93            Self::CodeOnly => {
94                "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
95            }
96            Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
97            Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
98            Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
99            Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
100        }
101    }
102}
103
104#[derive(Debug)]
105pub struct TaskClassification {
106    pub task_type: TaskType,
107    pub confidence: f64,
108    pub targets: Vec<String>,
109    pub keywords: Vec<String>,
110}
111
112const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
113    (
114        &[
115            "add",
116            "create",
117            "implement",
118            "build",
119            "write",
120            "generate",
121            "make",
122            "new feature",
123            "new",
124        ],
125        TaskType::Generate,
126        0.9,
127    ),
128    (
129        &[
130            "fix",
131            "bug",
132            "broken",
133            "crash",
134            "error in",
135            "not working",
136            "fails",
137            "wrong output",
138        ],
139        TaskType::FixBug,
140        0.95,
141    ),
142    (
143        &[
144            "refactor",
145            "clean up",
146            "restructure",
147            "rename",
148            "move",
149            "extract",
150            "simplify",
151            "split",
152        ],
153        TaskType::Refactor,
154        0.9,
155    ),
156    (
157        &[
158            "how",
159            "what",
160            "where",
161            "explain",
162            "understand",
163            "show me",
164            "describe",
165            "why does",
166        ],
167        TaskType::Explore,
168        0.85,
169    ),
170    (
171        &[
172            "test",
173            "spec",
174            "coverage",
175            "assert",
176            "unit test",
177            "integration test",
178            "mock",
179        ],
180        TaskType::Test,
181        0.9,
182    ),
183    (
184        &[
185            "debug",
186            "trace",
187            "inspect",
188            "log",
189            "breakpoint",
190            "step through",
191            "stack trace",
192        ],
193        TaskType::Debug,
194        0.9,
195    ),
196    (
197        &[
198            "config",
199            "setup",
200            "install",
201            "env",
202            "configure",
203            "settings",
204            "dotenv",
205        ],
206        TaskType::Config,
207        0.85,
208    ),
209    (
210        &[
211            "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
212        ],
213        TaskType::Deploy,
214        0.85,
215    ),
216    (
217        &[
218            "review",
219            "check",
220            "audit",
221            "look at",
222            "evaluate",
223            "assess",
224            "pr review",
225        ],
226        TaskType::Review,
227        0.8,
228    ),
229];
230
231pub fn classify(query: &str) -> TaskClassification {
232    let q = query.to_lowercase();
233    let words: Vec<&str> = q.split_whitespace().collect();
234
235    let mut best_type = TaskType::Explore;
236    let mut best_score = 0.0_f64;
237
238    for &(phrases, task_type, base_confidence) in PHRASE_RULES {
239        let mut match_count = 0usize;
240        for phrase in phrases {
241            if phrase.contains(' ') {
242                if q.contains(phrase) {
243                    match_count += 2;
244                }
245            } else if words.contains(phrase) {
246                match_count += 1;
247            }
248        }
249        if match_count > 0 {
250            let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
251            if score > best_score {
252                best_score = score;
253                best_type = task_type;
254            }
255        }
256    }
257
258    let targets = extract_targets(query);
259    let keywords = extract_keywords(&q);
260
261    if best_score < 0.1 {
262        best_type = TaskType::Explore;
263        best_score = 0.3;
264    }
265
266    TaskClassification {
267        task_type: best_type,
268        confidence: best_score,
269        targets,
270        keywords,
271    }
272}
273
274fn extract_targets(query: &str) -> Vec<String> {
275    let mut targets = Vec::new();
276
277    for word in query.split_whitespace() {
278        if word.contains('.') && !word.starts_with('.') {
279            let clean = word.trim_matches(|c: char| {
280                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
281            });
282            if looks_like_path(clean) {
283                targets.push(clean.to_string());
284            }
285        }
286        if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
287            let clean = word.trim_matches(|c: char| {
288                !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
289            });
290            if clean.len() > 2 {
291                targets.push(clean.to_string());
292            }
293        }
294    }
295
296    for word in query.split_whitespace() {
297        let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
298        if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
299            targets.push(w.to_string());
300        }
301        if w.chars().any(|c| c.is_uppercase())
302            && w.len() > 2
303            && !is_stop_word(w)
304            && !targets.contains(&w.to_string())
305        {
306            targets.push(w.to_string());
307        }
308    }
309
310    targets.truncate(5);
311    targets
312}
313
314fn looks_like_path(s: &str) -> bool {
315    let exts = [
316        ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
317    ];
318    exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
319}
320
321fn is_stop_word(w: &str) -> bool {
322    matches!(
323        w.to_lowercase().as_str(),
324        "the"
325            | "this"
326            | "that"
327            | "with"
328            | "from"
329            | "into"
330            | "have"
331            | "please"
332            | "could"
333            | "would"
334            | "should"
335            | "also"
336            | "just"
337            | "then"
338            | "when"
339            | "what"
340            | "where"
341            | "which"
342            | "there"
343            | "here"
344            | "these"
345            | "those"
346            | "does"
347            | "will"
348            | "shall"
349            | "can"
350            | "may"
351            | "must"
352            | "need"
353            | "want"
354            | "like"
355            | "make"
356            | "take"
357    )
358}
359
360fn extract_keywords(query: &str) -> Vec<String> {
361    query
362        .split_whitespace()
363        .filter(|w| w.len() > 3)
364        .filter(|w| !is_stop_word(w))
365        .map(|w| {
366            w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
367                .to_lowercase()
368        })
369        .filter(|w| !w.is_empty())
370        .take(8)
371        .collect()
372}
373
374pub fn format_briefing_header(classification: &TaskClassification) -> String {
375    format!(
376        "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
377        classification.task_type.as_str(),
378        classification.confidence * 100.0,
379        if classification.targets.is_empty() {
380            "-".to_string()
381        } else {
382            classification.targets.join(",")
383        },
384        classification.keywords.join(","),
385    )
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn classify_fix_bug() {
394        let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
395        assert_eq!(r.task_type, TaskType::FixBug);
396        assert!(r.confidence > 0.5);
397        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
398    }
399
400    #[test]
401    fn classify_generate() {
402        let r = classify("add a new function normalized_token_entropy to entropy.rs");
403        assert_eq!(r.task_type, TaskType::Generate);
404        assert!(r.confidence > 0.5);
405    }
406
407    #[test]
408    fn classify_refactor() {
409        let r = classify("refactor the compression pipeline to split into smaller modules");
410        assert_eq!(r.task_type, TaskType::Refactor);
411    }
412
413    #[test]
414    fn classify_explore() {
415        let r = classify("how does the session cache work?");
416        assert_eq!(r.task_type, TaskType::Explore);
417    }
418
419    #[test]
420    fn classify_debug() {
421        let r = classify("debug why the compression ratio drops for large files");
422        assert_eq!(r.task_type, TaskType::Debug);
423    }
424
425    #[test]
426    fn classify_test() {
427        let r = classify("write unit tests for the token_optimizer module");
428        assert_eq!(r.task_type, TaskType::Test);
429    }
430
431    #[test]
432    fn targets_extract_paths() {
433        let r = classify("fix entropy.rs and update core/mod.rs");
434        assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
435        assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
436    }
437
438    #[test]
439    fn targets_extract_identifiers() {
440        let r = classify("refactor SessionCache to use LRU eviction");
441        assert!(r.targets.iter().any(|t| t == "SessionCache"));
442    }
443
444    #[test]
445    fn fallback_to_explore() {
446        let r = classify("xyz qqq bbb");
447        assert_eq!(r.task_type, TaskType::Explore);
448        assert!(r.confidence < 0.5);
449    }
450}