1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
2pub enum TaskType {
3 Generate,
4 FixBug,
5 Refactor,
6 Explore,
7 Test,
8 Debug,
9 Config,
10 Deploy,
11 Review,
12}
13
14impl TaskType {
15 pub fn as_str(&self) -> &'static str {
16 match self {
17 Self::Generate => "generate",
18 Self::FixBug => "fix_bug",
19 Self::Refactor => "refactor",
20 Self::Explore => "explore",
21 Self::Test => "test",
22 Self::Debug => "debug",
23 Self::Config => "config",
24 Self::Deploy => "deploy",
25 Self::Review => "review",
26 }
27 }
28
29 pub fn thinking_budget(&self) -> ThinkingBudget {
30 match self {
31 Self::Generate => ThinkingBudget::Minimal,
32 Self::FixBug => ThinkingBudget::Minimal,
33 Self::Refactor => ThinkingBudget::Medium,
34 Self::Explore => ThinkingBudget::Medium,
35 Self::Test => ThinkingBudget::Minimal,
36 Self::Debug => ThinkingBudget::Medium,
37 Self::Config => ThinkingBudget::Minimal,
38 Self::Deploy => ThinkingBudget::Minimal,
39 Self::Review => ThinkingBudget::Medium,
40 }
41 }
42
43 pub fn output_format(&self) -> OutputFormat {
44 match self {
45 Self::Generate => OutputFormat::CodeOnly,
46 Self::FixBug => OutputFormat::DiffOnly,
47 Self::Refactor => OutputFormat::DiffOnly,
48 Self::Explore => OutputFormat::ExplainConcise,
49 Self::Test => OutputFormat::CodeOnly,
50 Self::Debug => OutputFormat::Trace,
51 Self::Config => OutputFormat::CodeOnly,
52 Self::Deploy => OutputFormat::StepList,
53 Self::Review => OutputFormat::ExplainConcise,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59pub enum ThinkingBudget {
60 Minimal,
61 Medium,
62 Trace,
63 Deep,
64}
65
66impl ThinkingBudget {
67 pub fn instruction(&self) -> &'static str {
68 match self {
69 Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
70 Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
71 Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
72 Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
73 }
74 }
75
76 pub fn suppresses_thinking(&self) -> bool {
77 matches!(self, Self::Minimal)
78 }
79}
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum OutputFormat {
83 CodeOnly,
84 DiffOnly,
85 ExplainConcise,
86 Trace,
87 StepList,
88}
89
90impl OutputFormat {
91 pub fn instruction(&self) -> &'static str {
92 match self {
93 Self::CodeOnly => {
94 "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
95 }
96 Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
97 Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
98 Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
99 Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
100 }
101 }
102}
103
104#[derive(Debug)]
105pub struct TaskClassification {
106 pub task_type: TaskType,
107 pub confidence: f64,
108 pub targets: Vec<String>,
109 pub keywords: Vec<String>,
110}
111
112const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
113 (
114 &[
115 "add",
116 "create",
117 "implement",
118 "build",
119 "write",
120 "generate",
121 "make",
122 "new feature",
123 "new",
124 ],
125 TaskType::Generate,
126 0.9,
127 ),
128 (
129 &[
130 "fix",
131 "bug",
132 "broken",
133 "crash",
134 "error in",
135 "not working",
136 "fails",
137 "wrong output",
138 ],
139 TaskType::FixBug,
140 0.95,
141 ),
142 (
143 &[
144 "refactor",
145 "clean up",
146 "restructure",
147 "rename",
148 "move",
149 "extract",
150 "simplify",
151 "split",
152 ],
153 TaskType::Refactor,
154 0.9,
155 ),
156 (
157 &[
158 "how",
159 "what",
160 "where",
161 "explain",
162 "understand",
163 "show me",
164 "describe",
165 "why does",
166 ],
167 TaskType::Explore,
168 0.85,
169 ),
170 (
171 &[
172 "test",
173 "spec",
174 "coverage",
175 "assert",
176 "unit test",
177 "integration test",
178 "mock",
179 ],
180 TaskType::Test,
181 0.9,
182 ),
183 (
184 &[
185 "debug",
186 "trace",
187 "inspect",
188 "log",
189 "breakpoint",
190 "step through",
191 "stack trace",
192 ],
193 TaskType::Debug,
194 0.9,
195 ),
196 (
197 &[
198 "config",
199 "setup",
200 "install",
201 "env",
202 "configure",
203 "settings",
204 "dotenv",
205 ],
206 TaskType::Config,
207 0.85,
208 ),
209 (
210 &[
211 "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
212 ],
213 TaskType::Deploy,
214 0.85,
215 ),
216 (
217 &[
218 "review",
219 "check",
220 "audit",
221 "look at",
222 "evaluate",
223 "assess",
224 "pr review",
225 ],
226 TaskType::Review,
227 0.8,
228 ),
229];
230
231pub fn classify(query: &str) -> TaskClassification {
232 let q = query.to_lowercase();
233 let words: Vec<&str> = q.split_whitespace().collect();
234
235 let mut best_type = TaskType::Explore;
236 let mut best_score = 0.0_f64;
237
238 for &(phrases, task_type, base_confidence) in PHRASE_RULES {
239 let mut match_count = 0usize;
240 for phrase in phrases {
241 if phrase.contains(' ') {
242 if q.contains(phrase) {
243 match_count += 2;
244 }
245 } else if words.contains(phrase) {
246 match_count += 1;
247 }
248 }
249 if match_count > 0 {
250 let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
251 if score > best_score {
252 best_score = score;
253 best_type = task_type;
254 }
255 }
256 }
257
258 let targets = extract_targets(query);
259 let keywords = extract_keywords(&q);
260
261 if best_score < 0.1 {
262 best_type = TaskType::Explore;
263 best_score = 0.3;
264 }
265
266 TaskClassification {
267 task_type: best_type,
268 confidence: best_score,
269 targets,
270 keywords,
271 }
272}
273
274fn extract_targets(query: &str) -> Vec<String> {
275 let mut targets = Vec::new();
276
277 for word in query.split_whitespace() {
278 if word.contains('.') && !word.starts_with('.') {
279 let clean = word.trim_matches(|c: char| {
280 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
281 });
282 if looks_like_path(clean) {
283 targets.push(clean.to_string());
284 }
285 }
286 if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
287 let clean = word.trim_matches(|c: char| {
288 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
289 });
290 if clean.len() > 2 {
291 targets.push(clean.to_string());
292 }
293 }
294 }
295
296 for word in query.split_whitespace() {
297 let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
298 if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
299 targets.push(w.to_string());
300 }
301 if w.chars().any(|c| c.is_uppercase())
302 && w.len() > 2
303 && !is_stop_word(w)
304 && !targets.contains(&w.to_string())
305 {
306 targets.push(w.to_string());
307 }
308 }
309
310 targets.truncate(5);
311 targets
312}
313
314fn looks_like_path(s: &str) -> bool {
315 let exts = [
316 ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
317 ];
318 exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
319}
320
321fn is_stop_word(w: &str) -> bool {
322 matches!(
323 w.to_lowercase().as_str(),
324 "the"
325 | "this"
326 | "that"
327 | "with"
328 | "from"
329 | "into"
330 | "have"
331 | "please"
332 | "could"
333 | "would"
334 | "should"
335 | "also"
336 | "just"
337 | "then"
338 | "when"
339 | "what"
340 | "where"
341 | "which"
342 | "there"
343 | "here"
344 | "these"
345 | "those"
346 | "does"
347 | "will"
348 | "shall"
349 | "can"
350 | "may"
351 | "must"
352 | "need"
353 | "want"
354 | "like"
355 | "make"
356 | "take"
357 )
358}
359
360fn extract_keywords(query: &str) -> Vec<String> {
361 query
362 .split_whitespace()
363 .filter(|w| w.len() > 3)
364 .filter(|w| !is_stop_word(w))
365 .map(|w| {
366 w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
367 .to_lowercase()
368 })
369 .filter(|w| !w.is_empty())
370 .take(8)
371 .collect()
372}
373
374pub fn classify_complexity(
375 query: &str,
376 classification: &TaskClassification,
377) -> super::adaptive::TaskComplexity {
378 use super::adaptive::TaskComplexity;
379
380 let q = query.to_lowercase();
381 let word_count = q.split_whitespace().count();
382 let target_count = classification.targets.len();
383
384 let has_multi_file = target_count >= 3;
385 let has_cross_cutting = q.contains("all files")
386 || q.contains("across")
387 || q.contains("everywhere")
388 || q.contains("every")
389 || q.contains("migration")
390 || q.contains("architecture");
391
392 let is_simple = word_count < 8
393 && target_count <= 1
394 && matches!(
395 classification.task_type,
396 TaskType::Generate | TaskType::Config
397 );
398
399 if is_simple {
400 TaskComplexity::Mechanical
401 } else if has_multi_file || has_cross_cutting {
402 TaskComplexity::Architectural
403 } else {
404 TaskComplexity::Standard
405 }
406}
407
408pub fn detect_multi_intent(query: &str) -> Vec<TaskClassification> {
409 let delimiters = [" and then ", " then ", " also ", " + ", ". "];
410
411 let mut parts: Vec<&str> = vec![query];
412 for delim in &delimiters {
413 let mut new_parts = Vec::new();
414 for part in &parts {
415 for sub in part.split(delim) {
416 let trimmed = sub.trim();
417 if !trimmed.is_empty() {
418 new_parts.push(trimmed);
419 }
420 }
421 }
422 parts = new_parts;
423 }
424
425 if parts.len() <= 1 {
426 return vec![classify(query)];
427 }
428
429 parts.iter().map(|part| classify(part)).collect()
430}
431
432pub fn format_briefing_header(classification: &TaskClassification) -> String {
433 format!(
434 "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
435 classification.task_type.as_str(),
436 classification.confidence * 100.0,
437 if classification.targets.is_empty() {
438 "-".to_string()
439 } else {
440 classification.targets.join(",")
441 },
442 classification.keywords.join(","),
443 )
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
447pub enum IntentScope {
448 SingleFile,
449 MultiFile,
450 CrossModule,
451 ProjectWide,
452}
453
454#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
455pub struct StructuredIntent {
456 pub task_type: TaskType,
457 pub confidence: f64,
458 pub targets: Vec<String>,
459 pub keywords: Vec<String>,
460 pub scope: IntentScope,
461 pub language_hint: Option<String>,
462 pub urgency: f64,
463 pub action_verb: Option<String>,
464}
465
466impl StructuredIntent {
467 pub fn from_query(query: &str) -> Self {
468 let classification = classify(query);
469 let complexity = classify_complexity(query, &classification);
470 let file_targets = classification
471 .targets
472 .iter()
473 .filter(|t| t.contains('.') || t.contains('/'))
474 .count();
475 let scope = match complexity {
476 super::adaptive::TaskComplexity::Mechanical => IntentScope::SingleFile,
477 super::adaptive::TaskComplexity::Standard => {
478 if file_targets > 1 {
479 IntentScope::MultiFile
480 } else {
481 IntentScope::SingleFile
482 }
483 }
484 super::adaptive::TaskComplexity::Architectural => {
485 let q = query.to_lowercase();
486 if q.contains("all files") || q.contains("everywhere") || q.contains("migration") {
487 IntentScope::ProjectWide
488 } else {
489 IntentScope::CrossModule
490 }
491 }
492 };
493
494 let language_hint = detect_language_hint(query, &classification.targets);
495 let urgency = detect_urgency(query);
496 let action_verb = extract_action_verb(query);
497
498 StructuredIntent {
499 task_type: classification.task_type,
500 confidence: classification.confidence,
501 targets: classification.targets,
502 keywords: classification.keywords,
503 scope,
504 language_hint,
505 urgency,
506 action_verb,
507 }
508 }
509
510 pub fn from_file_patterns(touched_files: &[String]) -> Self {
511 if touched_files.is_empty() {
512 return Self {
513 task_type: TaskType::Explore,
514 confidence: 0.3,
515 targets: Vec::new(),
516 keywords: Vec::new(),
517 scope: IntentScope::SingleFile,
518 language_hint: None,
519 urgency: 0.0,
520 action_verb: None,
521 };
522 }
523
524 let has_tests = touched_files
525 .iter()
526 .any(|f| f.contains("test") || f.contains("spec"));
527 let has_config = touched_files.iter().any(|f| {
528 let lower = f.to_lowercase();
529 lower.ends_with(".toml")
530 || lower.ends_with(".yaml")
531 || lower.ends_with(".yml")
532 || lower.ends_with(".json")
533 || lower.contains("config")
534 || lower.contains(".env")
535 });
536
537 let dirs: std::collections::HashSet<&str> = touched_files
538 .iter()
539 .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
540 .collect();
541
542 let task_type = if has_tests && touched_files.len() <= 3 {
543 TaskType::Test
544 } else if has_config && touched_files.len() <= 2 {
545 TaskType::Config
546 } else if dirs.len() > 3 {
547 TaskType::Refactor
548 } else {
549 TaskType::Explore
550 };
551
552 let scope = match touched_files.len() {
553 1 => IntentScope::SingleFile,
554 2..=4 => IntentScope::MultiFile,
555 _ => IntentScope::CrossModule,
556 };
557
558 let language_hint = detect_language_from_files(touched_files);
559
560 Self {
561 task_type,
562 confidence: 0.5,
563 targets: touched_files.to_vec(),
564 keywords: Vec::new(),
565 scope,
566 language_hint,
567 urgency: 0.0,
568 action_verb: None,
569 }
570 }
571
572 pub fn from_query_with_session(query: &str, touched_files: &[String]) -> Self {
573 let mut intent = Self::from_query(query);
574
575 if intent.language_hint.is_none() && !touched_files.is_empty() {
576 intent.language_hint = detect_language_from_files(touched_files);
577 }
578
579 if intent.scope == IntentScope::SingleFile && touched_files.len() > 3 {
580 let dirs: std::collections::HashSet<&str> = touched_files
581 .iter()
582 .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
583 .collect();
584 if dirs.len() > 2 {
585 intent.scope = IntentScope::MultiFile;
586 }
587 }
588
589 intent
590 }
591
592 pub fn format_header(&self) -> String {
593 format!(
594 "[TASK:{} SCOPE:{} CONF:{:.0}%{}{}]",
595 self.task_type.as_str(),
596 match self.scope {
597 IntentScope::SingleFile => "single",
598 IntentScope::MultiFile => "multi",
599 IntentScope::CrossModule => "cross",
600 IntentScope::ProjectWide => "project",
601 },
602 self.confidence * 100.0,
603 self.language_hint
604 .as_ref()
605 .map(|l| format!(" LANG:{l}"))
606 .unwrap_or_default(),
607 if self.urgency > 0.5 { " URGENT" } else { "" },
608 )
609 }
610}
611
612fn detect_language_hint(query: &str, targets: &[String]) -> Option<String> {
613 for t in targets {
614 let ext = std::path::Path::new(t).extension().and_then(|e| e.to_str());
615 match ext {
616 Some("rs") => return Some("rust".into()),
617 Some("ts" | "tsx") => return Some("typescript".into()),
618 Some("js" | "jsx") => return Some("javascript".into()),
619 Some("py") => return Some("python".into()),
620 Some("go") => return Some("go".into()),
621 Some("rb") => return Some("ruby".into()),
622 Some("java") => return Some("java".into()),
623 Some("swift") => return Some("swift".into()),
624 Some("zig") => return Some("zig".into()),
625 _ => {}
626 }
627 }
628
629 let q = query.to_lowercase();
630 let lang_keywords: &[(&str, &str)] = &[
631 ("rust", "rust"),
632 ("python", "python"),
633 ("typescript", "typescript"),
634 ("javascript", "javascript"),
635 ("golang", "go"),
636 (" go ", "go"),
637 ("ruby", "ruby"),
638 ("java ", "java"),
639 ("swift", "swift"),
640 ];
641 for &(kw, lang) in lang_keywords {
642 if q.contains(kw) {
643 return Some(lang.into());
644 }
645 }
646
647 None
648}
649
650fn detect_language_from_files(files: &[String]) -> Option<String> {
651 let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
652 for f in files {
653 let ext = std::path::Path::new(f)
654 .extension()
655 .and_then(|e| e.to_str())
656 .unwrap_or("");
657 let lang = match ext {
658 "rs" => "rust",
659 "ts" | "tsx" => "typescript",
660 "js" | "jsx" => "javascript",
661 "py" => "python",
662 "go" => "go",
663 "rb" => "ruby",
664 "java" => "java",
665 _ => continue,
666 };
667 *counts.entry(lang).or_insert(0) += 1;
668 }
669 counts
670 .into_iter()
671 .max_by_key(|(_, c)| *c)
672 .map(|(l, _)| l.to_string())
673}
674
675fn detect_urgency(query: &str) -> f64 {
676 let q = query.to_lowercase();
677 let urgent_words = [
678 "urgent",
679 "asap",
680 "immediately",
681 "critical",
682 "hotfix",
683 "emergency",
684 "blocker",
685 "breaking",
686 ];
687 let hits = urgent_words.iter().filter(|w| q.contains(*w)).count();
688 (hits as f64 * 0.4).min(1.0)
689}
690
691fn extract_action_verb(query: &str) -> Option<String> {
692 let verbs = [
693 "fix",
694 "add",
695 "create",
696 "implement",
697 "refactor",
698 "debug",
699 "test",
700 "write",
701 "update",
702 "remove",
703 "delete",
704 "rename",
705 "move",
706 "extract",
707 "split",
708 "merge",
709 "deploy",
710 "review",
711 "check",
712 "build",
713 "generate",
714 "optimize",
715 "clean",
716 ];
717 let q = query.to_lowercase();
718 let words: Vec<&str> = q.split_whitespace().collect();
719 for v in &verbs {
720 if words.first() == Some(v) || words.get(1) == Some(v) {
721 return Some(v.to_string());
722 }
723 }
724 for v in &verbs {
725 if words.contains(v) {
726 return Some(v.to_string());
727 }
728 }
729 None
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735
736 #[test]
737 fn classify_fix_bug() {
738 let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
739 assert_eq!(r.task_type, TaskType::FixBug);
740 assert!(r.confidence > 0.5);
741 assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
742 }
743
744 #[test]
745 fn classify_generate() {
746 let r = classify("add a new function normalized_token_entropy to entropy.rs");
747 assert_eq!(r.task_type, TaskType::Generate);
748 assert!(r.confidence > 0.5);
749 }
750
751 #[test]
752 fn classify_refactor() {
753 let r = classify("refactor the compression pipeline to split into smaller modules");
754 assert_eq!(r.task_type, TaskType::Refactor);
755 }
756
757 #[test]
758 fn classify_explore() {
759 let r = classify("how does the session cache work?");
760 assert_eq!(r.task_type, TaskType::Explore);
761 }
762
763 #[test]
764 fn classify_debug() {
765 let r = classify("debug why the compression ratio drops for large files");
766 assert_eq!(r.task_type, TaskType::Debug);
767 }
768
769 #[test]
770 fn classify_test() {
771 let r = classify("write unit tests for the token_optimizer module");
772 assert_eq!(r.task_type, TaskType::Test);
773 }
774
775 #[test]
776 fn targets_extract_paths() {
777 let r = classify("fix entropy.rs and update core/mod.rs");
778 assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
779 assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
780 }
781
782 #[test]
783 fn targets_extract_identifiers() {
784 let r = classify("refactor SessionCache to use LRU eviction");
785 assert!(r.targets.iter().any(|t| t == "SessionCache"));
786 }
787
788 #[test]
789 fn fallback_to_explore() {
790 let r = classify("xyz qqq bbb");
791 assert_eq!(r.task_type, TaskType::Explore);
792 assert!(r.confidence < 0.5);
793 }
794
795 #[test]
796 fn multi_intent_detection() {
797 let results = detect_multi_intent("fix the bug in auth.rs and then write unit tests");
798 assert!(results.len() >= 2);
799 assert_eq!(results[0].task_type, TaskType::FixBug);
800 assert_eq!(results[1].task_type, TaskType::Test);
801 }
802
803 #[test]
804 fn single_intent_no_split() {
805 let results = detect_multi_intent("fix the bug in auth.rs");
806 assert_eq!(results.len(), 1);
807 assert_eq!(results[0].task_type, TaskType::FixBug);
808 }
809
810 #[test]
811 fn complexity_mechanical() {
812 let r = classify("add a comment");
813 let c = classify_complexity("add a comment", &r);
814 assert_eq!(c, super::super::adaptive::TaskComplexity::Mechanical);
815 }
816
817 #[test]
818 fn complexity_architectural() {
819 let r = classify("refactor auth across all files and update the migration");
820 let c = classify_complexity(
821 "refactor auth across all files and update the migration",
822 &r,
823 );
824 assert_eq!(c, super::super::adaptive::TaskComplexity::Architectural);
825 }
826}