use std::collections::HashSet;
use super::QueryComplexity;
pub struct ComplexityDetector {
llm_client: Option<crate::llm::LlmClient>,
}
impl ComplexityDetector {
pub fn new() -> Self {
Self { llm_client: None }
}
pub fn with_llm_client(client: crate::llm::LlmClient) -> Self {
Self {
llm_client: Some(client),
}
}
pub async fn detect(&self, query: &str) -> QueryComplexity {
if let Some(ref client) = self.llm_client {
if let Some(complexity) = crate::retrieval::pilot::detect_with_llm(client, query).await
{
return complexity;
}
tracing::warn!("LLM complexity detection failed, falling back to heuristic");
}
self.detect_heuristic(query)
}
fn detect_heuristic(&self, query: &str) -> QueryComplexity {
let query_lower = query.to_lowercase();
let word_count = estimate_word_count(query);
let complex_indicators = [
"compare",
"contrast",
"analyze",
"evaluate",
"synthesize",
"explain why",
"how does",
"relationship between",
"cause and effect",
"对比",
"分析",
"评估",
"综合",
"为什么",
"原因",
"关系",
"影响",
"区别",
"异同",
];
for indicator in &complex_indicators {
if query_lower.contains(indicator) {
return QueryComplexity::Complex;
}
}
let simple_indicators = [
"what is",
"define",
"list",
"who",
"when",
"where",
"什么是",
"定义",
"列表",
"谁",
"何时",
"哪里",
"在哪",
];
for indicator in &simple_indicators {
if query_lower.contains(indicator) && word_count <= 15 {
return QueryComplexity::Simple;
}
}
let question_marks = query.matches('?').count() + query.matches('?').count();
if question_marks > 1 {
return QueryComplexity::Complex;
}
if word_count <= 5 {
QueryComplexity::Simple
} else if word_count <= 15 {
QueryComplexity::Medium
} else {
QueryComplexity::Complex
}
}
pub fn complexity_score(&self, complexity: QueryComplexity) -> f32 {
match complexity {
QueryComplexity::Simple => 0.2,
QueryComplexity::Medium => 0.5,
QueryComplexity::Complex => 0.8,
}
}
pub fn analyze(&self, query: &str) -> QueryAnalysis {
let words: Vec<&str> = query.split_whitespace().collect();
let unique_words: HashSet<&str> = words.iter().copied().collect();
QueryAnalysis {
word_count: words.len(),
unique_word_ratio: if words.is_empty() {
0.0
} else {
unique_words.len() as f32 / words.len() as f32
},
has_question_mark: query.contains('?') || query.contains('?'),
question_count: query.matches('?').count() + query.matches('?').count(),
complexity: self.detect_heuristic(query),
complexity_score: self.complexity_score(self.detect_heuristic(query)),
}
}
}
impl Default for ComplexityDetector {
fn default() -> Self {
Self::new()
}
}
fn estimate_word_count(text: &str) -> usize {
let mut count = 0usize;
let mut in_latin_word = false;
for ch in text.chars() {
if ch.is_whitespace() {
if in_latin_word {
count += 1;
in_latin_word = false;
}
} else if ch.is_ascii_alphanumeric() {
in_latin_word = true;
} else if is_cjk_char(ch) {
if in_latin_word {
count += 1;
in_latin_word = false;
}
count += 1;
} else {
if in_latin_word {
count += 1;
in_latin_word = false;
}
}
}
if in_latin_word {
count += 1;
}
count
}
fn is_cjk_char(ch: char) -> bool {
let cp = ch as u32;
(0x4E00..=0x9FFF).contains(&cp)
|| (0x3400..=0x4DBF).contains(&cp)
|| (0x20000..=0x2A6DF).contains(&cp)
|| (0x2A700..=0x2B73F).contains(&cp)
|| (0xF900..=0xFAFF).contains(&cp)
|| (0x2F800..=0x2FA1F).contains(&cp)
|| (0x3000..=0x303F).contains(&cp)
|| (0x3040..=0x309F).contains(&cp)
|| (0x30A0..=0x30FF).contains(&cp)
}
#[derive(Debug, Clone)]
pub struct QueryAnalysis {
pub word_count: usize,
pub unique_word_ratio: f32,
pub has_question_mark: bool,
pub question_count: usize,
pub complexity: QueryComplexity,
pub complexity_score: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_queries() {
let detector = ComplexityDetector::new();
assert_eq!(
detector.detect_heuristic("What is Rust?"),
QueryComplexity::Simple
);
assert_eq!(
detector.detect_heuristic("Define async"),
QueryComplexity::Simple
);
assert_eq!(
detector.detect_heuristic("什么是向量检索"),
QueryComplexity::Simple
);
}
#[test]
fn test_complex_queries() {
let detector = ComplexityDetector::new();
assert_eq!(
detector.detect_heuristic(
"Compare and contrast the different approaches to async programming"
),
QueryComplexity::Complex
);
assert_eq!(
detector.detect_heuristic("What is the relationship between ownership and borrowing?"),
QueryComplexity::Complex
);
assert_eq!(
detector.detect_heuristic("对比A和B的区别"),
QueryComplexity::Complex
);
assert_eq!(
detector.detect_heuristic("分析索引和检索的关系"),
QueryComplexity::Complex
);
}
}