rexis_rag/evaluation/
benchmarks.rs

1//! # Evaluation Benchmarks
2//!
3//! Standard benchmarks and datasets for RAG system evaluation.
4
5use super::{
6    EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
7    EvaluatorConfig, EvaluatorPerformance, GroundTruth, PerformanceStats, SystemResponse,
8    TestQuery,
9};
10use crate::RragResult;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tracing::{error, info, warn};
14
15/// Benchmark evaluator
16pub struct BenchmarkEvaluator {
17    benchmarks: Vec<Box<dyn Benchmark>>,
18}
19
20/// Trait for evaluation benchmarks
21pub trait Benchmark: Send + Sync {
22    /// Benchmark name
23    fn name(&self) -> &str;
24
25    /// Generate test queries and ground truth
26    fn generate_test_data(&self) -> RragResult<EvaluationData>;
27
28    /// Evaluate system against this benchmark
29    fn evaluate_benchmark(
30        &self,
31        system_responses: &[SystemResponse],
32    ) -> RragResult<BenchmarkResult>;
33
34    /// Get benchmark configuration
35    fn get_config(&self) -> BenchmarkConfig;
36}
37
38/// Configuration for benchmarks
39#[derive(Debug, Clone)]
40pub struct BenchmarkConfig {
41    /// Benchmark name
42    pub name: String,
43
44    /// Number of test queries
45    pub num_queries: usize,
46
47    /// Difficulty level
48    pub difficulty: DifficultyLevel,
49
50    /// Domain focus
51    pub domain: BenchmarkDomain,
52
53    /// Evaluation metrics
54    pub metrics: Vec<String>,
55}
56
57/// Difficulty levels for benchmarks
58#[derive(Debug, Clone)]
59pub enum DifficultyLevel {
60    Easy,
61    Medium,
62    Hard,
63    Expert,
64}
65
66/// Benchmark domains
67#[derive(Debug, Clone)]
68pub enum BenchmarkDomain {
69    General,
70    Science,
71    Technology,
72    History,
73    Literature,
74    Medicine,
75    Law,
76    Finance,
77    Education,
78    News,
79}
80
81/// Result from benchmark evaluation
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct BenchmarkResult {
84    /// Benchmark name
85    pub benchmark_name: String,
86
87    /// Overall score
88    pub overall_score: f32,
89
90    /// Detailed scores
91    pub detailed_scores: HashMap<String, f32>,
92
93    /// Ranking compared to baseline
94    pub ranking_info: RankingInfo,
95
96    /// Performance analysis
97    pub performance_analysis: PerformanceAnalysis,
98
99    /// Failure cases
100    pub failure_cases: Vec<FailureCase>,
101}
102
103/// Ranking information
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct RankingInfo {
106    /// Percentile ranking
107    pub percentile: f32,
108
109    /// Comparison to baseline systems
110    pub baseline_comparisons: HashMap<String, f32>,
111
112    /// Confidence interval
113    pub confidence_interval: (f32, f32),
114}
115
116/// Performance analysis
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct PerformanceAnalysis {
119    /// Strengths identified
120    pub strengths: Vec<String>,
121
122    /// Weaknesses identified
123    pub weaknesses: Vec<String>,
124
125    /// Performance by category
126    pub category_performance: HashMap<String, f32>,
127
128    /// Error patterns
129    pub error_patterns: Vec<ErrorPattern>,
130}
131
132/// Error pattern analysis
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct ErrorPattern {
135    /// Pattern description
136    pub description: String,
137
138    /// Frequency
139    pub frequency: f32,
140
141    /// Example queries
142    pub example_queries: Vec<String>,
143
144    /// Suggested improvements
145    pub improvements: Vec<String>,
146}
147
148/// Failure case analysis
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct FailureCase {
151    /// Query that failed
152    pub query: String,
153
154    /// Expected result
155    pub expected: String,
156
157    /// Actual result
158    pub actual: String,
159
160    /// Failure reason
161    pub failure_reason: String,
162
163    /// Severity
164    pub severity: FailureSeverity,
165}
166
167/// Failure severity levels
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub enum FailureSeverity {
170    Low,
171    Medium,
172    High,
173    Critical,
174}
175
176impl BenchmarkEvaluator {
177    /// Create new benchmark evaluator
178    pub fn new() -> Self {
179        let mut evaluator = Self {
180            benchmarks: Vec::new(),
181        };
182
183        // Initialize standard benchmarks
184        evaluator.initialize_benchmarks();
185
186        evaluator
187    }
188
189    /// Initialize standard benchmarks
190    fn initialize_benchmarks(&mut self) {
191        self.benchmarks
192            .push(Box::new(GeneralKnowledgeBenchmark::new()));
193        self.benchmarks
194            .push(Box::new(FactualAccuracyBenchmark::new()));
195        self.benchmarks.push(Box::new(ReasoningBenchmark::new()));
196        self.benchmarks
197            .push(Box::new(DomainSpecificBenchmark::new()));
198        self.benchmarks.push(Box::new(MultiHopBenchmark::new()));
199        self.benchmarks
200            .push(Box::new(ConversationalBenchmark::new()));
201    }
202
203    /// Run all benchmarks
204    pub async fn run_all_benchmarks(&self) -> RragResult<HashMap<String, BenchmarkResult>> {
205        let mut results = HashMap::new();
206
207        for benchmark in &self.benchmarks {
208            tracing::debug!("Running benchmark: {}", benchmark.name());
209
210            // Generate test data
211            let test_data = benchmark.generate_test_data()?;
212
213            // For demonstration, create mock system responses
214            let system_responses = self.create_mock_responses(&test_data);
215
216            // Evaluate benchmark
217            match benchmark.evaluate_benchmark(&system_responses) {
218                Ok(result) => {
219                    results.insert(benchmark.name().to_string(), result);
220                    info!(" {} completed", benchmark.name());
221                }
222                Err(e) => {
223                    error!(" {} failed: {}", benchmark.name(), e);
224                }
225            }
226        }
227
228        Ok(results)
229    }
230
231    /// Create mock system responses for demonstration
232    fn create_mock_responses(&self, test_data: &EvaluationData) -> Vec<SystemResponse> {
233        use super::{RetrievedDocument, SystemTiming};
234
235        test_data
236            .queries
237            .iter()
238            .map(|query| SystemResponse {
239                query_id: query.id.clone(),
240                retrieved_docs: vec![RetrievedDocument {
241                    doc_id: format!("doc_{}", query.id),
242                    content: format!("Relevant content for: {}", query.query),
243                    score: 0.8,
244                    rank: 0,
245                    metadata: HashMap::new(),
246                }],
247                generated_answer: Some(format!("Generated answer for: {}", query.query)),
248                timing: SystemTiming {
249                    total_time_ms: 1000.0 + (query.id.len() as f32 * 100.0),
250                    retrieval_time_ms: 600.0,
251                    generation_time_ms: Some(300.0),
252                    reranking_time_ms: Some(100.0),
253                },
254                metadata: HashMap::new(),
255            })
256            .collect()
257    }
258}
259
260impl Evaluator for BenchmarkEvaluator {
261    fn name(&self) -> &str {
262        "Benchmark"
263    }
264
265    fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
266        let start_time = std::time::Instant::now();
267        let mut overall_scores = HashMap::new();
268        let per_query_results = Vec::new();
269
270        // Evaluate against each benchmark
271        for benchmark in &self.benchmarks {
272            match benchmark.evaluate_benchmark(&data.system_responses) {
273                Ok(result) => {
274                    overall_scores.insert(benchmark.name().to_string(), result.overall_score);
275                }
276                Err(e) => {
277                    warn!(" Benchmark {} failed: {}", benchmark.name(), e);
278                }
279            }
280        }
281
282        // Calculate overall benchmark score
283        let overall_score = if overall_scores.is_empty() {
284            0.0
285        } else {
286            overall_scores.values().sum::<f32>() / overall_scores.len() as f32
287        };
288
289        overall_scores.insert("overall_benchmark_score".to_string(), overall_score);
290
291        let total_time = start_time.elapsed().as_millis() as f32;
292
293        // Generate insights
294        let insights = self.generate_insights(&overall_scores);
295        let recommendations = self.generate_recommendations(&overall_scores);
296
297        Ok(EvaluationResult {
298            id: uuid::Uuid::new_v4().to_string(),
299            evaluation_type: "Benchmark".to_string(),
300            overall_scores: overall_scores.clone(),
301            per_query_results,
302            summary: EvaluationSummary {
303                total_queries: data.queries.len(),
304                avg_scores: overall_scores.clone(),
305                std_deviations: HashMap::new(),
306                performance_stats: PerformanceStats {
307                    avg_eval_time_ms: total_time,
308                    total_eval_time_ms: total_time,
309                    peak_memory_usage_mb: 200.0,
310                    throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
311                },
312                insights,
313                recommendations,
314            },
315            metadata: EvaluationMetadata {
316                timestamp: chrono::Utc::now(),
317                evaluation_version: "1.0.0".to_string(),
318                system_config: HashMap::new(),
319                environment: std::env::vars().collect(),
320                git_commit: None,
321            },
322        })
323    }
324
325    fn supported_metrics(&self) -> Vec<String> {
326        self.benchmarks
327            .iter()
328            .map(|b| b.name().to_string())
329            .collect()
330    }
331
332    fn get_config(&self) -> EvaluatorConfig {
333        EvaluatorConfig {
334            name: "Benchmark".to_string(),
335            version: "1.0.0".to_string(),
336            metrics: self.supported_metrics(),
337            performance: EvaluatorPerformance {
338                avg_time_per_sample_ms: 500.0,
339                memory_usage_mb: 200.0,
340                accuracy: 0.95,
341            },
342        }
343    }
344}
345
346impl BenchmarkEvaluator {
347    /// Generate insights based on benchmark results
348    fn generate_insights(&self, scores: &HashMap<String, f32>) -> Vec<String> {
349        let mut insights = Vec::new();
350
351        // Overall performance insights
352        if let Some(&overall_score) = scores.get("overall_benchmark_score") {
353            if overall_score > 0.8 {
354                insights.push("🏆 Excellent performance across benchmarks".to_string());
355            } else if overall_score < 0.6 {
356                insights.push("⚠️ Below-average performance on standard benchmarks".to_string());
357            }
358        }
359
360        // Specific benchmark insights
361        if let Some(&general_score) = scores.get("GeneralKnowledge") {
362            if general_score < 0.6 {
363                insights.push("📚 General knowledge capabilities need improvement".to_string());
364            }
365        }
366
367        if let Some(&factual_score) = scores.get("FactualAccuracy") {
368            if factual_score < 0.7 {
369                insights.push("📊 Factual accuracy is below acceptable threshold".to_string());
370            }
371        }
372
373        if let Some(&reasoning_score) = scores.get("Reasoning") {
374            if reasoning_score < 0.6 {
375                insights.push("🧠 Reasoning capabilities require enhancement".to_string());
376            }
377        }
378
379        insights
380    }
381
382    /// Generate recommendations based on benchmark results
383    fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
384        let mut recommendations = Vec::new();
385
386        if let Some(&general_score) = scores.get("GeneralKnowledge") {
387            if general_score < 0.6 {
388                recommendations.push(
389                    "📖 Expand knowledge base with diverse, high-quality sources".to_string(),
390                );
391                recommendations.push(
392                    "🔍 Improve retrieval to find relevant background information".to_string(),
393                );
394            }
395        }
396
397        if let Some(&factual_score) = scores.get("FactualAccuracy") {
398            if factual_score < 0.7 {
399                recommendations
400                    .push("✅ Implement fact-checking and verification mechanisms".to_string());
401                recommendations.push("📑 Use more authoritative and recent sources".to_string());
402            }
403        }
404
405        if let Some(&reasoning_score) = scores.get("Reasoning") {
406            if reasoning_score < 0.6 {
407                recommendations.push("🔄 Implement chain-of-thought reasoning prompts".to_string());
408                recommendations.push("🧩 Add step-by-step problem decomposition".to_string());
409            }
410        }
411
412        recommendations
413    }
414}
415
416impl Default for BenchmarkEvaluator {
417    fn default() -> Self {
418        Self::new()
419    }
420}
421
422// Individual benchmark implementations
423struct GeneralKnowledgeBenchmark;
424
425impl GeneralKnowledgeBenchmark {
426    fn new() -> Self {
427        Self
428    }
429}
430
431impl Benchmark for GeneralKnowledgeBenchmark {
432    fn name(&self) -> &str {
433        "GeneralKnowledge"
434    }
435
436    fn generate_test_data(&self) -> RragResult<EvaluationData> {
437        let queries = vec![
438            TestQuery {
439                id: "gk_1".to_string(),
440                query: "What is the capital of France?".to_string(),
441                query_type: Some("factual".to_string()),
442                metadata: HashMap::new(),
443            },
444            TestQuery {
445                id: "gk_2".to_string(),
446                query: "Who wrote Romeo and Juliet?".to_string(),
447                query_type: Some("factual".to_string()),
448                metadata: HashMap::new(),
449            },
450            TestQuery {
451                id: "gk_3".to_string(),
452                query: "What is photosynthesis?".to_string(),
453                query_type: Some("conceptual".to_string()),
454                metadata: HashMap::new(),
455            },
456        ];
457
458        let ground_truth = vec![
459            GroundTruth {
460                query_id: "gk_1".to_string(),
461                relevant_docs: vec!["france_capital".to_string()],
462                expected_answer: Some("Paris".to_string()),
463                relevance_judgments: HashMap::new(),
464                metadata: HashMap::new(),
465            },
466            GroundTruth {
467                query_id: "gk_2".to_string(),
468                relevant_docs: vec!["shakespeare_works".to_string()],
469                expected_answer: Some("William Shakespeare".to_string()),
470                relevance_judgments: HashMap::new(),
471                metadata: HashMap::new(),
472            },
473            GroundTruth {
474                query_id: "gk_3".to_string(),
475                relevant_docs: vec!["biology_photosynthesis".to_string()],
476                expected_answer: Some(
477                    "Process by which plants convert light energy into chemical energy".to_string(),
478                ),
479                relevance_judgments: HashMap::new(),
480                metadata: HashMap::new(),
481            },
482        ];
483
484        Ok(EvaluationData {
485            queries,
486            ground_truth,
487            system_responses: Vec::new(),
488            context: HashMap::new(),
489        })
490    }
491
492    fn evaluate_benchmark(&self, responses: &[SystemResponse]) -> RragResult<BenchmarkResult> {
493        let mut correct_answers = 0;
494        let total_questions = responses.len();
495
496        // Simplified evaluation - check if answer is present
497        for response in responses {
498            if let Some(answer) = &response.generated_answer {
499                if !answer.trim().is_empty() {
500                    correct_answers += 1;
501                }
502            }
503        }
504
505        let overall_score = if total_questions > 0 {
506            correct_answers as f32 / total_questions as f32
507        } else {
508            0.0
509        };
510
511        let mut detailed_scores = HashMap::new();
512        detailed_scores.insert("accuracy".to_string(), overall_score);
513        detailed_scores.insert("coverage".to_string(), 1.0); // All questions attempted
514
515        Ok(BenchmarkResult {
516            benchmark_name: self.name().to_string(),
517            overall_score,
518            detailed_scores,
519            ranking_info: RankingInfo {
520                percentile: overall_score * 100.0,
521                baseline_comparisons: HashMap::new(),
522                confidence_interval: (overall_score - 0.1, overall_score + 0.1),
523            },
524            performance_analysis: PerformanceAnalysis {
525                strengths: vec!["Good response generation".to_string()],
526                weaknesses: if overall_score < 0.7 {
527                    vec!["Factual accuracy needs improvement".to_string()]
528                } else {
529                    vec![]
530                },
531                category_performance: HashMap::new(),
532                error_patterns: Vec::new(),
533            },
534            failure_cases: Vec::new(),
535        })
536    }
537
538    fn get_config(&self) -> BenchmarkConfig {
539        BenchmarkConfig {
540            name: self.name().to_string(),
541            num_queries: 3,
542            difficulty: DifficultyLevel::Easy,
543            domain: BenchmarkDomain::General,
544            metrics: vec!["accuracy".to_string(), "coverage".to_string()],
545        }
546    }
547}
548
549// Placeholder implementations for other benchmarks
550macro_rules! impl_simple_benchmark {
551    ($name:ident, $benchmark_name:literal, $difficulty:expr, $domain:expr) => {
552        struct $name;
553
554        impl $name {
555            fn new() -> Self {
556                Self
557            }
558        }
559
560        impl Benchmark for $name {
561            fn name(&self) -> &str {
562                $benchmark_name
563            }
564
565            fn generate_test_data(&self) -> RragResult<EvaluationData> {
566                // Generate simple test data
567                let queries = vec![TestQuery {
568                    id: format!("{}_1", $benchmark_name.to_lowercase()),
569                    query: format!("Sample query for {}", $benchmark_name),
570                    query_type: Some("test".to_string()),
571                    metadata: HashMap::new(),
572                }];
573
574                let ground_truth = vec![GroundTruth {
575                    query_id: format!("{}_1", $benchmark_name.to_lowercase()),
576                    relevant_docs: vec!["test_doc".to_string()],
577                    expected_answer: Some("Test answer".to_string()),
578                    relevance_judgments: HashMap::new(),
579                    metadata: HashMap::new(),
580                }];
581
582                Ok(EvaluationData {
583                    queries,
584                    ground_truth,
585                    system_responses: Vec::new(),
586                    context: HashMap::new(),
587                })
588            }
589
590            fn evaluate_benchmark(
591                &self,
592                _responses: &[SystemResponse],
593            ) -> RragResult<BenchmarkResult> {
594                let overall_score = 0.75; // Default score for placeholder
595
596                let mut detailed_scores = HashMap::new();
597                detailed_scores.insert("placeholder_score".to_string(), overall_score);
598
599                Ok(BenchmarkResult {
600                    benchmark_name: self.name().to_string(),
601                    overall_score,
602                    detailed_scores,
603                    ranking_info: RankingInfo {
604                        percentile: 75.0,
605                        baseline_comparisons: HashMap::new(),
606                        confidence_interval: (0.65, 0.85),
607                    },
608                    performance_analysis: PerformanceAnalysis {
609                        strengths: vec!["Placeholder performance".to_string()],
610                        weaknesses: vec!["Needs real implementation".to_string()],
611                        category_performance: HashMap::new(),
612                        error_patterns: Vec::new(),
613                    },
614                    failure_cases: Vec::new(),
615                })
616            }
617
618            fn get_config(&self) -> BenchmarkConfig {
619                BenchmarkConfig {
620                    name: self.name().to_string(),
621                    num_queries: 1,
622                    difficulty: $difficulty,
623                    domain: $domain,
624                    metrics: vec!["placeholder_score".to_string()],
625                }
626            }
627        }
628    };
629}
630
631impl_simple_benchmark!(
632    FactualAccuracyBenchmark,
633    "FactualAccuracy",
634    DifficultyLevel::Medium,
635    BenchmarkDomain::General
636);
637impl_simple_benchmark!(
638    ReasoningBenchmark,
639    "Reasoning",
640    DifficultyLevel::Hard,
641    BenchmarkDomain::General
642);
643impl_simple_benchmark!(
644    DomainSpecificBenchmark,
645    "DomainSpecific",
646    DifficultyLevel::Medium,
647    BenchmarkDomain::Science
648);
649impl_simple_benchmark!(
650    MultiHopBenchmark,
651    "MultiHop",
652    DifficultyLevel::Hard,
653    BenchmarkDomain::General
654);
655impl_simple_benchmark!(
656    ConversationalBenchmark,
657    "Conversational",
658    DifficultyLevel::Medium,
659    BenchmarkDomain::General
660);
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_general_knowledge_benchmark() {
668        let benchmark = GeneralKnowledgeBenchmark::new();
669
670        assert_eq!(benchmark.name(), "GeneralKnowledge");
671
672        let test_data = benchmark.generate_test_data().unwrap();
673        assert_eq!(test_data.queries.len(), 3);
674        assert_eq!(test_data.ground_truth.len(), 3);
675    }
676
677    #[test]
678    fn test_benchmark_evaluator() {
679        let evaluator = BenchmarkEvaluator::new();
680
681        assert_eq!(evaluator.name(), "Benchmark");
682        assert!(!evaluator.supported_metrics().is_empty());
683    }
684
685    #[test]
686    fn test_benchmark_evaluation() {
687        let benchmark = GeneralKnowledgeBenchmark::new();
688        let responses = vec![SystemResponse {
689            query_id: "test".to_string(),
690            retrieved_docs: vec![],
691            generated_answer: Some("Test answer".to_string()),
692            timing: super::super::SystemTiming {
693                total_time_ms: 1000.0,
694                retrieval_time_ms: 500.0,
695                generation_time_ms: Some(400.0),
696                reranking_time_ms: Some(100.0),
697            },
698            metadata: HashMap::new(),
699        }];
700
701        let result = benchmark.evaluate_benchmark(&responses).unwrap();
702        assert!(result.overall_score > 0.0);
703        assert_eq!(result.benchmark_name, "GeneralKnowledge");
704    }
705}