rexis_rag/query/
hyde.rs

1//! # HyDE (Hypothetical Document Embeddings)
2//!
3//! Generates hypothetical documents that would answer the user's query,
4//! then uses their embeddings for more effective semantic search.
5//! Based on the paper: "Precise Zero-Shot Dense Retrieval without Relevance Labels"
6
7use crate::{EmbeddingProvider, RragResult};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// HyDE generator for creating hypothetical document embeddings
12pub struct HyDEGenerator {
13    /// Configuration
14    config: HyDEConfig,
15
16    /// Embedding provider for generating embeddings
17    embedding_provider: Arc<dyn EmbeddingProvider>,
18
19    /// Document templates for different query types
20    templates: HashMap<String, Vec<DocumentTemplate>>,
21
22    /// Answer generation patterns
23    answer_patterns: Vec<AnswerPattern>,
24}
25
26/// Configuration for HyDE generation
27#[derive(Debug, Clone)]
28pub struct HyDEConfig {
29    /// Number of hypothetical documents to generate
30    pub num_hypothetical_docs: usize,
31
32    /// Maximum length for generated documents
33    pub max_document_length: usize,
34
35    /// Minimum length for generated documents
36    pub min_document_length: usize,
37
38    /// Enable query-specific document generation
39    pub enable_query_specific_generation: bool,
40
41    /// Enable domain-aware generation
42    pub enable_domain_awareness: bool,
43
44    /// Confidence threshold for accepting generated documents
45    pub confidence_threshold: f32,
46
47    /// Temperature for generation (creativity vs accuracy)
48    pub generation_temperature: f32,
49}
50
51impl Default for HyDEConfig {
52    fn default() -> Self {
53        Self {
54            num_hypothetical_docs: 3,
55            max_document_length: 500,
56            min_document_length: 50,
57            enable_query_specific_generation: true,
58            enable_domain_awareness: true,
59            confidence_threshold: 0.6,
60            generation_temperature: 0.7,
61        }
62    }
63}
64
65/// Document template for generating hypothetical answers
66#[derive(Debug, Clone)]
67struct DocumentTemplate {
68    /// Template name
69    name: String,
70    /// Template pattern with placeholders
71    pattern: String,
72    /// Query types this template works best for
73    query_types: Vec<String>,
74    /// Confidence score for this template
75    confidence: f32,
76}
77
78/// Pattern for generating answers
79#[derive(Debug, Clone)]
80struct AnswerPattern {
81    /// Pattern name
82    name: String,
83    /// Trigger keywords
84    triggers: Vec<String>,
85    /// Generation function
86    generator: fn(&str, &HyDEConfig) -> Vec<String>,
87    /// Confidence score
88    confidence: f32,
89}
90
91/// Result of HyDE generation
92#[derive(Debug, Clone)]
93pub struct HyDEResult {
94    /// Original query
95    pub query: String,
96
97    /// Generated hypothetical answer/document
98    pub hypothetical_answer: String,
99
100    /// Embedding of the hypothetical document
101    pub embedding: Option<crate::embeddings::Embedding>,
102
103    /// Generation method used
104    pub generation_method: String,
105
106    /// Confidence score (0.0 to 1.0)
107    pub confidence: f32,
108
109    /// Generation metadata
110    pub metadata: HyDEMetadata,
111}
112
113/// Metadata for HyDE generation
114#[derive(Debug, Clone)]
115pub struct HyDEMetadata {
116    /// Generation time in milliseconds
117    pub generation_time_ms: u64,
118
119    /// Document length in characters
120    pub document_length: usize,
121
122    /// Document length in tokens (approximate)
123    pub estimated_tokens: usize,
124
125    /// Query type detected
126    pub detected_query_type: String,
127
128    /// Domain detected
129    pub detected_domain: Option<String>,
130
131    /// Template used
132    pub template_used: Option<String>,
133}
134
135impl HyDEGenerator {
136    /// Create a new HyDE generator
137    pub fn new(config: HyDEConfig, embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
138        let templates = Self::init_templates();
139        let answer_patterns = Self::init_answer_patterns();
140
141        Self {
142            config,
143            embedding_provider,
144            templates,
145            answer_patterns,
146        }
147    }
148
149    /// Generate hypothetical documents for a query
150    pub async fn generate(&self, query: &str) -> RragResult<Vec<HyDEResult>> {
151        let start_time = std::time::Instant::now();
152        let mut results = Vec::new();
153
154        // Detect query characteristics
155        let query_type = self.detect_query_type(query);
156        let domain = if self.config.enable_domain_awareness {
157            self.detect_domain(query)
158        } else {
159            None
160        };
161
162        // Generate hypothetical documents using different strategies
163        let hypothetical_docs = self.generate_hypothetical_documents(query, &query_type, &domain);
164
165        for (i, doc) in hypothetical_docs.iter().enumerate() {
166            if doc.len() < self.config.min_document_length
167                || doc.len() > self.config.max_document_length
168            {
169                continue;
170            }
171
172            // Generate embedding for the hypothetical document
173            let embedding = match self.embedding_provider.embed_text(doc).await {
174                Ok(emb) => Some(emb),
175                Err(_) => None, // Continue without embedding if it fails
176            };
177
178            let confidence = self.calculate_confidence(query, doc, &query_type);
179
180            if confidence >= self.config.confidence_threshold {
181                results.push(HyDEResult {
182                    query: query.to_string(),
183                    hypothetical_answer: doc.clone(),
184                    embedding,
185                    generation_method: format!("pattern_{}", i),
186                    confidence,
187                    metadata: HyDEMetadata {
188                        generation_time_ms: start_time.elapsed().as_millis() as u64,
189                        document_length: doc.len(),
190                        estimated_tokens: doc.split_whitespace().count(),
191                        detected_query_type: query_type.clone(),
192                        detected_domain: domain.clone(),
193                        template_used: Some(format!("template_{}", i)),
194                    },
195                });
196            }
197
198            if results.len() >= self.config.num_hypothetical_docs {
199                break;
200            }
201        }
202
203        Ok(results)
204    }
205
206    /// Generate hypothetical documents using various strategies
207    fn generate_hypothetical_documents(
208        &self,
209        query: &str,
210        query_type: &str,
211        domain: &Option<String>,
212    ) -> Vec<String> {
213        let mut documents = Vec::new();
214
215        // Strategy 1: Template-based generation
216        if let Some(templates) = self.templates.get(query_type) {
217            for template in templates {
218                let doc = self.apply_template(query, template, domain);
219                documents.push(doc);
220            }
221        }
222
223        // Strategy 2: Pattern-based generation
224        for pattern in &self.answer_patterns {
225            if pattern
226                .triggers
227                .iter()
228                .any(|trigger| query.to_lowercase().contains(&trigger.to_lowercase()))
229            {
230                let generated_docs = (pattern.generator)(query, &self.config);
231                documents.extend(generated_docs);
232            }
233        }
234
235        // Strategy 3: Fallback generic generation
236        if documents.is_empty() {
237            documents.extend(self.generate_generic_documents(query, query_type));
238        }
239
240        // Limit and deduplicate
241        documents.sort();
242        documents.dedup();
243        documents.truncate(self.config.num_hypothetical_docs * 2); // Generate more, filter later
244
245        documents
246    }
247
248    /// Apply a template to generate a hypothetical document
249    fn apply_template(
250        &self,
251        query: &str,
252        template: &DocumentTemplate,
253        domain: &Option<String>,
254    ) -> String {
255        let mut result = template.pattern.clone();
256
257        // Extract key terms from query
258        let key_terms = self.extract_key_terms(query);
259        let main_subject = self.extract_main_subject(query);
260
261        // Replace placeholders
262        result = result.replace("{query}", query);
263        result = result.replace("{subject}", &main_subject);
264        result = result.replace("{key_terms}", &key_terms.join(", "));
265
266        if let Some(domain_name) = domain {
267            result = result.replace("{domain}", domain_name);
268        }
269
270        // Clean up the result
271        self.clean_generated_text(&result)
272    }
273
274    /// Generate generic hypothetical documents
275    fn generate_generic_documents(&self, query: &str, query_type: &str) -> Vec<String> {
276        let mut documents = Vec::new();
277        let main_subject = self.extract_main_subject(query);
278
279        match query_type {
280            "definitional" => {
281                documents.push(format!(
282                    "{} is a concept that refers to the fundamental principles and mechanisms underlying this topic. \
283                    It encompasses various aspects including its core definition, key characteristics, and primary applications. \
284                    Understanding {} requires examining its historical development, theoretical foundations, and practical implications. \
285                    The concept plays a crucial role in its respective field and has significant impact on related areas.",
286                    main_subject, main_subject
287                ));
288            }
289            "procedural" => {
290                documents.push(format!(
291                    "To accomplish {} successfully, there are several important steps to follow. \
292                    First, it's essential to understand the underlying principles and requirements. \
293                    The process typically involves careful planning, systematic execution, and continuous monitoring. \
294                    Key considerations include proper preparation, attention to detail, and adherence to best practices. \
295                    Following these guidelines will help ensure optimal results and avoid common pitfalls.",
296                    main_subject
297                ));
298            }
299            "comparative" => {
300                documents.push(format!(
301                    "When comparing different approaches to {}, several factors must be considered. \
302                    Each option has distinct advantages and disadvantages that affect their suitability for various use cases. \
303                    The comparison involves analyzing performance characteristics, resource requirements, and implementation complexity. \
304                    Understanding these differences helps in making informed decisions based on specific needs and constraints.",
305                    main_subject
306                ));
307            }
308            "factual" => {
309                documents.push(format!(
310                    "Regarding {}, there are several important facts and key information points to consider. \
311                    The available evidence and research data provide insights into various aspects of this topic. \
312                    Historical context, current developments, and future trends all contribute to a comprehensive understanding. \
313                    These facts form the foundation for deeper analysis and informed decision-making.",
314                    main_subject
315                ));
316            }
317            _ => {
318                documents.push(format!(
319                    "{} represents an important topic that deserves careful examination. \
320                    The subject encompasses multiple dimensions including theoretical aspects, practical applications, and real-world implications. \
321                    Understanding this topic requires considering various perspectives, analyzing available information, and drawing meaningful conclusions. \
322                    This comprehensive approach ensures a thorough grasp of the subject matter.",
323                    main_subject
324                ));
325            }
326        }
327
328        documents
329    }
330
331    /// Detect the type of query
332    fn detect_query_type(&self, query: &str) -> String {
333        let query_lower = query.to_lowercase();
334
335        if query_lower.starts_with("what is") || query_lower.starts_with("define") {
336            "definitional".to_string()
337        } else if query_lower.starts_with("how to") || query_lower.contains("step") {
338            "procedural".to_string()
339        } else if query_lower.contains("compare")
340            || query_lower.contains("vs")
341            || query_lower.contains("difference")
342        {
343            "comparative".to_string()
344        } else if query_lower.starts_with("when")
345            || query_lower.starts_with("where")
346            || query_lower.starts_with("who")
347        {
348            "factual".to_string()
349        } else if query_lower.starts_with("why") {
350            "causal".to_string()
351        } else if query_lower.starts_with("list") || query_lower.contains("examples") {
352            "enumerative".to_string()
353        } else {
354            "general".to_string()
355        }
356    }
357
358    /// Detect domain from query
359    fn detect_domain(&self, query: &str) -> Option<String> {
360        let query_lower = query.to_lowercase();
361
362        let domains = [
363            (
364                "technology",
365                vec![
366                    "code",
367                    "programming",
368                    "software",
369                    "api",
370                    "database",
371                    "algorithm",
372                    "computer",
373                    "tech",
374                ],
375            ),
376            (
377                "science",
378                vec![
379                    "research",
380                    "study",
381                    "experiment",
382                    "theory",
383                    "analysis",
384                    "scientific",
385                    "hypothesis",
386                ],
387            ),
388            (
389                "business",
390                vec![
391                    "market",
392                    "sales",
393                    "revenue",
394                    "customer",
395                    "profit",
396                    "strategy",
397                    "management",
398                    "company",
399                ],
400            ),
401            (
402                "health",
403                vec![
404                    "medical",
405                    "health",
406                    "disease",
407                    "treatment",
408                    "doctor",
409                    "medicine",
410                    "patient",
411                    "healthcare",
412                ],
413            ),
414            (
415                "education",
416                vec![
417                    "learn",
418                    "study",
419                    "school",
420                    "university",
421                    "course",
422                    "education",
423                    "teach",
424                    "academic",
425                ],
426            ),
427            (
428                "finance",
429                vec![
430                    "money",
431                    "investment",
432                    "financial",
433                    "bank",
434                    "trading",
435                    "economics",
436                    "cost",
437                    "price",
438                ],
439            ),
440        ];
441
442        for (domain, keywords) in &domains {
443            let matches = keywords
444                .iter()
445                .filter(|&&keyword| query_lower.contains(keyword))
446                .count();
447
448            if matches >= 2 || (matches == 1 && query_lower.split_whitespace().count() <= 5) {
449                return Some(domain.to_string());
450            }
451        }
452
453        None
454    }
455
456    /// Extract key terms from query
457    fn extract_key_terms(&self, query: &str) -> Vec<String> {
458        let stop_words = [
459            "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with",
460            "by", "is", "are", "was", "were", "be", "been", "have", "has", "had", "do", "does",
461            "did", "will", "would", "could", "should", "may", "might", "can", "what", "how", "why",
462            "when", "where", "who", "which",
463        ];
464
465        query
466            .split_whitespace()
467            .filter(|word| {
468                let clean_word = word
469                    .trim_matches(|c: char| !c.is_alphanumeric())
470                    .to_lowercase();
471                !stop_words.contains(&clean_word.as_str()) && clean_word.len() > 2
472            })
473            .map(|word| {
474                word.trim_matches(|c: char| !c.is_alphanumeric())
475                    .to_string()
476            })
477            .collect()
478    }
479
480    /// Extract main subject from query
481    fn extract_main_subject(&self, query: &str) -> String {
482        let key_terms = self.extract_key_terms(query);
483        if !key_terms.is_empty() {
484            key_terms[0].clone()
485        } else {
486            "the topic".to_string()
487        }
488    }
489
490    /// Clean generated text
491    fn clean_generated_text(&self, text: &str) -> String {
492        text.trim()
493            .replace("  ", " ")
494            .replace("\n\n", "\n")
495            .lines()
496            .filter(|line| !line.trim().is_empty())
497            .collect::<Vec<_>>()
498            .join(" ")
499    }
500
501    /// Calculate confidence score for generated document
502    fn calculate_confidence(&self, query: &str, document: &str, query_type: &str) -> f32 {
503        let mut confidence = 0.5; // Base confidence
504
505        // Check length appropriateness
506        if document.len() >= self.config.min_document_length
507            && document.len() <= self.config.max_document_length
508        {
509            confidence += 0.1;
510        }
511
512        // Check if key terms from query appear in document
513        let query_terms = self.extract_key_terms(query);
514        let document_lower = document.to_lowercase();
515        let term_matches = query_terms
516            .iter()
517            .filter(|term| document_lower.contains(&term.to_lowercase()))
518            .count();
519
520        if !query_terms.is_empty() {
521            confidence += (term_matches as f32 / query_terms.len() as f32) * 0.3;
522        }
523
524        // Bonus for appropriate query type handling
525        match query_type {
526            "definitional" if document.contains("is") || document.contains("refers to") => {
527                confidence += 0.1
528            }
529            "procedural" if document.contains("step") || document.contains("process") => {
530                confidence += 0.1
531            }
532            "comparative" if document.contains("compare") || document.contains("difference") => {
533                confidence += 0.1
534            }
535            _ => {}
536        }
537
538        confidence.min(1.0)
539    }
540
541    /// Initialize document templates
542    fn init_templates() -> HashMap<String, Vec<DocumentTemplate>> {
543        let mut templates = HashMap::new();
544
545        // Definitional templates
546        templates.insert("definitional".to_string(), vec![
547            DocumentTemplate {
548                name: "concept_definition".to_string(),
549                pattern: "{subject} is a fundamental concept in {domain} that encompasses several key aspects. It refers to the systematic approach and principles underlying this area of study. The definition includes both theoretical foundations and practical applications, making it essential for understanding related topics.".to_string(),
550                query_types: vec!["definitional".to_string()],
551                confidence: 0.8,
552            },
553        ]);
554
555        // Procedural templates
556        templates.insert("procedural".to_string(), vec![
557            DocumentTemplate {
558                name: "how_to_guide".to_string(),
559                pattern: "To effectively accomplish {subject}, follow these systematic steps and best practices. The process requires careful planning, proper execution, and continuous monitoring. Begin by understanding the requirements, then proceed with methodical implementation while considering potential challenges and solutions.".to_string(),
560                query_types: vec!["procedural".to_string()],
561                confidence: 0.8,
562            },
563        ]);
564
565        // Comparative templates
566        templates.insert("comparative".to_string(), vec![
567            DocumentTemplate {
568                name: "comparison_analysis".to_string(),
569                pattern: "When analyzing {subject}, several important factors distinguish different approaches and options. Each alternative offers unique advantages and limitations that affect performance, cost, and suitability for various use cases. The comparison reveals critical differences in functionality, efficiency, and implementation requirements.".to_string(),
570                query_types: vec!["comparative".to_string()],
571                confidence: 0.8,
572            },
573        ]);
574
575        templates
576    }
577
578    /// Initialize answer patterns
579    fn init_answer_patterns() -> Vec<AnswerPattern> {
580        vec![
581            AnswerPattern {
582                name: "technical_explanation".to_string(),
583                triggers: vec![
584                    "algorithm".to_string(),
585                    "system".to_string(),
586                    "technology".to_string(),
587                ],
588                generator: |query, _config| {
589                    vec![format!(
590                        "The technical implementation of {} involves several sophisticated components working together. \
591                        The system architecture incorporates advanced algorithms and optimized data structures to ensure \
592                        efficient performance and scalability. Key technical considerations include resource management, \
593                        error handling, and performance optimization strategies.",
594                        query
595                    )]
596                },
597                confidence: 0.7,
598            },
599            AnswerPattern {
600                name: "research_summary".to_string(),
601                triggers: vec![
602                    "research".to_string(),
603                    "study".to_string(),
604                    "analysis".to_string(),
605                ],
606                generator: |query, _config| {
607                    vec![format!(
608                        "Recent research on {} has revealed significant insights and findings that advance our understanding \
609                        of this field. Multiple studies have examined various aspects, employing rigorous methodologies \
610                        and comprehensive data analysis. The research findings contribute valuable knowledge and inform \
611                        evidence-based practices and future investigations.",
612                        query
613                    )]
614                },
615                confidence: 0.7,
616            },
617        ]
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use crate::embeddings::MockEmbeddingProvider;
625
626    #[tokio::test]
627    async fn test_hyde_generation() {
628        let provider = Arc::new(MockEmbeddingProvider::new());
629        let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
630
631        let results = hyde.generate("What is machine learning?").await.unwrap();
632
633        assert!(!results.is_empty());
634        assert!(results[0].confidence > 0.0);
635        assert!(results[0].hypothetical_answer.len() > 50);
636        assert_eq!(results[0].metadata.detected_query_type, "definitional");
637    }
638
639    #[tokio::test]
640    async fn test_procedural_query() {
641        let provider = Arc::new(MockEmbeddingProvider::new());
642        let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
643
644        let results = hyde.generate("How to implement a REST API?").await.unwrap();
645
646        assert!(!results.is_empty());
647        assert_eq!(results[0].metadata.detected_query_type, "procedural");
648        assert!(
649            results[0].hypothetical_answer.contains("step")
650                || results[0].hypothetical_answer.contains("process")
651        );
652    }
653
654    #[tokio::test]
655    async fn test_comparative_query() {
656        let provider = Arc::new(MockEmbeddingProvider::new());
657        let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
658
659        let results = hyde
660            .generate("Python vs Rust performance comparison")
661            .await
662            .unwrap();
663
664        assert!(!results.is_empty());
665        assert_eq!(results[0].metadata.detected_query_type, "comparative");
666    }
667
668    #[test]
669    fn test_query_type_detection() {
670        let provider = Arc::new(MockEmbeddingProvider::new());
671        let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
672
673        assert_eq!(hyde.detect_query_type("What is AI?"), "definitional");
674        assert_eq!(hyde.detect_query_type("How to code?"), "procedural");
675        assert_eq!(hyde.detect_query_type("Python vs Java"), "comparative");
676        assert_eq!(hyde.detect_query_type("When was it built?"), "factual");
677    }
678
679    #[test]
680    fn test_domain_detection() {
681        let provider = Arc::new(MockEmbeddingProvider::new());
682        let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
683
684        assert_eq!(
685            hyde.detect_domain("machine learning algorithm"),
686            Some("technology".to_string())
687        );
688        assert_eq!(
689            hyde.detect_domain("medical research study"),
690            Some("health".to_string())
691        );
692        assert_eq!(
693            hyde.detect_domain("market analysis strategy"),
694            Some("business".to_string())
695        );
696        assert_eq!(hyde.detect_domain("simple question"), None);
697    }
698}