Skip to main content

graphrag_core/monitoring/
benchmark.rs

1//! Benchmarking system for GraphRAG quality improvements
2//!
3//! This module provides comprehensive benchmarking tools to measure:
4//! - Accuracy improvements from new features
5//! - Token usage and cost reduction
6//! - Latency and throughput
7//! - Quality metrics (F1, Exact Match, BLEU)
8
9use serde::{Deserialize, Serialize};
10use std::time::Instant;
11
12/// Benchmark results for a single query
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct QueryBenchmark {
15    /// The query text
16    pub query: String,
17
18    /// Ground truth answer (if available)
19    pub ground_truth: Option<String>,
20
21    /// Generated answer
22    pub generated_answer: String,
23
24    /// Latency measurements
25    pub latency: LatencyMetrics,
26
27    /// Token usage
28    pub tokens: TokenMetrics,
29
30    /// Quality scores
31    pub quality: QualityMetrics,
32
33    /// Feature flags used
34    pub features_enabled: Vec<String>,
35}
36
37/// Latency breakdown by pipeline stage
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct LatencyMetrics {
40    /// Total end-to-end latency
41    pub total_ms: u64,
42
43    /// Retrieval latency
44    pub retrieval_ms: u64,
45
46    /// Reranking latency (if enabled)
47    pub reranking_ms: Option<u64>,
48
49    /// Generation latency
50    pub generation_ms: u64,
51
52    /// Other processing time
53    pub other_ms: u64,
54}
55
56/// Token usage tracking
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct TokenMetrics {
59    /// Input tokens to LLM
60    pub input_tokens: usize,
61
62    /// Output tokens from LLM
63    pub output_tokens: usize,
64
65    /// Total tokens
66    pub total_tokens: usize,
67
68    /// Estimated cost (USD)
69    pub estimated_cost_usd: f64,
70}
71
72/// Quality metrics for answer evaluation
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct QualityMetrics {
75    /// Exact match with ground truth (0.0 or 1.0)
76    pub exact_match: f32,
77
78    /// F1 score (token overlap)
79    pub f1_score: f32,
80
81    /// BLEU score (n-gram similarity)
82    pub bleu_score: Option<f32>,
83
84    /// ROUGE-L score (longest common subsequence)
85    pub rouge_l: Option<f32>,
86
87    /// Semantic similarity (if embeddings available)
88    pub semantic_similarity: Option<f32>,
89}
90
91/// Dataset for benchmarking
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct BenchmarkDataset {
94    /// Dataset name (e.g., "HotpotQA", "MuSiQue")
95    pub name: String,
96
97    /// List of queries with ground truth
98    pub queries: Vec<BenchmarkQuery>,
99}
100
101/// A single query with ground truth for evaluation
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct BenchmarkQuery {
104    /// Question text
105    pub question: String,
106
107    /// Ground truth answer
108    pub answer: String,
109
110    /// Supporting documents (if applicable)
111    pub context: Option<Vec<String>>,
112
113    /// Query difficulty (easy, medium, hard)
114    pub difficulty: Option<String>,
115
116    /// Query type (factual, multi-hop, reasoning)
117    pub query_type: Option<String>,
118}
119
120/// Configuration for benchmark runs
121#[derive(Debug, Clone)]
122pub struct BenchmarkConfig {
123    /// Enable LightRAG dual-level retrieval
124    pub enable_lightrag: bool,
125
126    /// Enable Leiden community detection
127    pub enable_leiden: bool,
128
129    /// Enable cross-encoder reranking
130    pub enable_cross_encoder: bool,
131
132    /// Enable HippoRAG PPR
133    pub enable_hipporag: bool,
134
135    /// Enable semantic chunking
136    pub enable_semantic_chunking: bool,
137
138    /// Number of retrieval candidates
139    pub top_k: usize,
140
141    /// LLM pricing (USD per 1K tokens)
142    pub input_token_price: f64,
143    /// Output token pricing (USD per 1K tokens)
144    pub output_token_price: f64,
145}
146
147impl Default for BenchmarkConfig {
148    fn default() -> Self {
149        Self {
150            enable_lightrag: false,
151            enable_leiden: false,
152            enable_cross_encoder: false,
153            enable_hipporag: false,
154            enable_semantic_chunking: false,
155            top_k: 10,
156            input_token_price: 0.0001,  // Example: $0.10 per 1M tokens
157            output_token_price: 0.0003, // Example: $0.30 per 1M tokens
158        }
159    }
160}
161
162/// Aggregate benchmark results across multiple queries
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct BenchmarkSummary {
165    /// Configuration used
166    pub config_name: String,
167
168    /// Number of queries evaluated
169    pub total_queries: usize,
170
171    /// Average metrics
172    pub avg_latency_ms: f64,
173    /// Average retrieval latency in milliseconds
174    pub avg_retrieval_ms: f64,
175    /// Average reranking latency in milliseconds
176    pub avg_reranking_ms: f64,
177    /// Average generation latency in milliseconds
178    pub avg_generation_ms: f64,
179
180    /// Token statistics
181    /// Total input tokens across all queries
182    pub total_input_tokens: usize,
183    /// Total output tokens across all queries
184    pub total_output_tokens: usize,
185    /// Total cost in USD
186    pub total_cost_usd: f64,
187    /// Average tokens per query
188    pub avg_tokens_per_query: f64,
189
190    /// Quality statistics
191    /// Average exact match score
192    pub avg_exact_match: f64,
193    /// Average F1 score
194    pub avg_f1_score: f64,
195    /// Average BLEU score
196    pub avg_bleu_score: f64,
197    /// Average ROUGE-L score
198    pub avg_rouge_l: f64,
199
200    /// Features enabled
201    pub features: Vec<String>,
202
203    /// Per-query results
204    pub query_results: Vec<QueryBenchmark>,
205}
206
207type RetrievalFn = Box<dyn Fn(&str) -> Vec<String> + Send + Sync>;
208type RerankerFn = Box<dyn Fn(&[String]) -> Vec<String> + Send + Sync>;
209type LlmFn = Box<dyn Fn(&str, &[String]) -> String + Send + Sync>;
210
211/// Main benchmarking coordinator
212pub struct BenchmarkRunner {
213    config: BenchmarkConfig,
214    /// Optional retrieval system for actual benchmarking
215    retrieval_fn: Option<RetrievalFn>,
216    /// Optional reranker function
217    reranker_fn: Option<RerankerFn>,
218    /// Optional LLM generation function
219    llm_fn: Option<LlmFn>,
220}
221
222impl BenchmarkRunner {
223    /// Create a new benchmark runner with simulation mode
224    pub fn new(config: BenchmarkConfig) -> Self {
225        Self {
226            config,
227            retrieval_fn: None,
228            reranker_fn: None,
229            llm_fn: None,
230        }
231    }
232
233    /// Set a custom retrieval function for actual benchmarking
234    ///
235    /// # Example
236    /// ```no_run
237    /// # use graphrag_core::monitoring::benchmark::{BenchmarkRunner, BenchmarkConfig};
238    /// let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
239    /// runner.with_retrieval(|query| {
240    ///     // Your retrieval implementation
241    ///     vec!["doc1".to_string(), "doc2".to_string()]
242    /// });
243    /// ```
244    pub fn with_retrieval<F>(&mut self, f: F) -> &mut Self
245    where
246        F: Fn(&str) -> Vec<String> + Send + Sync + 'static,
247    {
248        self.retrieval_fn = Some(Box::new(f));
249        self
250    }
251
252    /// Set a custom reranker function
253    ///
254    /// # Example
255    /// ```no_run
256    /// # use graphrag_core::monitoring::benchmark::{BenchmarkRunner, BenchmarkConfig};
257    /// let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
258    /// runner.with_reranker(|docs| {
259    ///     // Your reranking implementation
260    ///     docs.to_vec()
261    /// });
262    /// ```
263    pub fn with_reranker<F>(&mut self, f: F) -> &mut Self
264    where
265        F: Fn(&[String]) -> Vec<String> + Send + Sync + 'static,
266    {
267        self.reranker_fn = Some(Box::new(f));
268        self
269    }
270
271    /// Set a custom LLM generation function
272    ///
273    /// # Example
274    /// ```no_run
275    /// # use graphrag_core::monitoring::benchmark::{BenchmarkRunner, BenchmarkConfig};
276    /// let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
277    /// runner.with_llm(|query, context| {
278    ///     // Your LLM implementation
279    ///     format!("Generated answer for: {}", query)
280    /// });
281    /// ```
282    pub fn with_llm<F>(&mut self, f: F) -> &mut Self
283    where
284        F: Fn(&str, &[String]) -> String + Send + Sync + 'static,
285    {
286        self.llm_fn = Some(Box::new(f));
287        self
288    }
289
290    /// Run benchmark on a dataset
291    pub fn run_dataset(&mut self, dataset: &BenchmarkDataset) -> BenchmarkSummary {
292        println!("📊 Running benchmark on dataset: {}", dataset.name);
293        println!("📋 Queries: {}", dataset.queries.len());
294
295        let mut results = Vec::new();
296
297        for (i, query) in dataset.queries.iter().enumerate() {
298            println!(
299                "  [{}/{}] Processing: {}...",
300                i + 1,
301                dataset.queries.len(),
302                &query.question.chars().take(50).collect::<String>()
303            );
304
305            let result = self.benchmark_query(query);
306            results.push(result);
307        }
308
309        self.compute_summary(dataset.name.clone(), results)
310    }
311
312    /// Benchmark a single query
313    fn benchmark_query(&self, query: &BenchmarkQuery) -> QueryBenchmark {
314        let start = Instant::now();
315
316        // Retrieval phase
317        let retrieval_start = Instant::now();
318        let retrieved_docs = if let Some(ref retrieval_fn) = self.retrieval_fn {
319            // Call actual retrieval system
320            retrieval_fn(&query.question)
321        } else {
322            // Simulation mode: return empty results
323            vec![]
324        };
325        let retrieval_time = retrieval_start.elapsed();
326
327        // Reranking phase (if enabled)
328        let (reranked_docs, reranking_time) = if self.config.enable_cross_encoder {
329            let reranking_start = Instant::now();
330            let reranked = if let Some(ref reranker_fn) = self.reranker_fn {
331                // Call actual cross-encoder reranking
332                reranker_fn(&retrieved_docs)
333            } else {
334                // Simulation mode: no reranking
335                retrieved_docs.clone()
336            };
337            (reranked, Some(reranking_start.elapsed()))
338        } else {
339            (retrieved_docs.clone(), None)
340        };
341
342        // Generation phase
343        let generation_start = Instant::now();
344        let generated_answer = if let Some(ref llm_fn) = self.llm_fn {
345            // Call actual LLM generation with context
346            llm_fn(&query.question, &reranked_docs)
347        } else {
348            // Simulation mode: generate placeholder
349            format!("Generated answer for: {}", query.question)
350        };
351        let generation_time = generation_start.elapsed();
352
353        let total_time = start.elapsed();
354
355        // Calculate token usage (estimated)
356        let estimated_input_tokens = if self.config.enable_lightrag {
357            200 // LightRAG optimization: much lower
358        } else {
359            2000 // Traditional GraphRAG: ~10x more
360        };
361
362        let estimated_output_tokens = 100;
363
364        let tokens = TokenMetrics {
365            input_tokens: estimated_input_tokens,
366            output_tokens: estimated_output_tokens,
367            total_tokens: estimated_input_tokens + estimated_output_tokens,
368            estimated_cost_usd: (estimated_input_tokens as f64 / 1000.0
369                * self.config.input_token_price)
370                + (estimated_output_tokens as f64 / 1000.0 * self.config.output_token_price),
371        };
372
373        // Calculate quality metrics
374        let quality = self.calculate_quality_metrics(&generated_answer, &query.answer);
375
376        // Collect enabled features
377        let mut features = Vec::new();
378        if self.config.enable_lightrag {
379            features.push("LightRAG".to_string());
380        }
381        if self.config.enable_leiden {
382            features.push("Leiden".to_string());
383        }
384        if self.config.enable_cross_encoder {
385            features.push("Cross-Encoder".to_string());
386        }
387        if self.config.enable_hipporag {
388            features.push("HippoRAG PPR".to_string());
389        }
390        if self.config.enable_semantic_chunking {
391            features.push("Semantic Chunking".to_string());
392        }
393
394        QueryBenchmark {
395            query: query.question.clone(),
396            ground_truth: Some(query.answer.clone()),
397            generated_answer,
398            latency: LatencyMetrics {
399                total_ms: total_time.as_millis() as u64,
400                retrieval_ms: retrieval_time.as_millis() as u64,
401                reranking_ms: reranking_time.map(|d| d.as_millis() as u64),
402                generation_ms: generation_time.as_millis() as u64,
403                other_ms: 0,
404            },
405            tokens,
406            quality,
407            features_enabled: features,
408        }
409    }
410
411    /// Calculate quality metrics
412    fn calculate_quality_metrics(&self, generated: &str, ground_truth: &str) -> QualityMetrics {
413        // Exact match
414        let exact_match = if generated.trim().eq_ignore_ascii_case(ground_truth.trim()) {
415            1.0
416        } else {
417            0.0
418        };
419
420        // F1 score (token overlap)
421        let f1_score = self.calculate_f1_score(generated, ground_truth);
422
423        // BLEU score (n-gram overlap with brevity penalty)
424        let bleu_score = Some(self.calculate_bleu_score(generated, ground_truth));
425
426        // ROUGE-L score (Longest Common Subsequence F-score)
427        let rouge_l = Some(self.calculate_rouge_l(generated, ground_truth));
428
429        QualityMetrics {
430            exact_match,
431            f1_score,
432            bleu_score,
433            rouge_l,
434            semantic_similarity: None,
435        }
436    }
437
438    /// Calculate F1 score based on token overlap
439    fn calculate_f1_score(&self, generated: &str, ground_truth: &str) -> f32 {
440        let gen_tokens: Vec<String> = generated
441            .to_lowercase()
442            .split_whitespace()
443            .map(|s| s.to_string())
444            .collect();
445
446        let gt_tokens: Vec<String> = ground_truth
447            .to_lowercase()
448            .split_whitespace()
449            .map(|s| s.to_string())
450            .collect();
451
452        if gen_tokens.is_empty() || gt_tokens.is_empty() {
453            return 0.0;
454        }
455
456        // Calculate overlap
457        let mut common = 0;
458        for token in &gen_tokens {
459            if gt_tokens.contains(token) {
460                common += 1;
461            }
462        }
463
464        if common == 0 {
465            return 0.0;
466        }
467
468        let precision = common as f32 / gen_tokens.len() as f32;
469        let recall = common as f32 / gt_tokens.len() as f32;
470
471        2.0 * (precision * recall) / (precision + recall)
472    }
473
474    /// Calculate BLEU score (BiLingual Evaluation Understudy)
475    ///
476    /// BLEU score measures n-gram overlap between generated and reference text,
477    /// with a brevity penalty for overly short outputs.
478    ///
479    /// Formula: BLEU = BP * exp(1/N * sum(log(P_n)))
480    /// where P_n is the precision for n-grams and BP is the brevity penalty.
481    fn calculate_bleu_score(&self, candidate: &str, reference: &str) -> f32 {
482        // Tokenize candidate and reference
483        let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
484        let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
485
486        if candidate_tokens.is_empty() || reference_tokens.is_empty() {
487            return 0.0;
488        }
489
490        // Calculate n-gram precisions (n=1 to 4)
491        let max_n = 4;
492        let mut log_precision_sum = 0.0;
493        let mut valid_n_grams = 0;
494
495        for n in 1..=max_n {
496            let precision = self.calculate_ngram_precision(&candidate_tokens, &reference_tokens, n);
497
498            if precision > 0.0 {
499                log_precision_sum += precision.ln();
500                valid_n_grams += 1;
501            } else {
502                // If any n-gram precision is 0, BLEU score is 0
503                return 0.0;
504            }
505        }
506
507        // Calculate brevity penalty
508        let candidate_len = candidate_tokens.len() as f32;
509        let reference_len = reference_tokens.len() as f32;
510
511        let brevity_penalty = if candidate_len >= reference_len {
512            1.0
513        } else {
514            (1.0 - reference_len / candidate_len).exp()
515        };
516
517        // Final BLEU score: BP * exp(1/N * sum(log(P_n)))
518        let bleu = brevity_penalty * (log_precision_sum / valid_n_grams as f32).exp();
519
520        // Clamp to [0, 1] range
521        bleu.clamp(0.0, 1.0)
522    }
523
524    /// Calculate precision for n-grams with clipping
525    fn calculate_ngram_precision(&self, candidate: &[&str], reference: &[&str], n: usize) -> f32 {
526        if candidate.len() < n || reference.len() < n {
527            return 0.0;
528        }
529
530        // Extract n-grams from candidate
531        let candidate_ngrams = self.extract_ngrams(candidate, n);
532
533        // Extract n-grams from reference and count frequencies
534        let reference_ngrams = self.extract_ngrams(reference, n);
535        let mut reference_counts = std::collections::HashMap::new();
536        for ngram in &reference_ngrams {
537            *reference_counts.entry(ngram).or_insert(0) += 1;
538        }
539
540        // Count clipped matches (clip to max count in reference)
541        let mut clipped_matches = 0;
542        let mut candidate_counts = std::collections::HashMap::new();
543
544        for ngram in &candidate_ngrams {
545            let candidate_count = candidate_counts.entry(ngram).or_insert(0);
546            *candidate_count += 1;
547
548            if let Some(&ref_count) = reference_counts.get(&ngram) {
549                if *candidate_count <= ref_count {
550                    clipped_matches += 1;
551                }
552            }
553        }
554
555        // Precision = clipped_matches / total_candidate_ngrams
556        if candidate_ngrams.is_empty() {
557            0.0
558        } else {
559            clipped_matches as f32 / candidate_ngrams.len() as f32
560        }
561    }
562
563    /// Extract all n-grams from a token sequence
564    fn extract_ngrams(&self, tokens: &[&str], n: usize) -> Vec<Vec<String>> {
565        if tokens.len() < n {
566            return Vec::new();
567        }
568
569        tokens
570            .windows(n)
571            .map(|window| window.iter().map(|&s| s.to_string()).collect())
572            .collect()
573    }
574
575    /// Calculate ROUGE-L score (Recall-Oriented Understudy for Gisting Evaluation - Longest Common Subsequence)
576    ///
577    /// ROUGE-L measures the similarity between candidate and reference text using
578    /// the Longest Common Subsequence (LCS) to compute precision, recall, and F-score.
579    ///
580    /// Formula: F = ((1 + β²) * precision * recall) / (β² * precision + recall)
581    /// where β controls the importance of recall (typically β=1.2)
582    fn calculate_rouge_l(&self, candidate: &str, reference: &str) -> f32 {
583        // Tokenize candidate and reference
584        let candidate_tokens: Vec<&str> = candidate.split_whitespace().collect();
585        let reference_tokens: Vec<&str> = reference.split_whitespace().collect();
586
587        if candidate_tokens.is_empty() || reference_tokens.is_empty() {
588            return 0.0;
589        }
590
591        // Calculate LCS length
592        let lcs_length = self.lcs_length(&candidate_tokens, &reference_tokens);
593
594        if lcs_length == 0 {
595            return 0.0;
596        }
597
598        // Calculate precision and recall
599        let precision = lcs_length as f32 / candidate_tokens.len() as f32;
600        let recall = lcs_length as f32 / reference_tokens.len() as f32;
601
602        // Calculate F-score with β=1.2 (slightly favors recall)
603        let beta = 1.2;
604        let beta_squared = beta * beta;
605
606        let f_score =
607            ((1.0 + beta_squared) * precision * recall) / (beta_squared * precision + recall);
608
609        // Clamp to [0, 1] range
610        f_score.clamp(0.0, 1.0)
611    }
612
613    /// Calculate the length of the Longest Common Subsequence (LCS) using dynamic programming
614    ///
615    /// LCS is the longest sequence of tokens that appear in both texts in the same order
616    /// (but not necessarily consecutively).
617    ///
618    /// Time complexity: O(m * n) where m and n are the lengths of the two sequences
619    fn lcs_length(&self, seq1: &[&str], seq2: &[&str]) -> usize {
620        let m = seq1.len();
621        let n = seq2.len();
622
623        if m == 0 || n == 0 {
624            return 0;
625        }
626
627        // Create DP table: dp[i][j] = LCS length of seq1[0..i] and seq2[0..j]
628        let mut dp = vec![vec![0; n + 1]; m + 1];
629
630        // Fill the DP table
631        for i in 1..=m {
632            for j in 1..=n {
633                if seq1[i - 1] == seq2[j - 1] {
634                    // Characters match: extend LCS by 1
635                    dp[i][j] = dp[i - 1][j - 1] + 1;
636                } else {
637                    // Characters don't match: take max of excluding either character
638                    dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
639                }
640            }
641        }
642
643        dp[m][n]
644    }
645
646    /// Compute aggregate summary
647    fn compute_summary(
648        &self,
649        config_name: String,
650        results: Vec<QueryBenchmark>,
651    ) -> BenchmarkSummary {
652        let total = results.len();
653
654        if total == 0 {
655            return BenchmarkSummary {
656                config_name,
657                total_queries: 0,
658                avg_latency_ms: 0.0,
659                avg_retrieval_ms: 0.0,
660                avg_reranking_ms: 0.0,
661                avg_generation_ms: 0.0,
662                total_input_tokens: 0,
663                total_output_tokens: 0,
664                total_cost_usd: 0.0,
665                avg_tokens_per_query: 0.0,
666                avg_exact_match: 0.0,
667                avg_f1_score: 0.0,
668                avg_bleu_score: 0.0,
669                avg_rouge_l: 0.0,
670                features: Vec::new(),
671                query_results: results,
672            };
673        }
674
675        let avg_latency_ms = results
676            .iter()
677            .map(|r| r.latency.total_ms as f64)
678            .sum::<f64>()
679            / total as f64;
680        let avg_retrieval_ms = results
681            .iter()
682            .map(|r| r.latency.retrieval_ms as f64)
683            .sum::<f64>()
684            / total as f64;
685        let avg_reranking_ms = results
686            .iter()
687            .filter_map(|r| r.latency.reranking_ms)
688            .map(|ms| ms as f64)
689            .sum::<f64>()
690            / total as f64;
691        let avg_generation_ms = results
692            .iter()
693            .map(|r| r.latency.generation_ms as f64)
694            .sum::<f64>()
695            / total as f64;
696
697        let total_input_tokens: usize = results.iter().map(|r| r.tokens.input_tokens).sum();
698        let total_output_tokens: usize = results.iter().map(|r| r.tokens.output_tokens).sum();
699        let total_cost_usd: f64 = results.iter().map(|r| r.tokens.estimated_cost_usd).sum();
700
701        let avg_exact_match = results
702            .iter()
703            .map(|r| r.quality.exact_match as f64)
704            .sum::<f64>()
705            / total as f64;
706        let avg_f1_score = results
707            .iter()
708            .map(|r| r.quality.f1_score as f64)
709            .sum::<f64>()
710            / total as f64;
711
712        // Calculate average BLEU score (only count queries where BLEU was computed)
713        let bleu_scores: Vec<f64> = results
714            .iter()
715            .filter_map(|r| r.quality.bleu_score.map(|s| s as f64))
716            .collect();
717        let avg_bleu_score = if !bleu_scores.is_empty() {
718            bleu_scores.iter().sum::<f64>() / bleu_scores.len() as f64
719        } else {
720            0.0
721        };
722
723        // Calculate average ROUGE-L score (only count queries where ROUGE-L was computed)
724        let rouge_scores: Vec<f64> = results
725            .iter()
726            .filter_map(|r| r.quality.rouge_l.map(|s| s as f64))
727            .collect();
728        let avg_rouge_l = if !rouge_scores.is_empty() {
729            rouge_scores.iter().sum::<f64>() / rouge_scores.len() as f64
730        } else {
731            0.0
732        };
733
734        let features = if !results.is_empty() {
735            results[0].features_enabled.clone()
736        } else {
737            Vec::new()
738        };
739
740        BenchmarkSummary {
741            config_name,
742            total_queries: total,
743            avg_latency_ms,
744            avg_retrieval_ms,
745            avg_reranking_ms,
746            avg_generation_ms,
747            total_input_tokens,
748            total_output_tokens,
749            total_cost_usd,
750            avg_tokens_per_query: (total_input_tokens + total_output_tokens) as f64 / total as f64,
751            avg_exact_match,
752            avg_f1_score,
753            avg_bleu_score,
754            avg_rouge_l,
755            features,
756            query_results: results,
757        }
758    }
759
760    /// Print summary results
761    pub fn print_summary(&self, summary: &BenchmarkSummary) {
762        println!("\n📊 Benchmark Results: {}", summary.config_name);
763        println!("{}", "=".repeat(60));
764
765        println!("\n🎯 Quality Metrics:");
766        println!("  Exact Match:  {:.1}%", summary.avg_exact_match * 100.0);
767        println!("  F1 Score:     {:.3}", summary.avg_f1_score);
768        if summary.avg_bleu_score > 0.0 {
769            println!("  BLEU Score:   {:.3}", summary.avg_bleu_score);
770        }
771        if summary.avg_rouge_l > 0.0 {
772            println!("  ROUGE-L:      {:.3}", summary.avg_rouge_l);
773        }
774
775        println!("\n⏱️  Latency Metrics (avg):");
776        println!("  Total:        {:.1} ms", summary.avg_latency_ms);
777        println!("  Retrieval:    {:.1} ms", summary.avg_retrieval_ms);
778        if summary.avg_reranking_ms > 0.0 {
779            println!("  Reranking:    {:.1} ms", summary.avg_reranking_ms);
780        }
781        println!("  Generation:   {:.1} ms", summary.avg_generation_ms);
782
783        println!("\n💰 Token & Cost Metrics:");
784        println!("  Input tokens:  {}", summary.total_input_tokens);
785        println!("  Output tokens: {}", summary.total_output_tokens);
786        println!("  Total cost:    ${:.4}", summary.total_cost_usd);
787        println!("  Avg tokens/query: {:.0}", summary.avg_tokens_per_query);
788
789        println!("\n✨ Features Enabled:");
790        for feature in &summary.features {
791            println!("  ✅ {}", feature);
792        }
793
794        println!("\n{}", "=".repeat(60));
795    }
796
797    /// Compare two benchmark summaries
798    pub fn compare_summaries(&self, baseline: &BenchmarkSummary, improved: &BenchmarkSummary) {
799        println!("\n📈 Benchmark Comparison");
800        println!("{}", "=".repeat(60));
801
802        println!("\nConfiguration:");
803        println!("  Baseline: {}", baseline.config_name);
804        println!("  Improved: {}", improved.config_name);
805
806        println!("\n🎯 Quality Improvements:");
807        let em_improvement = ((improved.avg_exact_match - baseline.avg_exact_match)
808            / baseline.avg_exact_match)
809            * 100.0;
810        let f1_improvement =
811            ((improved.avg_f1_score - baseline.avg_f1_score) / baseline.avg_f1_score) * 100.0;
812        println!("  Exact Match:  {:+.1}%", em_improvement);
813        println!("  F1 Score:     {:+.1}%", f1_improvement);
814
815        println!("\n💰 Cost Savings:");
816        let token_reduction = ((baseline.total_input_tokens - improved.total_input_tokens) as f64
817            / baseline.total_input_tokens as f64)
818            * 100.0;
819        let cost_savings =
820            ((baseline.total_cost_usd - improved.total_cost_usd) / baseline.total_cost_usd) * 100.0;
821        println!(
822            "  Token reduction: {:.1}% ({} → {} tokens)",
823            token_reduction, baseline.total_input_tokens, improved.total_input_tokens
824        );
825        println!(
826            "  Cost savings:    {:.1}% (${:.4} → ${:.4})",
827            cost_savings, baseline.total_cost_usd, improved.total_cost_usd
828        );
829
830        println!("\n⏱️  Latency Changes:");
831        let latency_change =
832            ((improved.avg_latency_ms - baseline.avg_latency_ms) / baseline.avg_latency_ms) * 100.0;
833        println!(
834            "  Total latency: {:+.1}% ({:.1}ms → {:.1}ms)",
835            latency_change, baseline.avg_latency_ms, improved.avg_latency_ms
836        );
837
838        println!("\n{}", "=".repeat(60));
839    }
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_f1_score_calculation() {
848        let _runner = BenchmarkRunner::new(BenchmarkConfig::default());
849
850        // Perfect match
851        let f1 = _runner.calculate_f1_score("hello world", "hello world");
852        assert!((f1 - 1.0).abs() < 0.001);
853
854        // Partial overlap
855        let f1 = _runner.calculate_f1_score("hello world", "hello there");
856        assert!(f1 > 0.0 && f1 < 1.0);
857
858        // No overlap
859        let f1 = _runner.calculate_f1_score("foo bar", "baz qux");
860        assert_eq!(f1, 0.0);
861    }
862
863    #[test]
864    fn test_benchmark_summary() {
865        let dataset = BenchmarkDataset {
866            name: "Test".to_string(),
867            queries: vec![BenchmarkQuery {
868                question: "What is 2+2?".to_string(),
869                answer: "4".to_string(),
870                context: None,
871                difficulty: None,
872                query_type: None,
873            }],
874        };
875
876        let mut runner = BenchmarkRunner::new(BenchmarkConfig::default());
877        let summary = runner.run_dataset(&dataset);
878
879        assert_eq!(summary.total_queries, 1);
880        assert!(summary.avg_latency_ms >= 0.0);
881    }
882}