use serde::{Deserialize, Serialize};
use tracing::{debug, info};
use crate::llm::{LlmClient, LlmExecutor};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubQuery {
pub text: String,
pub complexity: SubQueryComplexity,
pub priority: u8,
pub depends_on: Vec<usize>,
pub query_type: SubQueryType,
pub path_constraint: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SubQueryComplexity {
Simple,
Medium,
Complex,
}
impl Default for SubQueryComplexity {
fn default() -> Self {
Self::Simple
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SubQueryType {
Fact,
Explanation,
Comparison,
Synthesis,
Navigation,
}
impl Default for SubQueryType {
fn default() -> Self {
Self::Fact
}
}
#[derive(Debug, Clone)]
pub struct DecompositionResult {
pub original: String,
pub sub_queries: Vec<SubQuery>,
pub was_decomposed: bool,
pub reason: String,
pub total_complexity: f32,
}
impl DecompositionResult {
pub fn no_decomposition(query: &str, reason: &str) -> Self {
Self {
original: query.to_string(),
sub_queries: vec![SubQuery {
text: query.to_string(),
complexity: SubQueryComplexity::Simple,
priority: 0,
depends_on: vec![],
query_type: SubQueryType::Fact,
path_constraint: None,
}],
was_decomposed: false,
reason: reason.to_string(),
total_complexity: 0.5,
}
}
pub fn is_multi_turn(&self) -> bool {
self.sub_queries.len() > 1
}
pub fn execution_order(&self) -> Vec<usize> {
if self.sub_queries.len() <= 1 {
return vec![0];
}
let mut order: Vec<usize> = (0..self.sub_queries.len()).collect();
order.sort_by(|&a, &b| {
let a_deps = self.sub_queries[a].depends_on.len();
let b_deps = self.sub_queries[b].depends_on.len();
if a_deps != b_deps {
return a_deps.cmp(&b_deps);
}
self.sub_queries[a]
.priority
.cmp(&self.sub_queries[b].priority)
});
order
}
}
#[derive(Debug, Clone)]
pub struct DecompositionConfig {
pub max_sub_queries: usize,
pub min_query_length: usize,
pub use_llm: bool,
pub complexity_threshold: f32,
pub detect_dependencies: bool,
}
impl Default for DecompositionConfig {
fn default() -> Self {
Self {
max_sub_queries: 5,
min_query_length: 20,
use_llm: true,
complexity_threshold: 0.7,
detect_dependencies: true,
}
}
}
pub struct QueryDecomposer {
config: DecompositionConfig,
llm_client: Option<LlmClient>,
llm_executor: Option<LlmExecutor>,
}
impl Default for QueryDecomposer {
fn default() -> Self {
Self::new(DecompositionConfig::default())
}
}
impl QueryDecomposer {
pub fn new(config: DecompositionConfig) -> Self {
Self {
config,
llm_client: None,
llm_executor: None,
}
}
pub fn with_llm_client(mut self, client: LlmClient) -> Self {
self.llm_client = Some(client);
self
}
pub fn with_llm_executor(mut self, executor: LlmExecutor) -> Self {
self.llm_executor = Some(executor);
self
}
pub async fn decompose(&self, query: &str) -> crate::error::Result<DecompositionResult> {
if !self.should_decompose(query) {
return Ok(DecompositionResult::no_decomposition(
query,
"Query is simple enough, no decomposition needed",
));
}
info!("Decomposing complex query: '{}'", query);
if self.config.use_llm && (self.llm_client.is_some() || self.llm_executor.is_some()) {
match self.llm_decompose(query).await {
Ok(result) => return Ok(result),
Err(e) => {
debug!(
"LLM decomposition failed, falling back to rule-based: {}",
e
);
}
}
}
self.rule_based_decompose(query)
}
fn should_decompose(&self, query: &str) -> bool {
if query.len() < self.config.min_query_length {
return false;
}
let complexity = self.calculate_complexity(query);
complexity >= self.config.complexity_threshold
}
fn calculate_complexity(&self, query: &str) -> f32 {
let mut score = 0.0;
let query_lower = query.to_lowercase();
let question_count = query.matches('?').count();
score += (question_count as f32 * 0.3).min(1.0);
let conjunctions = [" and ", " or ", " but ", " also ", " plus "];
let conjunction_count = conjunctions
.iter()
.filter(|c| query_lower.contains(*c))
.count();
score += (conjunction_count as f32 * 0.2).min(0.6);
let complex_indicators = [
"compare",
"contrast",
"difference between",
"relationship between",
"how does",
"why does",
"explain how",
"analyze",
"evaluate",
"synthesize",
];
for indicator in &complex_indicators {
if query_lower.contains(indicator) {
score += 0.2;
}
}
let word_count = query.split_whitespace().count();
if word_count > 15 {
score += 0.1 * ((word_count - 15) as f32 / 10.0).min(1.0);
}
score.min(1.0)
}
fn rule_based_decompose(&self, query: &str) -> crate::error::Result<DecompositionResult> {
let mut sub_queries = Vec::new();
let query_lower = query.to_lowercase();
let patterns = [
(" and ", " and "),
("? ", "? "),
(" also ", " also "),
(" as well as ", " as well as "),
];
if query.contains('?') {
let parts: Vec<&str> = query.split('?').filter(|s| !s.trim().is_empty()).collect();
for (i, part) in parts.iter().enumerate() {
let text = format!("{}?", part.trim());
sub_queries.push(SubQuery {
text,
complexity: self.estimate_sub_query_complexity(part),
priority: i as u8,
depends_on: vec![],
query_type: self.detect_query_type(part),
path_constraint: None,
});
}
}
if sub_queries.is_empty() {
for (pattern, _) in &patterns {
if query_lower.contains(pattern) {
let parts: Vec<&str> = query
.split(pattern)
.filter(|s| !s.trim().is_empty())
.collect();
if parts.len() > 1 {
for (i, part) in parts.iter().enumerate() {
sub_queries.push(SubQuery {
text: part.trim().to_string(),
complexity: self.estimate_sub_query_complexity(part),
priority: i as u8,
depends_on: if i > 0 && self.config.detect_dependencies {
vec![i - 1]
} else {
vec![]
},
query_type: self.detect_query_type(part),
path_constraint: None,
});
}
break;
}
}
}
}
if sub_queries.is_empty() || sub_queries.len() > self.config.max_sub_queries {
return Ok(DecompositionResult::no_decomposition(
query,
"No clear decomposition patterns found",
));
}
Ok(DecompositionResult {
original: query.to_string(),
sub_queries,
was_decomposed: true,
reason: "Rule-based decomposition".to_string(),
total_complexity: self.calculate_complexity(query),
})
}
async fn llm_decompose(&self, query: &str) -> crate::error::Result<DecompositionResult> {
let system = r#"You are a query decomposition expert. Break down complex queries into simpler sub-queries.
Rules:
1. Each sub-query should be answerable independently when possible
2. Preserve the original intent
3. Maximum 5 sub-queries
4. Return JSON format: {"sub_queries": [{"text": "...", "complexity": "simple|medium|complex", "priority": 0-4, "depends_on": [], "query_type": "fact|explanation|comparison|synthesis|navigation"}], "reason": "..."}
If the query is simple enough, return just one sub-query."#;
let user = format!("Decompose this query: {}", query);
let response = if let Some(ref executor) = self.llm_executor {
executor
.complete(system, &user)
.await
.map_err(|e| crate::error::Error::Llm(format!("LLM executor error: {}", e)))?
} else if let Some(ref client) = self.llm_client {
client
.complete(system, &user)
.await
.map_err(|e| crate::error::Error::Llm(format!("LLM client error: {}", e)))?
} else {
return Err(crate::error::Error::Config(
"No LLM client or executor configured".to_string(),
));
};
#[derive(Deserialize)]
struct DecompositionResponse {
sub_queries: Vec<SubQuery>,
reason: String,
}
let parsed: DecompositionResponse = serde_json::from_str(&extract_json(&response))
.map_err(|e| {
crate::error::Error::Llm(format!("Failed to parse decomposition: {}", e))
})?;
if parsed.sub_queries.is_empty() {
return Ok(DecompositionResult::no_decomposition(
query,
"LLM returned empty decomposition",
));
}
let sub_queries: Vec<SubQuery> = parsed
.sub_queries
.into_iter()
.take(self.config.max_sub_queries)
.collect();
Ok(DecompositionResult {
original: query.to_string(),
sub_queries,
was_decomposed: true,
reason: parsed.reason,
total_complexity: self.calculate_complexity(query),
})
}
fn estimate_sub_query_complexity(&self, text: &str) -> SubQueryComplexity {
let text_lower = text.to_lowercase();
if text_lower.contains("compare")
|| text_lower.contains("contrast")
|| text_lower.contains("analyze")
|| text_lower.contains("evaluate")
|| text_lower.contains("synthesize")
{
return SubQueryComplexity::Complex;
}
if text_lower.contains("how")
|| text_lower.contains("why")
|| text_lower.contains("explain")
|| text_lower.contains("describe")
{
return SubQueryComplexity::Medium;
}
SubQueryComplexity::Simple
}
fn detect_query_type(&self, text: &str) -> SubQueryType {
let text_lower = text.to_lowercase();
if text_lower.contains("compare")
|| text_lower.contains("difference")
|| text_lower.contains("versus")
|| text_lower.contains(" vs ")
{
return SubQueryType::Comparison;
}
if text_lower.contains("why")
|| text_lower.contains("how")
|| text_lower.contains("explain")
{
return SubQueryType::Explanation;
}
if text_lower.contains("summarize")
|| text_lower.contains("combine")
|| text_lower.contains("synthesize")
|| text_lower.contains("overall")
{
return SubQueryType::Synthesis;
}
if text_lower.contains("where")
|| text_lower.contains("which section")
|| text_lower.contains("find")
{
return SubQueryType::Navigation;
}
SubQueryType::Fact
}
}
fn extract_json(text: &str) -> String {
if let Some(start) = text.find('{') {
if let Some(end) = text.rfind('}') {
if end > start {
return text[start..=end].to_string();
}
}
}
text.to_string()
}
#[derive(Debug, Clone)]
pub struct SubQueryResult {
pub query: SubQuery,
pub content: String,
pub score: f32,
pub source_nodes: Vec<String>,
}
pub struct ResultAggregator {
pub max_tokens: usize,
pub priority_weight: f32,
}
impl Default for ResultAggregator {
fn default() -> Self {
Self {
max_tokens: 4000,
priority_weight: 0.3,
}
}
}
impl ResultAggregator {
pub fn new() -> Self {
Self::default()
}
pub fn aggregate(
&self,
results: &[SubQueryResult],
decomposition: &DecompositionResult,
) -> String {
if results.is_empty() {
return String::new();
}
if results.len() == 1 {
return results[0].content.clone();
}
let order = decomposition.execution_order();
let sorted_results: Vec<_> = order
.iter()
.filter_map(|&i| {
results
.iter()
.find(|r| r.query.text == decomposition.sub_queries[i].text)
})
.collect();
let mut combined = String::new();
let mut total_tokens = 0;
for result in sorted_results {
let section = format!("\n### {}\n\n{}\n", result.query.text, result.content);
let section_tokens = section.len() / 4; if total_tokens + section_tokens > self.max_tokens {
let remaining = self.max_tokens - total_tokens;
if remaining > 100 {
let end_pos = (remaining * 4).min(result.content.len());
combined.push_str(&format!(
"\n### {}\n\n{}\n",
result.query.text,
&result.content[..end_pos]
));
}
break;
}
combined.push_str(§ion);
total_tokens += section_tokens;
}
combined
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_complexity_calculation() {
let decomposer = QueryDecomposer::default();
let simple = "What is the architecture?";
let simple_score = decomposer.calculate_complexity(simple);
assert!(simple_score < 0.5);
let complex = "What is the architecture and how does it compare to other systems?";
let complex_score = decomposer.calculate_complexity(complex);
assert!(complex_score > simple_score);
}
#[test]
fn test_rule_based_decomposition() {
let decomposer = QueryDecomposer::default();
let result = decomposer
.rule_based_decompose("What is the architecture? How does caching work?")
.unwrap();
assert!(result.was_decomposed);
assert_eq!(result.sub_queries.len(), 2);
}
#[test]
fn test_no_decomposition() {
let result = DecompositionResult::no_decomposition("What is this?", "Query is simple");
assert!(!result.was_decomposed);
assert!(!result.is_multi_turn());
}
#[test]
fn test_execution_order() {
let mut result = DecompositionResult::no_decomposition("test", "test");
result.sub_queries = vec![
SubQuery {
text: "First".to_string(),
priority: 2,
depends_on: vec![],
query_type: SubQueryType::Fact,
complexity: SubQueryComplexity::Simple,
path_constraint: None,
},
SubQuery {
text: "Second".to_string(),
priority: 1,
depends_on: vec![0],
query_type: SubQueryType::Fact,
complexity: SubQueryComplexity::Simple,
path_constraint: None,
},
];
result.was_decomposed = true;
let order = result.execution_order();
assert_eq!(order, vec![0, 1]); }
#[test]
fn test_query_type_detection() {
let decomposer = QueryDecomposer::default();
assert_eq!(
decomposer.detect_query_type("Compare A and B"),
SubQueryType::Comparison
);
assert_eq!(
decomposer.detect_query_type("Why does this happen?"),
SubQueryType::Explanation
);
assert_eq!(
decomposer.detect_query_type("Where is the config?"),
SubQueryType::Navigation
);
}
#[test]
fn test_result_aggregator() {
let aggregator = ResultAggregator::new();
let results = vec![
SubQueryResult {
query: SubQuery {
text: "First question?".to_string(),
priority: 0,
depends_on: vec![],
query_type: SubQueryType::Fact,
complexity: SubQueryComplexity::Simple,
path_constraint: None,
},
content: "Answer 1".to_string(),
score: 0.9,
source_nodes: vec![],
},
SubQueryResult {
query: SubQuery {
text: "Second question?".to_string(),
priority: 1,
depends_on: vec![0],
query_type: SubQueryType::Fact,
complexity: SubQueryComplexity::Simple,
path_constraint: None,
},
content: "Answer 2".to_string(),
score: 0.8,
source_nodes: vec![],
},
];
let mut decomposition = DecompositionResult::no_decomposition("test", "test");
decomposition.sub_queries = results.iter().map(|r| r.query.clone()).collect();
decomposition.was_decomposed = true;
let combined = aggregator.aggregate(&results, &decomposition);
assert!(combined.contains("First question"));
assert!(combined.contains("Answer 1"));
}
}