1use serde::{Deserialize, Serialize};
10use std::time::Instant;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QueryBenchmark {
15 pub query: String,
17
18 pub ground_truth: Option<String>,
20
21 pub generated_answer: String,
23
24 pub latency: LatencyMetrics,
26
27 pub tokens: TokenMetrics,
29
30 pub quality: QualityMetrics,
32
33 pub features_enabled: Vec<String>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LatencyMetrics {
40 pub total_ms: u64,
42
43 pub retrieval_ms: u64,
45
46 pub reranking_ms: Option<u64>,
48
49 pub generation_ms: u64,
51
52 pub other_ms: u64,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TokenMetrics {
59 pub input_tokens: usize,
61
62 pub output_tokens: usize,
64
65 pub total_tokens: usize,
67
68 pub estimated_cost_usd: f64,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct QualityMetrics {
75 pub exact_match: f32,
77
78 pub f1_score: f32,
80
81 pub bleu_score: Option<f32>,
83
84 pub rouge_l: Option<f32>,
86
87 pub semantic_similarity: Option<f32>,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct BenchmarkDataset {
94 pub name: String,
96
97 pub queries: Vec<BenchmarkQuery>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct BenchmarkQuery {
104 pub question: String,
106
107 pub answer: String,
109
110 pub context: Option<Vec<String>>,
112
113 pub difficulty: Option<String>,
115
116 pub query_type: Option<String>,
118}
119
120#[derive(Debug, Clone)]
122pub struct BenchmarkConfig {
123 pub enable_lightrag: bool,
125
126 pub enable_leiden: bool,
128
129 pub enable_cross_encoder: bool,
131
132 pub enable_hipporag: bool,
134
135 pub enable_semantic_chunking: bool,
137
138 pub top_k: usize,
140
141 pub input_token_price: f64,
143 pub output_token_price: f64,
145}
146
147impl Default for BenchmarkConfig {
148 fn default() -> Self {
149 Self {
150 enable_lightrag: false,
151 enable_leiden: false,
152 enable_cross_encoder: false,
153 enable_hipporag: false,
154 enable_semantic_chunking: false,
155 top_k: 10,
156 input_token_price: 0.0001, output_token_price: 0.0003, }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct BenchmarkSummary {
165 pub config_name: String,
167
168 pub total_queries: usize,
170
171 pub avg_latency_ms: f64,
173 pub avg_retrieval_ms: f64,
175 pub avg_reranking_ms: f64,
177 pub avg_generation_ms: f64,
179
180 pub total_input_tokens: usize,
183 pub total_output_tokens: usize,
185 pub total_cost_usd: f64,
187 pub avg_tokens_per_query: f64,
189
190 pub avg_exact_match: f64,
193 pub avg_f1_score: f64,
195 pub avg_bleu_score: f64,
197 pub avg_rouge_l: f64,
199
200 pub features: Vec<String>,
202
203 pub query_results: Vec<QueryBenchmark>,
205}
206
207pub struct BenchmarkRunner {
209 config: BenchmarkConfig,
210 retrieval_fn: Option<Box<dyn Fn(&str) -> Vec<String> + Send + Sync>>,
212 reranker_fn: Option<Box<dyn Fn(&[String]) -> Vec<String> + Send + Sync>>,
214 llm_fn: Option<Box<dyn Fn(&str, &[String]) -> String + Send + Sync>>,
216}
217
218impl BenchmarkRunner {
219 pub fn new(config: BenchmarkConfig) -> Self {
221 Self {
222 config,
223 retrieval_fn: None,
224 reranker_fn: None,
225 llm_fn: None,
226 }
227 }
228
229 pub fn with_retrieval<F>(&mut self, f: F) -> &mut Self
241 where
242 F: Fn(&str) -> Vec<String> + Send + Sync + 'static,
243 {
244 self.retrieval_fn = Some(Box::new(f));
245 self
246 }
247
248 pub fn with_reranker<F>(&mut self, f: F) -> &mut Self
260 where
261 F: Fn(&[String]) -> Vec<String> + Send + Sync + 'static,
262 {
263 self.reranker_fn = Some(Box::new(f));
264 self
265 }
266
267 pub fn with_llm<F>(&mut self, f: F) -> &mut Self
279 where
280 F: Fn(&str, &[String]) -> String + Send + Sync + 'static,
281 {
282 self.llm_fn = Some(Box::new(f));
283 self
284 }
285
286 pub fn run_dataset(&mut self, dataset: &BenchmarkDataset) -> BenchmarkSummary {
288 println!("📊 Running benchmark on dataset: {}", dataset.name);
289 println!("📋 Queries: {}", dataset.queries.len());
290
291 let mut results = Vec::new();
292
293 for (i, query) in dataset.queries.iter().enumerate() {
294 println!(
295 " [{}/{}] Processing: {}...",
296 i + 1,
297 dataset.queries.len(),
298 &query.question.chars().take(50).collect::<String>()
299 );
300
301 let result = self.benchmark_query(query);
302 results.push(result);
303 }
304
305 self.compute_summary(dataset.name.clone(), results)
306 }
307
308 fn benchmark_query(&self, query: &BenchmarkQuery) -> QueryBenchmark {
310 let start = Instant::now();
311
312 let retrieval_start = Instant::now();
314 let retrieved_docs = if let Some(ref retrieval_fn) = self.retrieval_fn {
315 retrieval_fn(&query.question)
317 } else {
318 vec![]
320 };
321 let retrieval_time = retrieval_start.elapsed();
322
323 let (reranked_docs, reranking_time) = if self.config.enable_cross_encoder {
325 let reranking_start = Instant::now();
326 let reranked = if let Some(ref reranker_fn) = self.reranker_fn {
327 reranker_fn(&retrieved_docs)
329 } else {
330 retrieved_docs.clone()
332 };
333 (reranked, Some(reranking_start.elapsed()))
334 } else {
335 (retrieved_docs.clone(), None)
336 };
337
338 let generation_start = Instant::now();
340 let generated_answer = if let Some(ref llm_fn) = self.llm_fn {
341 llm_fn(&query.question, &reranked_docs)
343 } else {
344 format!("Generated answer for: {}", query.question)
346 };
347 let generation_time = generation_start.elapsed();
348
349 let total_time = start.elapsed();
350
351 let estimated_input_tokens = if self.config.enable_lightrag {
353 200 } else {
355 2000 };
357
358 let estimated_output_tokens = 100;
359
360 let tokens = TokenMetrics {
361 input_tokens: estimated_input_tokens,
362 output_tokens: estimated_output_tokens,
363 total_tokens: estimated_input_tokens + estimated_output_tokens,
364 estimated_cost_usd: (estimated_input_tokens as f64 / 1000.0
365 * self.config.input_token_price)
366 + (estimated_output_tokens as f64 / 1000.0 * self.config.output_token_price),
367 };
368
369 let quality = self.calculate_quality_metrics(&generated_answer, &query.answer);
371
372 let mut features = Vec::new();
374 if self.config.enable_lightrag {
375 features.push("LightRAG".to_string());
376 }
377 if self.config.enable_leiden {
378 features.push("Leiden".to_string());
379 }
380 if self.config.enable_cross_encoder {
381 features.push("Cross-Encoder".to_string());
382 }
383 if self.config.enable_hipporag {
384 features.push("HippoRAG PPR".to_string());
385 }
386 if self.config.enable_semantic_chunking {
387 features.push("Semantic Chunking".to_string());
388 }
389
390 QueryBenchmark {
391 query: query.question.clone(),
392 ground_truth: Some(query.answer.clone()),
393 generated_answer,
394 latency: LatencyMetrics {
395 total_ms: total_time.as_millis() as u64,
396 retrieval_ms: retrieval_time.as_millis() as u64,
397 reranking_ms: reranking_time.map(|d| d.as_millis() as u64),
398 generation_ms: generation_time.as_millis() as u64,
399 other_ms: 0,
400 },
401 tokens,
402 quality,
403 features_enabled: features,
404 }
405 }
406
407 fn calculate_quality_metrics(&self, generated: &str, ground_truth: &str) -> QualityMetrics {
409 let exact_match = if generated.trim().eq_ignore_ascii_case(ground_truth.trim()) {
411 1.0
412 } else {
413 0.0
414 };
415
416 let f1_score = self.calculate_f1_score(generated, ground_truth);
418
419 let bleu_score = Some(self.calculate_bleu_score(generated, ground_truth));
421
422 let rouge_l = Some(self.calculate_rouge_l(generated, ground_truth));
424
425 QualityMetrics {
426 exact_match,
427 f1_score,
428 bleu_score,
429 rouge_l,
430 semantic_similarity: None,
431 }
432 }
433
434 fn calculate_f1_score(&self, generated: &str, ground_truth: &str) -> f32 {
436 let gen_tokens: Vec<String> = generated
437 .to_lowercase()
438 .split_whitespace()
439 .map(|s| s.to_string())
440 .collect();
441
442 let gt_tokens: Vec<String> = ground_truth
443 .to_lowercase()
444 .split_whitespace()
445 .map(|s| s.to_string())
446 .collect();
447
448 if gen_tokens.is_empty() || gt_tokens.is_empty() {
449 return 0.0;
450 }
451
452 let mut common = 0;
454 for token in &gen_tokens {
455 if gt_tokens.contains(token) {
456 common += 1;
457 }
458 }
459
460 if common == 0 {
461 return 0.0;
462 }
463
464 let precision = common as f32 / gen_tokens.len() as f32;
465 let recall = common as f32 / gt_tokens.len() as f32;
466
467 2.0 * (precision * recall) / (precision + recall)
468 }
469
470 fn calculate_bleu_score(&self, candidate: &str, reference: &str) -> f32 {
478 let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
480 let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
481
482 if candidate_tokens.is_empty() || reference_tokens.is_empty() {
483 return 0.0;
484 }
485
486 let max_n = 4;
488 let mut log_precision_sum = 0.0;
489 let mut valid_n_grams = 0;
490
491 for n in 1..=max_n {
492 let precision = self.calculate_ngram_precision(&candidate_tokens, &reference_tokens, n);
493
494 if precision > 0.0 {
495 log_precision_sum += precision.ln();
496 valid_n_grams += 1;
497 } else {
498 return 0.0;
500 }
501 }
502
503 let candidate_len = candidate_tokens.len() as f32;
505 let reference_len = reference_tokens.len() as f32;
506
507 let brevity_penalty = if candidate_len >= reference_len {
508 1.0
509 } else {
510 (1.0 - reference_len / candidate_len).exp()
511 };
512
513 let bleu = brevity_penalty * (log_precision_sum / valid_n_grams as f32).exp();
515
516 bleu.max(0.0).min(1.0)
518 }
519
520 fn calculate_ngram_precision(&self, candidate: &[&str], reference: &[&str], n: usize) -> f32 {
522 if candidate.len() < n || reference.len() < n {
523 return 0.0;
524 }
525
526 let candidate_ngrams = self.extract_ngrams(candidate, n);
528
529 let reference_ngrams = self.extract_ngrams(reference, n);
531 let mut reference_counts = std::collections::HashMap::new();
532 for ngram in &reference_ngrams {
533 *reference_counts.entry(ngram).or_insert(0) += 1;
534 }
535
536 let mut clipped_matches = 0;
538 let mut candidate_counts = std::collections::HashMap::new();
539
540 for ngram in &candidate_ngrams {
541 let candidate_count = candidate_counts.entry(ngram).or_insert(0);
542 *candidate_count += 1;
543
544 if let Some(&ref_count) = reference_counts.get(&ngram) {
545 if *candidate_count <= ref_count {
546 clipped_matches += 1;
547 }
548 }
549 }
550
551 if candidate_ngrams.is_empty() {
553 0.0
554 } else {
555 clipped_matches as f32 / candidate_ngrams.len() as f32
556 }
557 }
558
559 fn extract_ngrams(&self, tokens: &[&str], n: usize) -> Vec<Vec<String>> {
561 if tokens.len() < n {
562 return Vec::new();
563 }
564
565 tokens
566 .windows(n)
567 .map(|window| window.iter().map(|&s| s.to_string()).collect())
568 .collect()
569 }
570
571 fn calculate_rouge_l(&self, candidate: &str, reference: &str) -> f32 {
579 let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
581 let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
582
583 if candidate_tokens.is_empty() || reference_tokens.is_empty() {
584 return 0.0;
585 }
586
587 let lcs_length = self.lcs_length(&candidate_tokens, &reference_tokens);
589
590 if lcs_length == 0 {
591 return 0.0;
592 }
593
594 let precision = lcs_length as f32 / candidate_tokens.len() as f32;
596 let recall = lcs_length as f32 / reference_tokens.len() as f32;
597
598 let beta = 1.2;
600 let beta_squared = beta * beta;
601
602 let f_score =
603 ((1.0 + beta_squared) * precision * recall) / (beta_squared * precision + recall);
604
605 f_score.max(0.0).min(1.0)
607 }
608
609 fn lcs_length(&self, seq1: &[&str], seq2: &[&str]) -> usize {
616 let m = seq1.len();
617 let n = seq2.len();
618
619 if m == 0 || n == 0 {
620 return 0;
621 }
622
623 let mut dp = vec![vec![0; n + 1]; m + 1];
625
626 for i in 1..=m {
628 for j in 1..=n {
629 if seq1[i - 1] == seq2[j - 1] {
630 dp[i][j] = dp[i - 1][j - 1] + 1;
632 } else {
633 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
635 }
636 }
637 }
638
639 dp[m][n]
640 }
641
642 fn compute_summary(
644 &self,
645 config_name: String,
646 results: Vec<QueryBenchmark>,
647 ) -> BenchmarkSummary {
648 let total = results.len();
649
650 if total == 0 {
651 return BenchmarkSummary {
652 config_name,
653 total_queries: 0,
654 avg_latency_ms: 0.0,
655 avg_retrieval_ms: 0.0,
656 avg_reranking_ms: 0.0,
657 avg_generation_ms: 0.0,
658 total_input_tokens: 0,
659 total_output_tokens: 0,
660 total_cost_usd: 0.0,
661 avg_tokens_per_query: 0.0,
662 avg_exact_match: 0.0,
663 avg_f1_score: 0.0,
664 avg_bleu_score: 0.0,
665 avg_rouge_l: 0.0,
666 features: Vec::new(),
667 query_results: results,
668 };
669 }
670
671 let avg_latency_ms = results
672 .iter()
673 .map(|r| r.latency.total_ms as f64)
674 .sum::<f64>()
675 / total as f64;
676 let avg_retrieval_ms = results
677 .iter()
678 .map(|r| r.latency.retrieval_ms as f64)
679 .sum::<f64>()
680 / total as f64;
681 let avg_reranking_ms = results
682 .iter()
683 .filter_map(|r| r.latency.reranking_ms)
684 .map(|ms| ms as f64)
685 .sum::<f64>()
686 / total as f64;
687 let avg_generation_ms = results
688 .iter()
689 .map(|r| r.latency.generation_ms as f64)
690 .sum::<f64>()
691 / total as f64;
692
693 let total_input_tokens: usize = results.iter().map(|r| r.tokens.input_tokens).sum();
694 let total_output_tokens: usize = results.iter().map(|r| r.tokens.output_tokens).sum();
695 let total_cost_usd: f64 = results.iter().map(|r| r.tokens.estimated_cost_usd).sum();
696
697 let avg_exact_match = results
698 .iter()
699 .map(|r| r.quality.exact_match as f64)
700 .sum::<f64>()
701 / total as f64;
702 let avg_f1_score = results
703 .iter()
704 .map(|r| r.quality.f1_score as f64)
705 .sum::<f64>()
706 / total as f64;
707
708 let bleu_scores: Vec<f64> = results
710 .iter()
711 .filter_map(|r| r.quality.bleu_score.map(|s| s as f64))
712 .collect();
713 let avg_bleu_score = if !bleu_scores.is_empty() {
714 bleu_scores.iter().sum::<f64>() / bleu_scores.len() as f64
715 } else {
716 0.0
717 };
718
719 let rouge_scores: Vec<f64> = results
721 .iter()
722 .filter_map(|r| r.quality.rouge_l.map(|s| s as f64))
723 .collect();
724 let avg_rouge_l = if !rouge_scores.is_empty() {
725 rouge_scores.iter().sum::<f64>() / rouge_scores.len() as f64
726 } else {
727 0.0
728 };
729
730 let features = if !results.is_empty() {
731 results[0].features_enabled.clone()
732 } else {
733 Vec::new()
734 };
735
736 BenchmarkSummary {
737 config_name,
738 total_queries: total,
739 avg_latency_ms,
740 avg_retrieval_ms,
741 avg_reranking_ms,
742 avg_generation_ms,
743 total_input_tokens,
744 total_output_tokens,
745 total_cost_usd,
746 avg_tokens_per_query: (total_input_tokens + total_output_tokens) as f64 / total as f64,
747 avg_exact_match,
748 avg_f1_score,
749 avg_bleu_score,
750 avg_rouge_l,
751 features,
752 query_results: results,
753 }
754 }
755
756 pub fn print_summary(&self, summary: &BenchmarkSummary) {
758 println!("\n📊 Benchmark Results: {}", summary.config_name);
759 println!("{}", "=".repeat(60));
760
761 println!("\n🎯 Quality Metrics:");
762 println!(" Exact Match: {:.1}%", summary.avg_exact_match * 100.0);
763 println!(" F1 Score: {:.3}", summary.avg_f1_score);
764 if summary.avg_bleu_score > 0.0 {
765 println!(" BLEU Score: {:.3}", summary.avg_bleu_score);
766 }
767 if summary.avg_rouge_l > 0.0 {
768 println!(" ROUGE-L: {:.3}", summary.avg_rouge_l);
769 }
770
771 println!("\n⏱️ Latency Metrics (avg):");
772 println!(" Total: {:.1} ms", summary.avg_latency_ms);
773 println!(" Retrieval: {:.1} ms", summary.avg_retrieval_ms);
774 if summary.avg_reranking_ms > 0.0 {
775 println!(" Reranking: {:.1} ms", summary.avg_reranking_ms);
776 }
777 println!(" Generation: {:.1} ms", summary.avg_generation_ms);
778
779 println!("\n💰 Token & Cost Metrics:");
780 println!(" Input tokens: {}", summary.total_input_tokens);
781 println!(" Output tokens: {}", summary.total_output_tokens);
782 println!(" Total cost: ${:.4}", summary.total_cost_usd);
783 println!(" Avg tokens/query: {:.0}", summary.avg_tokens_per_query);
784
785 println!("\n✨ Features Enabled:");
786 for feature in &summary.features {
787 println!(" ✅ {}", feature);
788 }
789
790 println!("\n{}", "=".repeat(60));
791 }
792
793 pub fn compare_summaries(&self, baseline: &BenchmarkSummary, improved: &BenchmarkSummary) {
795 println!("\n📈 Benchmark Comparison");
796 println!("{}", "=".repeat(60));
797
798 println!("\nConfiguration:");
799 println!(" Baseline: {}", baseline.config_name);
800 println!(" Improved: {}", improved.config_name);
801
802 println!("\n🎯 Quality Improvements:");
803 let em_improvement = ((improved.avg_exact_match - baseline.avg_exact_match)
804 / baseline.avg_exact_match)
805 * 100.0;
806 let f1_improvement =
807 ((improved.avg_f1_score - baseline.avg_f1_score) / baseline.avg_f1_score) * 100.0;
808 println!(" Exact Match: {:+.1}%", em_improvement);
809 println!(" F1 Score: {:+.1}%", f1_improvement);
810
811 println!("\n💰 Cost Savings:");
812 let token_reduction = ((baseline.total_input_tokens - improved.total_input_tokens) as f64
813 / baseline.total_input_tokens as f64)
814 * 100.0;
815 let cost_savings =
816 ((baseline.total_cost_usd - improved.total_cost_usd) / baseline.total_cost_usd) * 100.0;
817 println!(
818 " Token reduction: {:.1}% ({} → {} tokens)",
819 token_reduction, baseline.total_input_tokens, improved.total_input_tokens
820 );
821 println!(
822 " Cost savings: {:.1}% (${:.4} → ${:.4})",
823 cost_savings, baseline.total_cost_usd, improved.total_cost_usd
824 );
825
826 println!("\n⏱️ Latency Changes:");
827 let latency_change =
828 ((improved.avg_latency_ms - baseline.avg_latency_ms) / baseline.avg_latency_ms) * 100.0;
829 println!(
830 " Total latency: {:+.1}% ({:.1}ms → {:.1}ms)",
831 latency_change, baseline.avg_latency_ms, improved.avg_latency_ms
832 );
833
834 println!("\n{}", "=".repeat(60));
835 }
836}
837
838#[cfg(test)]
839mod tests {
840 use super::*;
841
842 #[test]
843 fn test_f1_score_calculation() {
844 let _runner = BenchmarkRunner::new(BenchmarkConfig::default());
845
846 let f1 = _runner.calculate_f1_score("hello world", "hello world");
848 assert!((f1 - 1.0).abs() < 0.001);
849
850 let f1 = _runner.calculate_f1_score("hello world", "hello there");
852 assert!(f1 > 0.0 && f1 < 1.0);
853
854 let f1 = _runner.calculate_f1_score("foo bar", "baz qux");
856 assert_eq!(f1, 0.0);
857 }
858
859 #[test]
860 fn test_benchmark_summary() {
861 let dataset = BenchmarkDataset {
862 name: "Test".to_string(),
863 queries: vec![BenchmarkQuery {
864 question: "What is 2+2?".to_string(),
865 answer: "4".to_string(),
866 context: None,
867 difficulty: None,
868 query_type: None,
869 }],
870 };
871
872 let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
873 let summary = runner.run_dataset(&dataset);
874
875 assert_eq!(summary.total_queries, 1);
876 assert!(summary.avg_latency_ms >= 0.0);
877 }
878}