use super::config::{ClassificationPattern, TaskClassificationConfig};
use super::decision::RoutingContext;
use super::error::{RoutingError, TaskType};
use regex::Regex;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TaskClassifier {
patterns: HashMap<TaskType, ClassificationPattern>,
compiled_patterns: HashMap<TaskType, Vec<Regex>>,
config: TaskClassificationConfig,
}
#[derive(Debug, Clone)]
pub struct ClassificationResult {
pub task_type: TaskType,
pub confidence: f64,
pub matched_patterns: Vec<String>,
pub keyword_matches: Vec<String>,
}
impl TaskClassifier {
pub fn new(config: TaskClassificationConfig) -> Result<Self, RoutingError> {
let mut compiled_patterns = HashMap::new();
for (task_type, pattern) in &config.patterns {
let mut regexes = Vec::new();
for pattern_str in &pattern.patterns {
let regex =
Regex::new(pattern_str).map_err(|e| RoutingError::ConfigurationError {
key: format!("classification.patterns.{}.patterns", task_type),
reason: format!("Invalid regex pattern '{}': {}", pattern_str, e),
})?;
regexes.push(regex);
}
compiled_patterns.insert(task_type.clone(), regexes);
}
Ok(Self {
patterns: config.patterns.clone(),
compiled_patterns,
config,
})
}
pub fn classify_task(
&self,
prompt: &str,
context: &RoutingContext,
) -> Result<ClassificationResult, RoutingError> {
if !self.config.enabled {
return Ok(ClassificationResult {
task_type: self.config.default_task_type.clone(),
confidence: 1.0,
matched_patterns: vec!["classification_disabled".to_string()],
keyword_matches: Vec::new(),
});
}
let prompt_lower = prompt.to_lowercase();
let mut scores = HashMap::new();
let mut all_matches = HashMap::new();
for (task_type, pattern) in &self.patterns {
let mut score = 0.0;
let mut matches = Vec::new();
let mut keyword_matches = Vec::new();
for keyword in &pattern.keywords {
if prompt_lower.contains(&keyword.to_lowercase()) {
score += pattern.weight * 0.5; keyword_matches.push(keyword.clone());
}
}
if let Some(regexes) = self.compiled_patterns.get(task_type) {
for (i, regex) in regexes.iter().enumerate() {
if regex.is_match(&prompt_lower) {
score += pattern.weight; matches.push(pattern.patterns[i].clone());
}
}
}
if score > 0.0 {
scores.insert(task_type.clone(), score);
all_matches.insert(task_type.clone(), (matches, keyword_matches));
}
}
self.apply_context_adjustments(&mut scores, context);
let best_match = scores
.iter()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal));
let (task_type, raw_score) = match best_match {
Some((task_type, score)) => (task_type.clone(), *score),
None => {
return Ok(ClassificationResult {
task_type: self.config.default_task_type.clone(),
confidence: 0.0,
matched_patterns: vec!["no_patterns_matched".to_string()],
keyword_matches: Vec::new(),
});
}
};
let max_possible_score = self.calculate_max_possible_score(&task_type);
let confidence = if max_possible_score > 0.0 {
(raw_score / max_possible_score).min(1.0)
} else {
0.0
};
let (matched_patterns, keyword_matches) = all_matches
.get(&task_type)
.cloned()
.unwrap_or((Vec::new(), Vec::new()));
if confidence < self.config.confidence_threshold {
return Ok(ClassificationResult {
task_type: self.config.default_task_type.clone(),
confidence,
matched_patterns: vec!["confidence_below_threshold".to_string()],
keyword_matches: Vec::new(),
});
}
Ok(ClassificationResult {
task_type,
confidence,
matched_patterns,
keyword_matches,
})
}
fn apply_context_adjustments(
&self,
scores: &mut HashMap<TaskType, f64>,
context: &RoutingContext,
) {
match context.expected_output_type {
super::decision::OutputType::Code => {
if let Some(score) = scores.get_mut(&TaskType::CodeGeneration) {
*score *= 1.5;
}
if let Some(score) = scores.get_mut(&TaskType::BoilerplateCode) {
*score *= 1.3;
}
}
super::decision::OutputType::Json | super::decision::OutputType::Structured => {
if let Some(score) = scores.get_mut(&TaskType::Extract) {
*score *= 1.4;
}
if let Some(score) = scores.get_mut(&TaskType::Analysis) {
*score *= 1.2;
}
}
_ => {}
}
for capability in &context.agent_capabilities {
match capability.as_str() {
"code_generation" => {
if let Some(score) = scores.get_mut(&TaskType::CodeGeneration) {
*score *= 1.2;
}
}
"analysis" => {
if let Some(score) = scores.get_mut(&TaskType::Analysis) {
*score *= 1.2;
}
if let Some(score) = scores.get_mut(&TaskType::Reasoning) {
*score *= 1.1;
}
}
"translation" => {
if let Some(score) = scores.get_mut(&TaskType::Translation) {
*score *= 1.3;
}
}
_ => {}
}
}
match context.agent_security_level {
super::decision::SecurityLevel::Critical | super::decision::SecurityLevel::High => {
if let Some(score) = scores.get_mut(&TaskType::Intent) {
*score *= 1.1;
}
if let Some(score) = scores.get_mut(&TaskType::Extract) {
*score *= 1.1;
}
if let Some(score) = scores.get_mut(&TaskType::Reasoning) {
*score *= 0.9;
}
}
_ => {}
}
}
fn calculate_max_possible_score(&self, task_type: &TaskType) -> f64 {
if let Some(pattern) = self.patterns.get(task_type) {
let keyword_score = pattern.keywords.len() as f64 * pattern.weight * 0.5;
let pattern_score = pattern.patterns.len() as f64 * pattern.weight;
keyword_score + pattern_score
} else {
1.0
}
}
pub fn add_pattern(
&mut self,
task_type: TaskType,
pattern: ClassificationPattern,
) -> Result<(), RoutingError> {
let mut regexes = Vec::new();
for pattern_str in &pattern.patterns {
let regex = Regex::new(pattern_str).map_err(|e| RoutingError::ConfigurationError {
key: format!("pattern.{}", task_type),
reason: format!("Invalid regex pattern '{}': {}", pattern_str, e),
})?;
regexes.push(regex);
}
self.compiled_patterns.insert(task_type.clone(), regexes);
self.patterns.insert(task_type, pattern);
Ok(())
}
pub fn remove_pattern(&mut self, task_type: &TaskType) {
self.patterns.remove(task_type);
self.compiled_patterns.remove(task_type);
}
pub fn get_statistics(&self) -> ClassificationStatistics {
ClassificationStatistics {
total_patterns: self.patterns.len(),
task_type_coverage: self.patterns.keys().cloned().collect(),
total_keywords: self.patterns.values().map(|p| p.keywords.len()).sum(),
total_regex_patterns: self.patterns.values().map(|p| p.patterns.len()).sum(),
confidence_threshold: self.config.confidence_threshold,
default_task_type: self.config.default_task_type.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct ClassificationStatistics {
pub total_patterns: usize,
pub task_type_coverage: Vec<TaskType>,
pub total_keywords: usize,
pub total_regex_patterns: usize,
pub confidence_threshold: f64,
pub default_task_type: TaskType,
}
#[cfg(test)]
mod tests {
use super::super::decision::{OutputType, RoutingContext};
use super::*;
use crate::types::AgentId;
fn create_test_config() -> TaskClassificationConfig {
let mut patterns = HashMap::new();
patterns.insert(
TaskType::CodeGeneration,
ClassificationPattern {
keywords: vec![
"code".to_string(),
"function".to_string(),
"implement".to_string(),
],
patterns: vec![
r"write.*code".to_string(),
r"implement.*function".to_string(),
],
weight: 1.0,
},
);
patterns.insert(
TaskType::Analysis,
ClassificationPattern {
keywords: vec![
"analyze".to_string(),
"analysis".to_string(),
"examine".to_string(),
],
patterns: vec![
r"analyze.*data".to_string(),
r"perform.*analysis".to_string(),
],
weight: 1.0,
},
);
TaskClassificationConfig {
enabled: true,
patterns,
confidence_threshold: 0.3,
default_task_type: TaskType::Custom("unknown".to_string()),
}
}
fn create_test_context() -> RoutingContext {
RoutingContext::new(
AgentId::new(),
TaskType::Custom("unknown".to_string()),
"test prompt".to_string(),
)
}
#[test]
fn test_code_generation_classification() {
let config = create_test_config();
let classifier = TaskClassifier::new(config).unwrap();
let context = create_test_context();
let result = classifier
.classify_task(
"Please write code to implement a sorting function",
&context,
)
.unwrap();
assert_eq!(result.task_type, TaskType::CodeGeneration);
assert!(result.confidence > 0.5);
assert!(!result.keyword_matches.is_empty());
}
#[test]
fn test_analysis_classification() {
let config = create_test_config();
let classifier = TaskClassifier::new(config).unwrap();
let context = create_test_context();
let result = classifier
.classify_task("Please analyze the data trends", &context)
.unwrap();
assert_eq!(result.task_type, TaskType::Analysis);
assert!(result.confidence > 0.3);
}
#[test]
fn test_no_match_fallback() {
let config = create_test_config();
let classifier = TaskClassifier::new(config).unwrap();
let context = create_test_context();
let result = classifier.classify_task("Hello world", &context).unwrap();
assert_eq!(result.task_type, TaskType::Custom("unknown".to_string()));
assert_eq!(result.confidence, 0.0);
}
#[test]
fn test_context_adjustments() {
let config = create_test_config();
let classifier = TaskClassifier::new(config).unwrap();
let mut context = create_test_context();
context.expected_output_type = OutputType::Code;
context.agent_capabilities = vec!["code_generation".to_string()];
let result = classifier
.classify_task("Please write some code", &context)
.unwrap();
assert_eq!(result.task_type, TaskType::CodeGeneration);
assert!(result.confidence > 0.5);
}
}