rrag/query/
decomposer.rs

1//! # Query Decomposer
2//!
3//! Intelligent decomposition of complex queries into simpler sub-queries.
4//! Helps improve retrieval by breaking down multi-part questions into focused searches.
5
6use crate::RragResult;
7use serde::{Deserialize, Serialize};
8
9/// Query decomposer for breaking down complex queries
10pub struct QueryDecomposer {
11    /// Configuration
12    config: DecompositionConfig,
13
14    /// Patterns for identifying decomposable queries
15    patterns: Vec<DecompositionPattern>,
16
17    /// Keywords that indicate complex queries
18    complexity_indicators: Vec<String>,
19}
20
21/// Configuration for query decomposition
22#[derive(Debug, Clone)]
23pub struct DecompositionConfig {
24    /// Maximum number of sub-queries to generate
25    pub max_sub_queries: usize,
26
27    /// Minimum length for a sub-query
28    pub min_sub_query_length: usize,
29
30    /// Enable temporal decomposition (time-based queries)
31    pub enable_temporal_decomposition: bool,
32
33    /// Enable logical decomposition (AND/OR queries)
34    pub enable_logical_decomposition: bool,
35
36    /// Enable topical decomposition (multi-topic queries)
37    pub enable_topical_decomposition: bool,
38
39    /// Enable comparative decomposition (comparison queries)
40    pub enable_comparative_decomposition: bool,
41
42    /// Confidence threshold for accepting decompositions
43    pub confidence_threshold: f32,
44}
45
46impl Default for DecompositionConfig {
47    fn default() -> Self {
48        Self {
49            max_sub_queries: 5,
50            min_sub_query_length: 5,
51            enable_temporal_decomposition: true,
52            enable_logical_decomposition: true,
53            enable_topical_decomposition: true,
54            enable_comparative_decomposition: true,
55            confidence_threshold: 0.6,
56        }
57    }
58}
59
60/// Decomposition strategies
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62pub enum DecompositionStrategy {
63    /// Logical decomposition (AND, OR clauses)
64    Logical,
65    /// Temporal decomposition (time-based aspects)
66    Temporal,
67    /// Topical decomposition (different subjects)
68    Topical,
69    /// Comparative decomposition (A vs B)
70    Comparative,
71    /// Sequential decomposition (step-by-step)
72    Sequential,
73    /// Causal decomposition (cause and effect)
74    Causal,
75}
76
77/// Pattern for identifying decomposable queries
78struct DecompositionPattern {
79    /// Name of the pattern
80    name: String,
81    /// Keywords that trigger this pattern
82    triggers: Vec<String>,
83    /// Decomposition strategy to apply
84    strategy: DecompositionStrategy,
85    /// Function to extract sub-queries
86    extractor: fn(&str) -> Vec<String>,
87    /// Confidence score
88    confidence: f32,
89}
90
91/// Sub-query generated from decomposition
92#[derive(Debug, Clone)]
93pub struct SubQuery {
94    /// The sub-query text
95    pub query: String,
96
97    /// Strategy used to generate this sub-query
98    pub strategy: DecompositionStrategy,
99
100    /// Confidence score (0.0 to 1.0)
101    pub confidence: f32,
102
103    /// Priority/importance (higher = more important)
104    pub priority: f32,
105
106    /// Metadata about the sub-query
107    pub metadata: SubQueryMetadata,
108}
109
110/// Metadata for sub-queries
111#[derive(Debug, Clone)]
112pub struct SubQueryMetadata {
113    /// Position in the original query
114    pub position: usize,
115
116    /// Relationship to other sub-queries
117    pub relationships: Vec<String>,
118
119    /// Expected answer type
120    pub expected_answer_type: String,
121
122    /// Dependencies on other sub-queries
123    pub dependencies: Vec<usize>,
124}
125
126impl QueryDecomposer {
127    /// Create a new query decomposer
128    pub fn new() -> Self {
129        Self::with_config(DecompositionConfig::default())
130    }
131
132    /// Create with custom configuration
133    pub fn with_config(config: DecompositionConfig) -> Self {
134        let patterns = Self::init_patterns();
135        let complexity_indicators = Self::init_complexity_indicators();
136
137        Self {
138            config,
139            patterns,
140            complexity_indicators,
141        }
142    }
143
144    /// Decompose a complex query into sub-queries
145    pub async fn decompose(&self, query: &str) -> RragResult<Vec<SubQuery>> {
146        let mut sub_queries = Vec::new();
147
148        // Check if query needs decomposition
149        if !self.should_decompose(query) {
150            return Ok(sub_queries);
151        }
152
153        // Apply different decomposition strategies
154        if self.config.enable_logical_decomposition {
155            sub_queries.extend(self.logical_decomposition(query));
156        }
157
158        if self.config.enable_temporal_decomposition {
159            sub_queries.extend(self.temporal_decomposition(query));
160        }
161
162        if self.config.enable_topical_decomposition {
163            sub_queries.extend(self.topical_decomposition(query));
164        }
165
166        if self.config.enable_comparative_decomposition {
167            sub_queries.extend(self.comparative_decomposition(query));
168        }
169
170        // Filter by confidence and limit results
171        sub_queries.retain(|sq| sq.confidence >= self.config.confidence_threshold);
172        sub_queries.sort_by(|a, b| {
173            b.priority
174                .partial_cmp(&a.priority)
175                .unwrap_or(std::cmp::Ordering::Equal)
176        });
177        sub_queries.truncate(self.config.max_sub_queries);
178
179        // Add metadata and dependencies
180        self.enrich_sub_queries(&mut sub_queries);
181
182        Ok(sub_queries)
183    }
184
185    /// Check if a query should be decomposed
186    fn should_decompose(&self, query: &str) -> bool {
187        let query_lower = query.to_lowercase();
188
189        // Check for complexity indicators
190        let has_complexity_indicators = self
191            .complexity_indicators
192            .iter()
193            .any(|indicator| query_lower.contains(indicator));
194
195        // Check for multiple questions
196        let question_count = query.matches('?').count();
197
198        // Check length threshold
199        let word_count = query.split_whitespace().count();
200
201        has_complexity_indicators || question_count > 1 || word_count > 15
202    }
203
204    /// Logical decomposition (AND, OR, BUT clauses)
205    fn logical_decomposition(&self, query: &str) -> Vec<SubQuery> {
206        let mut sub_queries = Vec::new();
207
208        // Split on logical connectors
209        let logical_connectors = ["and", "or", "but", "however", "also", "additionally"];
210
211        for connector in &logical_connectors {
212            if query.to_lowercase().contains(connector) {
213                let parts: Vec<&str> = query.split(&format!(" {} ", connector)).collect();
214                if parts.len() > 1 {
215                    for (i, part) in parts.iter().enumerate() {
216                        let trimmed = part.trim();
217                        if trimmed.len() >= self.config.min_sub_query_length {
218                            sub_queries.push(SubQuery {
219                                query: self.complete_sub_query(trimmed),
220                                strategy: DecompositionStrategy::Logical,
221                                confidence: 0.8,
222                                priority: 1.0 - (i as f32 * 0.1), // First parts are more important
223                                metadata: SubQueryMetadata {
224                                    position: i,
225                                    relationships: vec![connector.to_string()],
226                                    expected_answer_type: "factual".to_string(),
227                                    dependencies: vec![],
228                                },
229                            });
230                        }
231                    }
232                    break; // Only use the first matching connector
233                }
234            }
235        }
236
237        sub_queries
238    }
239
240    /// Temporal decomposition for time-based queries
241    fn temporal_decomposition(&self, query: &str) -> Vec<SubQuery> {
242        let mut sub_queries = Vec::new();
243        let query_lower = query.to_lowercase();
244
245        // Time indicators
246        let time_indicators = [
247            "when", "before", "after", "during", "since", "until", "timeline",
248        ];
249
250        if time_indicators
251            .iter()
252            .any(|&indicator| query_lower.contains(indicator))
253        {
254            // Extract temporal aspects
255            let temporal_aspects = self.extract_temporal_aspects(query);
256
257            for (i, aspect) in temporal_aspects.iter().enumerate() {
258                sub_queries.push(SubQuery {
259                    query: aspect.clone(),
260                    strategy: DecompositionStrategy::Temporal,
261                    confidence: 0.7,
262                    priority: 0.8,
263                    metadata: SubQueryMetadata {
264                        position: i,
265                        relationships: vec!["temporal".to_string()],
266                        expected_answer_type: "temporal".to_string(),
267                        dependencies: vec![],
268                    },
269                });
270            }
271        }
272
273        sub_queries
274    }
275
276    /// Topical decomposition for multi-topic queries
277    fn topical_decomposition(&self, query: &str) -> Vec<SubQuery> {
278        let mut sub_queries = Vec::new();
279
280        // Look for multiple topics or subjects
281        let topics = self.extract_topics(query);
282
283        if topics.len() > 1 {
284            for (i, topic) in topics.iter().enumerate() {
285                let topic_query = format!("What is {}?", topic);
286                sub_queries.push(SubQuery {
287                    query: topic_query,
288                    strategy: DecompositionStrategy::Topical,
289                    confidence: 0.6,
290                    priority: 0.7,
291                    metadata: SubQueryMetadata {
292                        position: i,
293                        relationships: vec!["topical".to_string()],
294                        expected_answer_type: "conceptual".to_string(),
295                        dependencies: vec![],
296                    },
297                });
298            }
299        }
300
301        sub_queries
302    }
303
304    /// Comparative decomposition for comparison queries
305    fn comparative_decomposition(&self, query: &str) -> Vec<SubQuery> {
306        let mut sub_queries = Vec::new();
307        let query_lower = query.to_lowercase();
308
309        // Comparison indicators
310        let comparison_indicators = [
311            "vs",
312            "versus",
313            "compare",
314            "difference",
315            "similar",
316            "different",
317        ];
318
319        if comparison_indicators
320            .iter()
321            .any(|&indicator| query_lower.contains(indicator))
322        {
323            let items = self.extract_comparison_items(query);
324
325            if items.len() >= 2 {
326                for item in &items {
327                    sub_queries.push(SubQuery {
328                        query: format!("What are the features of {}?", item),
329                        strategy: DecompositionStrategy::Comparative,
330                        confidence: 0.75,
331                        priority: 0.8,
332                        metadata: SubQueryMetadata {
333                            position: 0,
334                            relationships: vec!["comparative".to_string()],
335                            expected_answer_type: "comparative".to_string(),
336                            dependencies: vec![],
337                        },
338                    });
339                }
340
341                // Add a synthesis query
342                sub_queries.push(SubQuery {
343                    query: format!("Compare {} and {}", items[0], items[1]),
344                    strategy: DecompositionStrategy::Comparative,
345                    confidence: 0.9,
346                    priority: 1.0,
347                    metadata: SubQueryMetadata {
348                        position: items.len(),
349                        relationships: vec!["synthesis".to_string()],
350                        expected_answer_type: "comparative".to_string(),
351                        dependencies: (0..items.len()).collect(),
352                    },
353                });
354            }
355        }
356
357        sub_queries
358    }
359
360    /// Complete a sub-query to make it grammatically correct
361    fn complete_sub_query(&self, partial: &str) -> String {
362        let trimmed = partial.trim();
363
364        // If it doesn't start with a question word or have proper structure, add context
365        let question_words = ["what", "how", "why", "when", "where", "who", "which"];
366        let starts_with_question = question_words
367            .iter()
368            .any(|&word| trimmed.to_lowercase().starts_with(word));
369
370        if starts_with_question || trimmed.ends_with('?') {
371            trimmed.to_string()
372        } else {
373            format!("What is {}?", trimmed)
374        }
375    }
376
377    /// Extract temporal aspects from a query
378    fn extract_temporal_aspects(&self, query: &str) -> Vec<String> {
379        let mut aspects = Vec::new();
380
381        // Simple temporal extraction - in production, this would be more sophisticated
382        if query.to_lowercase().contains("when") {
383            aspects.push(format!(
384                "When did {} happen?",
385                self.extract_main_subject(query)
386            ));
387        }
388
389        if query.to_lowercase().contains("before") {
390            aspects.push(format!(
391                "What happened before {}?",
392                self.extract_main_subject(query)
393            ));
394        }
395
396        if query.to_lowercase().contains("after") {
397            aspects.push(format!(
398                "What happened after {}?",
399                self.extract_main_subject(query)
400            ));
401        }
402
403        aspects
404    }
405
406    /// Extract topics from a query
407    fn extract_topics(&self, query: &str) -> Vec<String> {
408        let mut topics = Vec::new();
409
410        // Simple topic extraction based on nouns and capitalized words
411        let words: Vec<&str> = query.split_whitespace().collect();
412
413        for window in words.windows(2) {
414            let word = window[0];
415            // Look for capitalized words (potential proper nouns/topics)
416            if word.chars().next().map_or(false, |c| c.is_uppercase()) && word.len() > 2 {
417                topics.push(word.to_string());
418            }
419        }
420
421        // Remove duplicates
422        topics.sort();
423        topics.dedup();
424
425        topics
426    }
427
428    /// Extract comparison items from a query
429    fn extract_comparison_items(&self, query: &str) -> Vec<String> {
430        let mut items = Vec::new();
431
432        // Look for patterns like "A vs B" or "A and B"
433        if let Some(vs_pos) = query.to_lowercase().find(" vs ") {
434            let before = &query[..vs_pos].trim();
435            let after = &query[vs_pos + 4..].trim();
436
437            items.push(self.extract_last_noun(before).to_string());
438            items.push(self.extract_first_noun(after).to_string());
439        } else if query.to_lowercase().contains("compare") {
440            // Extract nouns after "compare"
441            let words: Vec<&str> = query.split_whitespace().collect();
442            let mut collecting = false;
443
444            for word in words {
445                if word.to_lowercase() == "compare" {
446                    collecting = true;
447                    continue;
448                }
449
450                if collecting
451                    && word.len() > 2
452                    && !["and", "with", "to"].contains(&word.to_lowercase().as_str())
453                {
454                    items.push(
455                        word.trim_matches(|c: char| !c.is_alphanumeric())
456                            .to_string(),
457                    );
458                    if items.len() >= 2 {
459                        break;
460                    }
461                }
462            }
463        }
464
465        items
466    }
467
468    /// Extract the main subject from a query
469    fn extract_main_subject(&self, query: &str) -> String {
470        // Simple subject extraction - would be more sophisticated in production
471        let words: Vec<&str> = query.split_whitespace().collect();
472
473        // Look for the first meaningful noun
474        for word in words {
475            if word.len() > 3
476                && !["what", "when", "where", "how", "why", "who", "the", "and"]
477                    .contains(&word.to_lowercase().as_str())
478            {
479                return word
480                    .trim_matches(|c: char| !c.is_alphanumeric())
481                    .to_string();
482            }
483        }
484
485        "this".to_string()
486    }
487
488    /// Extract the last meaningful noun from text
489    fn extract_last_noun<'a>(&self, text: &'a str) -> &'a str {
490        let words: Vec<&str> = text.split_whitespace().collect();
491        for word in words.iter().rev() {
492            if word.len() > 2
493                && !["the", "and", "or", "of", "in", "on", "at"]
494                    .contains(&word.to_lowercase().as_str())
495            {
496                return word;
497            }
498        }
499        text
500    }
501
502    /// Extract the first meaningful noun from text
503    fn extract_first_noun<'a>(&self, text: &'a str) -> &'a str {
504        let words: Vec<&str> = text.split_whitespace().collect();
505        for word in words {
506            if word.len() > 2
507                && !["the", "and", "or", "of", "in", "on", "at"]
508                    .contains(&word.to_lowercase().as_str())
509            {
510                return word;
511            }
512        }
513        text
514    }
515
516    /// Enrich sub-queries with additional metadata
517    fn enrich_sub_queries(&self, sub_queries: &mut [SubQuery]) {
518        for (i, sub_query) in sub_queries.iter_mut().enumerate() {
519            // Add position metadata
520            sub_query.metadata.position = i;
521
522            // Determine expected answer type based on query structure
523            sub_query.metadata.expected_answer_type = self.determine_answer_type(&sub_query.query);
524        }
525    }
526
527    /// Determine the expected answer type for a query
528    fn determine_answer_type(&self, query: &str) -> String {
529        let query_lower = query.to_lowercase();
530
531        if query_lower.starts_with("what is") || query_lower.starts_with("define") {
532            "definitional".to_string()
533        } else if query_lower.starts_with("how") {
534            "procedural".to_string()
535        } else if query_lower.starts_with("when") {
536            "temporal".to_string()
537        } else if query_lower.starts_with("where") {
538            "locational".to_string()
539        } else if query_lower.starts_with("why") {
540            "causal".to_string()
541        } else if query_lower.contains("compare") || query_lower.contains("vs") {
542            "comparative".to_string()
543        } else {
544            "factual".to_string()
545        }
546    }
547
548    /// Initialize decomposition patterns
549    fn init_patterns() -> Vec<DecompositionPattern> {
550        vec![
551            DecompositionPattern {
552                name: "Logical AND".to_string(),
553                triggers: vec![
554                    "and".to_string(),
555                    "also".to_string(),
556                    "additionally".to_string(),
557                ],
558                strategy: DecompositionStrategy::Logical,
559                extractor: |query| {
560                    query
561                        .split(" and ")
562                        .map(|s| s.trim().to_string())
563                        .filter(|s| s.len() > 5)
564                        .collect()
565                },
566                confidence: 0.8,
567            },
568            DecompositionPattern {
569                name: "Comparative".to_string(),
570                triggers: vec![
571                    "vs".to_string(),
572                    "compare".to_string(),
573                    "difference".to_string(),
574                ],
575                strategy: DecompositionStrategy::Comparative,
576                extractor: |query| {
577                    if query.contains(" vs ") {
578                        query
579                            .split(" vs ")
580                            .map(|s| format!("What is {}?", s.trim()))
581                            .collect()
582                    } else {
583                        vec![]
584                    }
585                },
586                confidence: 0.9,
587            },
588        ]
589    }
590
591    /// Initialize complexity indicators
592    fn init_complexity_indicators() -> Vec<String> {
593        vec![
594            "and".to_string(),
595            "or".to_string(),
596            "but".to_string(),
597            "however".to_string(),
598            "also".to_string(),
599            "additionally".to_string(),
600            "furthermore".to_string(),
601            "moreover".to_string(),
602            "vs".to_string(),
603            "versus".to_string(),
604            "compare".to_string(),
605            "difference".to_string(),
606            "similar".to_string(),
607            "different".to_string(),
608            "before".to_string(),
609            "after".to_string(),
610            "during".to_string(),
611            "while".to_string(),
612            "meanwhile".to_string(),
613        ]
614    }
615}
616
617impl Default for QueryDecomposer {
618    fn default() -> Self {
619        Self::new()
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626
627    #[tokio::test]
628    async fn test_logical_decomposition() {
629        let decomposer = QueryDecomposer::new();
630
631        let query = "What is machine learning and how does deep learning work?";
632        let sub_queries = decomposer.decompose(query).await.unwrap();
633
634        assert!(!sub_queries.is_empty());
635        assert!(sub_queries
636            .iter()
637            .any(|sq| sq.strategy == DecompositionStrategy::Logical));
638    }
639
640    #[tokio::test]
641    async fn test_comparative_decomposition() {
642        let decomposer = QueryDecomposer::new();
643
644        let query = "What are the differences between Python vs Rust for system programming?";
645        let sub_queries = decomposer.decompose(query).await.unwrap();
646
647        assert!(!sub_queries.is_empty());
648        let comparative_queries: Vec<_> = sub_queries
649            .iter()
650            .filter(|sq| sq.strategy == DecompositionStrategy::Comparative)
651            .collect();
652        assert!(!comparative_queries.is_empty());
653    }
654
655    #[tokio::test]
656    async fn test_should_not_decompose_simple_query() {
657        let decomposer = QueryDecomposer::new();
658
659        let query = "What is Rust?";
660        let sub_queries = decomposer.decompose(query).await.unwrap();
661
662        // Simple queries should not be decomposed
663        assert!(sub_queries.is_empty());
664    }
665
666    #[tokio::test]
667    async fn test_temporal_decomposition() {
668        let decomposer = QueryDecomposer::new();
669
670        let query = "When did the Renaissance start and what happened before it?";
671        let sub_queries = decomposer.decompose(query).await.unwrap();
672
673        assert!(!sub_queries.is_empty());
674        let temporal_queries: Vec<_> = sub_queries
675            .iter()
676            .filter(|sq| sq.strategy == DecompositionStrategy::Temporal)
677            .collect();
678        assert!(!temporal_queries.is_empty());
679    }
680}