use std::collections::HashSet;
use super::QueryComplexity;
#[derive(Debug, Clone)]
pub struct ComplexityConfig {
pub simple_max_words: usize,
pub medium_max_words: usize,
pub complex_indicators: Vec<String>,
pub simple_indicators: Vec<String>,
}
impl Default for ComplexityConfig {
fn default() -> Self {
Self {
simple_max_words: 5,
medium_max_words: 15,
complex_indicators: vec![
"compare".to_string(),
"contrast".to_string(),
"analyze".to_string(),
"evaluate".to_string(),
"synthesize".to_string(),
"explain why".to_string(),
"how does".to_string(),
"what are the implications".to_string(),
"relationship between".to_string(),
"cause and effect".to_string(),
],
simple_indicators: vec![
"what is".to_string(),
"define".to_string(),
"list".to_string(),
"who".to_string(),
"when".to_string(),
"where".to_string(),
],
}
}
}
pub struct ComplexityDetector {
config: ComplexityConfig,
}
impl ComplexityDetector {
pub fn new() -> Self {
Self {
config: ComplexityConfig::default(),
}
}
pub fn with_config(config: ComplexityConfig) -> Self {
Self { config }
}
pub fn detect(&self, query: &str) -> QueryComplexity {
let query_lower = query.to_lowercase();
let word_count = query.split_whitespace().count();
for indicator in &self.config.complex_indicators {
if query_lower.contains(indicator) {
return QueryComplexity::Complex;
}
}
for indicator in &self.config.simple_indicators {
if query_lower.contains(indicator) {
if word_count <= self.config.medium_max_words {
return QueryComplexity::Simple;
}
}
}
let question_marks = query.matches('?').count();
if question_marks > 1 {
return QueryComplexity::Complex;
}
let conjunctions = ["and", "or", "but", "however", "although"];
let conjunction_count = conjunctions
.iter()
.filter(|c| query_lower.split_whitespace().any(|w| w == **c))
.count();
if conjunction_count >= 2 {
return QueryComplexity::Complex;
}
let depth_indicators = ["in the context of", "with respect to", "regarding", "about"];
for indicator in depth_indicators {
if query_lower.contains(indicator) {
return QueryComplexity::Medium;
}
}
if word_count <= self.config.simple_max_words {
QueryComplexity::Simple
} else if word_count <= self.config.medium_max_words {
QueryComplexity::Medium
} else {
QueryComplexity::Complex
}
}
pub fn complexity_score(&self, query: &str) -> f32 {
match self.detect(query) {
QueryComplexity::Simple => 0.2,
QueryComplexity::Medium => 0.5,
QueryComplexity::Complex => 0.8,
}
}
pub fn analyze(&self, query: &str) -> QueryAnalysis {
let query_lower = query.to_lowercase();
let words: Vec<&str> = query.split_whitespace().collect();
let unique_words: HashSet<&str> = words.iter().copied().collect();
QueryAnalysis {
word_count: words.len(),
unique_word_ratio: if words.is_empty() {
0.0
} else {
unique_words.len() as f32 / words.len() as f32
},
has_question_mark: query.contains('?'),
question_count: query.matches('?').count(),
complexity: self.detect(query),
complexity_score: self.complexity_score(query),
}
}
}
impl Default for ComplexityDetector {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct QueryAnalysis {
pub word_count: usize,
pub unique_word_ratio: f32,
pub has_question_mark: bool,
pub question_count: usize,
pub complexity: QueryComplexity,
pub complexity_score: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_queries() {
let detector = ComplexityDetector::new();
assert_eq!(detector.detect("What is Rust?"), QueryComplexity::Simple);
assert_eq!(detector.detect("Define async"), QueryComplexity::Simple);
assert_eq!(detector.detect("List features"), QueryComplexity::Simple);
}
#[test]
fn test_complex_queries() {
let detector = ComplexityDetector::new();
assert_eq!(
detector.detect("Compare and contrast the different approaches to async programming"),
QueryComplexity::Complex
);
assert_eq!(
detector.detect("What is the relationship between ownership and borrowing?"),
QueryComplexity::Complex
);
}
#[test]
fn test_medium_queries() {
let detector = ComplexityDetector::new();
let medium_query = "How do I implement a simple web server with error handling?";
assert_eq!(detector.detect(medium_query), QueryComplexity::Medium);
}
}