1use crate::types::{Candidate, Chunk, DfIdf, EmbeddingVector};
2use crate::error::Result;
3use crate::utils::{TextProcessor, QueryFeatures};
4use async_trait::async_trait;
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7use crate::embeddings::{EmbeddingService, FallbackEmbeddingService};
8
9#[derive(Debug, Clone)]
11pub struct HybridRetrievalConfig {
12 pub alpha: f64, pub beta: f64, pub gamma_kind_boost: HashMap<String, f64>, pub rerank: bool, pub diversify: bool, pub diversify_method: String, pub k_initial: i32, pub k_final: i32, pub fusion_dynamic: bool, }
22
23impl Default for HybridRetrievalConfig {
24 fn default() -> Self {
25 let mut gamma_kind_boost = HashMap::new();
26 gamma_kind_boost.insert("code".to_string(), 1.2);
27 gamma_kind_boost.insert("import".to_string(), 1.1);
28 gamma_kind_boost.insert("function".to_string(), 1.15);
29 gamma_kind_boost.insert("error".to_string(), 1.3);
30
31 Self {
32 alpha: 0.5, beta: 0.5, gamma_kind_boost,
35 rerank: true,
36 diversify: true,
37 diversify_method: "entity".to_string(),
38 k_initial: 200, k_final: 5, fusion_dynamic: false,
41 }
42 }
43}
44
45impl HybridRetrievalConfig {
48 pub fn hero() -> Self {
49 let gamma_kind_boost = HashMap::new(); Self {
52 alpha: 0.5, beta: 0.5, gamma_kind_boost,
55 rerank: true,
56 diversify: true,
57 diversify_method: "splade".to_string(), k_initial: 200, k_final: 5, fusion_dynamic: false,
61 }
62 }
63}
64
65#[async_trait]
67pub trait DocumentRepository: Send + Sync {
68 async fn get_chunks_by_session(&self, session_id: &str) -> Result<Vec<Chunk>>;
70
71 async fn get_dfidf_by_session(&self, session_id: &str) -> Result<Vec<DfIdf>>;
73
74 async fn get_chunk_by_id(&self, chunk_id: &str) -> Result<Option<Chunk>>;
76
77 async fn vector_search(&self, query_vector: &EmbeddingVector, k: i32) -> Result<Vec<Candidate>>;
79}
80
81pub struct Bm25SearchService;
83
84impl Bm25SearchService {
85 pub async fn search<R: DocumentRepository + ?Sized>(
87 repository: &R,
88 queries: &[String],
89 session_id: &str,
90 k: i32,
91 ) -> Result<Vec<Candidate>> {
92 let chunks = repository.get_chunks_by_session(session_id).await?;
93 if chunks.is_empty() {
94 return Ok(vec![]);
95 }
96
97 let dfidf_data = repository.get_dfidf_by_session(session_id).await?;
98 let term_idf_map: HashMap<String, f64> = dfidf_data
99 .into_iter()
100 .map(|entry| (entry.term, entry.idf))
101 .collect();
102
103 let total_length: i32 = chunks
105 .iter()
106 .map(|chunk| Self::tokenize(&chunk.text).len() as i32)
107 .sum();
108 let avg_doc_length = if chunks.is_empty() {
109 0.0
110 } else {
111 total_length as f64 / chunks.len() as f64
112 };
113
114 let all_query_terms: HashSet<String> = queries
116 .iter()
117 .flat_map(|query| Self::tokenize(query))
118 .collect();
119
120 let mut candidates = Vec::new();
122
123 for chunk in chunks {
124 let doc_terms = Self::tokenize(&chunk.text);
125 let doc_length = doc_terms.len() as f64;
126
127 let mut term_freqs = HashMap::new();
129 for term in &doc_terms {
130 if all_query_terms.contains(term) {
131 *term_freqs.entry(term.clone()).or_insert(0) += 1;
132 }
133 }
134
135 if term_freqs.is_empty() {
137 continue;
138 }
139
140 let score = Self::calculate_bm25(&term_freqs, doc_length, avg_doc_length, &term_idf_map, 1.2, 0.75);
141 if score > 0.0 {
142 candidates.push(Candidate {
143 doc_id: chunk.id,
144 score,
145 text: Some(chunk.text),
146 kind: Some(chunk.kind),
147 });
148 }
149 }
150
151 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
153 candidates.truncate(k as usize);
154
155 Ok(candidates)
156 }
157
158 fn tokenize(text: &str) -> Vec<String> {
160 TextProcessor::tokenize(text)
161 }
162
163 fn calculate_bm25(
165 term_freqs: &HashMap<String, i32>,
166 doc_length: f64,
167 avg_doc_length: f64,
168 term_idf_map: &HashMap<String, f64>,
169 k1: f64,
170 b: f64,
171 ) -> f64 {
172 let mut score = 0.0;
173
174 for (term, &tf) in term_freqs {
175 let idf = term_idf_map.get(term).copied().unwrap_or(0.0);
176 if idf <= 0.0 {
177 continue;
178 }
179
180 let numerator = (tf as f64) * (k1 + 1.0);
181 let denominator = (tf as f64) + k1 * (1.0 - b + b * (doc_length / avg_doc_length));
182
183 score += idf * (numerator / denominator);
184 }
185
186 score
187 }
188
189 #[allow(dead_code)]
191 fn calculate_bm25_default(
192 term_freqs: &HashMap<String, i32>,
193 doc_length: f64,
194 avg_doc_length: f64,
195 term_idf_map: &HashMap<String, f64>,
196 ) -> f64 {
197 Self::calculate_bm25(term_freqs, doc_length, avg_doc_length, term_idf_map, 1.2, 0.75)
198 }
199}
200
201pub struct VectorSearchService {
203 embedding_service: Arc<dyn EmbeddingService>,
204}
205
206impl VectorSearchService {
207 pub fn new(embedding_service: Arc<dyn EmbeddingService>) -> Self {
208 Self { embedding_service }
209 }
210
211 pub async fn search<R: DocumentRepository + ?Sized>(
213 &self,
214 repository: &R,
215 query: &str,
216 k: i32,
217 ) -> Result<Vec<Candidate>> {
218 let query_embedding = self.embedding_service.embed_single(query).await?;
219 repository.vector_search(&query_embedding, k).await
220 }
221}
222
223pub struct HybridRetrievalService {
225 vector_service: VectorSearchService,
226 config: HybridRetrievalConfig,
227}
228
229impl HybridRetrievalService {
230 pub fn new(embedding_service: Arc<dyn EmbeddingService>, config: HybridRetrievalConfig) -> Self {
231 Self {
232 vector_service: VectorSearchService::new(embedding_service),
233 config,
234 }
235 }
236
237 pub async fn retrieve<R: DocumentRepository + ?Sized>(
239 &self,
240 repository: &R,
241 queries: &[String],
242 session_id: &str,
243 ) -> Result<Vec<Candidate>> {
244 let combined_query = queries.join(" ");
245
246 tracing::info!("Starting hybrid retrieval for {} queries", queries.len());
247
248 let (lexical_results, vector_results) = tokio::try_join!(
250 Bm25SearchService::search(repository, queries, session_id, self.config.k_initial),
251 self.vector_service.search(repository, &combined_query, self.config.k_initial)
252 )?;
253
254 tracing::debug!(
255 "BM25 found {} candidates, Vector search found {} candidates",
256 lexical_results.len(),
257 vector_results.len()
258 );
259
260 let candidates = self.hybrid_score(lexical_results, vector_results, &combined_query)?;
262
263 tracing::info!("Hybrid scoring produced {} candidates", candidates.len());
264
265 let final_candidates = self.post_process(candidates).await?;
267
268 tracing::info!("Final result: {} candidates", final_candidates.len());
269 Ok(final_candidates)
270 }
271
272 fn hybrid_score(
274 &self,
275 lexical_results: Vec<Candidate>,
276 vector_results: Vec<Candidate>,
277 query: &str,
278 ) -> Result<Vec<Candidate>> {
279 let lexical_normalized = self.normalize_bm25_scores(lexical_results);
281 let vector_normalized = self.normalize_cosine_scores(vector_results);
282
283 let lexical_zscores = self.calculate_zscores(&lexical_normalized);
285 let vector_zscores = self.calculate_zscores(&vector_normalized);
286
287 let lexical_map: HashMap<String, f64> = lexical_zscores
289 .into_iter()
290 .map(|c| (c.doc_id, c.score))
291 .collect();
292
293 let vector_map: HashMap<String, f64> = vector_zscores
294 .into_iter()
295 .map(|c| (c.doc_id, c.score))
296 .collect();
297
298 let all_doc_ids: HashSet<String> = lexical_map
300 .keys()
301 .chain(vector_map.keys())
302 .cloned()
303 .collect();
304
305 let query_features = QueryFeatures::extract_features(query);
307
308 let mut candidates = Vec::new();
309
310 for doc_id in all_doc_ids {
311 let lex_zscore = lexical_map.get(&doc_id).copied().unwrap_or(0.0);
312 let vec_zscore = vector_map.get(&doc_id).copied().unwrap_or(0.0);
313
314 let mut hybrid_score = self.config.alpha * lex_zscore + self.config.beta * vec_zscore;
316
317 let kind = "text"; let dynamic_boost = QueryFeatures::gamma_boost(kind, &query_features);
321 let static_boost = self.config.gamma_kind_boost.get(kind).copied().unwrap_or(0.0);
322 let total_boost = 1.0 + dynamic_boost + static_boost;
323 hybrid_score *= total_boost;
324
325 candidates.push(Candidate {
326 doc_id,
327 score: hybrid_score,
328 text: None, kind: Some(kind.to_string()),
330 });
331 }
332
333 candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
335
336 Ok(candidates)
337 }
338
339 fn normalize_bm25_scores(&self, candidates: Vec<Candidate>) -> Vec<Candidate> {
341 if candidates.is_empty() {
342 return candidates;
343 }
344
345 let max_score = candidates
346 .iter()
347 .map(|c| c.score)
348 .fold(0.0, f64::max);
349
350 if max_score == 0.0 {
351 return candidates;
352 }
353
354 candidates
355 .into_iter()
356 .map(|mut c| {
357 c.score /= max_score;
358 c
359 })
360 .collect()
361 }
362
363 fn normalize_cosine_scores(&self, candidates: Vec<Candidate>) -> Vec<Candidate> {
365 candidates
366 .into_iter()
367 .map(|mut c| {
368 c.score = (c.score + 1.0) / 2.0;
369 c
370 })
371 .collect()
372 }
373
374 pub fn calculate_zscores(&self, candidates: &[Candidate]) -> Vec<Candidate> {
376 if candidates.is_empty() {
377 return candidates.to_vec();
378 }
379
380 let scores: Vec<f64> = candidates.iter().map(|c| c.score).collect();
382 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
383
384 let variance = scores.iter()
385 .map(|&score| (score - mean).powi(2))
386 .sum::<f64>() / scores.len() as f64;
387
388 let std_dev = variance.sqrt();
389
390 if std_dev == 0.0 {
392 return candidates.to_vec();
393 }
394
395 candidates.iter()
397 .map(|candidate| {
398 let zscore = (candidate.score - mean) / std_dev;
399 Candidate {
400 doc_id: candidate.doc_id.clone(),
401 score: zscore,
402 text: candidate.text.clone(),
403 kind: candidate.kind.clone(),
404 }
405 })
406 .collect()
407 }
408
409 async fn post_process(&self, mut candidates: Vec<Candidate>) -> Result<Vec<Candidate>> {
411 if self.config.rerank {
413 tracing::debug!("Reranking not implemented in basic version");
414 }
415
416 if self.config.diversify && candidates.len() > self.config.k_final as usize {
418 tracing::debug!("Diversification not implemented in basic version");
419 }
420
421 candidates.truncate(self.config.k_final as usize);
423
424 Ok(candidates)
425 }
426
427 pub fn mock_for_testing() -> Self {
429 let embedding_service = Arc::new(FallbackEmbeddingService::new(384)); Self::new(embedding_service, HybridRetrievalConfig::hero())
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::embeddings::FallbackEmbeddingService;
438 use lethe_shared::Chunk;
439 use uuid::Uuid;
440 use std::sync::Arc;
441
442 struct MockRepository {
444 chunks: Vec<Chunk>,
445 dfidf: Vec<DfIdf>,
446 }
447
448 #[async_trait]
449 impl DocumentRepository for MockRepository {
450 async fn get_chunks_by_session(&self, _session_id: &str) -> Result<Vec<Chunk>> {
451 Ok(self.chunks.clone())
452 }
453
454 async fn get_dfidf_by_session(&self, _session_id: &str) -> Result<Vec<DfIdf>> {
455 Ok(self.dfidf.clone())
456 }
457
458 async fn get_chunk_by_id(&self, chunk_id: &str) -> Result<Option<Chunk>> {
459 Ok(self.chunks.iter().find(|c| c.id == chunk_id).cloned())
460 }
461
462 async fn vector_search(&self, _query_vector: &EmbeddingVector, k: i32) -> Result<Vec<Candidate>> {
463 let candidates: Vec<Candidate> = self.chunks
465 .iter()
466 .take(k as usize)
467 .map(|chunk| Candidate {
468 doc_id: chunk.id.clone(),
469 score: 0.8, text: Some(chunk.text.clone()),
471 kind: Some(chunk.kind.clone()),
472 })
473 .collect();
474 Ok(candidates)
475 }
476 }
477
478 fn create_test_chunk(id: &str, text: &str, kind: &str) -> Chunk {
479 Chunk {
480 id: id.to_string(),
481 message_id: Uuid::new_v4(),
482 session_id: "test-session".to_string(),
483 offset_start: 0,
484 offset_end: text.len(),
485 kind: kind.to_string(),
486 text: text.to_string(),
487 tokens: text.split_whitespace().count() as i32,
488 }
489 }
490
491 #[tokio::test]
492 async fn test_bm25_search() {
493 let chunks = vec![
494 create_test_chunk("1", "hello world", "text"),
495 create_test_chunk("2", "world peace", "text"),
496 create_test_chunk("3", "goodbye world", "text"),
497 ];
498
499 let dfidf = vec![
500 DfIdf {
501 term: "hello".to_string(),
502 session_id: "test-session".to_string(),
503 df: 1,
504 idf: 1.0,
505 },
506 DfIdf {
507 term: "world".to_string(),
508 session_id: "test-session".to_string(),
509 df: 3,
510 idf: 0.5,
511 },
512 ];
513
514 let repository = MockRepository { chunks, dfidf };
515 let queries = vec!["hello world".to_string()];
516
517 let results = Bm25SearchService::search(&repository, &queries, "test-session", 10)
518 .await
519 .unwrap();
520
521 assert!(!results.is_empty());
522 assert_eq!(results[0].doc_id, "1"); }
524
525 #[tokio::test]
526 async fn test_hybrid_retrieval() {
527 let chunks = vec![
528 create_test_chunk("1", "async programming in rust", "text"),
529 create_test_chunk("2", "rust error handling", "text"),
530 create_test_chunk("3", "javascript async await", "text"),
531 ];
532
533 let dfidf = vec![
534 DfIdf {
535 term: "async".to_string(),
536 session_id: "test-session".to_string(),
537 df: 2,
538 idf: 0.4,
539 },
540 DfIdf {
541 term: "rust".to_string(),
542 session_id: "test-session".to_string(),
543 df: 2,
544 idf: 0.4,
545 },
546 ];
547
548 let repository = MockRepository { chunks, dfidf };
549 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
550 let config = HybridRetrievalConfig::default();
551 let service = HybridRetrievalService::new(embedding_service, config);
552
553 let queries = vec!["rust async programming".to_string()];
554 let results = service
555 .retrieve(&repository, &queries, "test-session")
556 .await
557 .unwrap();
558
559 assert!(!results.is_empty());
560 assert!(results.len() <= 5); }
562
563 #[tokio::test]
564 async fn test_hero_configuration() {
565 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
566 let hero_config = HybridRetrievalConfig::hero();
567 let service = HybridRetrievalService::new(embedding_service, hero_config);
568
569 assert_eq!(service.config.alpha, 0.5); assert_eq!(service.config.beta, 0.5); assert_eq!(service.config.k_initial, 200); assert_eq!(service.config.k_final, 5); assert_eq!(service.config.diversify_method, "splade");
575 }
576
577 #[test]
578 fn test_score_normalization() {
579 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
580 let config = HybridRetrievalConfig::default();
581 let service = HybridRetrievalService::new(embedding_service, config);
582
583 let candidates = vec![
584 Candidate {
585 doc_id: "1".to_string(),
586 score: 10.0,
587 text: None,
588 kind: None,
589 },
590 Candidate {
591 doc_id: "2".to_string(),
592 score: 5.0,
593 text: None,
594 kind: None,
595 },
596 ];
597
598 let normalized = service.normalize_bm25_scores(candidates);
599 assert_eq!(normalized[0].score, 1.0);
600 assert_eq!(normalized[1].score, 0.5);
601 }
602
603 #[test]
604 fn test_query_features() {
605 let features = QueryFeatures::extract_features("function_name() error in /path/file.rs");
606 assert!(features.has_code_symbol);
607 assert!(features.has_error_token);
608 assert!(features.has_path_or_file);
609
610 let boost = QueryFeatures::gamma_boost("code", &features);
611 assert!(boost > 0.0);
612 }
613
614 #[test]
615 fn test_query_features_comprehensive() {
616 let features1 = QueryFeatures::extract_features("call myFunction() here");
618 assert!(features1.has_code_symbol);
619 assert!(!features1.has_error_token);
620
621 let features2 = QueryFeatures::extract_features("use MyClass::StaticMethod");
623 assert!(features2.has_code_symbol);
624
625 let features3 = QueryFeatures::extract_features("NullPointerException occurred");
627 assert!(features3.has_error_token);
628 assert!(!features3.has_code_symbol);
629
630 let features4 = QueryFeatures::extract_features("check /home/user/file.txt");
632 assert!(features4.has_path_or_file);
633 assert!(!features4.has_error_token);
634
635 let features5 = QueryFeatures::extract_features("see C:\\Users\\Name\\doc.docx");
637 assert!(features5.has_path_or_file);
638
639 let features6 = QueryFeatures::extract_features("issue 1234 needs fixing");
641 assert!(features6.has_numeric_id);
642 assert!(!features6.has_code_symbol);
643
644 let features7 = QueryFeatures::extract_features("");
646 assert!(!features7.has_code_symbol);
647 assert!(!features7.has_error_token);
648 assert!(!features7.has_path_or_file);
649 assert!(!features7.has_numeric_id);
650 }
651
652 #[test]
653 fn test_gamma_boost_combinations() {
654 let features = QueryFeatures::extract_features("myFunction() returns value");
656
657 let code_boost = QueryFeatures::gamma_boost("code", &features);
658 assert!(code_boost > 0.0);
659
660 let user_code_boost = QueryFeatures::gamma_boost("user_code", &features);
661 assert!(user_code_boost > 0.0);
662
663 let text_boost = QueryFeatures::gamma_boost("text", &features);
664 assert_eq!(text_boost, 0.0); let error_features = QueryFeatures::extract_features("RuntimeError in execution");
668 let tool_boost = QueryFeatures::gamma_boost("tool_result", &error_features);
669 assert!(tool_boost > 0.0);
670
671 let path_features = QueryFeatures::extract_features("file located at /src/main.rs");
673 let code_path_boost = QueryFeatures::gamma_boost("code", &path_features);
674 assert!(code_path_boost > 0.0);
675
676 let combined_features = QueryFeatures::extract_features("function() error in /path/file.rs with ID 1234");
678 assert!(combined_features.has_code_symbol);
679 assert!(combined_features.has_error_token);
680 assert!(combined_features.has_path_or_file);
681 assert!(combined_features.has_numeric_id);
682
683 let combined_boost = QueryFeatures::gamma_boost("code", &combined_features);
684 assert!(combined_boost > 0.1); }
686
687 #[tokio::test]
688 async fn test_hybrid_retrieval_creation() {
689 use crate::embeddings::FallbackEmbeddingService;
690
691 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
692 let service = HybridRetrievalService::new(embedding_service.clone(), HybridRetrievalConfig::default());
693
694 assert_eq!(service.config.alpha, 0.5); assert_eq!(service.config.beta, 0.5); assert!(service.config.gamma_kind_boost.contains_key("code"));
698 }
699
700 #[tokio::test]
701 async fn test_retrieval_service_configurations() {
702 use crate::embeddings::FallbackEmbeddingService;
703
704 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
705
706 let custom_config = HybridRetrievalConfig {
708 alpha: 0.3,
709 beta: 0.7,
710 gamma_kind_boost: std::collections::HashMap::from([
711 ("code".to_string(), 0.15),
712 ("user_code".to_string(), 0.12),
713 ]),
714 rerank: true,
715 diversify: false,
716 diversify_method: "simple".to_string(),
717 k_initial: 50,
718 k_final: 10,
719 fusion_dynamic: false,
720 };
721
722 let service = HybridRetrievalService::new(embedding_service.clone(), custom_config.clone());
723
724 assert_eq!(service.config.alpha, 0.3);
726 assert_eq!(service.config.beta, 0.7);
727 assert_eq!(service.config.gamma_kind_boost.get("code"), Some(&0.15));
728 assert_eq!(service.config.k_final, 10);
729 }
730
731 #[test]
732 fn test_bm25_service_properties() {
733 let mut service = Bm25SearchService;
734
735 let _ = service;
740 }
741
742 #[test]
743 fn test_vector_search_service_properties() {
744 use crate::embeddings::FallbackEmbeddingService;
745
746 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
747 let service = VectorSearchService::new(embedding_service.clone());
748
749 assert_eq!(service.embedding_service.name(), "fallback");
751
752 assert_eq!(service.embedding_service.dimension(), 384);
754 }
755
756 #[test]
757 fn test_retrieval_config_defaults() {
758 let config = HybridRetrievalConfig::default();
760
761 assert_eq!(config.alpha, 0.5); assert_eq!(config.beta, 0.5); assert_eq!(config.k_initial, 200); assert_eq!(config.k_final, 5); assert!(config.diversify);
766 assert!(config.gamma_kind_boost.contains_key("code"));
767
768 assert_eq!(config.gamma_kind_boost.get("code"), Some(&1.2));
770 }
771
772 #[test]
773 fn test_zscore_calculation() {
774 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
775 let config = HybridRetrievalConfig::hero();
776 let service = HybridRetrievalService::new(embedding_service, config);
777
778 let candidates = vec![
779 Candidate {
780 doc_id: "1".to_string(),
781 score: 10.0,
782 text: None,
783 kind: None,
784 },
785 Candidate {
786 doc_id: "2".to_string(),
787 score: 5.0,
788 text: None,
789 kind: None,
790 },
791 Candidate {
792 doc_id: "3".to_string(),
793 score: 0.0,
794 text: None,
795 kind: None,
796 },
797 ];
798
799 let zscores = service.calculate_zscores(&candidates);
800
801 let scores: Vec<f64> = zscores.iter().map(|c| c.score).collect();
803 let mean = scores.iter().sum::<f64>() / scores.len() as f64;
804 assert!((mean).abs() < 1e-10); assert!(zscores[0].score > 0.0);
808 assert!(zscores[2].score < 0.0);
810 }
811
812 #[test]
813 fn test_zscore_fusion_end_to_end() {
814 let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
818 let hero_config = HybridRetrievalConfig::hero();
819 let service = HybridRetrievalService::new(embedding_service, hero_config);
820
821 assert_eq!(service.config.alpha, 0.5, "α must be 0.5 for z-score fusion");
823 assert_eq!(service.config.beta, 0.5, "β must be 0.5 for z-score fusion");
824 assert_eq!(service.config.k_initial, 200, "k_initial must be 200 (hero)");
825 assert_eq!(service.config.k_final, 5, "k_final must be 5 (hero)");
826 assert_eq!(service.config.diversify_method, "splade", "must use splade diversification");
827
828 let bm25_candidates = vec![
830 Candidate { doc_id: "A".to_string(), score: 3.0, text: None, kind: None },
831 Candidate { doc_id: "B".to_string(), score: 2.0, text: None, kind: None },
832 Candidate { doc_id: "C".to_string(), score: 1.0, text: None, kind: None },
833 ];
834
835 let vector_candidates = vec![
836 Candidate { doc_id: "A".to_string(), score: 0.9, text: None, kind: None },
837 Candidate { doc_id: "B".to_string(), score: 0.6, text: None, kind: None },
838 Candidate { doc_id: "C".to_string(), score: 0.3, text: None, kind: None },
839 ];
840
841 let bm25_zscores = service.calculate_zscores(&bm25_candidates);
843 let vector_zscores = service.calculate_zscores(&vector_candidates);
844
845 let bm25_scores: Vec<f64> = bm25_zscores.iter().map(|c| c.score).collect();
847 let bm25_mean = bm25_scores.iter().sum::<f64>() / bm25_scores.len() as f64;
848 let bm25_var = bm25_scores.iter().map(|&x| (x - bm25_mean).powi(2)).sum::<f64>() / bm25_scores.len() as f64;
849 let bm25_std = bm25_var.sqrt();
850
851 println!("BM25 z-score validation:");
852 println!(" Mean: {:.10} (should be ≈ 0)", bm25_mean);
853 println!(" Std Dev: {:.10} (should be ≈ 1)", bm25_std);
854
855 assert!((bm25_mean).abs() < 1e-10, "BM25 z-score mean must be ≈ 0");
856 assert!((bm25_std - 1.0).abs() < 1e-10, "BM25 z-score std must be ≈ 1");
857
858 let vector_scores: Vec<f64> = vector_zscores.iter().map(|c| c.score).collect();
860 let vector_mean = vector_scores.iter().sum::<f64>() / vector_scores.len() as f64;
861 let vector_var = vector_scores.iter().map(|&x| (x - vector_mean).powi(2)).sum::<f64>() / vector_scores.len() as f64;
862 let vector_std = vector_var.sqrt();
863
864 println!("Vector z-score validation:");
865 println!(" Mean: {:.10} (should be ≈ 0)", vector_mean);
866 println!(" Std Dev: {:.10} (should be ≈ 1)", vector_std);
867
868 assert!((vector_mean).abs() < 1e-10, "Vector z-score mean must be ≈ 0");
869 assert!((vector_std - 1.0).abs() < 1e-10, "Vector z-score std must be ≈ 1");
870
871 let hybrid_scores: Vec<f64> = bm25_zscores.iter()
873 .zip(vector_zscores.iter())
874 .map(|(bm25, vector)| service.config.alpha * bm25.score + service.config.beta * vector.score)
875 .collect();
876
877 println!("Hybrid fusion validation:");
878 println!(" BM25 z-scores: {:?}", bm25_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
879 println!(" Vector z-scores: {:?}", vector_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
880 println!(" Hybrid scores (0.5 * bm25_z + 0.5 * vector_z): {:?}", hybrid_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
881
882 assert!(hybrid_scores.len() == 3, "Must have 3 hybrid scores");
884 assert!(hybrid_scores[0] > hybrid_scores[1], "Scores should be ordered");
885 assert!(hybrid_scores[1] > hybrid_scores[2], "Scores should be ordered");
886
887 let expected_0 = 0.5 * bm25_scores[0] + 0.5 * vector_scores[0];
889 let expected_1 = 0.5 * bm25_scores[1] + 0.5 * vector_scores[1];
890 let expected_2 = 0.5 * bm25_scores[2] + 0.5 * vector_scores[2];
891
892 assert!((hybrid_scores[0] - expected_0).abs() < 1e-10, "Hybrid calculation must match expected formula");
893 assert!((hybrid_scores[1] - expected_1).abs() < 1e-10, "Hybrid calculation must match expected formula");
894 assert!((hybrid_scores[2] - expected_2).abs() < 1e-10, "Hybrid calculation must match expected formula");
895
896 println!("✅ Z-Score Fusion End-to-End Validation PASSED");
897 println!(" Hero configuration: ✓");
898 println!(" Z-score normalization: ✓");
899 println!(" Fusion calculation: ✓");
900 println!(" Mathematical properties: ✓");
901 }
902}