1use crate::types::{Candidate, Chunk, DfIdf, EmbeddingVector};
2use crate::error::{Result, LetheError};
3use crate::utils::{TextProcessor, QueryFeatures};
4use async_trait::async_trait;
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7use crate::embeddings::{EmbeddingService, FallbackEmbeddingService};
8use sha2::{Sha256, Digest};
9use serde::Serialize;
10
11#[derive(Debug, Clone, Serialize)]
13pub struct HybridRetrievalConfig {
14 pub alpha: f64, pub beta: f64, pub gamma_kind_boost: HashMap<String, f64>, pub rerank: bool, pub diversify: bool, pub diversify_method: String, pub k_initial: i32, pub k_final: i32, pub fusion_dynamic: bool, }
24
25impl Default for HybridRetrievalConfig {
26 fn default() -> Self {
27 let mut gamma_kind_boost = HashMap::new();
28 gamma_kind_boost.insert("code".to_string(), 1.2);
29 gamma_kind_boost.insert("import".to_string(), 1.1);
30 gamma_kind_boost.insert("function".to_string(), 1.15);
31 gamma_kind_boost.insert("error".to_string(), 1.3);
32
33 Self {
34 alpha: 0.5, beta: 0.5, gamma_kind_boost,
37 rerank: true,
38 diversify: true,
39 diversify_method: "entity".to_string(),
40 k_initial: 200, k_final: 5, fusion_dynamic: false,
43 }
44 }
45}
46
47impl HybridRetrievalConfig {
50 pub fn hero() -> Self {
51 let gamma_kind_boost = HashMap::new(); Self {
54 alpha: 0.5, beta: 0.5, gamma_kind_boost,
57 rerank: true,
58 diversify: true,
59 diversify_method: "splade".to_string(), k_initial: 200, k_final: 5, fusion_dynamic: false,
63 }
64 }
65
66 pub fn compute_hash(&self) -> String {
68 let json = serde_json::to_string(self).expect("Failed to serialize config");
69 let mut hasher = Sha256::new();
70 hasher.update(json.as_bytes());
71 hex::encode(hasher.finalize())
72 }
73
74 pub fn validate_hero_config_hash(&self, expected_hash: &str, allow_override: bool) -> Result<()> {
76 let actual_hash = self.compute_hash();
77
78 if actual_hash != expected_hash {
79 let error_msg = format!(
80 "Hero configuration hash mismatch! Expected: {}, Actual: {}. \
81 This indicates the configuration has been tampered with or is not the canonical hero config.",
82 expected_hash, actual_hash
83 );
84
85 if allow_override {
86 tracing::warn!("{} Override flag is set - continuing with non-canonical config.", error_msg);
87 Ok(())
88 } else {
89 Err(LetheError::config(error_msg))
90 }
91 } else {
92 tracing::info!("Hero configuration hash validated successfully: {}", actual_hash);
93 Ok(())
94 }
95 }
96
97 pub fn hero_with_validation(expected_hash: &str, allow_override: bool) -> Result<Self> {
99 let config = Self::hero();
100 config.validate_hero_config_hash(expected_hash, allow_override)?;
101 Ok(config)
102 }
103}
104
105#[async_trait]
107pub trait DocumentRepository: Send + Sync {
108 async fn get_chunks_by_session(&self, session_id: &str) -> Result<Vec<Chunk>>;
110
111 async fn get_dfidf_by_session(&self, session_id: &str) -> Result<Vec<DfIdf>>;
113
114 async fn get_chunk_by_id(&self, chunk_id: &str) -> Result<Option<Chunk>>;
116
117 async fn vector_search(&self, query_vector: &EmbeddingVector, k: i32) -> Result<Vec<Candidate>>;
119}
120
121pub struct Bm25SearchService;
123
124impl Bm25SearchService {
125 pub async fn search<R: DocumentRepository + ?Sized>(
127 repository: &R,
128 queries: &[String],
129 session_id: &str,
130 k: i32,
131 ) -> Result<Vec<Candidate>> {
132 let chunks = repository.get_chunks_by_session(session_id).await?;
133 if chunks.is_empty() {
134 return Ok(vec![]);
135 }
136
137 let dfidf_data = repository.get_dfidf_by_session(session_id).await?;
138 let term_idf_map: HashMap<String, f64> = dfidf_data
139 .into_iter()
140 .map(|entry| (entry.term, entry.idf))
141 .collect();
142
143 let total_length: i32 = chunks
145 .iter()
146 .map(|chunk| Self::tokenize(&chunk.text).len() as i32)
147 .sum();
148 let avg_doc_length = if chunks.is_empty() {
149 0.0
150 } else {
151 total_length as f64 / chunks.len() as f64
152 };
153
154 let all_query_terms: HashSet<String> = queries
156 .iter()
157 .flat_map(|query| Self::tokenize(query))
158 .collect();
159
160 let mut candidates = Vec::new();
162
163 for chunk in chunks {
164 let doc_terms = Self::tokenize(&chunk.text);
165 let doc_length = doc_terms.len() as f64;
166
167 let mut term_freqs = HashMap::new();
169 for term in &doc_terms {
170 if all_query_terms.contains(term) {
171 *term_freqs.entry(term.clone()).or_insert(0) += 1;
172 }
173 }
174
175 if term_freqs.is_empty() {
177 continue;
178 }
179
180 let score = Self::calculate_bm25(&term_freqs, doc_length, avg_doc_length, &term_idf_map, 1.2, 0.75);
181 if score > 0.0 {
182 candidates.push(Candidate {
183 doc_id: chunk.id,
184 score,
185 text: Some(chunk.text),
186 kind: Some(chunk.kind),
187 });
188 }
189 }
190
191 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
193 candidates.truncate(k as usize);
194
195 Ok(candidates)
196 }
197
198 fn tokenize(text: &str) -> Vec<String> {
200 TextProcessor::tokenize(text)
201 }
202
203 fn calculate_bm25(
205 term_freqs: &HashMap<String, i32>,
206 doc_length: f64,
207 avg_doc_length: f64,
208 term_idf_map: &HashMap<String, f64>,
209 k1: f64,
210 b: f64,
211 ) -> f64 {
212 let mut score = 0.0;
213
214 for (term, &tf) in term_freqs {
215 let idf = term_idf_map.get(term).copied().unwrap_or(0.0);
216 if idf <= 0.0 {
217 continue;
218 }
219
220 let numerator = (tf as f64) * (k1 + 1.0);
221 let denominator = (tf as f64) + k1 * (1.0 - b + b * (doc_length / avg_doc_length));
222
223 score += idf * (numerator / denominator);
224 }
225
226 score
227 }
228
229 #[allow(dead_code)]
231 fn calculate_bm25_default(
232 term_freqs: &HashMap<String, i32>,
233 doc_length: f64,
234 avg_doc_length: f64,
235 term_idf_map: &HashMap<String, f64>,
236 ) -> f64 {
237 Self::calculate_bm25(term_freqs, doc_length, avg_doc_length, term_idf_map, 1.2, 0.75)
238 }
239}
240
241pub struct VectorSearchService {
243 embedding_service: Arc<dyn EmbeddingService>,
244}
245
246impl VectorSearchService {
247 pub fn new(embedding_service: Arc<dyn EmbeddingService>) -> Self {
248 Self { embedding_service }
249 }
250
251 pub async fn search<R: DocumentRepository + ?Sized>(
253 &self,
254 repository: &R,
255 query: &str,
256 k: i32,
257 ) -> Result<Vec<Candidate>> {
258 let query_embedding = self.embedding_service.embed_single(query).await?;
259 repository.vector_search(&query_embedding, k).await
260 }
261}
262
263pub struct HybridRetrievalService {
265 vector_service: VectorSearchService,
266 config: HybridRetrievalConfig,
267}
268
269impl HybridRetrievalService {
270 pub fn new(embedding_service: Arc<dyn EmbeddingService>, config: HybridRetrievalConfig) -> Self {
271 Self {
272 vector_service: VectorSearchService::new(embedding_service),
273 config,
274 }
275 }
276
277 pub async fn retrieve<R: DocumentRepository + ?Sized>(
279 &self,
280 repository: &R,
281 queries: &[String],
282 session_id: &str,
283 ) -> Result<Vec<Candidate>> {
284 let combined_query = queries.join(" ");
285
286 tracing::info!("Starting hybrid retrieval for {} queries", queries.len());
287
288 let (lexical_results, vector_results) = tokio::try_join!(
290 Bm25SearchService::search(repository, queries, session_id, self.config.k_initial),
291 self.vector_service.search(repository, &combined_query, self.config.k_initial)
292 )?;
293
294 tracing::debug!(
295 "BM25 found {} candidates, Vector search found {} candidates",
296 lexical_results.len(),
297 vector_results.len()
298 );
299
300 let candidates = self.hybrid_score(lexical_results, vector_results, &combined_query)?;
302
303 tracing::info!("Hybrid scoring produced {} candidates", candidates.len());
304
305 let final_candidates = self.post_process(candidates).await?;
307
308 tracing::info!("Final result: {} candidates", final_candidates.len());
309 Ok(final_candidates)
310 }
311
312 fn hybrid_score(
314 &self,
315 lexical_results: Vec<Candidate>,
316 vector_results: Vec<Candidate>,
317 query: &str,
318 ) -> Result<Vec<Candidate>> {
319 let lexical_normalized = self.normalize_bm25_scores(lexical_results);
321 let vector_normalized = self.normalize_cosine_scores(vector_results);
322
323 let lexical_zscores = self.calculate_zscores(&lexical_normalized);
325 let vector_zscores = self.calculate_zscores(&vector_normalized);
326
327 let lexical_map: HashMap<String, f64> = lexical_zscores
329 .into_iter()
330 .map(|c| (c.doc_id, c.score))
331 .collect();
332
333 let vector_map: HashMap<String, f64> = vector_zscores
334 .into_iter()
335 .map(|c| (c.doc_id, c.score))
336 .collect();
337
338 let all_doc_ids: HashSet<String> = lexical_map
340 .keys()
341 .chain(vector_map.keys())
342 .cloned()
343 .collect();
344
345 let query_features = QueryFeatures::extract_features(query);
347
348 let mut candidates = Vec::new();
349
350 for doc_id in all_doc_ids {
351 let lex_zscore = lexical_map.get(&doc_id).copied().unwrap_or(0.0);
352 let vec_zscore = vector_map.get(&doc_id).copied().unwrap_or(0.0);
353
354 let mut hybrid_score = self.config.alpha * lex_zscore + self.config.beta * vec_zscore;
356
357 let kind = "text"; let dynamic_boost = QueryFeatures::gamma_boost(kind, &query_features);
361 let static_boost = self.config.gamma_kind_boost.get(kind).copied().unwrap_or(0.0);
362 let total_boost = 1.0 + dynamic_boost + static_boost;
363 hybrid_score *= total_boost;
364
365 candidates.push(Candidate {
366 doc_id,
367 score: hybrid_score,
368 text: None, kind: Some(kind.to_string()),
370 });
371 }
372
373 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
375
376 Ok(candidates)
377 }
378
379 fn normalize_bm25_scores(&self, candidates: Vec<Candidate>) -> Vec<Candidate> {
381 if candidates.is_empty() {
382 return candidates;
383 }
384
385 let max_score = candidates
386 .iter()
387 .map(|c| c.score)
388 .fold(0.0, f64::max);
389
390 if max_score == 0.0 {
391 return candidates;
392 }
393
394 candidates
395 .into_iter()
396 .map(|mut c| {
397 c.score /= max_score;
398 c
399 })
400 .collect()
401 }
402
403 fn normalize_cosine_scores(&self, candidates: Vec<Candidate>) -> Vec<Candidate> {
405 candidates
406 .into_iter()
407 .map(|mut c| {
408 c.score = (c.score + 1.0) / 2.0;
409 c
410 })
411 .collect()
412 }
413
414 pub fn calculate_zscores(&self, candidates: &[Candidate]) -> Vec<Candidate> {
416 if candidates.is_empty() {
417 return candidates.to_vec();
418 }
419
420 let scores: Vec<f64> = candidates.iter().map(|c| c.score).collect();
422 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
423
424 let variance = scores.iter()
425 .map(|&score| (score - mean).powi(2))
426 .sum::<f64>() / scores.len() as f64;
427
428 let std_dev = variance.sqrt();
429
430 if std_dev == 0.0 {
432 return candidates.to_vec();
433 }
434
435 candidates.iter()
437 .map(|candidate| {
438 let zscore = (candidate.score - mean) / std_dev;
439 Candidate {
440 doc_id: candidate.doc_id.clone(),
441 score: zscore,
442 text: candidate.text.clone(),
443 kind: candidate.kind.clone(),
444 }
445 })
446 .collect()
447 }
448
449 async fn post_process(&self, mut candidates: Vec<Candidate>) -> Result<Vec<Candidate>> {
451 if self.config.rerank {
453 tracing::debug!("Reranking not implemented in basic version");
454 }
455
456 if self.config.diversify && candidates.len() > self.config.k_final as usize {
458 tracing::debug!("Diversification not implemented in basic version");
459 }
460
461 candidates.truncate(self.config.k_final as usize);
463
464 Ok(candidates)
465 }
466
467 pub fn mock_for_testing() -> Self {
469 let embedding_service = Arc::new(FallbackEmbeddingService::new(384)); Self::new(embedding_service, HybridRetrievalConfig::hero())
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::embeddings::FallbackEmbeddingService;
478 use lethe_shared::Chunk;
479 use uuid::Uuid;
480 use std::sync::Arc;
481
482 struct MockRepository {
484 chunks: Vec<Chunk>,
485 dfidf: Vec<DfIdf>,
486 }
487
488 #[async_trait]
489 impl DocumentRepository for MockRepository {
490 async fn get_chunks_by_session(&self, _session_id: &str) -> Result<Vec<Chunk>> {
491 Ok(self.chunks.clone())
492 }
493
494 async fn get_dfidf_by_session(&self, _session_id: &str) -> Result<Vec<DfIdf>> {
495 Ok(self.dfidf.clone())
496 }
497
498 async fn get_chunk_by_id(&self, chunk_id: &str) -> Result<Option<Chunk>> {
499 Ok(self.chunks.iter().find(|c| c.id == chunk_id).cloned())
500 }
501
502 async fn vector_search(&self, _query_vector: &EmbeddingVector, k: i32) -> Result<Vec<Candidate>> {
503 let candidates: Vec<Candidate> = self.chunks
505 .iter()
506 .take(k as usize)
507 .map(|chunk| Candidate {
508 doc_id: chunk.id.clone(),
509 score: 0.8, text: Some(chunk.text.clone()),
511 kind: Some(chunk.kind.clone()),
512 })
513 .collect();
514 Ok(candidates)
515 }
516 }
517
518 fn create_test_chunk(id: &str, text: &str, kind: &str) -> Chunk {
519 Chunk {
520 id: id.to_string(),
521 message_id: Uuid::new_v4(),
522 session_id: "test-session".to_string(),
523 offset_start: 0,
524 offset_end: text.len(),
525 kind: kind.to_string(),
526 text: text.to_string(),
527 tokens: text.split_whitespace().count() as i32,
528 }
529 }
530
531 #[tokio::test]
532 async fn test_bm25_search() {
533 let chunks = vec![
534 create_test_chunk("1", "hello world", "text"),
535 create_test_chunk("2", "world peace", "text"),
536 create_test_chunk("3", "goodbye world", "text"),
537 ];
538
539 let dfidf = vec![
540 DfIdf {
541 term: "hello".to_string(),
542 session_id: "test-session".to_string(),
543 df: 1,
544 idf: 1.0,
545 },
546 DfIdf {
547 term: "world".to_string(),
548 session_id: "test-session".to_string(),
549 df: 3,
550 idf: 0.5,
551 },
552 ];
553
554 let repository = MockRepository { chunks, dfidf };
555 let queries = vec!["hello world".to_string()];
556
557 let results = Bm25SearchService::search(&repository, &queries, "test-session", 10)
558 .await
559 .unwrap();
560
561 assert!(!results.is_empty());
562 assert_eq!(results[0].doc_id, "1"); }
564
565 #[tokio::test]
566 async fn test_hybrid_retrieval() {
567 let chunks = vec![
568 create_test_chunk("1", "async programming in rust", "text"),
569 create_test_chunk("2", "rust error handling", "text"),
570 create_test_chunk("3", "javascript async await", "text"),
571 ];
572
573 let dfidf = vec![
574 DfIdf {
575 term: "async".to_string(),
576 session_id: "test-session".to_string(),
577 df: 2,
578 idf: 0.4,
579 },
580 DfIdf {
581 term: "rust".to_string(),
582 session_id: "test-session".to_string(),
583 df: 2,
584 idf: 0.4,
585 },
586 ];
587
588 let repository = MockRepository { chunks, dfidf };
589 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
590 let config = HybridRetrievalConfig::default();
591 let service = HybridRetrievalService::new(embedding_service, config);
592
593 let queries = vec!["rust async programming".to_string()];
594 let results = service
595 .retrieve(&repository, &queries, "test-session")
596 .await
597 .unwrap();
598
599 assert!(!results.is_empty());
600 assert!(results.len() <= 5); }
602
603 #[tokio::test]
604 async fn test_hero_configuration() {
605 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
606 let hero_config = HybridRetrievalConfig::hero();
607 let service = HybridRetrievalService::new(embedding_service, hero_config);
608
609 assert_eq!(service.config.alpha, 0.5); assert_eq!(service.config.beta, 0.5); assert_eq!(service.config.k_initial, 200); assert_eq!(service.config.k_final, 5); assert_eq!(service.config.diversify_method, "splade");
615 }
616
617 #[test]
618 fn test_score_normalization() {
619 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
620 let config = HybridRetrievalConfig::default();
621 let service = HybridRetrievalService::new(embedding_service, config);
622
623 let candidates = vec![
624 Candidate {
625 doc_id: "1".to_string(),
626 score: 10.0,
627 text: None,
628 kind: None,
629 },
630 Candidate {
631 doc_id: "2".to_string(),
632 score: 5.0,
633 text: None,
634 kind: None,
635 },
636 ];
637
638 let normalized = service.normalize_bm25_scores(candidates);
639 assert_eq!(normalized[0].score, 1.0);
640 assert_eq!(normalized[1].score, 0.5);
641 }
642
643 #[test]
644 fn test_query_features() {
645 let features = QueryFeatures::extract_features("function_name() error in /path/file.rs");
646 assert!(features.has_code_symbol);
647 assert!(features.has_error_token);
648 assert!(features.has_path_or_file);
649
650 let boost = QueryFeatures::gamma_boost("code", &features);
651 assert!(boost > 0.0);
652 }
653
654 #[test]
655 fn test_query_features_comprehensive() {
656 let features1 = QueryFeatures::extract_features("call myFunction() here");
658 assert!(features1.has_code_symbol);
659 assert!(!features1.has_error_token);
660
661 let features2 = QueryFeatures::extract_features("use MyClass::StaticMethod");
663 assert!(features2.has_code_symbol);
664
665 let features3 = QueryFeatures::extract_features("NullPointerException occurred");
667 assert!(features3.has_error_token);
668 assert!(!features3.has_code_symbol);
669
670 let features4 = QueryFeatures::extract_features("check /home/user/file.txt");
672 assert!(features4.has_path_or_file);
673 assert!(!features4.has_error_token);
674
675 let features5 = QueryFeatures::extract_features("see C:\\Users\\Name\\doc.docx");
677 assert!(features5.has_path_or_file);
678
679 let features6 = QueryFeatures::extract_features("issue 1234 needs fixing");
681 assert!(features6.has_numeric_id);
682 assert!(!features6.has_code_symbol);
683
684 let features7 = QueryFeatures::extract_features("");
686 assert!(!features7.has_code_symbol);
687 assert!(!features7.has_error_token);
688 assert!(!features7.has_path_or_file);
689 assert!(!features7.has_numeric_id);
690 }
691
692 #[test]
693 fn test_gamma_boost_combinations() {
694 let features = QueryFeatures::extract_features("myFunction() returns value");
696
697 let code_boost = QueryFeatures::gamma_boost("code", &features);
698 assert!(code_boost > 0.0);
699
700 let user_code_boost = QueryFeatures::gamma_boost("user_code", &features);
701 assert!(user_code_boost > 0.0);
702
703 let text_boost = QueryFeatures::gamma_boost("text", &features);
704 assert_eq!(text_boost, 0.0); let error_features = QueryFeatures::extract_features("RuntimeError in execution");
708 let tool_boost = QueryFeatures::gamma_boost("tool_result", &error_features);
709 assert!(tool_boost > 0.0);
710
711 let path_features = QueryFeatures::extract_features("file located at /src/main.rs");
713 let code_path_boost = QueryFeatures::gamma_boost("code", &path_features);
714 assert!(code_path_boost > 0.0);
715
716 let combined_features = QueryFeatures::extract_features("function() error in /path/file.rs with ID 1234");
718 assert!(combined_features.has_code_symbol);
719 assert!(combined_features.has_error_token);
720 assert!(combined_features.has_path_or_file);
721 assert!(combined_features.has_numeric_id);
722
723 let combined_boost = QueryFeatures::gamma_boost("code", &combined_features);
724 assert!(combined_boost > 0.1); }
726
727 #[tokio::test]
728 async fn test_hybrid_retrieval_creation() {
729 use crate::embeddings::FallbackEmbeddingService;
730
731 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
732 let service = HybridRetrievalService::new(embedding_service.clone(), HybridRetrievalConfig::default());
733
734 assert_eq!(service.config.alpha, 0.5); assert_eq!(service.config.beta, 0.5); assert!(service.config.gamma_kind_boost.contains_key("code"));
738 }
739
740 #[tokio::test]
741 async fn test_retrieval_service_configurations() {
742 use crate::embeddings::FallbackEmbeddingService;
743
744 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
745
746 let custom_config = HybridRetrievalConfig {
748 alpha: 0.3,
749 beta: 0.7,
750 gamma_kind_boost: std::collections::HashMap::from([
751 ("code".to_string(), 0.15),
752 ("user_code".to_string(), 0.12),
753 ]),
754 rerank: true,
755 diversify: false,
756 diversify_method: "simple".to_string(),
757 k_initial: 50,
758 k_final: 10,
759 fusion_dynamic: false,
760 };
761
762 let service = HybridRetrievalService::new(embedding_service.clone(), custom_config.clone());
763
764 assert_eq!(service.config.alpha, 0.3);
766 assert_eq!(service.config.beta, 0.7);
767 assert_eq!(service.config.gamma_kind_boost.get("code"), Some(&0.15));
768 assert_eq!(service.config.k_final, 10);
769 }
770
771 #[test]
772 fn test_bm25_service_properties() {
773 let mut service = Bm25SearchService;
774
775 let _ = service;
780 }
781
782 #[test]
783 fn test_vector_search_service_properties() {
784 use crate::embeddings::FallbackEmbeddingService;
785
786 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
787 let service = VectorSearchService::new(embedding_service.clone());
788
789 assert_eq!(service.embedding_service.name(), "fallback");
791
792 assert_eq!(service.embedding_service.dimension(), 384);
794 }
795
796 #[test]
797 fn test_retrieval_config_defaults() {
798 let config = HybridRetrievalConfig::default();
800
801 assert_eq!(config.alpha, 0.5); assert_eq!(config.beta, 0.5); assert_eq!(config.k_initial, 200); assert_eq!(config.k_final, 5); assert!(config.diversify);
806 assert!(config.gamma_kind_boost.contains_key("code"));
807
808 assert_eq!(config.gamma_kind_boost.get("code"), Some(&1.2));
810 }
811
812 #[test]
813 fn test_zscore_calculation() {
814 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
815 let config = HybridRetrievalConfig::hero();
816 let service = HybridRetrievalService::new(embedding_service, config);
817
818 let candidates = vec![
819 Candidate {
820 doc_id: "1".to_string(),
821 score: 10.0,
822 text: None,
823 kind: None,
824 },
825 Candidate {
826 doc_id: "2".to_string(),
827 score: 5.0,
828 text: None,
829 kind: None,
830 },
831 Candidate {
832 doc_id: "3".to_string(),
833 score: 0.0,
834 text: None,
835 kind: None,
836 },
837 ];
838
839 let zscores = service.calculate_zscores(&candidates);
840
841 let scores: Vec<f64> = zscores.iter().map(|c| c.score).collect();
843 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
844 assert!((mean).abs() < 1e-10); assert!(zscores[0].score > 0.0);
848 assert!(zscores[2].score < 0.0);
850 }
851
852 #[test]
853 fn test_zscore_fusion_end_to_end() {
854 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
858 let hero_config = HybridRetrievalConfig::hero();
859 let service = HybridRetrievalService::new(embedding_service, hero_config);
860
861 assert_eq!(service.config.alpha, 0.5, "α must be 0.5 for z-score fusion");
863 assert_eq!(service.config.beta, 0.5, "β must be 0.5 for z-score fusion");
864 assert_eq!(service.config.k_initial, 200, "k_initial must be 200 (hero)");
865 assert_eq!(service.config.k_final, 5, "k_final must be 5 (hero)");
866 assert_eq!(service.config.diversify_method, "splade", "must use splade diversification");
867
868 let bm25_candidates = vec![
870 Candidate { doc_id: "A".to_string(), score: 3.0, text: None, kind: None },
871 Candidate { doc_id: "B".to_string(), score: 2.0, text: None, kind: None },
872 Candidate { doc_id: "C".to_string(), score: 1.0, text: None, kind: None },
873 ];
874
875 let vector_candidates = vec![
876 Candidate { doc_id: "A".to_string(), score: 0.9, text: None, kind: None },
877 Candidate { doc_id: "B".to_string(), score: 0.6, text: None, kind: None },
878 Candidate { doc_id: "C".to_string(), score: 0.3, text: None, kind: None },
879 ];
880
881 let bm25_zscores = service.calculate_zscores(&bm25_candidates);
883 let vector_zscores = service.calculate_zscores(&vector_candidates);
884
885 let bm25_scores: Vec<f64> = bm25_zscores.iter().map(|c| c.score).collect();
887 let bm25_mean = bm25_scores.iter().sum::<f64>() / bm25_scores.len() as f64;
888 let bm25_var = bm25_scores.iter().map(|&x| (x - bm25_mean).powi(2)).sum::<f64>() / bm25_scores.len() as f64;
889 let bm25_std = bm25_var.sqrt();
890
891 println!("BM25 z-score validation:");
892 println!(" Mean: {:.10} (should be ≈ 0)", bm25_mean);
893 println!(" Std Dev: {:.10} (should be ≈ 1)", bm25_std);
894
895 assert!((bm25_mean).abs() < 1e-10, "BM25 z-score mean must be ≈ 0");
896 assert!((bm25_std - 1.0).abs() < 1e-10, "BM25 z-score std must be ≈ 1");
897
898 let vector_scores: Vec<f64> = vector_zscores.iter().map(|c| c.score).collect();
900 let vector_mean = vector_scores.iter().sum::<f64>() / vector_scores.len() as f64;
901 let vector_var = vector_scores.iter().map(|&x| (x - vector_mean).powi(2)).sum::<f64>() / vector_scores.len() as f64;
902 let vector_std = vector_var.sqrt();
903
904 println!("Vector z-score validation:");
905 println!(" Mean: {:.10} (should be ≈ 0)", vector_mean);
906 println!(" Std Dev: {:.10} (should be ≈ 1)", vector_std);
907
908 assert!((vector_mean).abs() < 1e-10, "Vector z-score mean must be ≈ 0");
909 assert!((vector_std - 1.0).abs() < 1e-10, "Vector z-score std must be ≈ 1");
910
911 let hybrid_scores: Vec<f64> = bm25_zscores.iter()
913 .zip(vector_zscores.iter())
914 .map(|(bm25, vector)| service.config.alpha * bm25.score + service.config.beta * vector.score)
915 .collect();
916
917 println!("Hybrid fusion validation:");
918 println!(" BM25 z-scores: {:?}", bm25_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
919 println!(" Vector z-scores: {:?}", vector_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
920 println!(" Hybrid scores (0.5 * bm25_z + 0.5 * vector_z): {:?}", hybrid_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
921
922 assert!(hybrid_scores.len() == 3, "Must have 3 hybrid scores");
924 assert!(hybrid_scores[0] > hybrid_scores[1], "Scores should be ordered");
925 assert!(hybrid_scores[1] > hybrid_scores[2], "Scores should be ordered");
926
927 let expected_0 = 0.5 * bm25_scores[0] + 0.5 * vector_scores[0];
929 let expected_1 = 0.5 * bm25_scores[1] + 0.5 * vector_scores[1];
930 let expected_2 = 0.5 * bm25_scores[2] + 0.5 * vector_scores[2];
931
932 assert!((hybrid_scores[0] - expected_0).abs() < 1e-10, "Hybrid calculation must match expected formula");
933 assert!((hybrid_scores[1] - expected_1).abs() < 1e-10, "Hybrid calculation must match expected formula");
934 assert!((hybrid_scores[2] - expected_2).abs() < 1e-10, "Hybrid calculation must match expected formula");
935
936 println!("✅ Z-Score Fusion End-to-End Validation PASSED");
937 println!(" Hero configuration: ✓");
938 println!(" Z-score normalization: ✓");
939 println!(" Fusion calculation: ✓");
940 println!(" Mathematical properties: ✓");
941 }
942}