use lethe_domain::retrieval::{HybridRetrievalConfig, HybridRetrievalService, DocumentRepository, Bm25SearchService};
use lethe_domain::embeddings::FallbackEmbeddingService;
use lethe_shared::{Candidate, Chunk, DfIdf, EmbeddingVector, Result};
use std::sync::Arc;
use async_trait::async_trait;
use uuid::Uuid;
use std::collections::HashMap;
struct ValidationRepository {
chunks: Vec<Chunk>,
dfidf: Vec<DfIdf>,
}
#[async_trait]
impl DocumentRepository for ValidationRepository {
async fn get_chunks_by_session(&self, _session_id: &str) -> Result<Vec<Chunk>> {
Ok(self.chunks.clone())
}
async fn get_dfidf_by_session(&self, _session_id: &str) -> Result<Vec<DfIdf>> {
Ok(self.dfidf.clone())
}
async fn get_chunk_by_id(&self, chunk_id: &str) -> Result<Option<Chunk>> {
Ok(self.chunks.iter().find(|c| c.id == chunk_id).cloned())
}
async fn vector_search(&self, _query_vector: &EmbeddingVector, k: i32) -> Result<Vec<Candidate>> {
Ok(vec![
Candidate { doc_id: "1".to_string(), score: 0.9, text: Some("high vector score".to_string()), kind: Some("code".to_string()) },
Candidate { doc_id: "2".to_string(), score: 0.6, text: Some("medium vector score".to_string()), kind: Some("text".to_string()) },
Candidate { doc_id: "3".to_string(), score: 0.3, text: Some("low vector score".to_string()), kind: Some("function".to_string()) },
])
}
}
fn create_validation_data() -> (Vec<Chunk>, Vec<DfIdf>) {
let chunks = vec![
Chunk {
id: "1".to_string(),
message_id: Uuid::new_v4(),
session_id: "test-session".to_string(),
offset_start: 0,
offset_end: 50,
kind: "code".to_string(),
text: "test implementation code snippet".to_string(),
tokens: 5,
},
Chunk {
id: "2".to_string(),
message_id: Uuid::new_v4(),
session_id: "test-session".to_string(),
offset_start: 0,
offset_end: 40,
kind: "text".to_string(),
text: "test documentation text".to_string(),
tokens: 4,
},
Chunk {
id: "3".to_string(),
message_id: Uuid::new_v4(),
session_id: "test-session".to_string(),
offset_start: 0,
offset_end: 30,
kind: "function".to_string(),
text: "test function definition".to_string(),
tokens: 4,
},
];
let dfidf = vec![
DfIdf { term: "test".to_string(), session_id: "test".to_string(), df: 3, idf: 0.48 },
DfIdf { term: "implementation".to_string(), session_id: "test".to_string(), df: 1, idf: 1.10 },
DfIdf { term: "code".to_string(), session_id: "test".to_string(), df: 2, idf: 0.69 },
DfIdf { term: "documentation".to_string(), session_id: "test".to_string(), df: 1, idf: 1.10 },
DfIdf { term: "function".to_string(), session_id: "test".to_string(), df: 2, idf: 0.69 },
];
(chunks, dfidf)
}
fn manual_zscore_calculation(scores: &[f64]) -> Vec<f64> {
let mean = scores.iter().sum::<f64>() / scores.len() as f64;
let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / scores.len() as f64;
let std_dev = variance.sqrt();
if std_dev == 0.0 {
scores.to_vec() } else {
scores.iter().map(|&x| (x - mean) / std_dev).collect()
}
}
#[tokio::main]
async fn main() -> Result<()> {
println!("🧪 Z-Score Fusion Validation Test");
println!("==================================");
let (chunks, dfidf) = create_validation_data();
let repository = ValidationRepository { chunks, dfidf };
let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
let hero_config = HybridRetrievalConfig::hero();
let service = HybridRetrievalService::new(embedding_service, hero_config);
println!("\n📊 Hero Configuration:");
println!(" α (BM25 weight): {}", service.config.alpha);
println!(" β (Vector weight): {}", service.config.beta);
println!(" k_initial: {}", service.config.k_initial);
println!(" k_final: {}", service.config.k_final);
println!(" diversify_method: {}", service.config.diversify_method);
let queries = vec!["test implementation code".to_string()];
let results = service
.retrieve(&repository, &queries, "test-session")
.await?;
println!("\n🔍 Retrieval Results:");
println!(" Total results: {}", results.len());
for (i, candidate) in results.iter().enumerate() {
println!(" {}. ID: {}, Score: {:.6}, Text: {:?}",
i + 1, candidate.doc_id, candidate.score,
candidate.text.as_ref().unwrap_or(&"None".to_string()));
}
println!("\n🧮 Z-Score Validation:");
let test_candidates = vec![
Candidate { doc_id: "A".to_string(), score: 10.0, text: None, kind: None },
Candidate { doc_id: "B".to_string(), score: 5.0, text: None, kind: None },
Candidate { doc_id: "C".to_string(), score: 0.0, text: None, kind: None },
];
let raw_scores: Vec<f64> = test_candidates.iter().map(|c| c.score).collect();
let manual_zscores = manual_zscore_calculation(&raw_scores);
let service_zscores = service.calculate_zscores(&test_candidates);
let service_scores: Vec<f64> = service_zscores.iter().map(|c| c.score).collect();
println!(" Raw scores: {:?}", raw_scores);
println!(" Manual z-scores: {:?}", manual_zscores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
println!(" Service z-scores: {:?}", service_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
let mean = service_scores.iter().sum::<f64>() / service_scores.len() as f64;
let variance = service_scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / service_scores.len() as f64;
let std_dev = variance.sqrt();
println!("\n📈 Z-Score Mathematical Properties:");
println!(" Mean: {:.10} (should be ≈ 0)", mean);
println!(" Standard Deviation: {:.10} (should be ≈ 1)", std_dev);
let mean_ok = mean.abs() < 1e-10;
let stddev_ok = (std_dev - 1.0).abs() < 1e-10;
println!("\n✅ Validation Results:");
println!(" Mean ≈ 0: {} ({})", if mean_ok { "PASS" } else { "FAIL" }, mean_ok);
println!(" Std Dev ≈ 1: {} ({})", if stddev_ok { "PASS" } else { "FAIL" }, stddev_ok);
println!("\n⚙️ BM25 Direct Test:");
let mut term_freqs = HashMap::new();
term_freqs.insert("test".to_string(), 2);
term_freqs.insert("code".to_string(), 1);
let mut term_idf_map = HashMap::new();
term_idf_map.insert("test".to_string(), 0.48);
term_idf_map.insert("code".to_string(), 0.69);
let bm25_score = Bm25SearchService::calculate_bm25(
&term_freqs,
5.0, 4.5, &term_idf_map,
1.2, 0.75 );
println!(" BM25 score: {:.6}", bm25_score);
println!(" Score > 0: {} ({})", if bm25_score > 0.0 { "PASS" } else { "FAIL" }, bm25_score > 0.0);
println!("\n🔀 Fusion Test:");
let bm25_scores = vec![2.0, 1.5, 1.0];
let vector_scores = vec![0.9, 0.6, 0.3];
let bm25_zscores = manual_zscore_calculation(&bm25_scores);
let vector_zscores = manual_zscore_calculation(&vector_scores);
println!(" BM25 raw: {:?}", bm25_scores);
println!(" BM25 z-scores: {:?}", bm25_zscores.iter().map(|&x| format!("{:.3}", x)).collect::<Vec<_>>());
println!(" Vector raw: {:?}", vector_scores);
println!(" Vector z-scores: {:?}", vector_zscores.iter().map(|&x| format!("{:.3}", x)).collect::<Vec<_>>());
let hybrid_scores: Vec<f64> = bm25_zscores.iter()
.zip(vector_zscores.iter())
.map(|(bm25_z, vector_z)| 0.5 * bm25_z + 0.5 * vector_z)
.collect();
println!(" Hybrid scores: {:?}", hybrid_scores.iter().map(|&x| format!("{:.3}", x)).collect::<Vec<_>>());
println!("\n🎯 Final Validation:");
let all_passed = mean_ok && stddev_ok && bm25_score > 0.0 && results.len() <= 5;
println!(" All tests passed: {} ({})", if all_passed { "✅ PASS" } else { "❌ FAIL" }, all_passed);
if all_passed {
println!("\n🎉 Z-Score Fusion Implementation VERIFIED!");
println!(" Hero configuration properly applied");
println!(" Mathematical properties correct");
println!(" End-to-end retrieval working");
} else {
println!("\n❌ Z-Score Fusion Implementation FAILED validation");
}
Ok(())
}