1use serde::{Deserialize, Serialize};
10use std::time::Instant;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QueryBenchmark {
15 pub query: String,
17
18 pub ground_truth: Option<String>,
20
21 pub generated_answer: String,
23
24 pub latency: LatencyMetrics,
26
27 pub tokens: TokenMetrics,
29
30 pub quality: QualityMetrics,
32
33 pub features_enabled: Vec<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LatencyMetrics {
40 pub total_ms: u64,
42
43 pub retrieval_ms: u64,
45
46 pub reranking_ms: Option<u64>,
48
49 pub generation_ms: u64,
51
52 pub other_ms: u64,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TokenMetrics {
59 pub input_tokens: usize,
61
62 pub output_tokens: usize,
64
65 pub total_tokens: usize,
67
68 pub estimated_cost_usd: f64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct QualityMetrics {
75 pub exact_match: f32,
77
78 pub f1_score: f32,
80
81 pub bleu_score: Option<f32>,
83
84 pub rouge_l: Option<f32>,
86
87 pub semantic_similarity: Option<f32>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct BenchmarkDataset {
94 pub name: String,
96
97 pub queries: Vec<BenchmarkQuery>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct BenchmarkQuery {
104 pub question: String,
106
107 pub answer: String,
109
110 pub context: Option<Vec<String>>,
112
113 pub difficulty: Option<String>,
115
116 pub query_type: Option<String>,
118}
119
120#[derive(Debug, Clone)]
122pub struct BenchmarkConfig {
123 pub enable_lightrag: bool,
125
126 pub enable_leiden: bool,
128
129 pub enable_cross_encoder: bool,
131
132 pub enable_hipporag: bool,
134
135 pub enable_semantic_chunking: bool,
137
138 pub top_k: usize,
140
141 pub input_token_price: f64,
143 pub output_token_price: f64,
145}
146
147impl Default for BenchmarkConfig {
148 fn default() -> Self {
149 Self {
150 enable_lightrag: false,
151 enable_leiden: false,
152 enable_cross_encoder: false,
153 enable_hipporag: false,
154 enable_semantic_chunking: false,
155 top_k: 10,
156 input_token_price: 0.0001, output_token_price: 0.0003, }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct BenchmarkSummary {
165 pub config_name: String,
167
168 pub total_queries: usize,
170
171 pub avg_latency_ms: f64,
173 pub avg_retrieval_ms: f64,
175 pub avg_reranking_ms: f64,
177 pub avg_generation_ms: f64,
179
180 pub total_input_tokens: usize,
183 pub total_output_tokens: usize,
185 pub total_cost_usd: f64,
187 pub avg_tokens_per_query: f64,
189
190 pub avg_exact_match: f64,
193 pub avg_f1_score: f64,
195 pub avg_bleu_score: f64,
197 pub avg_rouge_l: f64,
199
200 pub features: Vec<String>,
202
203 pub query_results: Vec<QueryBenchmark>,
205}
206
207type RetrievalFn = Box<dyn Fn(&str) -> Vec<String> + Send + Sync>;
208type RerankerFn = Box<dyn Fn(&[String]) -> Vec<String> + Send + Sync>;
209type LlmFn = Box<dyn Fn(&str, &[String]) -> String + Send + Sync>;
210
211pub struct BenchmarkRunner {
213 config: BenchmarkConfig,
214 retrieval_fn: Option<RetrievalFn>,
216 reranker_fn: Option<RerankerFn>,
218 llm_fn: Option<LlmFn>,
220}
221
222impl BenchmarkRunner {
223 pub fn new(config: BenchmarkConfig) -> Self {
225 Self {
226 config,
227 retrieval_fn: None,
228 reranker_fn: None,
229 llm_fn: None,
230 }
231 }
232
233 pub fn with_retrieval<F>(&mut self, f: F) -> &mut Self
245 where
246 F: Fn(&str) -> Vec<String> + Send + Sync + 'static,
247 {
248 self.retrieval_fn = Some(Box::new(f));
249 self
250 }
251
252 pub fn with_reranker<F>(&mut self, f: F) -> &mut Self
264 where
265 F: Fn(&[String]) -> Vec<String> + Send + Sync + 'static,
266 {
267 self.reranker_fn = Some(Box::new(f));
268 self
269 }
270
271 pub fn with_llm<F>(&mut self, f: F) -> &mut Self
283 where
284 F: Fn(&str, &[String]) -> String + Send + Sync + 'static,
285 {
286 self.llm_fn = Some(Box::new(f));
287 self
288 }
289
290 pub fn run_dataset(&mut self, dataset: &BenchmarkDataset) -> BenchmarkSummary {
292 println!("📊 Running benchmark on dataset: {}", dataset.name);
293 println!("📋 Queries: {}", dataset.queries.len());
294
295 let mut results = Vec::new();
296
297 for (i, query) in dataset.queries.iter().enumerate() {
298 println!(
299 " [{}/{}] Processing: {}...",
300 i + 1,
301 dataset.queries.len(),
302 &query.question.chars().take(50).collect::<String>()
303 );
304
305 let result = self.benchmark_query(query);
306 results.push(result);
307 }
308
309 self.compute_summary(dataset.name.clone(), results)
310 }
311
312 fn benchmark_query(&self, query: &BenchmarkQuery) -> QueryBenchmark {
314 let start = Instant::now();
315
316 let retrieval_start = Instant::now();
318 let retrieved_docs = if let Some(ref retrieval_fn) = self.retrieval_fn {
319 retrieval_fn(&query.question)
321 } else {
322 vec![]
324 };
325 let retrieval_time = retrieval_start.elapsed();
326
327 let (reranked_docs, reranking_time) = if self.config.enable_cross_encoder {
329 let reranking_start = Instant::now();
330 let reranked = if let Some(ref reranker_fn) = self.reranker_fn {
331 reranker_fn(&retrieved_docs)
333 } else {
334 retrieved_docs.clone()
336 };
337 (reranked, Some(reranking_start.elapsed()))
338 } else {
339 (retrieved_docs.clone(), None)
340 };
341
342 let generation_start = Instant::now();
344 let generated_answer = if let Some(ref llm_fn) = self.llm_fn {
345 llm_fn(&query.question, &reranked_docs)
347 } else {
348 format!("Generated answer for: {}", query.question)
350 };
351 let generation_time = generation_start.elapsed();
352
353 let total_time = start.elapsed();
354
355 let estimated_input_tokens = if self.config.enable_lightrag {
357 200 } else {
359 2000 };
361
362 let estimated_output_tokens = 100;
363
364 let tokens = TokenMetrics {
365 input_tokens: estimated_input_tokens,
366 output_tokens: estimated_output_tokens,
367 total_tokens: estimated_input_tokens + estimated_output_tokens,
368 estimated_cost_usd: (estimated_input_tokens as f64 / 1000.0
369 * self.config.input_token_price)
370 + (estimated_output_tokens as f64 / 1000.0 * self.config.output_token_price),
371 };
372
373 let quality = self.calculate_quality_metrics(&generated_answer, &query.answer);
375
376 let mut features = Vec::new();
378 if self.config.enable_lightrag {
379 features.push("LightRAG".to_string());
380 }
381 if self.config.enable_leiden {
382 features.push("Leiden".to_string());
383 }
384 if self.config.enable_cross_encoder {
385 features.push("Cross-Encoder".to_string());
386 }
387 if self.config.enable_hipporag {
388 features.push("HippoRAG PPR".to_string());
389 }
390 if self.config.enable_semantic_chunking {
391 features.push("Semantic Chunking".to_string());
392 }
393
394 QueryBenchmark {
395 query: query.question.clone(),
396 ground_truth: Some(query.answer.clone()),
397 generated_answer,
398 latency: LatencyMetrics {
399 total_ms: total_time.as_millis() as u64,
400 retrieval_ms: retrieval_time.as_millis() as u64,
401 reranking_ms: reranking_time.map(|d| d.as_millis() as u64),
402 generation_ms: generation_time.as_millis() as u64,
403 other_ms: 0,
404 },
405 tokens,
406 quality,
407 features_enabled: features,
408 }
409 }
410
411 fn calculate_quality_metrics(&self, generated: &str, ground_truth: &str) -> QualityMetrics {
413 let exact_match = if generated.trim().eq_ignore_ascii_case(ground_truth.trim()) {
415 1.0
416 } else {
417 0.0
418 };
419
420 let f1_score = self.calculate_f1_score(generated, ground_truth);
422
423 let bleu_score = Some(self.calculate_bleu_score(generated, ground_truth));
425
426 let rouge_l = Some(self.calculate_rouge_l(generated, ground_truth));
428
429 QualityMetrics {
430 exact_match,
431 f1_score,
432 bleu_score,
433 rouge_l,
434 semantic_similarity: None,
435 }
436 }
437
438 fn calculate_f1_score(&self, generated: &str, ground_truth: &str) -> f32 {
440 let gen_tokens: Vec<String> = generated
441 .to_lowercase()
442 .split_whitespace()
443 .map(|s| s.to_string())
444 .collect();
445
446 let gt_tokens: Vec<String> = ground_truth
447 .to_lowercase()
448 .split_whitespace()
449 .map(|s| s.to_string())
450 .collect();
451
452 if gen_tokens.is_empty() || gt_tokens.is_empty() {
453 return 0.0;
454 }
455
456 let mut common = 0;
458 for token in &gen_tokens {
459 if gt_tokens.contains(token) {
460 common += 1;
461 }
462 }
463
464 if common == 0 {
465 return 0.0;
466 }
467
468 let precision = common as f32 / gen_tokens.len() as f32;
469 let recall = common as f32 / gt_tokens.len() as f32;
470
471 2.0 * (precision * recall) / (precision + recall)
472 }
473
474 fn calculate_bleu_score(&self, candidate: &str, reference: &str) -> f32 {
482 let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
484 let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
485
486 if candidate_tokens.is_empty() || reference_tokens.is_empty() {
487 return 0.0;
488 }
489
490 let max_n = 4;
492 let mut log_precision_sum = 0.0;
493 let mut valid_n_grams = 0;
494
495 for n in 1..=max_n {
496 let precision = self.calculate_ngram_precision(&candidate_tokens, &reference_tokens, n);
497
498 if precision > 0.0 {
499 log_precision_sum += precision.ln();
500 valid_n_grams += 1;
501 } else {
502 return 0.0;
504 }
505 }
506
507 let candidate_len = candidate_tokens.len() as f32;
509 let reference_len = reference_tokens.len() as f32;
510
511 let brevity_penalty = if candidate_len >= reference_len {
512 1.0
513 } else {
514 (1.0 - reference_len / candidate_len).exp()
515 };
516
517 let bleu = brevity_penalty * (log_precision_sum / valid_n_grams as f32).exp();
519
520 bleu.clamp(0.0, 1.0)
522 }
523
524 fn calculate_ngram_precision(&self, candidate: &[&str], reference: &[&str], n: usize) -> f32 {
526 if candidate.len() < n || reference.len() < n {
527 return 0.0;
528 }
529
530 let candidate_ngrams = self.extract_ngrams(candidate, n);
532
533 let reference_ngrams = self.extract_ngrams(reference, n);
535 let mut reference_counts = std::collections::HashMap::new();
536 for ngram in &reference_ngrams {
537 *reference_counts.entry(ngram).or_insert(0) += 1;
538 }
539
540 let mut clipped_matches = 0;
542 let mut candidate_counts = std::collections::HashMap::new();
543
544 for ngram in &candidate_ngrams {
545 let candidate_count = candidate_counts.entry(ngram).or_insert(0);
546 *candidate_count += 1;
547
548 if let Some(&ref_count) = reference_counts.get(&ngram) {
549 if *candidate_count <= ref_count {
550 clipped_matches += 1;
551 }
552 }
553 }
554
555 if candidate_ngrams.is_empty() {
557 0.0
558 } else {
559 clipped_matches as f32 / candidate_ngrams.len() as f32
560 }
561 }
562
563 fn extract_ngrams(&self, tokens: &[&str], n: usize) -> Vec<Vec<String>> {
565 if tokens.len() < n {
566 return Vec::new();
567 }
568
569 tokens
570 .windows(n)
571 .map(|window| window.iter().map(|&s| s.to_string()).collect())
572 .collect()
573 }
574
575 fn calculate_rouge_l(&self, candidate: &str, reference: &str) -> f32 {
583 let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
585 let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
586
587 if candidate_tokens.is_empty() || reference_tokens.is_empty() {
588 return 0.0;
589 }
590
591 let lcs_length = self.lcs_length(&candidate_tokens, &reference_tokens);
593
594 if lcs_length == 0 {
595 return 0.0;
596 }
597
598 let precision = lcs_length as f32 / candidate_tokens.len() as f32;
600 let recall = lcs_length as f32 / reference_tokens.len() as f32;
601
602 let beta = 1.2;
604 let beta_squared = beta * beta;
605
606 let f_score =
607 ((1.0 + beta_squared) * precision * recall) / (beta_squared * precision + recall);
608
609 f_score.clamp(0.0, 1.0)
611 }
612
613 fn lcs_length(&self, seq1: &[&str], seq2: &[&str]) -> usize {
620 let m = seq1.len();
621 let n = seq2.len();
622
623 if m == 0 || n == 0 {
624 return 0;
625 }
626
627 let mut dp = vec![vec![0; n + 1]; m + 1];
629
630 for i in 1..=m {
632 for j in 1..=n {
633 if seq1[i - 1] == seq2[j - 1] {
634 dp[i][j] = dp[i - 1][j - 1] + 1;
636 } else {
637 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
639 }
640 }
641 }
642
643 dp[m][n]
644 }
645
646 fn compute_summary(
648 &self,
649 config_name: String,
650 results: Vec<QueryBenchmark>,
651 ) -> BenchmarkSummary {
652 let total = results.len();
653
654 if total == 0 {
655 return BenchmarkSummary {
656 config_name,
657 total_queries: 0,
658 avg_latency_ms: 0.0,
659 avg_retrieval_ms: 0.0,
660 avg_reranking_ms: 0.0,
661 avg_generation_ms: 0.0,
662 total_input_tokens: 0,
663 total_output_tokens: 0,
664 total_cost_usd: 0.0,
665 avg_tokens_per_query: 0.0,
666 avg_exact_match: 0.0,
667 avg_f1_score: 0.0,
668 avg_bleu_score: 0.0,
669 avg_rouge_l: 0.0,
670 features: Vec::new(),
671 query_results: results,
672 };
673 }
674
675 let avg_latency_ms = results
676 .iter()
677 .map(|r| r.latency.total_ms as f64)
678 .sum::<f64>()
679 / total as f64;
680 let avg_retrieval_ms = results
681 .iter()
682 .map(|r| r.latency.retrieval_ms as f64)
683 .sum::<f64>()
684 / total as f64;
685 let avg_reranking_ms = results
686 .iter()
687 .filter_map(|r| r.latency.reranking_ms)
688 .map(|ms| ms as f64)
689 .sum::<f64>()
690 / total as f64;
691 let avg_generation_ms = results
692 .iter()
693 .map(|r| r.latency.generation_ms as f64)
694 .sum::<f64>()
695 / total as f64;
696
697 let total_input_tokens: usize = results.iter().map(|r| r.tokens.input_tokens).sum();
698 let total_output_tokens: usize = results.iter().map(|r| r.tokens.output_tokens).sum();
699 let total_cost_usd: f64 = results.iter().map(|r| r.tokens.estimated_cost_usd).sum();
700
701 let avg_exact_match = results
702 .iter()
703 .map(|r| r.quality.exact_match as f64)
704 .sum::<f64>()
705 / total as f64;
706 let avg_f1_score = results
707 .iter()
708 .map(|r| r.quality.f1_score as f64)
709 .sum::<f64>()
710 / total as f64;
711
712 let bleu_scores: Vec<f64> = results
714 .iter()
715 .filter_map(|r| r.quality.bleu_score.map(|s| s as f64))
716 .collect();
717 let avg_bleu_score = if !bleu_scores.is_empty() {
718 bleu_scores.iter().sum::<f64>() / bleu_scores.len() as f64
719 } else {
720 0.0
721 };
722
723 let rouge_scores: Vec<f64> = results
725 .iter()
726 .filter_map(|r| r.quality.rouge_l.map(|s| s as f64))
727 .collect();
728 let avg_rouge_l = if !rouge_scores.is_empty() {
729 rouge_scores.iter().sum::<f64>() / rouge_scores.len() as f64
730 } else {
731 0.0
732 };
733
734 let features = if !results.is_empty() {
735 results[0].features_enabled.clone()
736 } else {
737 Vec::new()
738 };
739
740 BenchmarkSummary {
741 config_name,
742 total_queries: total,
743 avg_latency_ms,
744 avg_retrieval_ms,
745 avg_reranking_ms,
746 avg_generation_ms,
747 total_input_tokens,
748 total_output_tokens,
749 total_cost_usd,
750 avg_tokens_per_query: (total_input_tokens + total_output_tokens) as f64 / total as f64,
751 avg_exact_match,
752 avg_f1_score,
753 avg_bleu_score,
754 avg_rouge_l,
755 features,
756 query_results: results,
757 }
758 }
759
760 pub fn print_summary(&self, summary: &BenchmarkSummary) {
762 println!("\n📊 Benchmark Results: {}", summary.config_name);
763 println!("{}", "=".repeat(60));
764
765 println!("\n🎯 Quality Metrics:");
766 println!(" Exact Match: {:.1}%", summary.avg_exact_match * 100.0);
767 println!(" F1 Score: {:.3}", summary.avg_f1_score);
768 if summary.avg_bleu_score > 0.0 {
769 println!(" BLEU Score: {:.3}", summary.avg_bleu_score);
770 }
771 if summary.avg_rouge_l > 0.0 {
772 println!(" ROUGE-L: {:.3}", summary.avg_rouge_l);
773 }
774
775 println!("\n⏱️ Latency Metrics (avg):");
776 println!(" Total: {:.1} ms", summary.avg_latency_ms);
777 println!(" Retrieval: {:.1} ms", summary.avg_retrieval_ms);
778 if summary.avg_reranking_ms > 0.0 {
779 println!(" Reranking: {:.1} ms", summary.avg_reranking_ms);
780 }
781 println!(" Generation: {:.1} ms", summary.avg_generation_ms);
782
783 println!("\n💰 Token & Cost Metrics:");
784 println!(" Input tokens: {}", summary.total_input_tokens);
785 println!(" Output tokens: {}", summary.total_output_tokens);
786 println!(" Total cost: ${:.4}", summary.total_cost_usd);
787 println!(" Avg tokens/query: {:.0}", summary.avg_tokens_per_query);
788
789 println!("\n✨ Features Enabled:");
790 for feature in &summary.features {
791 println!(" ✅ {}", feature);
792 }
793
794 println!("\n{}", "=".repeat(60));
795 }
796
797 pub fn compare_summaries(&self, baseline: &BenchmarkSummary, improved: &BenchmarkSummary) {
799 println!("\n📈 Benchmark Comparison");
800 println!("{}", "=".repeat(60));
801
802 println!("\nConfiguration:");
803 println!(" Baseline: {}", baseline.config_name);
804 println!(" Improved: {}", improved.config_name);
805
806 println!("\n🎯 Quality Improvements:");
807 let em_improvement = ((improved.avg_exact_match - baseline.avg_exact_match)
808 / baseline.avg_exact_match)
809 * 100.0;
810 let f1_improvement =
811 ((improved.avg_f1_score - baseline.avg_f1_score) / baseline.avg_f1_score) * 100.0;
812 println!(" Exact Match: {:+.1}%", em_improvement);
813 println!(" F1 Score: {:+.1}%", f1_improvement);
814
815 println!("\n💰 Cost Savings:");
816 let token_reduction = ((baseline.total_input_tokens - improved.total_input_tokens) as f64
817 / baseline.total_input_tokens as f64)
818 * 100.0;
819 let cost_savings =
820 ((baseline.total_cost_usd - improved.total_cost_usd) / baseline.total_cost_usd) * 100.0;
821 println!(
822 " Token reduction: {:.1}% ({} → {} tokens)",
823 token_reduction, baseline.total_input_tokens, improved.total_input_tokens
824 );
825 println!(
826 " Cost savings: {:.1}% (${:.4} → ${:.4})",
827 cost_savings, baseline.total_cost_usd, improved.total_cost_usd
828 );
829
830 println!("\n⏱️ Latency Changes:");
831 let latency_change =
832 ((improved.avg_latency_ms - baseline.avg_latency_ms) / baseline.avg_latency_ms) * 100.0;
833 println!(
834 " Total latency: {:+.1}% ({:.1}ms → {:.1}ms)",
835 latency_change, baseline.avg_latency_ms, improved.avg_latency_ms
836 );
837
838 println!("\n{}", "=".repeat(60));
839 }
840}
841
842#[cfg(test)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_f1_score_calculation() {
848 let _runner = BenchmarkRunner::new(BenchmarkConfig::default());
849
850 let f1 = _runner.calculate_f1_score("hello world", "hello world");
852 assert!((f1 - 1.0).abs() < 0.001);
853
854 let f1 = _runner.calculate_f1_score("hello world", "hello there");
856 assert!(f1 > 0.0 && f1 < 1.0);
857
858 let f1 = _runner.calculate_f1_score("foo bar", "baz qux");
860 assert_eq!(f1, 0.0);
861 }
862
863 #[test]
864 fn test_benchmark_summary() {
865 let dataset = BenchmarkDataset {
866 name: "Test".to_string(),
867 queries: vec![BenchmarkQuery {
868 question: "What is 2+2?".to_string(),
869 answer: "4".to_string(),
870 context: None,
871 difficulty: None,
872 query_type: None,
873 }],
874 };
875
876 let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
877 let summary = runner.run_dataset(&dataset);
878
879 assert_eq!(summary.total_queries, 1);
880 assert!(summary.avg_latency_ms >= 0.0);
881 }
882}