1use crate::RragResult;
7use serde::{Deserialize, Serialize};
8
9pub struct QueryDecomposer {
11 config: DecompositionConfig,
13
14 patterns: Vec<DecompositionPattern>,
16
17 complexity_indicators: Vec<String>,
19}
20
21#[derive(Debug, Clone)]
23pub struct DecompositionConfig {
24 pub max_sub_queries: usize,
26
27 pub min_sub_query_length: usize,
29
30 pub enable_temporal_decomposition: bool,
32
33 pub enable_logical_decomposition: bool,
35
36 pub enable_topical_decomposition: bool,
38
39 pub enable_comparative_decomposition: bool,
41
42 pub confidence_threshold: f32,
44}
45
46impl Default for DecompositionConfig {
47 fn default() -> Self {
48 Self {
49 max_sub_queries: 5,
50 min_sub_query_length: 5,
51 enable_temporal_decomposition: true,
52 enable_logical_decomposition: true,
53 enable_topical_decomposition: true,
54 enable_comparative_decomposition: true,
55 confidence_threshold: 0.6,
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub enum DecompositionStrategy {
63 Logical,
65 Temporal,
67 Topical,
69 Comparative,
71 Sequential,
73 Causal,
75}
76
77struct DecompositionPattern {
79 name: String,
81 triggers: Vec<String>,
83 strategy: DecompositionStrategy,
85 extractor: fn(&str) -> Vec<String>,
87 confidence: f32,
89}
90
91#[derive(Debug, Clone)]
93pub struct SubQuery {
94 pub query: String,
96
97 pub strategy: DecompositionStrategy,
99
100 pub confidence: f32,
102
103 pub priority: f32,
105
106 pub metadata: SubQueryMetadata,
108}
109
110#[derive(Debug, Clone)]
112pub struct SubQueryMetadata {
113 pub position: usize,
115
116 pub relationships: Vec<String>,
118
119 pub expected_answer_type: String,
121
122 pub dependencies: Vec<usize>,
124}
125
126impl QueryDecomposer {
127 pub fn new() -> Self {
129 Self::with_config(DecompositionConfig::default())
130 }
131
132 pub fn with_config(config: DecompositionConfig) -> Self {
134 let patterns = Self::init_patterns();
135 let complexity_indicators = Self::init_complexity_indicators();
136
137 Self {
138 config,
139 patterns,
140 complexity_indicators,
141 }
142 }
143
144 pub async fn decompose(&self, query: &str) -> RragResult<Vec<SubQuery>> {
146 let mut sub_queries = Vec::new();
147
148 if !self.should_decompose(query) {
150 return Ok(sub_queries);
151 }
152
153 if self.config.enable_logical_decomposition {
155 sub_queries.extend(self.logical_decomposition(query));
156 }
157
158 if self.config.enable_temporal_decomposition {
159 sub_queries.extend(self.temporal_decomposition(query));
160 }
161
162 if self.config.enable_topical_decomposition {
163 sub_queries.extend(self.topical_decomposition(query));
164 }
165
166 if self.config.enable_comparative_decomposition {
167 sub_queries.extend(self.comparative_decomposition(query));
168 }
169
170 sub_queries.retain(|sq| sq.confidence >= self.config.confidence_threshold);
172 sub_queries.sort_by(|a, b| {
173 b.priority
174 .partial_cmp(&a.priority)
175 .unwrap_or(std::cmp::Ordering::Equal)
176 });
177 sub_queries.truncate(self.config.max_sub_queries);
178
179 self.enrich_sub_queries(&mut sub_queries);
181
182 Ok(sub_queries)
183 }
184
185 fn should_decompose(&self, query: &str) -> bool {
187 let query_lower = query.to_lowercase();
188
189 let has_complexity_indicators = self
191 .complexity_indicators
192 .iter()
193 .any(|indicator| query_lower.contains(indicator));
194
195 let question_count = query.matches('?').count();
197
198 let word_count = query.split_whitespace().count();
200
201 has_complexity_indicators || question_count > 1 || word_count > 15
202 }
203
204 fn logical_decomposition(&self, query: &str) -> Vec<SubQuery> {
206 let mut sub_queries = Vec::new();
207
208 let logical_connectors = ["and", "or", "but", "however", "also", "additionally"];
210
211 for connector in &logical_connectors {
212 if query.to_lowercase().contains(connector) {
213 let parts: Vec<&str> = query.split(&format!(" {} ", connector)).collect();
214 if parts.len() > 1 {
215 for (i, part) in parts.iter().enumerate() {
216 let trimmed = part.trim();
217 if trimmed.len() >= self.config.min_sub_query_length {
218 sub_queries.push(SubQuery {
219 query: self.complete_sub_query(trimmed),
220 strategy: DecompositionStrategy::Logical,
221 confidence: 0.8,
222 priority: 1.0 - (i as f32 * 0.1), metadata: SubQueryMetadata {
224 position: i,
225 relationships: vec![connector.to_string()],
226 expected_answer_type: "factual".to_string(),
227 dependencies: vec![],
228 },
229 });
230 }
231 }
232 break; }
234 }
235 }
236
237 sub_queries
238 }
239
240 fn temporal_decomposition(&self, query: &str) -> Vec<SubQuery> {
242 let mut sub_queries = Vec::new();
243 let query_lower = query.to_lowercase();
244
245 let time_indicators = [
247 "when", "before", "after", "during", "since", "until", "timeline",
248 ];
249
250 if time_indicators
251 .iter()
252 .any(|&indicator| query_lower.contains(indicator))
253 {
254 let temporal_aspects = self.extract_temporal_aspects(query);
256
257 for (i, aspect) in temporal_aspects.iter().enumerate() {
258 sub_queries.push(SubQuery {
259 query: aspect.clone(),
260 strategy: DecompositionStrategy::Temporal,
261 confidence: 0.7,
262 priority: 0.8,
263 metadata: SubQueryMetadata {
264 position: i,
265 relationships: vec!["temporal".to_string()],
266 expected_answer_type: "temporal".to_string(),
267 dependencies: vec![],
268 },
269 });
270 }
271 }
272
273 sub_queries
274 }
275
276 fn topical_decomposition(&self, query: &str) -> Vec<SubQuery> {
278 let mut sub_queries = Vec::new();
279
280 let topics = self.extract_topics(query);
282
283 if topics.len() > 1 {
284 for (i, topic) in topics.iter().enumerate() {
285 let topic_query = format!("What is {}?", topic);
286 sub_queries.push(SubQuery {
287 query: topic_query,
288 strategy: DecompositionStrategy::Topical,
289 confidence: 0.6,
290 priority: 0.7,
291 metadata: SubQueryMetadata {
292 position: i,
293 relationships: vec!["topical".to_string()],
294 expected_answer_type: "conceptual".to_string(),
295 dependencies: vec![],
296 },
297 });
298 }
299 }
300
301 sub_queries
302 }
303
304 fn comparative_decomposition(&self, query: &str) -> Vec<SubQuery> {
306 let mut sub_queries = Vec::new();
307 let query_lower = query.to_lowercase();
308
309 let comparison_indicators = [
311 "vs",
312 "versus",
313 "compare",
314 "difference",
315 "similar",
316 "different",
317 ];
318
319 if comparison_indicators
320 .iter()
321 .any(|&indicator| query_lower.contains(indicator))
322 {
323 let items = self.extract_comparison_items(query);
324
325 if items.len() >= 2 {
326 for item in &items {
327 sub_queries.push(SubQuery {
328 query: format!("What are the features of {}?", item),
329 strategy: DecompositionStrategy::Comparative,
330 confidence: 0.75,
331 priority: 0.8,
332 metadata: SubQueryMetadata {
333 position: 0,
334 relationships: vec!["comparative".to_string()],
335 expected_answer_type: "comparative".to_string(),
336 dependencies: vec![],
337 },
338 });
339 }
340
341 sub_queries.push(SubQuery {
343 query: format!("Compare {} and {}", items[0], items[1]),
344 strategy: DecompositionStrategy::Comparative,
345 confidence: 0.9,
346 priority: 1.0,
347 metadata: SubQueryMetadata {
348 position: items.len(),
349 relationships: vec!["synthesis".to_string()],
350 expected_answer_type: "comparative".to_string(),
351 dependencies: (0..items.len()).collect(),
352 },
353 });
354 }
355 }
356
357 sub_queries
358 }
359
360 fn complete_sub_query(&self, partial: &str) -> String {
362 let trimmed = partial.trim();
363
364 let question_words = ["what", "how", "why", "when", "where", "who", "which"];
366 let starts_with_question = question_words
367 .iter()
368 .any(|&word| trimmed.to_lowercase().starts_with(word));
369
370 if starts_with_question || trimmed.ends_with('?') {
371 trimmed.to_string()
372 } else {
373 format!("What is {}?", trimmed)
374 }
375 }
376
377 fn extract_temporal_aspects(&self, query: &str) -> Vec<String> {
379 let mut aspects = Vec::new();
380
381 if query.to_lowercase().contains("when") {
383 aspects.push(format!(
384 "When did {} happen?",
385 self.extract_main_subject(query)
386 ));
387 }
388
389 if query.to_lowercase().contains("before") {
390 aspects.push(format!(
391 "What happened before {}?",
392 self.extract_main_subject(query)
393 ));
394 }
395
396 if query.to_lowercase().contains("after") {
397 aspects.push(format!(
398 "What happened after {}?",
399 self.extract_main_subject(query)
400 ));
401 }
402
403 aspects
404 }
405
406 fn extract_topics(&self, query: &str) -> Vec<String> {
408 let mut topics = Vec::new();
409
410 let words: Vec<&str> = query.split_whitespace().collect();
412
413 for window in words.windows(2) {
414 let word = window[0];
415 if word.chars().next().map_or(false, |c| c.is_uppercase()) && word.len() > 2 {
417 topics.push(word.to_string());
418 }
419 }
420
421 topics.sort();
423 topics.dedup();
424
425 topics
426 }
427
428 fn extract_comparison_items(&self, query: &str) -> Vec<String> {
430 let mut items = Vec::new();
431
432 if let Some(vs_pos) = query.to_lowercase().find(" vs ") {
434 let before = &query[..vs_pos].trim();
435 let after = &query[vs_pos + 4..].trim();
436
437 items.push(self.extract_last_noun(before).to_string());
438 items.push(self.extract_first_noun(after).to_string());
439 } else if query.to_lowercase().contains("compare") {
440 let words: Vec<&str> = query.split_whitespace().collect();
442 let mut collecting = false;
443
444 for word in words {
445 if word.to_lowercase() == "compare" {
446 collecting = true;
447 continue;
448 }
449
450 if collecting
451 && word.len() > 2
452 && !["and", "with", "to"].contains(&word.to_lowercase().as_str())
453 {
454 items.push(
455 word.trim_matches(|c: char| !c.is_alphanumeric())
456 .to_string(),
457 );
458 if items.len() >= 2 {
459 break;
460 }
461 }
462 }
463 }
464
465 items
466 }
467
468 fn extract_main_subject(&self, query: &str) -> String {
470 let words: Vec<&str> = query.split_whitespace().collect();
472
473 for word in words {
475 if word.len() > 3
476 && !["what", "when", "where", "how", "why", "who", "the", "and"]
477 .contains(&word.to_lowercase().as_str())
478 {
479 return word
480 .trim_matches(|c: char| !c.is_alphanumeric())
481 .to_string();
482 }
483 }
484
485 "this".to_string()
486 }
487
488 fn extract_last_noun<'a>(&self, text: &'a str) -> &'a str {
490 let words: Vec<&str> = text.split_whitespace().collect();
491 for word in words.iter().rev() {
492 if word.len() > 2
493 && !["the", "and", "or", "of", "in", "on", "at"]
494 .contains(&word.to_lowercase().as_str())
495 {
496 return word;
497 }
498 }
499 text
500 }
501
502 fn extract_first_noun<'a>(&self, text: &'a str) -> &'a str {
504 let words: Vec<&str> = text.split_whitespace().collect();
505 for word in words {
506 if word.len() > 2
507 && !["the", "and", "or", "of", "in", "on", "at"]
508 .contains(&word.to_lowercase().as_str())
509 {
510 return word;
511 }
512 }
513 text
514 }
515
516 fn enrich_sub_queries(&self, sub_queries: &mut [SubQuery]) {
518 for (i, sub_query) in sub_queries.iter_mut().enumerate() {
519 sub_query.metadata.position = i;
521
522 sub_query.metadata.expected_answer_type = self.determine_answer_type(&sub_query.query);
524 }
525 }
526
527 fn determine_answer_type(&self, query: &str) -> String {
529 let query_lower = query.to_lowercase();
530
531 if query_lower.starts_with("what is") || query_lower.starts_with("define") {
532 "definitional".to_string()
533 } else if query_lower.starts_with("how") {
534 "procedural".to_string()
535 } else if query_lower.starts_with("when") {
536 "temporal".to_string()
537 } else if query_lower.starts_with("where") {
538 "locational".to_string()
539 } else if query_lower.starts_with("why") {
540 "causal".to_string()
541 } else if query_lower.contains("compare") || query_lower.contains("vs") {
542 "comparative".to_string()
543 } else {
544 "factual".to_string()
545 }
546 }
547
548 fn init_patterns() -> Vec<DecompositionPattern> {
550 vec![
551 DecompositionPattern {
552 name: "Logical AND".to_string(),
553 triggers: vec![
554 "and".to_string(),
555 "also".to_string(),
556 "additionally".to_string(),
557 ],
558 strategy: DecompositionStrategy::Logical,
559 extractor: |query| {
560 query
561 .split(" and ")
562 .map(|s| s.trim().to_string())
563 .filter(|s| s.len() > 5)
564 .collect()
565 },
566 confidence: 0.8,
567 },
568 DecompositionPattern {
569 name: "Comparative".to_string(),
570 triggers: vec![
571 "vs".to_string(),
572 "compare".to_string(),
573 "difference".to_string(),
574 ],
575 strategy: DecompositionStrategy::Comparative,
576 extractor: |query| {
577 if query.contains(" vs ") {
578 query
579 .split(" vs ")
580 .map(|s| format!("What is {}?", s.trim()))
581 .collect()
582 } else {
583 vec![]
584 }
585 },
586 confidence: 0.9,
587 },
588 ]
589 }
590
591 fn init_complexity_indicators() -> Vec<String> {
593 vec![
594 "and".to_string(),
595 "or".to_string(),
596 "but".to_string(),
597 "however".to_string(),
598 "also".to_string(),
599 "additionally".to_string(),
600 "furthermore".to_string(),
601 "moreover".to_string(),
602 "vs".to_string(),
603 "versus".to_string(),
604 "compare".to_string(),
605 "difference".to_string(),
606 "similar".to_string(),
607 "different".to_string(),
608 "before".to_string(),
609 "after".to_string(),
610 "during".to_string(),
611 "while".to_string(),
612 "meanwhile".to_string(),
613 ]
614 }
615}
616
617impl Default for QueryDecomposer {
618 fn default() -> Self {
619 Self::new()
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626
627 #[tokio::test]
628 async fn test_logical_decomposition() {
629 let decomposer = QueryDecomposer::new();
630
631 let query = "What is machine learning and how does deep learning work?";
632 let sub_queries = decomposer.decompose(query).await.unwrap();
633
634 assert!(!sub_queries.is_empty());
635 assert!(sub_queries
636 .iter()
637 .any(|sq| sq.strategy == DecompositionStrategy::Logical));
638 }
639
640 #[tokio::test]
641 async fn test_comparative_decomposition() {
642 let decomposer = QueryDecomposer::new();
643
644 let query = "What are the differences between Python vs Rust for system programming?";
645 let sub_queries = decomposer.decompose(query).await.unwrap();
646
647 assert!(!sub_queries.is_empty());
648 let comparative_queries: Vec<_> = sub_queries
649 .iter()
650 .filter(|sq| sq.strategy == DecompositionStrategy::Comparative)
651 .collect();
652 assert!(!comparative_queries.is_empty());
653 }
654
655 #[tokio::test]
656 async fn test_should_not_decompose_simple_query() {
657 let decomposer = QueryDecomposer::new();
658
659 let query = "What is Rust?";
660 let sub_queries = decomposer.decompose(query).await.unwrap();
661
662 assert!(sub_queries.is_empty());
664 }
665
666 #[tokio::test]
667 async fn test_temporal_decomposition() {
668 let decomposer = QueryDecomposer::new();
669
670 let query = "When did the Renaissance start and what happened before it?";
671 let sub_queries = decomposer.decompose(query).await.unwrap();
672
673 assert!(!sub_queries.is_empty());
674 let temporal_queries: Vec<_> = sub_queries
675 .iter()
676 .filter(|sq| sq.strategy == DecompositionStrategy::Temporal)
677 .collect();
678 assert!(!temporal_queries.is_empty());
679 }
680}