use super::ClaudeFlowTask;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TaskType {
Code,
Research,
Test,
Review,
Docs,
Debug,
Architecture,
Security,
Performance,
General,
}
impl From<ClaudeFlowTask> for TaskType {
fn from(task: ClaudeFlowTask) -> Self {
match task {
ClaudeFlowTask::CodeGeneration => TaskType::Code,
ClaudeFlowTask::CodeReview => TaskType::Review,
ClaudeFlowTask::Testing => TaskType::Test,
ClaudeFlowTask::Research => TaskType::Research,
ClaudeFlowTask::Documentation => TaskType::Docs,
ClaudeFlowTask::Debugging => TaskType::Debug,
ClaudeFlowTask::Refactoring => TaskType::Code,
ClaudeFlowTask::Security => TaskType::Security,
ClaudeFlowTask::Performance => TaskType::Performance,
ClaudeFlowTask::Architecture => TaskType::Architecture,
}
}
}
#[derive(Debug, Clone)]
pub struct ClassificationResult {
pub task_type: TaskType,
pub confidence: f32,
pub secondary: Vec<(TaskType, f32)>,
pub languages: Vec<String>,
pub frameworks: Vec<String>,
pub complexity: u8,
pub recommended_agents: u8,
}
pub struct TaskClassifier {
language_patterns: Vec<(String, Vec<&'static str>)>,
framework_patterns: Vec<(String, Vec<&'static str>)>,
}
impl TaskClassifier {
pub fn new() -> Self {
Self {
language_patterns: Self::build_language_patterns(),
framework_patterns: Self::build_framework_patterns(),
}
}
fn build_language_patterns() -> Vec<(String, Vec<&'static str>)> {
vec![
(
"rust".to_string(),
vec!["rust", "cargo", ".rs", "tokio", "async-std", "serde"],
),
(
"typescript".to_string(),
vec!["typescript", "ts", ".tsx", "deno", "bun"],
),
(
"javascript".to_string(),
vec!["javascript", "js", "node", "npm", "react", "vue"],
),
(
"python".to_string(),
vec!["python", "pip", ".py", "django", "flask", "pytorch"],
),
("go".to_string(), vec!["golang", "go ", ".go", "goroutine"]),
]
}
fn build_framework_patterns() -> Vec<(String, Vec<&'static str>)> {
vec![
(
"react".to_string(),
vec!["react", "jsx", "tsx", "next.js", "nextjs"],
),
(
"express".to_string(),
vec!["express", "middleware", "router"],
),
(
"tokio".to_string(),
vec!["tokio", "async", "await", "spawn"],
),
("actix".to_string(), vec!["actix", "actix-web"]),
(
"jest".to_string(),
vec!["jest", "describe", "it(", "expect("],
),
("pytest".to_string(), vec!["pytest", "test_", "fixture"]),
]
}
pub fn classify(&self, description: &str) -> ClassificationResult {
let lower = description.to_lowercase();
let (task_type, confidence, secondary) = self.detect_task_type(&lower);
let languages = self.detect_languages(&lower);
let frameworks = self.detect_frameworks(&lower);
let complexity = self.estimate_complexity(&lower, &languages);
let recommended_agents = self.recommend_agent_count(complexity, &secondary);
ClassificationResult {
task_type,
confidence,
secondary,
languages,
frameworks,
complexity,
recommended_agents,
}
}
fn detect_task_type(&self, lower: &str) -> (TaskType, f32, Vec<(TaskType, f32)>) {
let mut scores: Vec<(TaskType, f32)> = vec![
(TaskType::Code, self.score_code(lower)),
(TaskType::Research, self.score_research(lower)),
(TaskType::Test, self.score_test(lower)),
(TaskType::Review, self.score_review(lower)),
(TaskType::Docs, self.score_docs(lower)),
(TaskType::Debug, self.score_debug(lower)),
(TaskType::Architecture, self.score_architecture(lower)),
(TaskType::Security, self.score_security(lower)),
(TaskType::Performance, self.score_performance(lower)),
];
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let primary = scores[0];
let secondary: Vec<(TaskType, f32)> = scores[1..4]
.iter()
.filter(|(_, s)| *s > 0.1)
.cloned()
.collect();
(primary.0, primary.1, secondary)
}
fn score_code(&self, s: &str) -> f32 {
let keywords = [
"implement",
"create",
"build",
"code",
"write",
"function",
"class",
"module",
];
self.keyword_score(s, &keywords)
}
fn score_research(&self, s: &str) -> f32 {
let keywords = [
"research",
"analyze",
"investigate",
"explore",
"find",
"understand",
"learn",
];
self.keyword_score(s, &keywords)
}
fn score_test(&self, s: &str) -> f32 {
let keywords = [
"test",
"verify",
"validate",
"assert",
"coverage",
"unit",
"integration",
"e2e",
];
self.keyword_score(s, &keywords)
}
fn score_review(&self, s: &str) -> f32 {
let keywords = [
"review", "audit", "inspect", "check", "quality", "lint", "pr",
];
self.keyword_score(s, &keywords)
}
fn score_docs(&self, s: &str) -> f32 {
let keywords = [
"document", "readme", "api docs", "comment", "explain", "describe",
];
self.keyword_score(s, &keywords)
}
fn score_debug(&self, s: &str) -> f32 {
let keywords = [
"debug",
"fix",
"error",
"bug",
"issue",
"crash",
"exception",
"trace",
];
self.keyword_score(s, &keywords)
}
fn score_architecture(&self, s: &str) -> f32 {
let keywords = [
"architecture",
"design",
"structure",
"pattern",
"system",
"scalable",
"modular",
];
self.keyword_score(s, &keywords)
}
fn score_security(&self, s: &str) -> f32 {
let keywords = [
"security",
"vulnerability",
"cve",
"injection",
"auth",
"encrypt",
"xss",
"csrf",
];
self.keyword_score(s, &keywords)
}
fn score_performance(&self, s: &str) -> f32 {
let keywords = [
"performance",
"optimize",
"speed",
"memory",
"benchmark",
"profile",
"latency",
"throughput",
];
self.keyword_score(s, &keywords)
}
fn keyword_score(&self, text: &str, keywords: &[&str]) -> f32 {
let matches: f32 = keywords.iter().filter(|k| text.contains(*k)).count() as f32;
(matches / keywords.len() as f32).min(1.0)
}
fn detect_languages(&self, lower: &str) -> Vec<String> {
self.language_patterns
.iter()
.filter(|(_, patterns)| patterns.iter().any(|p| lower.contains(p)))
.map(|(lang, _)| lang.clone())
.collect()
}
fn detect_frameworks(&self, lower: &str) -> Vec<String> {
self.framework_patterns
.iter()
.filter(|(_, patterns)| patterns.iter().any(|p| lower.contains(p)))
.map(|(fw, _)| fw.clone())
.collect()
}
fn estimate_complexity(&self, lower: &str, languages: &[String]) -> u8 {
let mut complexity: u8 = 3;
complexity += (languages.len() as u8).saturating_sub(1);
if lower.contains("distributed") || lower.contains("concurrent") {
complexity += 2;
}
if lower.contains("migration") || lower.contains("refactor") {
complexity += 1;
}
if lower.contains("security") || lower.contains("authentication") {
complexity += 1;
}
complexity.min(10)
}
fn recommend_agent_count(&self, complexity: u8, secondary: &[(TaskType, f32)]) -> u8 {
let base = match complexity {
1..=3 => 1,
4..=6 => 2,
7..=8 => 3,
_ => 4,
};
let secondary_count = secondary.iter().filter(|(_, score)| *score > 0.3).count() as u8;
(base + secondary_count.min(2)).min(6)
}
}
impl Default for TaskClassifier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classification() {
let classifier = TaskClassifier::new();
let result = classifier.classify("implement a REST API endpoint in Rust using actix-web");
assert_eq!(result.task_type, TaskType::Code);
assert!(result.languages.contains(&"rust".to_string()));
assert!(result.frameworks.contains(&"actix".to_string()));
}
#[test]
fn test_complexity() {
let classifier = TaskClassifier::new();
let simple = classifier.classify("fix a typo");
let complex =
classifier.classify("implement distributed authentication with security audit");
assert!(complex.complexity > simple.complexity);
}
}