use serde::{Deserialize, Serialize};
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryBenchmark {
pub query: String,
pub ground_truth: Option<String>,
pub generated_answer: String,
pub latency: LatencyMetrics,
pub tokens: TokenMetrics,
pub quality: QualityMetrics,
pub features_enabled: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyMetrics {
pub total_ms: u64,
pub retrieval_ms: u64,
pub reranking_ms: Option<u64>,
pub generation_ms: u64,
pub other_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenMetrics {
pub input_tokens: usize,
pub output_tokens: usize,
pub total_tokens: usize,
pub estimated_cost_usd: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityMetrics {
pub exact_match: f32,
pub f1_score: f32,
pub bleu_score: Option<f32>,
pub rouge_l: Option<f32>,
pub semantic_similarity: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkDataset {
pub name: String,
pub queries: Vec<BenchmarkQuery>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkQuery {
pub question: String,
pub answer: String,
pub context: Option<Vec<String>>,
pub difficulty: Option<String>,
pub query_type: Option<String>,
}
#[derive(Debug, Clone)]
pub struct BenchmarkConfig {
pub enable_lightrag: bool,
pub enable_leiden: bool,
pub enable_cross_encoder: bool,
pub enable_hipporag: bool,
pub enable_semantic_chunking: bool,
pub top_k: usize,
pub input_token_price: f64,
pub output_token_price: f64,
}
impl Default for BenchmarkConfig {
fn default() -> Self {
Self {
enable_lightrag: false,
enable_leiden: false,
enable_cross_encoder: false,
enable_hipporag: false,
enable_semantic_chunking: false,
top_k: 10,
input_token_price: 0.0001, output_token_price: 0.0003, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkSummary {
pub config_name: String,
pub total_queries: usize,
pub avg_latency_ms: f64,
pub avg_retrieval_ms: f64,
pub avg_reranking_ms: f64,
pub avg_generation_ms: f64,
pub total_input_tokens: usize,
pub total_output_tokens: usize,
pub total_cost_usd: f64,
pub avg_tokens_per_query: f64,
pub avg_exact_match: f64,
pub avg_f1_score: f64,
pub avg_bleu_score: f64,
pub avg_rouge_l: f64,
pub features: Vec<String>,
pub query_results: Vec<QueryBenchmark>,
}
type RetrievalFn = Box<dyn Fn(&str) -> Vec<String> + Send + Sync>;
type RerankerFn = Box<dyn Fn(&[String]) -> Vec<String> + Send + Sync>;
type LlmFn = Box<dyn Fn(&str, &[String]) -> String + Send + Sync>;
pub struct BenchmarkRunner {
config: BenchmarkConfig,
retrieval_fn: Option<RetrievalFn>,
reranker_fn: Option<RerankerFn>,
llm_fn: Option<LlmFn>,
}
impl BenchmarkRunner {
pub fn new(config: BenchmarkConfig) -> Self {
Self {
config,
retrieval_fn: None,
reranker_fn: None,
llm_fn: None,
}
}
pub fn with_retrieval<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&str) -> Vec<String> + Send + Sync + 'static,
{
self.retrieval_fn = Some(Box::new(f));
self
}
pub fn with_reranker<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&[String]) -> Vec<String> + Send + Sync + 'static,
{
self.reranker_fn = Some(Box::new(f));
self
}
pub fn with_llm<F>(&mut self, f: F) -> &mut Self
where
F: Fn(&str, &[String]) -> String + Send + Sync + 'static,
{
self.llm_fn = Some(Box::new(f));
self
}
pub fn run_dataset(&mut self, dataset: &BenchmarkDataset) -> BenchmarkSummary {
println!("📊 Running benchmark on dataset: {}", dataset.name);
println!("📋 Queries: {}", dataset.queries.len());
let mut results = Vec::new();
for (i, query) in dataset.queries.iter().enumerate() {
println!(
" [{}/{}] Processing: {}...",
i + 1,
dataset.queries.len(),
&query.question.chars().take(50).collect::<String>()
);
let result = self.benchmark_query(query);
results.push(result);
}
self.compute_summary(dataset.name.clone(), results)
}
fn benchmark_query(&self, query: &BenchmarkQuery) -> QueryBenchmark {
let start = Instant::now();
let retrieval_start = Instant::now();
let retrieved_docs = if let Some(ref retrieval_fn) = self.retrieval_fn {
retrieval_fn(&query.question)
} else {
vec![]
};
let retrieval_time = retrieval_start.elapsed();
let (reranked_docs, reranking_time) = if self.config.enable_cross_encoder {
let reranking_start = Instant::now();
let reranked = if let Some(ref reranker_fn) = self.reranker_fn {
reranker_fn(&retrieved_docs)
} else {
retrieved_docs.clone()
};
(reranked, Some(reranking_start.elapsed()))
} else {
(retrieved_docs.clone(), None)
};
let generation_start = Instant::now();
let generated_answer = if let Some(ref llm_fn) = self.llm_fn {
llm_fn(&query.question, &reranked_docs)
} else {
format!("Generated answer for: {}", query.question)
};
let generation_time = generation_start.elapsed();
let total_time = start.elapsed();
let estimated_input_tokens = if self.config.enable_lightrag {
200 } else {
2000 };
let estimated_output_tokens = 100;
let tokens = TokenMetrics {
input_tokens: estimated_input_tokens,
output_tokens: estimated_output_tokens,
total_tokens: estimated_input_tokens + estimated_output_tokens,
estimated_cost_usd: (estimated_input_tokens as f64 / 1000.0
* self.config.input_token_price)
+ (estimated_output_tokens as f64 / 1000.0 * self.config.output_token_price),
};
let quality = self.calculate_quality_metrics(&generated_answer, &query.answer);
let mut features = Vec::new();
if self.config.enable_lightrag {
features.push("LightRAG".to_string());
}
if self.config.enable_leiden {
features.push("Leiden".to_string());
}
if self.config.enable_cross_encoder {
features.push("Cross-Encoder".to_string());
}
if self.config.enable_hipporag {
features.push("HippoRAG PPR".to_string());
}
if self.config.enable_semantic_chunking {
features.push("Semantic Chunking".to_string());
}
QueryBenchmark {
query: query.question.clone(),
ground_truth: Some(query.answer.clone()),
generated_answer,
latency: LatencyMetrics {
total_ms: total_time.as_millis() as u64,
retrieval_ms: retrieval_time.as_millis() as u64,
reranking_ms: reranking_time.map(|d| d.as_millis() as u64),
generation_ms: generation_time.as_millis() as u64,
other_ms: 0,
},
tokens,
quality,
features_enabled: features,
}
}
fn calculate_quality_metrics(&self, generated: &str, ground_truth: &str) -> QualityMetrics {
let exact_match = if generated.trim().eq_ignore_ascii_case(ground_truth.trim()) {
1.0
} else {
0.0
};
let f1_score = self.calculate_f1_score(generated, ground_truth);
let bleu_score = Some(self.calculate_bleu_score(generated, ground_truth));
let rouge_l = Some(self.calculate_rouge_l(generated, ground_truth));
QualityMetrics {
exact_match,
f1_score,
bleu_score,
rouge_l,
semantic_similarity: None,
}
}
fn calculate_f1_score(&self, generated: &str, ground_truth: &str) -> f32 {
let gen_tokens: Vec<String> = generated
.to_lowercase()
.split_whitespace()
.map(|s| s.to_string())
.collect();
let gt_tokens: Vec<String> = ground_truth
.to_lowercase()
.split_whitespace()
.map(|s| s.to_string())
.collect();
if gen_tokens.is_empty() || gt_tokens.is_empty() {
return 0.0;
}
let mut common = 0;
for token in &gen_tokens {
if gt_tokens.contains(token) {
common += 1;
}
}
if common == 0 {
return 0.0;
}
let precision = common as f32 / gen_tokens.len() as f32;
let recall = common as f32 / gt_tokens.len() as f32;
2.0 * (precision * recall) / (precision + recall)
}
fn calculate_bleu_score(&self, candidate: &str, reference: &str) -> f32 {
let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
if candidate_tokens.is_empty() || reference_tokens.is_empty() {
return 0.0;
}
let max_n = 4;
let mut log_precision_sum = 0.0;
let mut valid_n_grams = 0;
for n in 1..=max_n {
let precision = self.calculate_ngram_precision(&candidate_tokens, &reference_tokens, n);
if precision > 0.0 {
log_precision_sum += precision.ln();
valid_n_grams += 1;
} else {
return 0.0;
}
}
let candidate_len = candidate_tokens.len() as f32;
let reference_len = reference_tokens.len() as f32;
let brevity_penalty = if candidate_len >= reference_len {
1.0
} else {
(1.0 - reference_len / candidate_len).exp()
};
let bleu = brevity_penalty * (log_precision_sum / valid_n_grams as f32).exp();
bleu.clamp(0.0, 1.0)
}
fn calculate_ngram_precision(&self, candidate: &[&str], reference: &[&str], n: usize) -> f32 {
if candidate.len() < n || reference.len() < n {
return 0.0;
}
let candidate_ngrams = self.extract_ngrams(candidate, n);
let reference_ngrams = self.extract_ngrams(reference, n);
let mut reference_counts = std::collections::HashMap::new();
for ngram in &reference_ngrams {
*reference_counts.entry(ngram).or_insert(0) += 1;
}
let mut clipped_matches = 0;
let mut candidate_counts = std::collections::HashMap::new();
for ngram in &candidate_ngrams {
let candidate_count = candidate_counts.entry(ngram).or_insert(0);
*candidate_count += 1;
if let Some(&ref_count) = reference_counts.get(&ngram) {
if *candidate_count <= ref_count {
clipped_matches += 1;
}
}
}
if candidate_ngrams.is_empty() {
0.0
} else {
clipped_matches as f32 / candidate_ngrams.len() as f32
}
}
fn extract_ngrams(&self, tokens: &[&str], n: usize) -> Vec<Vec<String>> {
if tokens.len() < n {
return Vec::new();
}
tokens
.windows(n)
.map(|window| window.iter().map(|&s| s.to_string()).collect())
.collect()
}
fn calculate_rouge_l(&self, candidate: &str, reference: &str) -> f32 {
let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
if candidate_tokens.is_empty() || reference_tokens.is_empty() {
return 0.0;
}
let lcs_length = self.lcs_length(&candidate_tokens, &reference_tokens);
if lcs_length == 0 {
return 0.0;
}
let precision = lcs_length as f32 / candidate_tokens.len() as f32;
let recall = lcs_length as f32 / reference_tokens.len() as f32;
let beta = 1.2;
let beta_squared = beta * beta;
let f_score =
((1.0 + beta_squared) * precision * recall) / (beta_squared * precision + recall);
f_score.clamp(0.0, 1.0)
}
fn lcs_length(&self, seq1: &[&str], seq2: &[&str]) -> usize {
let m = seq1.len();
let n = seq2.len();
if m == 0 || n == 0 {
return 0;
}
let mut dp = vec![vec![0; n + 1]; m + 1];
for i in 1..=m {
for j in 1..=n {
if seq1[i - 1] == seq2[j - 1] {
dp[i][j] = dp[i - 1][j - 1] + 1;
} else {
dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
}
}
}
dp[m][n]
}
fn compute_summary(
&self,
config_name: String,
results: Vec<QueryBenchmark>,
) -> BenchmarkSummary {
let total = results.len();
if total == 0 {
return BenchmarkSummary {
config_name,
total_queries: 0,
avg_latency_ms: 0.0,
avg_retrieval_ms: 0.0,
avg_reranking_ms: 0.0,
avg_generation_ms: 0.0,
total_input_tokens: 0,
total_output_tokens: 0,
total_cost_usd: 0.0,
avg_tokens_per_query: 0.0,
avg_exact_match: 0.0,
avg_f1_score: 0.0,
avg_bleu_score: 0.0,
avg_rouge_l: 0.0,
features: Vec::new(),
query_results: results,
};
}
let avg_latency_ms = results
.iter()
.map(|r| r.latency.total_ms as f64)
.sum::<f64>()
/ total as f64;
let avg_retrieval_ms = results
.iter()
.map(|r| r.latency.retrieval_ms as f64)
.sum::<f64>()
/ total as f64;
let avg_reranking_ms = results
.iter()
.filter_map(|r| r.latency.reranking_ms)
.map(|ms| ms as f64)
.sum::<f64>()
/ total as f64;
let avg_generation_ms = results
.iter()
.map(|r| r.latency.generation_ms as f64)
.sum::<f64>()
/ total as f64;
let total_input_tokens: usize = results.iter().map(|r| r.tokens.input_tokens).sum();
let total_output_tokens: usize = results.iter().map(|r| r.tokens.output_tokens).sum();
let total_cost_usd: f64 = results.iter().map(|r| r.tokens.estimated_cost_usd).sum();
let avg_exact_match = results
.iter()
.map(|r| r.quality.exact_match as f64)
.sum::<f64>()
/ total as f64;
let avg_f1_score = results
.iter()
.map(|r| r.quality.f1_score as f64)
.sum::<f64>()
/ total as f64;
let bleu_scores: Vec<f64> = results
.iter()
.filter_map(|r| r.quality.bleu_score.map(|s| s as f64))
.collect();
let avg_bleu_score = if !bleu_scores.is_empty() {
bleu_scores.iter().sum::<f64>() / bleu_scores.len() as f64
} else {
0.0
};
let rouge_scores: Vec<f64> = results
.iter()
.filter_map(|r| r.quality.rouge_l.map(|s| s as f64))
.collect();
let avg_rouge_l = if !rouge_scores.is_empty() {
rouge_scores.iter().sum::<f64>() / rouge_scores.len() as f64
} else {
0.0
};
let features = if !results.is_empty() {
results[0].features_enabled.clone()
} else {
Vec::new()
};
BenchmarkSummary {
config_name,
total_queries: total,
avg_latency_ms,
avg_retrieval_ms,
avg_reranking_ms,
avg_generation_ms,
total_input_tokens,
total_output_tokens,
total_cost_usd,
avg_tokens_per_query: (total_input_tokens + total_output_tokens) as f64 / total as f64,
avg_exact_match,
avg_f1_score,
avg_bleu_score,
avg_rouge_l,
features,
query_results: results,
}
}
pub fn print_summary(&self, summary: &BenchmarkSummary) {
println!("\n📊 Benchmark Results: {}", summary.config_name);
println!("{}", "=".repeat(60));
println!("\n🎯 Quality Metrics:");
println!(" Exact Match: {:.1}%", summary.avg_exact_match * 100.0);
println!(" F1 Score: {:.3}", summary.avg_f1_score);
if summary.avg_bleu_score > 0.0 {
println!(" BLEU Score: {:.3}", summary.avg_bleu_score);
}
if summary.avg_rouge_l > 0.0 {
println!(" ROUGE-L: {:.3}", summary.avg_rouge_l);
}
println!("\n⏱️ Latency Metrics (avg):");
println!(" Total: {:.1} ms", summary.avg_latency_ms);
println!(" Retrieval: {:.1} ms", summary.avg_retrieval_ms);
if summary.avg_reranking_ms > 0.0 {
println!(" Reranking: {:.1} ms", summary.avg_reranking_ms);
}
println!(" Generation: {:.1} ms", summary.avg_generation_ms);
println!("\n💰 Token & Cost Metrics:");
println!(" Input tokens: {}", summary.total_input_tokens);
println!(" Output tokens: {}", summary.total_output_tokens);
println!(" Total cost: ${:.4}", summary.total_cost_usd);
println!(" Avg tokens/query: {:.0}", summary.avg_tokens_per_query);
println!("\n✨ Features Enabled:");
for feature in &summary.features {
println!(" ✅ {}", feature);
}
println!("\n{}", "=".repeat(60));
}
pub fn compare_summaries(&self, baseline: &BenchmarkSummary, improved: &BenchmarkSummary) {
println!("\n📈 Benchmark Comparison");
println!("{}", "=".repeat(60));
println!("\nConfiguration:");
println!(" Baseline: {}", baseline.config_name);
println!(" Improved: {}", improved.config_name);
println!("\n🎯 Quality Improvements:");
let em_improvement = ((improved.avg_exact_match - baseline.avg_exact_match)
/ baseline.avg_exact_match)
* 100.0;
let f1_improvement =
((improved.avg_f1_score - baseline.avg_f1_score) / baseline.avg_f1_score) * 100.0;
println!(" Exact Match: {:+.1}%", em_improvement);
println!(" F1 Score: {:+.1}%", f1_improvement);
println!("\n💰 Cost Savings:");
let token_reduction = ((baseline.total_input_tokens - improved.total_input_tokens) as f64
/ baseline.total_input_tokens as f64)
* 100.0;
let cost_savings =
((baseline.total_cost_usd - improved.total_cost_usd) / baseline.total_cost_usd) * 100.0;
println!(
" Token reduction: {:.1}% ({} → {} tokens)",
token_reduction, baseline.total_input_tokens, improved.total_input_tokens
);
println!(
" Cost savings: {:.1}% (${:.4} → ${:.4})",
cost_savings, baseline.total_cost_usd, improved.total_cost_usd
);
println!("\n⏱️ Latency Changes:");
let latency_change =
((improved.avg_latency_ms - baseline.avg_latency_ms) / baseline.avg_latency_ms) * 100.0;
println!(
" Total latency: {:+.1}% ({:.1}ms → {:.1}ms)",
latency_change, baseline.avg_latency_ms, improved.avg_latency_ms
);
println!("\n{}", "=".repeat(60));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f1_score_calculation() {
let _runner = BenchmarkRunner::new(BenchmarkConfig::default());
let f1 = _runner.calculate_f1_score("hello world", "hello world");
assert!((f1 - 1.0).abs() < 0.001);
let f1 = _runner.calculate_f1_score("hello world", "hello there");
assert!(f1 > 0.0 && f1 < 1.0);
let f1 = _runner.calculate_f1_score("foo bar", "baz qux");
assert_eq!(f1, 0.0);
}
#[test]
fn test_benchmark_summary() {
let dataset = BenchmarkDataset {
name: "Test".to_string(),
queries: vec![BenchmarkQuery {
question: "What is 2+2?".to_string(),
answer: "4".to_string(),
context: None,
difficulty: None,
query_type: None,
}],
};
let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
let summary = runner.run_dataset(&dataset);
assert_eq!(summary.total_queries, 1);
assert!(summary.avg_latency_ms >= 0.0);
}
}