lethe-core-rust 0.1.1

High-performance hybrid retrieval engine combining BM25 lexical search with vector similarity using z-score fusion. Features hero configuration for optimal parity with splade baseline, gamma boosting for code/error contexts, and comprehensive chunking pipeline.
Documentation
use lethe_domain::retrieval::{HybridRetrievalConfig, HybridRetrievalService, DocumentRepository, Bm25SearchService};
use lethe_domain::embeddings::FallbackEmbeddingService;
use lethe_shared::{Candidate, Chunk, DfIdf, EmbeddingVector, Result};
use std::sync::Arc;
use async_trait::async_trait;
use uuid::Uuid;
use std::collections::HashMap;

// Test repository with controlled data to validate z-score fusion behavior
struct ValidationRepository {
    chunks: Vec<Chunk>,
    dfidf: Vec<DfIdf>,
}

#[async_trait]
impl DocumentRepository for ValidationRepository {
    async fn get_chunks_by_session(&self, _session_id: &str) -> Result<Vec<Chunk>> {
        Ok(self.chunks.clone())
    }

    async fn get_dfidf_by_session(&self, _session_id: &str) -> Result<Vec<DfIdf>> {
        Ok(self.dfidf.clone())
    }

    async fn get_chunk_by_id(&self, chunk_id: &str) -> Result<Option<Chunk>> {
        Ok(self.chunks.iter().find(|c| c.id == chunk_id).cloned())
    }

    async fn vector_search(&self, _query_vector: &EmbeddingVector, k: i32) -> Result<Vec<Candidate>> {
        // Return candidates with known vector scores for validation
        Ok(vec![
            Candidate { doc_id: "1".to_string(), score: 0.9, text: Some("high vector score".to_string()), kind: Some("code".to_string()) },
            Candidate { doc_id: "2".to_string(), score: 0.6, text: Some("medium vector score".to_string()), kind: Some("text".to_string()) },
            Candidate { doc_id: "3".to_string(), score: 0.3, text: Some("low vector score".to_string()), kind: Some("function".to_string()) },
        ])
    }
}

fn create_validation_data() -> (Vec<Chunk>, Vec<DfIdf>) {
    let chunks = vec![
        Chunk {
            id: "1".to_string(),
            message_id: Uuid::new_v4(),
            session_id: "test-session".to_string(),
            offset_start: 0,
            offset_end: 50,
            kind: "code".to_string(),
            text: "test implementation code snippet".to_string(),
            tokens: 5,
        },
        Chunk {
            id: "2".to_string(),
            message_id: Uuid::new_v4(),
            session_id: "test-session".to_string(),
            offset_start: 0,
            offset_end: 40,
            kind: "text".to_string(),
            text: "test documentation text".to_string(),
            tokens: 4,
        },
        Chunk {
            id: "3".to_string(),
            message_id: Uuid::new_v4(),
            session_id: "test-session".to_string(),
            offset_start: 0,
            offset_end: 30,
            kind: "function".to_string(),
            text: "test function definition".to_string(),
            tokens: 4,
        },
    ];

    let dfidf = vec![
        DfIdf { term: "test".to_string(), session_id: "test".to_string(), df: 3, idf: 0.48 },
        DfIdf { term: "implementation".to_string(), session_id: "test".to_string(), df: 1, idf: 1.10 },
        DfIdf { term: "code".to_string(), session_id: "test".to_string(), df: 2, idf: 0.69 },
        DfIdf { term: "documentation".to_string(), session_id: "test".to_string(), df: 1, idf: 1.10 },
        DfIdf { term: "function".to_string(), session_id: "test".to_string(), df: 2, idf: 0.69 },
    ];

    (chunks, dfidf)
}

// Manual z-score calculation to verify our implementation
fn manual_zscore_calculation(scores: &[f64]) -> Vec<f64> {
    let mean = scores.iter().sum::<f64>() / scores.len() as f64;
    let variance = scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / scores.len() as f64;
    let std_dev = variance.sqrt();
    
    if std_dev == 0.0 {
        scores.to_vec() // Return original scores if no variance
    } else {
        scores.iter().map(|&x| (x - mean) / std_dev).collect()
    }
}

#[tokio::main]
async fn main() -> Result<()> {
    println!("🧪 Z-Score Fusion Validation Test");
    println!("==================================");
    
    let (chunks, dfidf) = create_validation_data();
    let repository = ValidationRepository { chunks, dfidf };
    
    let embedding_service = Arc::new(FallbackEmbeddingService::new(384));
    let hero_config = HybridRetrievalConfig::hero();
    let service = HybridRetrievalService::new(embedding_service, hero_config);

    println!("\n📊 Hero Configuration:");
    println!("  α (BM25 weight): {}", service.config.alpha);
    println!("  β (Vector weight): {}", service.config.beta);
    println!("  k_initial: {}", service.config.k_initial);
    println!("  k_final: {}", service.config.k_final);
    println!("  diversify_method: {}", service.config.diversify_method);

    // Test query designed to trigger both BM25 and vector matches
    let queries = vec!["test implementation code".to_string()];
    let results = service
        .retrieve(&repository, &queries, "test-session")
        .await?;

    println!("\n🔍 Retrieval Results:");
    println!("  Total results: {}", results.len());
    
    for (i, candidate) in results.iter().enumerate() {
        println!("  {}. ID: {}, Score: {:.6}, Text: {:?}", 
                 i + 1, candidate.doc_id, candidate.score, 
                 candidate.text.as_ref().unwrap_or(&"None".to_string()));
    }

    // Now let's validate the z-score calculation directly
    println!("\n🧮 Z-Score Validation:");
    
    // Test with known values
    let test_candidates = vec![
        Candidate { doc_id: "A".to_string(), score: 10.0, text: None, kind: None },
        Candidate { doc_id: "B".to_string(), score: 5.0, text: None, kind: None },
        Candidate { doc_id: "C".to_string(), score: 0.0, text: None, kind: None },
    ];
    
    let raw_scores: Vec<f64> = test_candidates.iter().map(|c| c.score).collect();
    let manual_zscores = manual_zscore_calculation(&raw_scores);
    let service_zscores = service.calculate_zscores(&test_candidates);
    let service_scores: Vec<f64> = service_zscores.iter().map(|c| c.score).collect();
    
    println!("  Raw scores: {:?}", raw_scores);
    println!("  Manual z-scores: {:?}", manual_zscores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
    println!("  Service z-scores: {:?}", service_scores.iter().map(|&x| format!("{:.6}", x)).collect::<Vec<_>>());
    
    // Validate mathematical properties
    let mean = service_scores.iter().sum::<f64>() / service_scores.len() as f64;
    let variance = service_scores.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / service_scores.len() as f64;
    let std_dev = variance.sqrt();
    
    println!("\n📈 Z-Score Mathematical Properties:");
    println!("  Mean: {:.10} (should be ≈ 0)", mean);
    println!("  Standard Deviation: {:.10} (should be ≈ 1)", std_dev);
    
    // Validate z-score properties
    let mean_ok = mean.abs() < 1e-10;
    let stddev_ok = (std_dev - 1.0).abs() < 1e-10;
    
    println!("\n✅ Validation Results:");
    println!("  Mean ≈ 0: {} ({})", if mean_ok { "PASS" } else { "FAIL" }, mean_ok);
    println!("  Std Dev ≈ 1: {} ({})", if stddev_ok { "PASS" } else { "FAIL" }, stddev_ok);
    
    // Test BM25 scoring directly
    println!("\n⚙️  BM25 Direct Test:");
    let mut term_freqs = HashMap::new();
    term_freqs.insert("test".to_string(), 2);
    term_freqs.insert("code".to_string(), 1);
    
    let mut term_idf_map = HashMap::new();
    term_idf_map.insert("test".to_string(), 0.48);
    term_idf_map.insert("code".to_string(), 0.69);
    
    let bm25_score = Bm25SearchService::calculate_bm25(
        &term_freqs,
        5.0,  // doc_length  
        4.5,  // avg_doc_length
        &term_idf_map,
        1.2,  // k1 (hero parameter)
        0.75  // b (hero parameter)
    );
    
    println!("  BM25 score: {:.6}", bm25_score);
    println!("  Score > 0: {} ({})", if bm25_score > 0.0 { "PASS" } else { "FAIL" }, bm25_score > 0.0);

    // Test fusion calculation
    println!("\n🔀 Fusion Test:");
    let bm25_scores = vec![2.0, 1.5, 1.0];
    let vector_scores = vec![0.9, 0.6, 0.3];
    
    let bm25_zscores = manual_zscore_calculation(&bm25_scores);
    let vector_zscores = manual_zscore_calculation(&vector_scores);
    
    println!("  BM25 raw: {:?}", bm25_scores);
    println!("  BM25 z-scores: {:?}", bm25_zscores.iter().map(|&x| format!("{:.3}", x)).collect::<Vec<_>>());
    println!("  Vector raw: {:?}", vector_scores);  
    println!("  Vector z-scores: {:?}", vector_zscores.iter().map(|&x| format!("{:.3}", x)).collect::<Vec<_>>());
    
    // Calculate hybrid scores with α=0.5, β=0.5
    let hybrid_scores: Vec<f64> = bm25_zscores.iter()
        .zip(vector_zscores.iter())
        .map(|(bm25_z, vector_z)| 0.5 * bm25_z + 0.5 * vector_z)
        .collect();
    
    println!("  Hybrid scores: {:?}", hybrid_scores.iter().map(|&x| format!("{:.3}", x)).collect::<Vec<_>>());
    
    println!("\n🎯 Final Validation:");
    let all_passed = mean_ok && stddev_ok && bm25_score > 0.0 && results.len() <= 5;
    println!("  All tests passed: {} ({})", if all_passed { "✅ PASS" } else { "❌ FAIL" }, all_passed);
    
    if all_passed {
        println!("\n🎉 Z-Score Fusion Implementation VERIFIED!");
        println!("   Hero configuration properly applied");
        println!("   Mathematical properties correct");
        println!("   End-to-end retrieval working");
    } else {
        println!("\n❌ Z-Score Fusion Implementation FAILED validation");
    }

    Ok(())
}