1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
2pub enum TaskType {
3 Generate,
4 FixBug,
5 Refactor,
6 Explore,
7 Test,
8 Debug,
9 Config,
10 Deploy,
11 Review,
12}
13
14impl TaskType {
15 pub fn as_str(&self) -> &'static str {
16 match self {
17 Self::Generate => "generate",
18 Self::FixBug => "fix_bug",
19 Self::Refactor => "refactor",
20 Self::Explore => "explore",
21 Self::Test => "test",
22 Self::Debug => "debug",
23 Self::Config => "config",
24 Self::Deploy => "deploy",
25 Self::Review => "review",
26 }
27 }
28
29 pub fn thinking_budget(&self) -> ThinkingBudget {
30 match self {
31 Self::Generate | Self::FixBug | Self::Test | Self::Config | Self::Deploy => {
32 ThinkingBudget::Minimal
33 }
34 Self::Refactor | Self::Explore | Self::Debug | Self::Review => ThinkingBudget::Medium,
35 }
36 }
37
38 pub fn output_format(&self) -> OutputFormat {
39 match self {
40 Self::Generate | Self::Test | Self::Config => OutputFormat::CodeOnly,
41 Self::FixBug | Self::Refactor => OutputFormat::DiffOnly,
42 Self::Explore | Self::Review => OutputFormat::ExplainConcise,
43 Self::Debug => OutputFormat::Trace,
44 Self::Deploy => OutputFormat::StepList,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum ThinkingBudget {
51 Minimal,
52 Medium,
53 Trace,
54 Deep,
55}
56
57impl ThinkingBudget {
58 pub fn instruction(&self) -> &'static str {
59 match self {
60 Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
61 Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
62 Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
63 Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
64 }
65 }
66
67 pub fn suppresses_thinking(&self) -> bool {
68 matches!(self, Self::Minimal)
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum OutputFormat {
74 CodeOnly,
75 DiffOnly,
76 ExplainConcise,
77 Trace,
78 StepList,
79}
80
81impl OutputFormat {
82 pub fn instruction(&self) -> &'static str {
83 match self {
84 Self::CodeOnly => {
85 "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
86 }
87 Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
88 Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
89 Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
90 Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
91 }
92 }
93}
94
95#[derive(Debug)]
96pub struct TaskClassification {
97 pub task_type: TaskType,
98 pub confidence: f64,
99 pub targets: Vec<String>,
100 pub keywords: Vec<String>,
101}
102
103const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
104 (
105 &[
106 "add",
107 "create",
108 "implement",
109 "build",
110 "write",
111 "generate",
112 "make",
113 "new feature",
114 "new",
115 ],
116 TaskType::Generate,
117 0.9,
118 ),
119 (
120 &[
121 "fix",
122 "bug",
123 "broken",
124 "crash",
125 "error in",
126 "not working",
127 "fails",
128 "wrong output",
129 ],
130 TaskType::FixBug,
131 0.95,
132 ),
133 (
134 &[
135 "refactor",
136 "clean up",
137 "restructure",
138 "rename",
139 "move",
140 "extract",
141 "simplify",
142 "split",
143 ],
144 TaskType::Refactor,
145 0.9,
146 ),
147 (
148 &[
149 "how",
150 "what",
151 "where",
152 "explain",
153 "understand",
154 "show me",
155 "describe",
156 "why does",
157 ],
158 TaskType::Explore,
159 0.85,
160 ),
161 (
162 &[
163 "test",
164 "spec",
165 "coverage",
166 "assert",
167 "unit test",
168 "integration test",
169 "mock",
170 ],
171 TaskType::Test,
172 0.9,
173 ),
174 (
175 &[
176 "debug",
177 "trace",
178 "inspect",
179 "log",
180 "breakpoint",
181 "step through",
182 "stack trace",
183 ],
184 TaskType::Debug,
185 0.9,
186 ),
187 (
188 &[
189 "config",
190 "setup",
191 "install",
192 "env",
193 "configure",
194 "settings",
195 "dotenv",
196 ],
197 TaskType::Config,
198 0.85,
199 ),
200 (
201 &[
202 "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
203 ],
204 TaskType::Deploy,
205 0.85,
206 ),
207 (
208 &[
209 "review",
210 "check",
211 "audit",
212 "look at",
213 "evaluate",
214 "assess",
215 "pr review",
216 ],
217 TaskType::Review,
218 0.8,
219 ),
220];
221
222pub fn classify(query: &str) -> TaskClassification {
223 let q = query.to_lowercase();
224 let words: Vec<&str> = q.split_whitespace().collect();
225
226 let mut best_type = TaskType::Explore;
227 let mut best_score = 0.0_f64;
228
229 for &(phrases, task_type, base_confidence) in PHRASE_RULES {
230 let mut match_count = 0usize;
231 for phrase in phrases {
232 if phrase.contains(' ') {
233 if q.contains(phrase) {
234 match_count += 2;
235 }
236 } else if words.contains(phrase) {
237 match_count += 1;
238 }
239 }
240 if match_count > 0 {
241 let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
242 if score > best_score {
243 best_score = score;
244 best_type = task_type;
245 }
246 }
247 }
248
249 let targets = extract_targets(query);
250 let keywords = extract_keywords(&q);
251
252 if best_score < 0.1 {
253 best_type = TaskType::Explore;
254 best_score = 0.3;
255 }
256
257 TaskClassification {
258 task_type: best_type,
259 confidence: best_score,
260 targets,
261 keywords,
262 }
263}
264
265fn extract_targets(query: &str) -> Vec<String> {
266 let mut targets = Vec::new();
267
268 for word in query.split_whitespace() {
269 if word.contains('.') && !word.starts_with('.') {
270 let clean = word.trim_matches(|c: char| {
271 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
272 });
273 if looks_like_path(clean) {
274 targets.push(clean.to_string());
275 }
276 }
277 if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
278 let clean = word.trim_matches(|c: char| {
279 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
280 });
281 if clean.len() > 2 {
282 targets.push(clean.to_string());
283 }
284 }
285 }
286
287 for word in query.split_whitespace() {
288 let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
289 if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
290 targets.push(w.to_string());
291 }
292 if w.chars().any(char::is_uppercase)
293 && w.len() > 2
294 && !is_stop_word(w)
295 && !targets.contains(&w.to_string())
296 {
297 targets.push(w.to_string());
298 }
299 }
300
301 targets.truncate(5);
302 targets
303}
304
305fn looks_like_path(s: &str) -> bool {
306 let exts = [
307 ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
308 ];
309 exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
310}
311
312fn is_stop_word(w: &str) -> bool {
313 matches!(
314 w.to_lowercase().as_str(),
315 "the"
316 | "this"
317 | "that"
318 | "with"
319 | "from"
320 | "into"
321 | "have"
322 | "please"
323 | "could"
324 | "would"
325 | "should"
326 | "also"
327 | "just"
328 | "then"
329 | "when"
330 | "what"
331 | "where"
332 | "which"
333 | "there"
334 | "here"
335 | "these"
336 | "those"
337 | "does"
338 | "will"
339 | "shall"
340 | "can"
341 | "may"
342 | "must"
343 | "need"
344 | "want"
345 | "like"
346 | "make"
347 | "take"
348 )
349}
350
351fn extract_keywords(query: &str) -> Vec<String> {
352 query
353 .split_whitespace()
354 .filter(|w| w.len() > 3)
355 .filter(|w| !is_stop_word(w))
356 .map(|w| {
357 w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
358 .to_lowercase()
359 })
360 .filter(|w| !w.is_empty())
361 .take(8)
362 .collect()
363}
364
365pub fn classify_complexity(
366 query: &str,
367 classification: &TaskClassification,
368) -> super::adaptive::TaskComplexity {
369 use super::adaptive::TaskComplexity;
370
371 let q = query.to_lowercase();
372 let word_count = q.split_whitespace().count();
373 let target_count = classification.targets.len();
374
375 let has_multi_file = target_count >= 3;
376 let has_cross_cutting = q.contains("all files")
377 || q.contains("across")
378 || q.contains("everywhere")
379 || q.contains("every")
380 || q.contains("migration")
381 || q.contains("architecture");
382
383 let is_simple = word_count < 8
384 && target_count <= 1
385 && matches!(
386 classification.task_type,
387 TaskType::Generate | TaskType::Config
388 );
389
390 if is_simple {
391 TaskComplexity::Mechanical
392 } else if has_multi_file || has_cross_cutting {
393 TaskComplexity::Architectural
394 } else {
395 TaskComplexity::Standard
396 }
397}
398
399pub fn detect_multi_intent(query: &str) -> Vec<TaskClassification> {
400 let delimiters = [" and then ", " then ", " also ", " + ", ". "];
401
402 let mut parts: Vec<&str> = vec![query];
403 for delim in &delimiters {
404 let mut new_parts = Vec::new();
405 for part in &parts {
406 for sub in part.split(delim) {
407 let trimmed = sub.trim();
408 if !trimmed.is_empty() {
409 new_parts.push(trimmed);
410 }
411 }
412 }
413 parts = new_parts;
414 }
415
416 if parts.len() <= 1 {
417 return vec![classify(query)];
418 }
419
420 parts.iter().map(|part| classify(part)).collect()
421}
422
423pub fn format_briefing_header(classification: &TaskClassification) -> String {
424 format!(
425 "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
426 classification.task_type.as_str(),
427 classification.confidence * 100.0,
428 if classification.targets.is_empty() {
429 "-".to_string()
430 } else {
431 classification.targets.join(",")
432 },
433 classification.keywords.join(","),
434 )
435}
436
437#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
438pub enum IntentScope {
439 SingleFile,
440 MultiFile,
441 CrossModule,
442 ProjectWide,
443}
444
445#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
446pub struct StructuredIntent {
447 pub task_type: TaskType,
448 pub confidence: f64,
449 pub targets: Vec<String>,
450 pub keywords: Vec<String>,
451 pub scope: IntentScope,
452 pub language_hint: Option<String>,
453 pub urgency: f64,
454 pub action_verb: Option<String>,
455}
456
457impl StructuredIntent {
458 pub fn from_query(query: &str) -> Self {
459 let classification = classify(query);
460 let complexity = classify_complexity(query, &classification);
461 let file_targets = classification
462 .targets
463 .iter()
464 .filter(|t| t.contains('.') || t.contains('/'))
465 .count();
466 let scope = match complexity {
467 super::adaptive::TaskComplexity::Mechanical => IntentScope::SingleFile,
468 super::adaptive::TaskComplexity::Standard => {
469 if file_targets > 1 {
470 IntentScope::MultiFile
471 } else {
472 IntentScope::SingleFile
473 }
474 }
475 super::adaptive::TaskComplexity::Architectural => {
476 let q = query.to_lowercase();
477 if q.contains("all files") || q.contains("everywhere") || q.contains("migration") {
478 IntentScope::ProjectWide
479 } else {
480 IntentScope::CrossModule
481 }
482 }
483 };
484
485 let language_hint = detect_language_hint(query, &classification.targets);
486 let urgency = detect_urgency(query);
487 let action_verb = extract_action_verb(query);
488
489 StructuredIntent {
490 task_type: classification.task_type,
491 confidence: classification.confidence,
492 targets: classification.targets,
493 keywords: classification.keywords,
494 scope,
495 language_hint,
496 urgency,
497 action_verb,
498 }
499 }
500
501 pub fn from_file_patterns(touched_files: &[String]) -> Self {
502 if touched_files.is_empty() {
503 return Self {
504 task_type: TaskType::Explore,
505 confidence: 0.3,
506 targets: Vec::new(),
507 keywords: Vec::new(),
508 scope: IntentScope::SingleFile,
509 language_hint: None,
510 urgency: 0.0,
511 action_verb: None,
512 };
513 }
514
515 let has_tests = touched_files
516 .iter()
517 .any(|f| f.contains("test") || f.contains("spec"));
518 let has_config = touched_files.iter().any(|f| {
519 let p = std::path::Path::new(f.as_str());
520 let is_config_ext = p.extension().is_some_and(|e| {
521 e.eq_ignore_ascii_case("toml")
522 || e.eq_ignore_ascii_case("yaml")
523 || e.eq_ignore_ascii_case("yml")
524 || e.eq_ignore_ascii_case("json")
525 });
526 is_config_ext || f.contains("config") || f.contains(".env")
527 });
528
529 let dirs: std::collections::HashSet<&str> = touched_files
530 .iter()
531 .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
532 .collect();
533
534 let task_type = if has_tests && touched_files.len() <= 3 {
535 TaskType::Test
536 } else if has_config && touched_files.len() <= 2 {
537 TaskType::Config
538 } else if dirs.len() > 3 {
539 TaskType::Refactor
540 } else {
541 TaskType::Explore
542 };
543
544 let scope = match touched_files.len() {
545 1 => IntentScope::SingleFile,
546 2..=4 => IntentScope::MultiFile,
547 _ => IntentScope::CrossModule,
548 };
549
550 let language_hint = detect_language_from_files(touched_files);
551
552 Self {
553 task_type,
554 confidence: 0.5,
555 targets: touched_files.to_vec(),
556 keywords: Vec::new(),
557 scope,
558 language_hint,
559 urgency: 0.0,
560 action_verb: None,
561 }
562 }
563
564 pub fn from_query_with_session(query: &str, touched_files: &[String]) -> Self {
565 let mut intent = Self::from_query(query);
566
567 if intent.language_hint.is_none() && !touched_files.is_empty() {
568 intent.language_hint = detect_language_from_files(touched_files);
569 }
570
571 if intent.scope == IntentScope::SingleFile && touched_files.len() > 3 {
572 let dirs: std::collections::HashSet<&str> = touched_files
573 .iter()
574 .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
575 .collect();
576 if dirs.len() > 2 {
577 intent.scope = IntentScope::MultiFile;
578 }
579 }
580
581 intent
582 }
583
584 pub fn format_header(&self) -> String {
585 format!(
586 "[TASK:{} SCOPE:{} CONF:{:.0}%{}{}]",
587 self.task_type.as_str(),
588 match self.scope {
589 IntentScope::SingleFile => "single",
590 IntentScope::MultiFile => "multi",
591 IntentScope::CrossModule => "cross",
592 IntentScope::ProjectWide => "project",
593 },
594 self.confidence * 100.0,
595 self.language_hint
596 .as_ref()
597 .map(|l| format!(" LANG:{l}"))
598 .unwrap_or_default(),
599 if self.urgency > 0.5 { " URGENT" } else { "" },
600 )
601 }
602}
603
604fn detect_language_hint(query: &str, targets: &[String]) -> Option<String> {
605 for t in targets {
606 let ext = std::path::Path::new(t).extension().and_then(|e| e.to_str());
607 match ext {
608 Some("rs") => return Some("rust".into()),
609 Some("ts" | "tsx") => return Some("typescript".into()),
610 Some("js" | "jsx") => return Some("javascript".into()),
611 Some("py") => return Some("python".into()),
612 Some("go") => return Some("go".into()),
613 Some("rb") => return Some("ruby".into()),
614 Some("java") => return Some("java".into()),
615 Some("swift") => return Some("swift".into()),
616 Some("zig") => return Some("zig".into()),
617 _ => {}
618 }
619 }
620
621 let q = query.to_lowercase();
622 let lang_keywords: &[(&str, &str)] = &[
623 ("rust", "rust"),
624 ("python", "python"),
625 ("typescript", "typescript"),
626 ("javascript", "javascript"),
627 ("golang", "go"),
628 (" go ", "go"),
629 ("ruby", "ruby"),
630 ("java ", "java"),
631 ("swift", "swift"),
632 ];
633 for &(kw, lang) in lang_keywords {
634 if q.contains(kw) {
635 return Some(lang.into());
636 }
637 }
638
639 None
640}
641
642fn detect_language_from_files(files: &[String]) -> Option<String> {
643 let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
644 for f in files {
645 let ext = std::path::Path::new(f)
646 .extension()
647 .and_then(|e| e.to_str())
648 .unwrap_or("");
649 let lang = match ext {
650 "rs" => "rust",
651 "ts" | "tsx" => "typescript",
652 "js" | "jsx" => "javascript",
653 "py" => "python",
654 "go" => "go",
655 "rb" => "ruby",
656 "java" => "java",
657 _ => continue,
658 };
659 *counts.entry(lang).or_insert(0) += 1;
660 }
661 counts
662 .into_iter()
663 .max_by_key(|(_, c)| *c)
664 .map(|(l, _)| l.to_string())
665}
666
667fn detect_urgency(query: &str) -> f64 {
668 let q = query.to_lowercase();
669 let urgent_words = [
670 "urgent",
671 "asap",
672 "immediately",
673 "critical",
674 "hotfix",
675 "emergency",
676 "blocker",
677 "breaking",
678 ];
679 let hits = urgent_words.iter().filter(|w| q.contains(*w)).count();
680 (hits as f64 * 0.4).min(1.0)
681}
682
683fn extract_action_verb(query: &str) -> Option<String> {
684 let verbs = [
685 "fix",
686 "add",
687 "create",
688 "implement",
689 "refactor",
690 "debug",
691 "test",
692 "write",
693 "update",
694 "remove",
695 "delete",
696 "rename",
697 "move",
698 "extract",
699 "split",
700 "merge",
701 "deploy",
702 "review",
703 "check",
704 "build",
705 "generate",
706 "optimize",
707 "clean",
708 ];
709 let q = query.to_lowercase();
710 let words: Vec<&str> = q.split_whitespace().collect();
711 for v in &verbs {
712 if words.first() == Some(v) || words.get(1) == Some(v) {
713 return Some(v.to_string());
714 }
715 }
716 for v in &verbs {
717 if words.contains(v) {
718 return Some(v.to_string());
719 }
720 }
721 None
722}
723
724#[cfg(test)]
725mod tests {
726 use super::*;
727
728 #[test]
729 fn classify_fix_bug() {
730 let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
731 assert_eq!(r.task_type, TaskType::FixBug);
732 assert!(r.confidence > 0.5);
733 assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
734 }
735
736 #[test]
737 fn classify_generate() {
738 let r = classify("add a new function normalized_token_entropy to entropy.rs");
739 assert_eq!(r.task_type, TaskType::Generate);
740 assert!(r.confidence > 0.5);
741 }
742
743 #[test]
744 fn classify_refactor() {
745 let r = classify("refactor the compression pipeline to split into smaller modules");
746 assert_eq!(r.task_type, TaskType::Refactor);
747 }
748
749 #[test]
750 fn classify_explore() {
751 let r = classify("how does the session cache work?");
752 assert_eq!(r.task_type, TaskType::Explore);
753 }
754
755 #[test]
756 fn classify_debug() {
757 let r = classify("debug why the compression ratio drops for large files");
758 assert_eq!(r.task_type, TaskType::Debug);
759 }
760
761 #[test]
762 fn classify_test() {
763 let r = classify("write unit tests for the token_optimizer module");
764 assert_eq!(r.task_type, TaskType::Test);
765 }
766
767 #[test]
768 fn targets_extract_paths() {
769 let r = classify("fix entropy.rs and update core/mod.rs");
770 assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
771 assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
772 }
773
774 #[test]
775 fn targets_extract_identifiers() {
776 let r = classify("refactor SessionCache to use LRU eviction");
777 assert!(r.targets.iter().any(|t| t == "SessionCache"));
778 }
779
780 #[test]
781 fn fallback_to_explore() {
782 let r = classify("xyz qqq bbb");
783 assert_eq!(r.task_type, TaskType::Explore);
784 assert!(r.confidence < 0.5);
785 }
786
787 #[test]
788 fn multi_intent_detection() {
789 let results = detect_multi_intent("fix the bug in auth.rs and then write unit tests");
790 assert!(results.len() >= 2);
791 assert_eq!(results[0].task_type, TaskType::FixBug);
792 assert_eq!(results[1].task_type, TaskType::Test);
793 }
794
795 #[test]
796 fn single_intent_no_split() {
797 let results = detect_multi_intent("fix the bug in auth.rs");
798 assert_eq!(results.len(), 1);
799 assert_eq!(results[0].task_type, TaskType::FixBug);
800 }
801
802 #[test]
803 fn complexity_mechanical() {
804 let r = classify("add a comment");
805 let c = classify_complexity("add a comment", &r);
806 assert_eq!(c, super::super::adaptive::TaskComplexity::Mechanical);
807 }
808
809 #[test]
810 fn complexity_architectural() {
811 let r = classify("refactor auth across all files and update the migration");
812 let c = classify_complexity(
813 "refactor auth across all files and update the migration",
814 &r,
815 );
816 assert_eq!(c, super::super::adaptive::TaskComplexity::Architectural);
817 }
818}