1use crate::{EmbeddingProvider, RragResult};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub struct HyDEGenerator {
13 config: HyDEConfig,
15
16 embedding_provider: Arc<dyn EmbeddingProvider>,
18
19 templates: HashMap<String, Vec<DocumentTemplate>>,
21
22 answer_patterns: Vec<AnswerPattern>,
24}
25
26#[derive(Debug, Clone)]
28pub struct HyDEConfig {
29 pub num_hypothetical_docs: usize,
31
32 pub max_document_length: usize,
34
35 pub min_document_length: usize,
37
38 pub enable_query_specific_generation: bool,
40
41 pub enable_domain_awareness: bool,
43
44 pub confidence_threshold: f32,
46
47 pub generation_temperature: f32,
49}
50
51impl Default for HyDEConfig {
52 fn default() -> Self {
53 Self {
54 num_hypothetical_docs: 3,
55 max_document_length: 500,
56 min_document_length: 50,
57 enable_query_specific_generation: true,
58 enable_domain_awareness: true,
59 confidence_threshold: 0.6,
60 generation_temperature: 0.7,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67struct DocumentTemplate {
68 name: String,
70 pattern: String,
72 query_types: Vec<String>,
74 confidence: f32,
76}
77
78#[derive(Debug, Clone)]
80struct AnswerPattern {
81 name: String,
83 triggers: Vec<String>,
85 generator: fn(&str, &HyDEConfig) -> Vec<String>,
87 confidence: f32,
89}
90
91#[derive(Debug, Clone)]
93pub struct HyDEResult {
94 pub query: String,
96
97 pub hypothetical_answer: String,
99
100 pub embedding: Option<crate::embeddings::Embedding>,
102
103 pub generation_method: String,
105
106 pub confidence: f32,
108
109 pub metadata: HyDEMetadata,
111}
112
113#[derive(Debug, Clone)]
115pub struct HyDEMetadata {
116 pub generation_time_ms: u64,
118
119 pub document_length: usize,
121
122 pub estimated_tokens: usize,
124
125 pub detected_query_type: String,
127
128 pub detected_domain: Option<String>,
130
131 pub template_used: Option<String>,
133}
134
135impl HyDEGenerator {
136 pub fn new(config: HyDEConfig, embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
138 let templates = Self::init_templates();
139 let answer_patterns = Self::init_answer_patterns();
140
141 Self {
142 config,
143 embedding_provider,
144 templates,
145 answer_patterns,
146 }
147 }
148
149 pub async fn generate(&self, query: &str) -> RragResult<Vec<HyDEResult>> {
151 let start_time = std::time::Instant::now();
152 let mut results = Vec::new();
153
154 let query_type = self.detect_query_type(query);
156 let domain = if self.config.enable_domain_awareness {
157 self.detect_domain(query)
158 } else {
159 None
160 };
161
162 let hypothetical_docs = self.generate_hypothetical_documents(query, &query_type, &domain);
164
165 for (i, doc) in hypothetical_docs.iter().enumerate() {
166 if doc.len() < self.config.min_document_length
167 || doc.len() > self.config.max_document_length
168 {
169 continue;
170 }
171
172 let embedding = match self.embedding_provider.embed_text(doc).await {
174 Ok(emb) => Some(emb),
175 Err(_) => None, };
177
178 let confidence = self.calculate_confidence(query, doc, &query_type);
179
180 if confidence >= self.config.confidence_threshold {
181 results.push(HyDEResult {
182 query: query.to_string(),
183 hypothetical_answer: doc.clone(),
184 embedding,
185 generation_method: format!("pattern_{}", i),
186 confidence,
187 metadata: HyDEMetadata {
188 generation_time_ms: start_time.elapsed().as_millis() as u64,
189 document_length: doc.len(),
190 estimated_tokens: doc.split_whitespace().count(),
191 detected_query_type: query_type.clone(),
192 detected_domain: domain.clone(),
193 template_used: Some(format!("template_{}", i)),
194 },
195 });
196 }
197
198 if results.len() >= self.config.num_hypothetical_docs {
199 break;
200 }
201 }
202
203 Ok(results)
204 }
205
206 fn generate_hypothetical_documents(
208 &self,
209 query: &str,
210 query_type: &str,
211 domain: &Option<String>,
212 ) -> Vec<String> {
213 let mut documents = Vec::new();
214
215 if let Some(templates) = self.templates.get(query_type) {
217 for template in templates {
218 let doc = self.apply_template(query, template, domain);
219 documents.push(doc);
220 }
221 }
222
223 for pattern in &self.answer_patterns {
225 if pattern
226 .triggers
227 .iter()
228 .any(|trigger| query.to_lowercase().contains(&trigger.to_lowercase()))
229 {
230 let generated_docs = (pattern.generator)(query, &self.config);
231 documents.extend(generated_docs);
232 }
233 }
234
235 if documents.is_empty() {
237 documents.extend(self.generate_generic_documents(query, query_type));
238 }
239
240 documents.sort();
242 documents.dedup();
243 documents.truncate(self.config.num_hypothetical_docs * 2); documents
246 }
247
248 fn apply_template(
250 &self,
251 query: &str,
252 template: &DocumentTemplate,
253 domain: &Option<String>,
254 ) -> String {
255 let mut result = template.pattern.clone();
256
257 let key_terms = self.extract_key_terms(query);
259 let main_subject = self.extract_main_subject(query);
260
261 result = result.replace("{query}", query);
263 result = result.replace("{subject}", &main_subject);
264 result = result.replace("{key_terms}", &key_terms.join(", "));
265
266 if let Some(domain_name) = domain {
267 result = result.replace("{domain}", domain_name);
268 }
269
270 self.clean_generated_text(&result)
272 }
273
274 fn generate_generic_documents(&self, query: &str, query_type: &str) -> Vec<String> {
276 let mut documents = Vec::new();
277 let main_subject = self.extract_main_subject(query);
278
279 match query_type {
280 "definitional" => {
281 documents.push(format!(
282 "{} is a concept that refers to the fundamental principles and mechanisms underlying this topic. \
283 It encompasses various aspects including its core definition, key characteristics, and primary applications. \
284 Understanding {} requires examining its historical development, theoretical foundations, and practical implications. \
285 The concept plays a crucial role in its respective field and has significant impact on related areas.",
286 main_subject, main_subject
287 ));
288 }
289 "procedural" => {
290 documents.push(format!(
291 "To accomplish {} successfully, there are several important steps to follow. \
292 First, it's essential to understand the underlying principles and requirements. \
293 The process typically involves careful planning, systematic execution, and continuous monitoring. \
294 Key considerations include proper preparation, attention to detail, and adherence to best practices. \
295 Following these guidelines will help ensure optimal results and avoid common pitfalls.",
296 main_subject
297 ));
298 }
299 "comparative" => {
300 documents.push(format!(
301 "When comparing different approaches to {}, several factors must be considered. \
302 Each option has distinct advantages and disadvantages that affect their suitability for various use cases. \
303 The comparison involves analyzing performance characteristics, resource requirements, and implementation complexity. \
304 Understanding these differences helps in making informed decisions based on specific needs and constraints.",
305 main_subject
306 ));
307 }
308 "factual" => {
309 documents.push(format!(
310 "Regarding {}, there are several important facts and key information points to consider. \
311 The available evidence and research data provide insights into various aspects of this topic. \
312 Historical context, current developments, and future trends all contribute to a comprehensive understanding. \
313 These facts form the foundation for deeper analysis and informed decision-making.",
314 main_subject
315 ));
316 }
317 _ => {
318 documents.push(format!(
319 "{} represents an important topic that deserves careful examination. \
320 The subject encompasses multiple dimensions including theoretical aspects, practical applications, and real-world implications. \
321 Understanding this topic requires considering various perspectives, analyzing available information, and drawing meaningful conclusions. \
322 This comprehensive approach ensures a thorough grasp of the subject matter.",
323 main_subject
324 ));
325 }
326 }
327
328 documents
329 }
330
331 fn detect_query_type(&self, query: &str) -> String {
333 let query_lower = query.to_lowercase();
334
335 if query_lower.starts_with("what is") || query_lower.starts_with("define") {
336 "definitional".to_string()
337 } else if query_lower.starts_with("how to") || query_lower.contains("step") {
338 "procedural".to_string()
339 } else if query_lower.contains("compare")
340 || query_lower.contains("vs")
341 || query_lower.contains("difference")
342 {
343 "comparative".to_string()
344 } else if query_lower.starts_with("when")
345 || query_lower.starts_with("where")
346 || query_lower.starts_with("who")
347 {
348 "factual".to_string()
349 } else if query_lower.starts_with("why") {
350 "causal".to_string()
351 } else if query_lower.starts_with("list") || query_lower.contains("examples") {
352 "enumerative".to_string()
353 } else {
354 "general".to_string()
355 }
356 }
357
358 fn detect_domain(&self, query: &str) -> Option<String> {
360 let query_lower = query.to_lowercase();
361
362 let domains = [
363 (
364 "technology",
365 vec![
366 "code",
367 "programming",
368 "software",
369 "api",
370 "database",
371 "algorithm",
372 "computer",
373 "tech",
374 ],
375 ),
376 (
377 "science",
378 vec![
379 "research",
380 "study",
381 "experiment",
382 "theory",
383 "analysis",
384 "scientific",
385 "hypothesis",
386 ],
387 ),
388 (
389 "business",
390 vec![
391 "market",
392 "sales",
393 "revenue",
394 "customer",
395 "profit",
396 "strategy",
397 "management",
398 "company",
399 ],
400 ),
401 (
402 "health",
403 vec![
404 "medical",
405 "health",
406 "disease",
407 "treatment",
408 "doctor",
409 "medicine",
410 "patient",
411 "healthcare",
412 ],
413 ),
414 (
415 "education",
416 vec![
417 "learn",
418 "study",
419 "school",
420 "university",
421 "course",
422 "education",
423 "teach",
424 "academic",
425 ],
426 ),
427 (
428 "finance",
429 vec![
430 "money",
431 "investment",
432 "financial",
433 "bank",
434 "trading",
435 "economics",
436 "cost",
437 "price",
438 ],
439 ),
440 ];
441
442 for (domain, keywords) in &domains {
443 let matches = keywords
444 .iter()
445 .filter(|&&keyword| query_lower.contains(keyword))
446 .count();
447
448 if matches >= 2 || (matches == 1 && query_lower.split_whitespace().count() <= 5) {
449 return Some(domain.to_string());
450 }
451 }
452
453 None
454 }
455
456 fn extract_key_terms(&self, query: &str) -> Vec<String> {
458 let stop_words = [
459 "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with",
460 "by", "is", "are", "was", "were", "be", "been", "have", "has", "had", "do", "does",
461 "did", "will", "would", "could", "should", "may", "might", "can", "what", "how", "why",
462 "when", "where", "who", "which",
463 ];
464
465 query
466 .split_whitespace()
467 .filter(|word| {
468 let clean_word = word
469 .trim_matches(|c: char| !c.is_alphanumeric())
470 .to_lowercase();
471 !stop_words.contains(&clean_word.as_str()) && clean_word.len() > 2
472 })
473 .map(|word| {
474 word.trim_matches(|c: char| !c.is_alphanumeric())
475 .to_string()
476 })
477 .collect()
478 }
479
480 fn extract_main_subject(&self, query: &str) -> String {
482 let key_terms = self.extract_key_terms(query);
483 if !key_terms.is_empty() {
484 key_terms[0].clone()
485 } else {
486 "the topic".to_string()
487 }
488 }
489
490 fn clean_generated_text(&self, text: &str) -> String {
492 text.trim()
493 .replace(" ", " ")
494 .replace("\n\n", "\n")
495 .lines()
496 .filter(|line| !line.trim().is_empty())
497 .collect::<Vec<_>>()
498 .join(" ")
499 }
500
501 fn calculate_confidence(&self, query: &str, document: &str, query_type: &str) -> f32 {
503 let mut confidence = 0.5; if document.len() >= self.config.min_document_length
507 && document.len() <= self.config.max_document_length
508 {
509 confidence += 0.1;
510 }
511
512 let query_terms = self.extract_key_terms(query);
514 let document_lower = document.to_lowercase();
515 let term_matches = query_terms
516 .iter()
517 .filter(|term| document_lower.contains(&term.to_lowercase()))
518 .count();
519
520 if !query_terms.is_empty() {
521 confidence += (term_matches as f32 / query_terms.len() as f32) * 0.3;
522 }
523
524 match query_type {
526 "definitional" if document.contains("is") || document.contains("refers to") => {
527 confidence += 0.1
528 }
529 "procedural" if document.contains("step") || document.contains("process") => {
530 confidence += 0.1
531 }
532 "comparative" if document.contains("compare") || document.contains("difference") => {
533 confidence += 0.1
534 }
535 _ => {}
536 }
537
538 confidence.min(1.0)
539 }
540
541 fn init_templates() -> HashMap<String, Vec<DocumentTemplate>> {
543 let mut templates = HashMap::new();
544
545 templates.insert("definitional".to_string(), vec![
547 DocumentTemplate {
548 name: "concept_definition".to_string(),
549 pattern: "{subject} is a fundamental concept in {domain} that encompasses several key aspects. It refers to the systematic approach and principles underlying this area of study. The definition includes both theoretical foundations and practical applications, making it essential for understanding related topics.".to_string(),
550 query_types: vec!["definitional".to_string()],
551 confidence: 0.8,
552 },
553 ]);
554
555 templates.insert("procedural".to_string(), vec![
557 DocumentTemplate {
558 name: "how_to_guide".to_string(),
559 pattern: "To effectively accomplish {subject}, follow these systematic steps and best practices. The process requires careful planning, proper execution, and continuous monitoring. Begin by understanding the requirements, then proceed with methodical implementation while considering potential challenges and solutions.".to_string(),
560 query_types: vec!["procedural".to_string()],
561 confidence: 0.8,
562 },
563 ]);
564
565 templates.insert("comparative".to_string(), vec![
567 DocumentTemplate {
568 name: "comparison_analysis".to_string(),
569 pattern: "When analyzing {subject}, several important factors distinguish different approaches and options. Each alternative offers unique advantages and limitations that affect performance, cost, and suitability for various use cases. The comparison reveals critical differences in functionality, efficiency, and implementation requirements.".to_string(),
570 query_types: vec!["comparative".to_string()],
571 confidence: 0.8,
572 },
573 ]);
574
575 templates
576 }
577
578 fn init_answer_patterns() -> Vec<AnswerPattern> {
580 vec![
581 AnswerPattern {
582 name: "technical_explanation".to_string(),
583 triggers: vec![
584 "algorithm".to_string(),
585 "system".to_string(),
586 "technology".to_string(),
587 ],
588 generator: |query, _config| {
589 vec![format!(
590 "The technical implementation of {} involves several sophisticated components working together. \
591 The system architecture incorporates advanced algorithms and optimized data structures to ensure \
592 efficient performance and scalability. Key technical considerations include resource management, \
593 error handling, and performance optimization strategies.",
594 query
595 )]
596 },
597 confidence: 0.7,
598 },
599 AnswerPattern {
600 name: "research_summary".to_string(),
601 triggers: vec![
602 "research".to_string(),
603 "study".to_string(),
604 "analysis".to_string(),
605 ],
606 generator: |query, _config| {
607 vec![format!(
608 "Recent research on {} has revealed significant insights and findings that advance our understanding \
609 of this field. Multiple studies have examined various aspects, employing rigorous methodologies \
610 and comprehensive data analysis. The research findings contribute valuable knowledge and inform \
611 evidence-based practices and future investigations.",
612 query
613 )]
614 },
615 confidence: 0.7,
616 },
617 ]
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624 use crate::embeddings::MockEmbeddingProvider;
625
626 #[tokio::test]
627 async fn test_hyde_generation() {
628 let provider = Arc::new(MockEmbeddingProvider::new());
629 let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
630
631 let results = hyde.generate("What is machine learning?").await.unwrap();
632
633 assert!(!results.is_empty());
634 assert!(results[0].confidence > 0.0);
635 assert!(results[0].hypothetical_answer.len() > 50);
636 assert_eq!(results[0].metadata.detected_query_type, "definitional");
637 }
638
639 #[tokio::test]
640 async fn test_procedural_query() {
641 let provider = Arc::new(MockEmbeddingProvider::new());
642 let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
643
644 let results = hyde.generate("How to implement a REST API?").await.unwrap();
645
646 assert!(!results.is_empty());
647 assert_eq!(results[0].metadata.detected_query_type, "procedural");
648 assert!(
649 results[0].hypothetical_answer.contains("step")
650 || results[0].hypothetical_answer.contains("process")
651 );
652 }
653
654 #[tokio::test]
655 async fn test_comparative_query() {
656 let provider = Arc::new(MockEmbeddingProvider::new());
657 let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
658
659 let results = hyde
660 .generate("Python vs Rust performance comparison")
661 .await
662 .unwrap();
663
664 assert!(!results.is_empty());
665 assert_eq!(results[0].metadata.detected_query_type, "comparative");
666 }
667
668 #[test]
669 fn test_query_type_detection() {
670 let provider = Arc::new(MockEmbeddingProvider::new());
671 let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
672
673 assert_eq!(hyde.detect_query_type("What is AI?"), "definitional");
674 assert_eq!(hyde.detect_query_type("How to code?"), "procedural");
675 assert_eq!(hyde.detect_query_type("Python vs Java"), "comparative");
676 assert_eq!(hyde.detect_query_type("When was it built?"), "factual");
677 }
678
679 #[test]
680 fn test_domain_detection() {
681 let provider = Arc::new(MockEmbeddingProvider::new());
682 let hyde = HyDEGenerator::new(HyDEConfig::default(), provider);
683
684 assert_eq!(
685 hyde.detect_domain("machine learning algorithm"),
686 Some("technology".to_string())
687 );
688 assert_eq!(
689 hyde.detect_domain("medical research study"),
690 Some("health".to_string())
691 );
692 assert_eq!(
693 hyde.detect_domain("market analysis strategy"),
694 Some("business".to_string())
695 );
696 assert_eq!(hyde.detect_domain("simple question"), None);
697 }
698}