1use async_trait::async_trait;
2use crate::error::{Result, LetheError};
3use crate::types::EmbeddingVector;
4use crate::embeddings::EmbeddingService;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct HydeConfig {
11 pub num_documents: usize,
13 pub temperature: f32,
15 pub max_tokens: usize,
17 pub combine_with_query: bool,
19}
20
21impl Default for HydeConfig {
22 fn default() -> Self {
23 Self {
24 num_documents: 3,
25 temperature: 0.7,
26 max_tokens: 256,
27 combine_with_query: true,
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct HypotheticalDocument {
35 pub id: String,
36 pub text: String,
37 pub embedding: Option<EmbeddingVector>,
38 pub confidence: f32,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct HydeExpansion {
44 pub original_query: String,
45 pub hypothetical_documents: Vec<HypotheticalDocument>,
46 pub combined_embedding: Option<EmbeddingVector>,
47 pub expansion_quality: f32,
48}
49
50#[async_trait]
52pub trait LlmService: Send + Sync {
53 async fn generate_text(&self, prompt: &str, config: &HydeConfig) -> Result<Vec<String>>;
54}
55
56pub struct HydeService {
58 llm_service: Arc<dyn LlmService>,
59 embedding_service: Arc<dyn EmbeddingService>,
60 config: HydeConfig,
61}
62
63impl HydeService {
64 pub fn new(
65 llm_service: Arc<dyn LlmService>,
66 embedding_service: Arc<dyn EmbeddingService>,
67 config: HydeConfig,
68 ) -> Self {
69 Self {
70 llm_service,
71 embedding_service,
72 config,
73 }
74 }
75
76 pub async fn expand_query(&self, query: &str) -> Result<HydeExpansion> {
78 let hypothetical_texts = self.generate_hypothetical_documents(query).await?;
80
81 let mut hypothetical_documents = Vec::new();
83 for (i, text) in hypothetical_texts.into_iter().enumerate() {
84 let id = format!("hyde_{}", i);
85 let embedding = self.embedding_service.embed(&[text.clone()]).await?;
86 let embedding = embedding.into_iter().next().unwrap();
87 let confidence = self.calculate_confidence(&text, query);
88
89 hypothetical_documents.push(HypotheticalDocument {
90 id,
91 text,
92 embedding: Some(embedding),
93 confidence,
94 });
95 }
96
97 let combined_embedding = if self.config.combine_with_query {
99 Some(self.create_combined_embedding(query, &hypothetical_documents).await?)
100 } else {
101 None
102 };
103
104 let expansion_quality = self.calculate_expansion_quality(&hypothetical_documents);
106
107 Ok(HydeExpansion {
108 original_query: query.to_string(),
109 hypothetical_documents,
110 combined_embedding,
111 expansion_quality,
112 })
113 }
114
115 async fn generate_hypothetical_documents(&self, query: &str) -> Result<Vec<String>> {
117 let prompt = self.build_hyde_prompt(query);
118 self.llm_service.generate_text(&prompt, &self.config).await
119 }
120
121 fn build_hyde_prompt(&self, query: &str) -> String {
123 format!(
124 r#"Given the following query, write {} high-quality, detailed document passages that would contain the answer to this query. Each passage should be informative, well-structured, and directly relevant to the query.
125
126Query: {query}
127
128Generate {num_docs} hypothetical document passages:
129
1301."#,
131 self.config.num_documents,
132 query = query,
133 num_docs = self.config.num_documents
134 )
135 }
136
137 fn calculate_confidence(&self, document: &str, query: &str) -> f32 {
139 let query_lower = query.to_lowercase();
141 let query_words: std::collections::HashSet<&str> = query_lower
142 .split_whitespace()
143 .collect();
144
145 let doc_lower = document.to_lowercase();
146 let doc_words: std::collections::HashSet<&str> = doc_lower
147 .split_whitespace()
148 .collect();
149
150 let overlap = query_words.intersection(&doc_words).count();
151 let total_query_words = query_words.len();
152
153 if total_query_words == 0 {
154 return 0.0;
155 }
156
157 let overlap_score = overlap as f32 / total_query_words as f32;
158
159 let length_score = (document.len() as f32 / 500.0).min(1.0);
161
162 (overlap_score * 0.6 + length_score * 0.4).min(1.0)
164 }
165
166 async fn create_combined_embedding(
168 &self,
169 query: &str,
170 hypothetical_documents: &[HypotheticalDocument],
171 ) -> Result<EmbeddingVector> {
172 let query_embedding = self.embedding_service.embed(&[query.to_string()]).await?;
174 let query_embedding = query_embedding.into_iter().next().unwrap();
175
176 let mut weighted_embeddings = Vec::new();
178
179 weighted_embeddings.push((query_embedding, 1.0));
181
182 for doc in hypothetical_documents {
184 if let Some(ref embedding) = doc.embedding {
185 weighted_embeddings.push((embedding.clone(), doc.confidence));
186 }
187 }
188
189 self.calculate_weighted_average(&weighted_embeddings)
191 }
192
193 fn calculate_weighted_average(&self, embeddings: &[(EmbeddingVector, f32)]) -> Result<EmbeddingVector> {
195 if embeddings.is_empty() {
196 return Err(LetheError::validation("embeddings", "No embeddings to average"));
197 }
198
199 let dimension = embeddings[0].0.data.len();
200 let mut result = vec![0.0; dimension];
201 let mut total_weight = 0.0;
202
203 for (embedding, weight) in embeddings {
204 if embedding.data.len() != dimension {
205 return Err(LetheError::validation("dimension", "Embedding dimension mismatch"));
206 }
207
208 for (i, &value) in embedding.data.iter().enumerate() {
209 result[i] += value * weight;
210 }
211 total_weight += weight;
212 }
213
214 if total_weight > 0.0 {
216 for value in &mut result {
217 *value /= total_weight;
218 }
219 }
220
221 Ok(EmbeddingVector {
222 data: result,
223 dimension,
224 })
225 }
226
227 fn calculate_expansion_quality(&self, hypothetical_documents: &[HypotheticalDocument]) -> f32 {
229 if hypothetical_documents.is_empty() {
230 return 0.0;
231 }
232
233 let avg_confidence: f32 = hypothetical_documents
235 .iter()
236 .map(|doc| doc.confidence)
237 .sum::<f32>() / hypothetical_documents.len() as f32;
238
239 let lengths: Vec<f32> = hypothetical_documents
241 .iter()
242 .map(|doc| doc.text.len() as f32)
243 .collect();
244
245 let avg_length = lengths.iter().sum::<f32>() / lengths.len() as f32;
246 let variance = lengths
247 .iter()
248 .map(|&len| (len - avg_length).powi(2))
249 .sum::<f32>() / lengths.len() as f32;
250
251 let diversity_score = (variance / avg_length).min(1.0);
252
253 avg_confidence * 0.8 + diversity_score * 0.2
255 }
256
257 pub fn get_best_documents<'a>(&self, expansion: &'a HydeExpansion, limit: usize) -> Vec<&'a HypotheticalDocument> {
259 let mut documents = expansion.hypothetical_documents.iter().collect::<Vec<_>>();
260 documents.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal));
261 documents.into_iter().take(limit).collect()
262 }
263}
264
265#[cfg(test)]
267pub struct MockLlmService {
268 responses: std::collections::HashMap<String, Vec<String>>,
269}
270
271#[cfg(test)]
272impl MockLlmService {
273 pub fn new() -> Self {
274 Self {
275 responses: std::collections::HashMap::new(),
276 }
277 }
278
279 pub fn add_response(&mut self, prompt: String, responses: Vec<String>) {
280 self.responses.insert(prompt, responses);
281 }
282}
283
284#[cfg(test)]
285#[async_trait]
286impl LlmService for MockLlmService {
287 async fn generate_text(&self, prompt: &str, _config: &HydeConfig) -> Result<Vec<String>> {
288 if prompt.contains("machine learning") {
290 Ok(vec![
291 "Machine learning is a subset of artificial intelligence that enables computers to learn and make decisions from data without explicit programming.".to_string(),
292 "Modern machine learning algorithms include deep learning neural networks, random forests, and support vector machines.".to_string(),
293 "Applications of machine learning span computer vision, natural language processing, and predictive analytics.".to_string(),
294 ])
295 } else {
296 Ok(vec![
297 "This is a hypothetical document about the query topic.".to_string(),
298 "Another relevant document with detailed information.".to_string(),
299 "A third document providing additional context.".to_string(),
300 ])
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::embeddings::FallbackEmbeddingService;
309
310 #[tokio::test]
311 async fn test_hyde_expansion() {
312 let llm_service = Arc::new(MockLlmService::new());
313 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
314 let config = HydeConfig::default();
315
316 let hyde_service = HydeService::new(llm_service, embedding_service, config);
317
318 let expansion = hyde_service.expand_query("What is machine learning?").await.unwrap();
319
320 assert_eq!(expansion.original_query, "What is machine learning?");
321 assert_eq!(expansion.hypothetical_documents.len(), 3);
322 assert!(expansion.expansion_quality > 0.0);
323
324 for doc in &expansion.hypothetical_documents {
325 assert!(!doc.text.is_empty());
326 assert!(doc.confidence >= 0.0 && doc.confidence <= 1.0);
327 assert!(doc.embedding.is_some());
328 }
329 }
330
331 #[test]
332 fn test_confidence_calculation() {
333 let llm_service = Arc::new(MockLlmService::new());
334 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
335 let config = HydeConfig::default();
336
337 let hyde_service = HydeService::new(llm_service, embedding_service, config);
338
339 let query = "machine learning algorithms";
340 let document = "Machine learning algorithms are used to build predictive models and analyze data patterns.";
341
342 let confidence = hyde_service.calculate_confidence(document, query);
343 assert!(confidence > 0.0 && confidence <= 1.0);
344 }
345
346 #[test]
347 fn test_weighted_average_embeddings() {
348 let llm_service = Arc::new(MockLlmService::new());
349 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
350 let config = HydeConfig::default();
351
352 let hyde_service = HydeService::new(llm_service, embedding_service, config);
353
354 let embeddings = vec![
355 (EmbeddingVector { data: vec![1.0, 0.0, 0.0], dimension: 3 }, 1.0),
356 (EmbeddingVector { data: vec![0.0, 1.0, 0.0], dimension: 3 }, 1.0),
357 ];
358
359 let result = hyde_service.calculate_weighted_average(&embeddings).unwrap();
360 assert_eq!(result.data, vec![0.5, 0.5, 0.0]);
361 assert_eq!(result.dimension, 3);
362 }
363
364 #[test]
365 fn test_best_documents_selection() {
366 let expansion = HydeExpansion {
367 original_query: "test".to_string(),
368 hypothetical_documents: vec![
369 HypotheticalDocument {
370 id: "1".to_string(),
371 text: "doc1".to_string(),
372 embedding: None,
373 confidence: 0.9,
374 },
375 HypotheticalDocument {
376 id: "2".to_string(),
377 text: "doc2".to_string(),
378 embedding: None,
379 confidence: 0.7,
380 },
381 HypotheticalDocument {
382 id: "3".to_string(),
383 text: "doc3".to_string(),
384 embedding: None,
385 confidence: 0.8,
386 },
387 ],
388 combined_embedding: None,
389 expansion_quality: 0.8,
390 };
391
392 let llm_service = Arc::new(MockLlmService::new());
393 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
394 let config = HydeConfig::default();
395
396 let hyde_service = HydeService::new(llm_service, embedding_service, config);
397
398 let best = hyde_service.get_best_documents(&expansion, 2);
399 assert_eq!(best.len(), 2);
400 assert_eq!(best[0].confidence, 0.9);
401 assert_eq!(best[1].confidence, 0.8);
402 }
403}