1#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
2#[serde(rename_all = "snake_case")]
3pub enum TaskType {
4 Generate,
5 FixBug,
6 Refactor,
7 Explore,
8 Test,
9 Debug,
10 Config,
11 Deploy,
12 Review,
13}
14
15impl TaskType {
16 pub fn as_str(&self) -> &'static str {
17 match self {
18 Self::Generate => "generate",
19 Self::FixBug => "fix_bug",
20 Self::Refactor => "refactor",
21 Self::Explore => "explore",
22 Self::Test => "test",
23 Self::Debug => "debug",
24 Self::Config => "config",
25 Self::Deploy => "deploy",
26 Self::Review => "review",
27 }
28 }
29
30 pub fn thinking_budget(&self) -> ThinkingBudget {
31 match self {
32 Self::Generate => ThinkingBudget::Minimal,
33 Self::FixBug => ThinkingBudget::Minimal,
34 Self::Refactor => ThinkingBudget::Medium,
35 Self::Explore => ThinkingBudget::Medium,
36 Self::Test => ThinkingBudget::Minimal,
37 Self::Debug => ThinkingBudget::Medium,
38 Self::Config => ThinkingBudget::Minimal,
39 Self::Deploy => ThinkingBudget::Minimal,
40 Self::Review => ThinkingBudget::Medium,
41 }
42 }
43
44 pub fn output_format(&self) -> OutputFormat {
45 match self {
46 Self::Generate => OutputFormat::CodeOnly,
47 Self::FixBug => OutputFormat::DiffOnly,
48 Self::Refactor => OutputFormat::DiffOnly,
49 Self::Explore => OutputFormat::ExplainConcise,
50 Self::Test => OutputFormat::CodeOnly,
51 Self::Debug => OutputFormat::Trace,
52 Self::Config => OutputFormat::CodeOnly,
53 Self::Deploy => OutputFormat::StepList,
54 Self::Review => OutputFormat::ExplainConcise,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum ThinkingBudget {
61 Minimal,
62 Medium,
63 Trace,
64 Deep,
65}
66
67impl ThinkingBudget {
68 pub fn instruction(&self) -> &'static str {
69 match self {
70 Self::Minimal => "THINKING: Skip analysis. The task is clear — generate code directly.",
71 Self::Medium => "THINKING: 2-3 step analysis max. Identify what to change, then act. Do not over-analyze.",
72 Self::Trace => "THINKING: Short trace only. Identify root cause in 3 steps max, then generate fix.",
73 Self::Deep => "THINKING: Analyze structure and dependencies. Summarize findings concisely.",
74 }
75 }
76
77 pub fn suppresses_thinking(&self) -> bool {
78 matches!(self, Self::Minimal)
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum OutputFormat {
84 CodeOnly,
85 DiffOnly,
86 ExplainConcise,
87 Trace,
88 StepList,
89}
90
91impl OutputFormat {
92 pub fn instruction(&self) -> &'static str {
93 match self {
94 Self::CodeOnly => {
95 "OUTPUT-HINT: Prefer code blocks. Minimize prose unless user asks for explanation."
96 }
97 Self::DiffOnly => "OUTPUT-HINT: Prefer showing only changed lines as +/- diffs.",
98 Self::ExplainConcise => "OUTPUT-HINT: Brief summary, then code/data if relevant.",
99 Self::Trace => "OUTPUT-HINT: Show cause→effect chain with code references.",
100 Self::StepList => "OUTPUT-HINT: Numbered action list, one step at a time.",
101 }
102 }
103}
104
105#[derive(Debug)]
106pub struct TaskClassification {
107 pub task_type: TaskType,
108 pub confidence: f64,
109 pub targets: Vec<String>,
110 pub keywords: Vec<String>,
111}
112
113const PHRASE_RULES: &[(&[&str], TaskType, f64)] = &[
114 (
115 &[
116 "add",
117 "create",
118 "implement",
119 "build",
120 "write",
121 "generate",
122 "make",
123 "new feature",
124 "new",
125 ],
126 TaskType::Generate,
127 0.9,
128 ),
129 (
130 &[
131 "fix",
132 "bug",
133 "broken",
134 "crash",
135 "error in",
136 "not working",
137 "fails",
138 "wrong output",
139 ],
140 TaskType::FixBug,
141 0.95,
142 ),
143 (
144 &[
145 "refactor",
146 "clean up",
147 "restructure",
148 "rename",
149 "move",
150 "extract",
151 "simplify",
152 "split",
153 ],
154 TaskType::Refactor,
155 0.9,
156 ),
157 (
158 &[
159 "how",
160 "what",
161 "where",
162 "explain",
163 "understand",
164 "show me",
165 "describe",
166 "why does",
167 ],
168 TaskType::Explore,
169 0.85,
170 ),
171 (
172 &[
173 "test",
174 "spec",
175 "coverage",
176 "assert",
177 "unit test",
178 "integration test",
179 "mock",
180 ],
181 TaskType::Test,
182 0.9,
183 ),
184 (
185 &[
186 "debug",
187 "trace",
188 "inspect",
189 "log",
190 "breakpoint",
191 "step through",
192 "stack trace",
193 ],
194 TaskType::Debug,
195 0.9,
196 ),
197 (
198 &[
199 "config",
200 "setup",
201 "install",
202 "env",
203 "configure",
204 "settings",
205 "dotenv",
206 ],
207 TaskType::Config,
208 0.85,
209 ),
210 (
211 &[
212 "deploy", "release", "publish", "ship", "ci/cd", "pipeline", "docker",
213 ],
214 TaskType::Deploy,
215 0.85,
216 ),
217 (
218 &[
219 "review",
220 "check",
221 "audit",
222 "look at",
223 "evaluate",
224 "assess",
225 "pr review",
226 ],
227 TaskType::Review,
228 0.8,
229 ),
230];
231
232pub fn classify(query: &str) -> TaskClassification {
233 let q = query.to_lowercase();
234 let words: Vec<&str> = q.split_whitespace().collect();
235
236 let mut best_type = TaskType::Explore;
237 let mut best_score = 0.0_f64;
238
239 for &(phrases, task_type, base_confidence) in PHRASE_RULES {
240 let mut match_count = 0usize;
241 for phrase in phrases {
242 if phrase.contains(' ') {
243 if q.contains(phrase) {
244 match_count += 2;
245 }
246 } else if words.contains(phrase) {
247 match_count += 1;
248 }
249 }
250 if match_count > 0 {
251 let score = base_confidence * (match_count as f64).min(2.0) / 2.0;
252 if score > best_score {
253 best_score = score;
254 best_type = task_type;
255 }
256 }
257 }
258
259 let targets = extract_targets(query);
260 let keywords = extract_keywords(&q);
261
262 if best_score < 0.1 {
263 best_type = TaskType::Explore;
264 best_score = 0.3;
265 }
266
267 TaskClassification {
268 task_type: best_type,
269 confidence: best_score,
270 targets,
271 keywords,
272 }
273}
274
275fn extract_targets(query: &str) -> Vec<String> {
276 let mut targets = Vec::new();
277
278 for word in query.split_whitespace() {
279 if word.contains('.') && !word.starts_with('.') {
280 let clean = word.trim_matches(|c: char| {
281 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
282 });
283 if looks_like_path(clean) {
284 targets.push(clean.to_string());
285 }
286 }
287 if word.contains('/') && !word.starts_with("//") && !word.starts_with("http") {
288 let clean = word.trim_matches(|c: char| {
289 !c.is_alphanumeric() && c != '.' && c != '/' && c != '_' && c != '-'
290 });
291 if clean.len() > 2 {
292 targets.push(clean.to_string());
293 }
294 }
295 }
296
297 for word in query.split_whitespace() {
298 let w = word.trim_matches(|c: char| !c.is_alphanumeric() && c != '_');
299 if w.contains('_') && w.len() > 3 && !targets.contains(&w.to_string()) {
300 targets.push(w.to_string());
301 }
302 if w.chars().any(|c| c.is_uppercase())
303 && w.len() > 2
304 && !is_stop_word(w)
305 && !targets.contains(&w.to_string())
306 {
307 targets.push(w.to_string());
308 }
309 }
310
311 targets.truncate(5);
312 targets
313}
314
315fn looks_like_path(s: &str) -> bool {
316 let exts = [
317 ".rs", ".ts", ".tsx", ".js", ".jsx", ".py", ".go", ".toml", ".yaml", ".yml", ".json", ".md",
318 ];
319 exts.iter().any(|ext| s.ends_with(ext)) || s.contains('/')
320}
321
322fn is_stop_word(w: &str) -> bool {
323 matches!(
324 w.to_lowercase().as_str(),
325 "the"
326 | "this"
327 | "that"
328 | "with"
329 | "from"
330 | "into"
331 | "have"
332 | "please"
333 | "could"
334 | "would"
335 | "should"
336 | "also"
337 | "just"
338 | "then"
339 | "when"
340 | "what"
341 | "where"
342 | "which"
343 | "there"
344 | "here"
345 | "these"
346 | "those"
347 | "does"
348 | "will"
349 | "shall"
350 | "can"
351 | "may"
352 | "must"
353 | "need"
354 | "want"
355 | "like"
356 | "make"
357 | "take"
358 )
359}
360
361fn extract_keywords(query: &str) -> Vec<String> {
362 query
363 .split_whitespace()
364 .filter(|w| w.len() > 3)
365 .filter(|w| !is_stop_word(w))
366 .map(|w| {
367 w.trim_matches(|c: char| !c.is_alphanumeric() && c != '_')
368 .to_lowercase()
369 })
370 .filter(|w| !w.is_empty())
371 .take(8)
372 .collect()
373}
374
375pub fn classify_complexity(
376 query: &str,
377 classification: &TaskClassification,
378) -> super::adaptive::TaskComplexity {
379 use super::adaptive::TaskComplexity;
380
381 let q = query.to_lowercase();
382 let word_count = q.split_whitespace().count();
383 let target_count = classification.targets.len();
384
385 let has_multi_file = target_count >= 3;
386 let has_cross_cutting = q.contains("all files")
387 || q.contains("across")
388 || q.contains("everywhere")
389 || q.contains("every")
390 || q.contains("migration")
391 || q.contains("architecture");
392
393 let is_simple = word_count < 8
394 && target_count <= 1
395 && matches!(
396 classification.task_type,
397 TaskType::Generate | TaskType::Config
398 );
399
400 if is_simple {
401 TaskComplexity::Mechanical
402 } else if has_multi_file || has_cross_cutting {
403 TaskComplexity::Architectural
404 } else {
405 TaskComplexity::Standard
406 }
407}
408
409pub fn detect_multi_intent(query: &str) -> Vec<TaskClassification> {
410 let delimiters = [" and then ", " then ", " also ", " + ", ". "];
411
412 let mut parts: Vec<&str> = vec![query];
413 for delim in &delimiters {
414 let mut new_parts = Vec::new();
415 for part in &parts {
416 for sub in part.split(delim) {
417 let trimmed = sub.trim();
418 if !trimmed.is_empty() {
419 new_parts.push(trimmed);
420 }
421 }
422 }
423 parts = new_parts;
424 }
425
426 if parts.len() <= 1 {
427 return vec![classify(query)];
428 }
429
430 parts.iter().map(|part| classify(part)).collect()
431}
432
433pub fn format_briefing_header(classification: &TaskClassification) -> String {
434 format!(
435 "[TASK:{} CONF:{:.0}% TARGETS:{} KW:{}]",
436 classification.task_type.as_str(),
437 classification.confidence * 100.0,
438 if classification.targets.is_empty() {
439 "-".to_string()
440 } else {
441 classification.targets.join(",")
442 },
443 classification.keywords.join(","),
444 )
445}
446
447#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
448#[serde(rename_all = "snake_case")]
449pub enum IntentScope {
450 SingleFile,
451 MultiFile,
452 CrossModule,
453 ProjectWide,
454}
455
456#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
457pub struct StructuredIntent {
458 pub task_type: TaskType,
459 pub confidence: f64,
460 pub targets: Vec<String>,
461 pub keywords: Vec<String>,
462 pub scope: IntentScope,
463 pub language_hint: Option<String>,
464 pub urgency: f64,
465 pub action_verb: Option<String>,
466}
467
468impl StructuredIntent {
469 pub fn from_query(query: &str) -> Self {
470 let classification = classify(query);
471 let complexity = classify_complexity(query, &classification);
472 let file_targets = classification
473 .targets
474 .iter()
475 .filter(|t| t.contains('.') || t.contains('/'))
476 .count();
477 let scope = match complexity {
478 super::adaptive::TaskComplexity::Mechanical => IntentScope::SingleFile,
479 super::adaptive::TaskComplexity::Standard => {
480 if file_targets > 1 {
481 IntentScope::MultiFile
482 } else {
483 IntentScope::SingleFile
484 }
485 }
486 super::adaptive::TaskComplexity::Architectural => {
487 let q = query.to_lowercase();
488 if q.contains("all files") || q.contains("everywhere") || q.contains("migration") {
489 IntentScope::ProjectWide
490 } else {
491 IntentScope::CrossModule
492 }
493 }
494 };
495
496 let language_hint = detect_language_hint(query, &classification.targets);
497 let urgency = detect_urgency(query);
498 let action_verb = extract_action_verb(query);
499
500 StructuredIntent {
501 task_type: classification.task_type,
502 confidence: classification.confidence,
503 targets: classification.targets,
504 keywords: classification.keywords,
505 scope,
506 language_hint,
507 urgency,
508 action_verb,
509 }
510 }
511
512 pub fn from_query_with_session(query: &str, touched_files: &[String]) -> Self {
513 let mut intent = Self::from_query(query);
514
515 if intent.language_hint.is_none() && !touched_files.is_empty() {
516 intent.language_hint = detect_language_from_files(touched_files);
517 }
518
519 if intent.scope == IntentScope::SingleFile && touched_files.len() > 3 {
520 let dirs: std::collections::HashSet<&str> = touched_files
521 .iter()
522 .filter_map(|f| std::path::Path::new(f).parent()?.to_str())
523 .collect();
524 if dirs.len() > 2 {
525 intent.scope = IntentScope::MultiFile;
526 }
527 }
528
529 intent
530 }
531
532 pub fn format_header(&self) -> String {
533 format!(
534 "[TASK:{} SCOPE:{} CONF:{:.0}%{}{}]",
535 self.task_type.as_str(),
536 match self.scope {
537 IntentScope::SingleFile => "single",
538 IntentScope::MultiFile => "multi",
539 IntentScope::CrossModule => "cross",
540 IntentScope::ProjectWide => "project",
541 },
542 self.confidence * 100.0,
543 self.language_hint
544 .as_ref()
545 .map(|l| format!(" LANG:{l}"))
546 .unwrap_or_default(),
547 if self.urgency > 0.5 { " URGENT" } else { "" },
548 )
549 }
550}
551
552fn detect_language_hint(query: &str, targets: &[String]) -> Option<String> {
553 for t in targets {
554 let ext = std::path::Path::new(t).extension().and_then(|e| e.to_str());
555 match ext {
556 Some("rs") => return Some("rust".into()),
557 Some("ts" | "tsx") => return Some("typescript".into()),
558 Some("js" | "jsx") => return Some("javascript".into()),
559 Some("py") => return Some("python".into()),
560 Some("go") => return Some("go".into()),
561 Some("rb") => return Some("ruby".into()),
562 Some("java") => return Some("java".into()),
563 Some("swift") => return Some("swift".into()),
564 Some("zig") => return Some("zig".into()),
565 _ => {}
566 }
567 }
568
569 let q = query.to_lowercase();
570 let lang_keywords: &[(&str, &str)] = &[
571 ("rust", "rust"),
572 ("python", "python"),
573 ("typescript", "typescript"),
574 ("javascript", "javascript"),
575 ("golang", "go"),
576 (" go ", "go"),
577 ("ruby", "ruby"),
578 ("java ", "java"),
579 ("swift", "swift"),
580 ];
581 for &(kw, lang) in lang_keywords {
582 if q.contains(kw) {
583 return Some(lang.into());
584 }
585 }
586
587 None
588}
589
590fn detect_language_from_files(files: &[String]) -> Option<String> {
591 let mut counts: std::collections::HashMap<&str, usize> = std::collections::HashMap::new();
592 for f in files {
593 let ext = std::path::Path::new(f)
594 .extension()
595 .and_then(|e| e.to_str())
596 .unwrap_or("");
597 let lang = match ext {
598 "rs" => "rust",
599 "ts" | "tsx" => "typescript",
600 "js" | "jsx" => "javascript",
601 "py" => "python",
602 "go" => "go",
603 "rb" => "ruby",
604 "java" => "java",
605 _ => continue,
606 };
607 *counts.entry(lang).or_insert(0) += 1;
608 }
609 counts
610 .into_iter()
611 .max_by_key(|(_, c)| *c)
612 .map(|(l, _)| l.to_string())
613}
614
615fn detect_urgency(query: &str) -> f64 {
616 let q = query.to_lowercase();
617 let urgent_words = [
618 "urgent",
619 "asap",
620 "immediately",
621 "critical",
622 "hotfix",
623 "emergency",
624 "blocker",
625 "breaking",
626 ];
627 let hits = urgent_words.iter().filter(|w| q.contains(*w)).count();
628 (hits as f64 * 0.4).min(1.0)
629}
630
631fn extract_action_verb(query: &str) -> Option<String> {
632 let verbs = [
633 "fix",
634 "add",
635 "create",
636 "implement",
637 "refactor",
638 "debug",
639 "test",
640 "write",
641 "update",
642 "remove",
643 "delete",
644 "rename",
645 "move",
646 "extract",
647 "split",
648 "merge",
649 "deploy",
650 "review",
651 "check",
652 "build",
653 "generate",
654 "optimize",
655 "clean",
656 ];
657 let q = query.to_lowercase();
658 let words: Vec<&str> = q.split_whitespace().collect();
659 for v in &verbs {
660 if words.first() == Some(v) || words.get(1) == Some(v) {
661 return Some(v.to_string());
662 }
663 }
664 for v in &verbs {
665 if words.contains(v) {
666 return Some(v.to_string());
667 }
668 }
669 None
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn classify_fix_bug() {
678 let r = classify("fix the bug in entropy.rs where token_entropy returns NaN");
679 assert_eq!(r.task_type, TaskType::FixBug);
680 assert!(r.confidence > 0.5);
681 assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
682 }
683
684 #[test]
685 fn classify_generate() {
686 let r = classify("add a new function normalized_token_entropy to entropy.rs");
687 assert_eq!(r.task_type, TaskType::Generate);
688 assert!(r.confidence > 0.5);
689 }
690
691 #[test]
692 fn classify_refactor() {
693 let r = classify("refactor the compression pipeline to split into smaller modules");
694 assert_eq!(r.task_type, TaskType::Refactor);
695 }
696
697 #[test]
698 fn classify_explore() {
699 let r = classify("how does the session cache work?");
700 assert_eq!(r.task_type, TaskType::Explore);
701 }
702
703 #[test]
704 fn classify_debug() {
705 let r = classify("debug why the compression ratio drops for large files");
706 assert_eq!(r.task_type, TaskType::Debug);
707 }
708
709 #[test]
710 fn classify_test() {
711 let r = classify("write unit tests for the token_optimizer module");
712 assert_eq!(r.task_type, TaskType::Test);
713 }
714
715 #[test]
716 fn targets_extract_paths() {
717 let r = classify("fix entropy.rs and update core/mod.rs");
718 assert!(r.targets.iter().any(|t| t.contains("entropy.rs")));
719 assert!(r.targets.iter().any(|t| t.contains("core/mod.rs")));
720 }
721
722 #[test]
723 fn targets_extract_identifiers() {
724 let r = classify("refactor SessionCache to use LRU eviction");
725 assert!(r.targets.iter().any(|t| t == "SessionCache"));
726 }
727
728 #[test]
729 fn fallback_to_explore() {
730 let r = classify("xyz qqq bbb");
731 assert_eq!(r.task_type, TaskType::Explore);
732 assert!(r.confidence < 0.5);
733 }
734
735 #[test]
736 fn multi_intent_detection() {
737 let results = detect_multi_intent("fix the bug in auth.rs and then write unit tests");
738 assert!(results.len() >= 2);
739 assert_eq!(results[0].task_type, TaskType::FixBug);
740 assert_eq!(results[1].task_type, TaskType::Test);
741 }
742
743 #[test]
744 fn single_intent_no_split() {
745 let results = detect_multi_intent("fix the bug in auth.rs");
746 assert_eq!(results.len(), 1);
747 assert_eq!(results[0].task_type, TaskType::FixBug);
748 }
749
750 #[test]
751 fn complexity_mechanical() {
752 let r = classify("add a comment");
753 let c = classify_complexity("add a comment", &r);
754 assert_eq!(c, super::super::adaptive::TaskComplexity::Mechanical);
755 }
756
757 #[test]
758 fn complexity_architectural() {
759 let r = classify("refactor auth across all files and update the migration");
760 let c = classify_complexity(
761 "refactor auth across all files and update the migration",
762 &r,
763 );
764 assert_eq!(c, super::super::adaptive::TaskComplexity::Architectural);
765 }
766
767 #[test]
768 fn structured_intent_from_fixbug_query() {
769 let intent = StructuredIntent::from_query("fix the NaN bug in entropy.rs");
770 assert_eq!(intent.task_type, TaskType::FixBug);
771 assert!(intent.targets.iter().any(|t| t.contains("entropy.rs")));
772 assert_eq!(intent.language_hint.as_deref(), Some("rust"));
773 assert_eq!(intent.action_verb.as_deref(), Some("fix"));
774 assert_eq!(intent.scope, IntentScope::SingleFile);
775 }
776
777 #[test]
778 fn structured_intent_project_wide() {
779 let intent =
780 StructuredIntent::from_query("refactor auth across all files and update migration");
781 assert_eq!(intent.task_type, TaskType::Refactor);
782 assert_eq!(intent.scope, IntentScope::ProjectWide);
783 }
784
785 #[test]
786 fn structured_intent_urgency() {
787 let normal = StructuredIntent::from_query("add a new function");
788 let urgent = StructuredIntent::from_query("urgent hotfix for critical auth bug");
789 assert!(urgent.urgency > normal.urgency);
790 assert!(urgent.urgency >= 0.8);
791 }
792
793 #[test]
794 fn structured_intent_language_from_targets() {
795 let intent = StructuredIntent::from_query("fix main.py auth handler");
796 assert_eq!(intent.language_hint.as_deref(), Some("python"));
797 }
798
799 #[test]
800 fn structured_intent_with_session() {
801 let files = vec![
802 "src/core/session.rs".to_string(),
803 "src/core/litm.rs".to_string(),
804 "src/tools/ctx_read.rs".to_string(),
805 ];
806 let intent = StructuredIntent::from_query_with_session("how does this work?", &files);
807 assert_eq!(intent.language_hint.as_deref(), Some("rust"));
808 }
809
810 #[test]
811 fn structured_intent_header_format() {
812 let intent = StructuredIntent::from_query("fix bug in entropy.rs");
813 let header = intent.format_header();
814 assert!(header.contains("TASK:fix_bug"));
815 assert!(header.contains("SCOPE:"));
816 assert!(header.contains("LANG:rust"));
817 }
818
819 #[test]
820 fn detect_urgency_none() {
821 assert_eq!(detect_urgency("add a function"), 0.0);
822 }
823
824 #[test]
825 fn detect_urgency_multiple() {
826 let u = detect_urgency("urgent critical blocker");
827 assert!(u > 0.9);
828 }
829
830 #[test]
831 fn action_verb_extraction() {
832 assert_eq!(extract_action_verb("fix the bug"), Some("fix".to_string()));
833 assert_eq!(
834 extract_action_verb("please add logging"),
835 Some("add".to_string())
836 );
837 assert_eq!(extract_action_verb("xyz qqq"), None);
838 }
839}