Skip to main content

graphrag_core/monitoring/
benchmark.rs

1//! Benchmarking system for GraphRAG quality improvements
2//!
3//! This module provides comprehensive benchmarking tools to measure:
4//! - Accuracy improvements from new features
5//! - Token usage and cost reduction
6//! - Latency and throughput
7//! - Quality metrics (F1, Exact Match, BLEU)
8
9use std::time::Instant;
10use serde::{Deserialize, Serialize};
11
12/// Benchmark results for a single query
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QueryBenchmark {
15    /// The query text
16    pub query: String,
17
18    /// Ground truth answer (if available)
19    pub ground_truth: Option<String>,
20
21    /// Generated answer
22    pub generated_answer: String,
23
24    /// Latency measurements
25    pub latency: LatencyMetrics,
26
27    /// Token usage
28    pub tokens: TokenMetrics,
29
30    /// Quality scores
31    pub quality: QualityMetrics,
32
33    /// Feature flags used
34    pub features_enabled: Vec<String>,
35}
36
37/// Latency breakdown by pipeline stage
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LatencyMetrics {
40    /// Total end-to-end latency
41    pub total_ms: u64,
42
43    /// Retrieval latency
44    pub retrieval_ms: u64,
45
46    /// Reranking latency (if enabled)
47    pub reranking_ms: Option<u64>,
48
49    /// Generation latency
50    pub generation_ms: u64,
51
52    /// Other processing time
53    pub other_ms: u64,
54}
55
56/// Token usage tracking
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TokenMetrics {
59    /// Input tokens to LLM
60    pub input_tokens: usize,
61
62    /// Output tokens from LLM
63    pub output_tokens: usize,
64
65    /// Total tokens
66    pub total_tokens: usize,
67
68    /// Estimated cost (USD)
69    pub estimated_cost_usd: f64,
70}
71
72/// Quality metrics for answer evaluation
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct QualityMetrics {
75    /// Exact match with ground truth (0.0 or 1.0)
76    pub exact_match: f32,
77
78    /// F1 score (token overlap)
79    pub f1_score: f32,
80
81    /// BLEU score (n-gram similarity)
82    pub bleu_score: Option<f32>,
83
84    /// ROUGE-L score (longest common subsequence)
85    pub rouge_l: Option<f32>,
86
87    /// Semantic similarity (if embeddings available)
88    pub semantic_similarity: Option<f32>,
89}
90
91/// Dataset for benchmarking
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct BenchmarkDataset {
94    /// Dataset name (e.g., "HotpotQA", "MuSiQue")
95    pub name: String,
96
97    /// List of queries with ground truth
98    pub queries: Vec<BenchmarkQuery>,
99}
100
101/// A single query with ground truth for evaluation
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct BenchmarkQuery {
104    /// Question text
105    pub question: String,
106
107    /// Ground truth answer
108    pub answer: String,
109
110    /// Supporting documents (if applicable)
111    pub context: Option<Vec<String>>,
112
113    /// Query difficulty (easy, medium, hard)
114    pub difficulty: Option<String>,
115
116    /// Query type (factual, multi-hop, reasoning)
117    pub query_type: Option<String>,
118}
119
120/// Configuration for benchmark runs
121#[derive(Debug, Clone)]
122pub struct BenchmarkConfig {
123    /// Enable LightRAG dual-level retrieval
124    pub enable_lightrag: bool,
125
126    /// Enable Leiden community detection
127    pub enable_leiden: bool,
128
129    /// Enable cross-encoder reranking
130    pub enable_cross_encoder: bool,
131
132    /// Enable HippoRAG PPR
133    pub enable_hipporag: bool,
134
135    /// Enable semantic chunking
136    pub enable_semantic_chunking: bool,
137
138    /// Number of retrieval candidates
139    pub top_k: usize,
140
141    /// LLM pricing (USD per 1K tokens)
142    pub input_token_price: f64,
143    /// Output token pricing (USD per 1K tokens)
144    pub output_token_price: f64,
145}
146
147impl Default for BenchmarkConfig {
148    fn default() -> Self {
149        Self {
150            enable_lightrag: false,
151            enable_leiden: false,
152            enable_cross_encoder: false,
153            enable_hipporag: false,
154            enable_semantic_chunking: false,
155            top_k: 10,
156            input_token_price: 0.0001,   // Example: $0.10 per 1M tokens
157            output_token_price: 0.0003,  // Example: $0.30 per 1M tokens
158        }
159    }
160}
161
162/// Aggregate benchmark results across multiple queries
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct BenchmarkSummary {
165    /// Configuration used
166    pub config_name: String,
167
168    /// Number of queries evaluated
169    pub total_queries: usize,
170
171    /// Average metrics
172    pub avg_latency_ms: f64,
173    /// Average retrieval latency in milliseconds
174    pub avg_retrieval_ms: f64,
175    /// Average reranking latency in milliseconds
176    pub avg_reranking_ms: f64,
177    /// Average generation latency in milliseconds
178    pub avg_generation_ms: f64,
179
180    /// Token statistics
181    /// Total input tokens across all queries
182    pub total_input_tokens: usize,
183    /// Total output tokens across all queries
184    pub total_output_tokens: usize,
185    /// Total cost in USD
186    pub total_cost_usd: f64,
187    /// Average tokens per query
188    pub avg_tokens_per_query: f64,
189
190    /// Quality statistics
191    /// Average exact match score
192    pub avg_exact_match: f64,
193    /// Average F1 score
194    pub avg_f1_score: f64,
195    /// Average BLEU score
196    pub avg_bleu_score: f64,
197    /// Average ROUGE-L score
198    pub avg_rouge_l: f64,
199
200    /// Features enabled
201    pub features: Vec<String>,
202
203    /// Per-query results
204    pub query_results: Vec<QueryBenchmark>,
205}
206
207/// Main benchmarking coordinator
208pub struct BenchmarkRunner {
209    config: BenchmarkConfig,
210}
211
212impl BenchmarkRunner {
213    /// Create a new benchmark runner
214    pub fn new(config: BenchmarkConfig) -> Self {
215        Self {
216            config,
217        }
218    }
219
220    /// Run benchmark on a dataset
221    pub fn run_dataset(&mut self, dataset: &BenchmarkDataset) -> BenchmarkSummary {
222        println!("šŸ“Š Running benchmark on dataset: {}", dataset.name);
223        println!("šŸ“‹ Queries: {}", dataset.queries.len());
224
225        let mut results = Vec::new();
226
227        for (i, query) in dataset.queries.iter().enumerate() {
228            println!("  [{}/{}] Processing: {}...", i + 1, dataset.queries.len(),
229                &query.question.chars().take(50).collect::<String>());
230
231            let result = self.benchmark_query(query);
232            results.push(result);
233        }
234
235        self.compute_summary(dataset.name.clone(), results)
236    }
237
238    /// Benchmark a single query
239    fn benchmark_query(&self, query: &BenchmarkQuery) -> QueryBenchmark {
240        let start = Instant::now();
241
242        // Simulate retrieval
243        let retrieval_start = Instant::now();
244        // TODO: Call actual retrieval system
245        let retrieval_time = retrieval_start.elapsed();
246
247        // Simulate reranking (if enabled)
248        let reranking_time = if self.config.enable_cross_encoder {
249            let reranking_start = Instant::now();
250            // TODO: Call cross-encoder reranking
251            Some(reranking_start.elapsed())
252        } else {
253            None
254        };
255
256        // Simulate generation
257        let generation_start = Instant::now();
258        // TODO: Call actual LLM generation
259        let generated_answer = format!("Generated answer for: {}", query.question);
260        let generation_time = generation_start.elapsed();
261
262        let total_time = start.elapsed();
263
264        // Calculate token usage (estimated)
265        let estimated_input_tokens = if self.config.enable_lightrag {
266            200 // LightRAG optimization: much lower
267        } else {
268            2000 // Traditional GraphRAG: ~10x more
269        };
270
271        let estimated_output_tokens = 100;
272
273        let tokens = TokenMetrics {
274            input_tokens: estimated_input_tokens,
275            output_tokens: estimated_output_tokens,
276            total_tokens: estimated_input_tokens + estimated_output_tokens,
277            estimated_cost_usd: (estimated_input_tokens as f64 / 1000.0 * self.config.input_token_price)
278                + (estimated_output_tokens as f64 / 1000.0 * self.config.output_token_price),
279        };
280
281        // Calculate quality metrics
282        let quality = self.calculate_quality_metrics(&generated_answer, &query.answer);
283
284        // Collect enabled features
285        let mut features = Vec::new();
286        if self.config.enable_lightrag {
287            features.push("LightRAG".to_string());
288        }
289        if self.config.enable_leiden {
290            features.push("Leiden".to_string());
291        }
292        if self.config.enable_cross_encoder {
293            features.push("Cross-Encoder".to_string());
294        }
295        if self.config.enable_hipporag {
296            features.push("HippoRAG PPR".to_string());
297        }
298        if self.config.enable_semantic_chunking {
299            features.push("Semantic Chunking".to_string());
300        }
301
302        QueryBenchmark {
303            query: query.question.clone(),
304            ground_truth: Some(query.answer.clone()),
305            generated_answer,
306            latency: LatencyMetrics {
307                total_ms: total_time.as_millis() as u64,
308                retrieval_ms: retrieval_time.as_millis() as u64,
309                reranking_ms: reranking_time.map(|d| d.as_millis() as u64),
310                generation_ms: generation_time.as_millis() as u64,
311                other_ms: 0,
312            },
313            tokens,
314            quality,
315            features_enabled: features,
316        }
317    }
318
319    /// Calculate quality metrics
320    fn calculate_quality_metrics(&self, generated: &str, ground_truth: &str) -> QualityMetrics {
321        // Exact match
322        let exact_match = if generated.trim().eq_ignore_ascii_case(ground_truth.trim()) {
323            1.0
324        } else {
325            0.0
326        };
327
328        // F1 score (token overlap)
329        let f1_score = self.calculate_f1_score(generated, ground_truth);
330
331        QualityMetrics {
332            exact_match,
333            f1_score,
334            bleu_score: None,  // TODO: Implement BLEU
335            rouge_l: None,     // TODO: Implement ROUGE-L
336            semantic_similarity: None,
337        }
338    }
339
340    /// Calculate F1 score based on token overlap
341    fn calculate_f1_score(&self, generated: &str, ground_truth: &str) -> f32 {
342        let gen_tokens: Vec<String> = generated
343            .to_lowercase()
344            .split_whitespace()
345            .map(|s| s.to_string())
346            .collect();
347
348        let gt_tokens: Vec<String> = ground_truth
349            .to_lowercase()
350            .split_whitespace()
351            .map(|s| s.to_string())
352            .collect();
353
354        if gen_tokens.is_empty() || gt_tokens.is_empty() {
355            return 0.0;
356        }
357
358        // Calculate overlap
359        let mut common = 0;
360        for token in &gen_tokens {
361            if gt_tokens.contains(token) {
362                common += 1;
363            }
364        }
365
366        if common == 0 {
367            return 0.0;
368        }
369
370        let precision = common as f32 / gen_tokens.len() as f32;
371        let recall = common as f32 / gt_tokens.len() as f32;
372
373        2.0 * (precision * recall) / (precision + recall)
374    }
375
376    /// Compute aggregate summary
377    fn compute_summary(&self, config_name: String, results: Vec<QueryBenchmark>) -> BenchmarkSummary {
378        let total = results.len();
379
380        if total == 0 {
381            return BenchmarkSummary {
382                config_name,
383                total_queries: 0,
384                avg_latency_ms: 0.0,
385                avg_retrieval_ms: 0.0,
386                avg_reranking_ms: 0.0,
387                avg_generation_ms: 0.0,
388                total_input_tokens: 0,
389                total_output_tokens: 0,
390                total_cost_usd: 0.0,
391                avg_tokens_per_query: 0.0,
392                avg_exact_match: 0.0,
393                avg_f1_score: 0.0,
394                avg_bleu_score: 0.0,
395                avg_rouge_l: 0.0,
396                features: Vec::new(),
397                query_results: results,
398            };
399        }
400
401        let avg_latency_ms = results.iter().map(|r| r.latency.total_ms as f64).sum::<f64>() / total as f64;
402        let avg_retrieval_ms = results.iter().map(|r| r.latency.retrieval_ms as f64).sum::<f64>() / total as f64;
403        let avg_reranking_ms = results.iter()
404            .filter_map(|r| r.latency.reranking_ms)
405            .map(|ms| ms as f64)
406            .sum::<f64>() / total as f64;
407        let avg_generation_ms = results.iter().map(|r| r.latency.generation_ms as f64).sum::<f64>() / total as f64;
408
409        let total_input_tokens: usize = results.iter().map(|r| r.tokens.input_tokens).sum();
410        let total_output_tokens: usize = results.iter().map(|r| r.tokens.output_tokens).sum();
411        let total_cost_usd: f64 = results.iter().map(|r| r.tokens.estimated_cost_usd).sum();
412
413        let avg_exact_match = results.iter().map(|r| r.quality.exact_match as f64).sum::<f64>() / total as f64;
414        let avg_f1_score = results.iter().map(|r| r.quality.f1_score as f64).sum::<f64>() / total as f64;
415
416        let features = if !results.is_empty() {
417            results[0].features_enabled.clone()
418        } else {
419            Vec::new()
420        };
421
422        BenchmarkSummary {
423            config_name,
424            total_queries: total,
425            avg_latency_ms,
426            avg_retrieval_ms,
427            avg_reranking_ms,
428            avg_generation_ms,
429            total_input_tokens,
430            total_output_tokens,
431            total_cost_usd,
432            avg_tokens_per_query: (total_input_tokens + total_output_tokens) as f64 / total as f64,
433            avg_exact_match,
434            avg_f1_score,
435            avg_bleu_score: 0.0,  // TODO
436            avg_rouge_l: 0.0,     // TODO
437            features,
438            query_results: results,
439        }
440    }
441
442    /// Print summary results
443    pub fn print_summary(&self, summary: &BenchmarkSummary) {
444        println!("\nšŸ“Š Benchmark Results: {}", summary.config_name);
445        println!("{}", "=".repeat(60));
446
447        println!("\nšŸŽÆ Quality Metrics:");
448        println!("  Exact Match:  {:.1}%", summary.avg_exact_match * 100.0);
449        println!("  F1 Score:     {:.3}", summary.avg_f1_score);
450
451        println!("\nā±ļø  Latency Metrics (avg):");
452        println!("  Total:        {:.1} ms", summary.avg_latency_ms);
453        println!("  Retrieval:    {:.1} ms", summary.avg_retrieval_ms);
454        if summary.avg_reranking_ms > 0.0 {
455            println!("  Reranking:    {:.1} ms", summary.avg_reranking_ms);
456        }
457        println!("  Generation:   {:.1} ms", summary.avg_generation_ms);
458
459        println!("\nšŸ’° Token & Cost Metrics:");
460        println!("  Input tokens:  {}", summary.total_input_tokens);
461        println!("  Output tokens: {}", summary.total_output_tokens);
462        println!("  Total cost:    ${:.4}", summary.total_cost_usd);
463        println!("  Avg tokens/query: {:.0}", summary.avg_tokens_per_query);
464
465        println!("\n✨ Features Enabled:");
466        for feature in &summary.features {
467            println!("  āœ… {}", feature);
468        }
469
470        println!("\n{}", "=".repeat(60));
471    }
472
473    /// Compare two benchmark summaries
474    pub fn compare_summaries(&self, baseline: &BenchmarkSummary, improved: &BenchmarkSummary) {
475        println!("\nšŸ“ˆ Benchmark Comparison");
476        println!("{}", "=".repeat(60));
477
478        println!("\nConfiguration:");
479        println!("  Baseline: {}", baseline.config_name);
480        println!("  Improved: {}", improved.config_name);
481
482        println!("\nšŸŽÆ Quality Improvements:");
483        let em_improvement = ((improved.avg_exact_match - baseline.avg_exact_match) / baseline.avg_exact_match) * 100.0;
484        let f1_improvement = ((improved.avg_f1_score - baseline.avg_f1_score) / baseline.avg_f1_score) * 100.0;
485        println!("  Exact Match:  {:+.1}%", em_improvement);
486        println!("  F1 Score:     {:+.1}%", f1_improvement);
487
488        println!("\nšŸ’° Cost Savings:");
489        let token_reduction = ((baseline.total_input_tokens - improved.total_input_tokens) as f64 / baseline.total_input_tokens as f64) * 100.0;
490        let cost_savings = ((baseline.total_cost_usd - improved.total_cost_usd) / baseline.total_cost_usd) * 100.0;
491        println!("  Token reduction: {:.1}% ({} → {} tokens)",
492            token_reduction,
493            baseline.total_input_tokens,
494            improved.total_input_tokens
495        );
496        println!("  Cost savings:    {:.1}% (${:.4} → ${:.4})",
497            cost_savings,
498            baseline.total_cost_usd,
499            improved.total_cost_usd
500        );
501
502        println!("\nā±ļø  Latency Changes:");
503        let latency_change = ((improved.avg_latency_ms - baseline.avg_latency_ms) / baseline.avg_latency_ms) * 100.0;
504        println!("  Total latency: {:+.1}% ({:.1}ms → {:.1}ms)",
505            latency_change,
506            baseline.avg_latency_ms,
507            improved.avg_latency_ms
508        );
509
510        println!("\n{}", "=".repeat(60));
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_f1_score_calculation() {
520        let _runner = BenchmarkRunner::new(BenchmarkConfig::default());
521
522        // Perfect match
523        let f1 = _runner.calculate_f1_score("hello world", "hello world");
524        assert!((f1 - 1.0).abs() < 0.001);
525
526        // Partial overlap
527        let f1 = _runner.calculate_f1_score("hello world", "hello there");
528        assert!(f1 > 0.0 && f1 < 1.0);
529
530        // No overlap
531        let f1 = _runner.calculate_f1_score("foo bar", "baz qux");
532        assert_eq!(f1, 0.0);
533    }
534
535    #[test]
536    fn test_benchmark_summary() {
537        let dataset = BenchmarkDataset {
538            name: "Test".to_string(),
539            queries: vec![
540                BenchmarkQuery {
541                    question: "What is 2+2?".to_string(),
542                    answer: "4".to_string(),
543                    context: None,
544                    difficulty: None,
545                    query_type: None,
546                },
547            ],
548        };
549
550        let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
551        let summary = runner.run_dataset(&dataset);
552
553        assert_eq!(summary.total_queries, 1);
554        assert!(summary.avg_latency_ms >= 0.0);
555    }
556}