#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskCategory {
Research,
SimpleEdit,
ComplexImpl,
Frontend,
Debugging,
Testing,
Refactoring,
Documentation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Complexity {
Low,
Medium,
High,
}
pub struct TaskProfile {
pub category: TaskCategory,
pub complexity: Complexity,
}
impl TaskCategory {
pub fn label(&self) -> &'static str {
match self {
Self::Research => "research",
Self::SimpleEdit => "simple-edit",
Self::ComplexImpl => "complex-impl",
Self::Frontend => "frontend",
Self::Debugging => "debugging",
Self::Testing => "testing",
Self::Refactoring => "refactoring",
Self::Documentation => "documentation",
}
}
}
impl Complexity {
pub fn label(&self) -> &'static str {
match self {
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
}
}
}
const RESEARCH_PREFIXES: &[&str] = &[
"research:",
"what is",
"how does",
"explain",
"find",
"list",
];
const RESEARCH_TERMS: &[&str] = &["?", "documentation", "compare", "analyze"];
const SIMPLE_EDIT_TERMS: &[&str] = &[
"rename",
"change",
"update",
"fix typo",
"add type",
"annotation",
];
const FRONTEND_TERMS: &[&str] = &[
"ui",
"frontend",
"css",
"html",
"react",
"component",
"layout",
"design",
"responsive",
];
const COMPLEX_IMPL_TERMS: &[&str] = &["implement", "create", "build"];
const DEBUGGING_TERMS: &[&str] = &[
"debug",
"fix bug",
"investigate",
"error",
"crash",
"panic",
"trace",
"root cause",
];
const TESTING_TERMS: &[&str] = &[
"test",
"spec",
"coverage",
"assertion",
"mock",
"fixture",
"benchmark",
];
const REFACTORING_TERMS: &[&str] = &[
"refactor",
"restructure",
"extract",
"split",
"reorganize",
"decouple",
"modularize",
];
const DOCUMENTATION_TERMS: &[&str] = &[
"document",
"readme",
"changelog",
"comment",
"docstring",
"api doc",
"jsdoc",
];
pub(crate) const LOW_VALUE_TERMS: &[&str] = &[
"run test",
"cargo test",
"cargo fmt",
"cargo clippy",
"format code",
"lint",
"update docs",
"update readme",
"update changelog",
"add comment",
"add docstring",
"type annotation",
];
const FILE_SUFFIXES: &[&str] = &[
".rs", ".toml", ".md", ".json", ".yaml", ".yml", ".ts", ".tsx", ".js", ".jsx", ".css", ".html",
];
pub fn classify(prompt: &str, file_count: usize, prompt_len: usize) -> TaskProfile {
let norm = prompt.trim().to_lowercase();
let category = if contains_any_word(&norm, FRONTEND_TERMS) {
TaskCategory::Frontend
} else if RESEARCH_PREFIXES.iter().any(|p| norm.starts_with(p))
|| contains_any(&norm, RESEARCH_TERMS)
{
TaskCategory::Research
} else if contains_any(&norm, SIMPLE_EDIT_TERMS) {
TaskCategory::SimpleEdit
} else if contains_any(&norm, COMPLEX_IMPL_TERMS) {
TaskCategory::ComplexImpl
} else if contains_any(&norm, TESTING_TERMS) {
TaskCategory::Testing
} else if contains_any(&norm, DEBUGGING_TERMS) {
TaskCategory::Debugging
} else if contains_any(&norm, DOCUMENTATION_TERMS) {
TaskCategory::Documentation
} else if contains_any(&norm, REFACTORING_TERMS) {
TaskCategory::Refactoring
} else if file_count > 0 {
TaskCategory::ComplexImpl
} else {
TaskCategory::Research
};
let has_scope = contains_any(&norm, &["across", "all files", "entire"]);
let complexity = if prompt_len > 500 || file_count > 3 || has_scope {
Complexity::High
} else if prompt_len < 150 && file_count <= 1 {
Complexity::Low
} else {
Complexity::Medium
};
TaskProfile {
category,
complexity,
}
}
pub(crate) fn count_file_mentions(prompt: &str) -> usize {
prompt
.split_whitespace()
.map(trim_token)
.filter(|tok| tok.contains('/') || FILE_SUFFIXES.iter().any(|s| tok.ends_with(s)))
.count()
}
fn trim_token(token: &str) -> &str {
token.trim_matches(|ch: char| !ch.is_alphanumeric() && ch != '.' && ch != '_' && ch != '/')
}
pub(crate) fn contains_any(prompt: &str, terms: &[&str]) -> bool {
terms.iter().any(|term| prompt.contains(term))
}
fn contains_any_word(text: &str, terms: &[&str]) -> bool {
let bytes = text.as_bytes();
terms.iter().any(|term| {
text.match_indices(term).any(|(i, _)| {
let before = i == 0 || !bytes[i - 1].is_ascii_alphanumeric();
let end = i + term.len();
let after = end >= bytes.len() || !bytes[end].is_ascii_alphanumeric();
before && after
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn research_question() {
let p = classify("Explain the authentication flow?", 0, 35);
assert_eq!(p.category, TaskCategory::Research);
}
#[test]
fn simple_edit_rename() {
let p = classify("rename field in types.rs", 1, 24);
assert_eq!(p.category, TaskCategory::SimpleEdit);
}
#[test]
fn frontend_react() {
let p = classify("Create responsive React component", 0, 34);
assert_eq!(p.category, TaskCategory::Frontend);
}
#[test]
fn complex_impl_long() {
let prompt = "Implement a multi-file feature across many modules. ".repeat(12);
let p = classify(&prompt, 5, prompt.len());
assert_eq!(p.category, TaskCategory::ComplexImpl);
assert_eq!(p.complexity, Complexity::High);
}
#[test]
fn debugging_category() {
let p = classify("debug the panic in parser", 0, 25);
assert_eq!(p.category, TaskCategory::Debugging);
}
#[test]
fn testing_category() {
let p = classify("add unit tests for auth module", 0, 30);
assert_eq!(p.category, TaskCategory::Testing);
}
#[test]
fn refactoring_category() {
let p = classify("refactor the dispatch module", 0, 28);
assert_eq!(p.category, TaskCategory::Refactoring);
}
#[test]
fn low_complexity_short() {
let p = classify("fix typo in name", 0, 16);
assert_eq!(p.complexity, Complexity::Low);
}
#[test]
fn high_complexity_long() {
let prompt = "x".repeat(600);
let p = classify(&prompt, 5, 600);
assert_eq!(p.complexity, Complexity::High);
}
}