use super::{
EvaluationData, EvaluationMetadata, EvaluationResult, EvaluationSummary, Evaluator,
EvaluatorConfig, EvaluatorPerformance, GroundTruth, PerformanceStats, SystemResponse,
TestQuery,
};
use crate::RragResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct BenchmarkEvaluator {
benchmarks: Vec<Box<dyn Benchmark>>,
}
pub trait Benchmark: Send + Sync {
fn name(&self) -> &str;
fn generate_test_data(&self) -> RragResult<EvaluationData>;
fn evaluate_benchmark(
&self,
system_responses: &[SystemResponse],
) -> RragResult<BenchmarkResult>;
fn get_config(&self) -> BenchmarkConfig;
}
#[derive(Debug, Clone)]
pub struct BenchmarkConfig {
pub name: String,
pub num_queries: usize,
pub difficulty: DifficultyLevel,
pub domain: BenchmarkDomain,
pub metrics: Vec<String>,
}
#[derive(Debug, Clone)]
pub enum DifficultyLevel {
Easy,
Medium,
Hard,
Expert,
}
#[derive(Debug, Clone)]
pub enum BenchmarkDomain {
General,
Science,
Technology,
History,
Literature,
Medicine,
Law,
Finance,
Education,
News,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkResult {
pub benchmark_name: String,
pub overall_score: f32,
pub detailed_scores: HashMap<String, f32>,
pub ranking_info: RankingInfo,
pub performance_analysis: PerformanceAnalysis,
pub failure_cases: Vec<FailureCase>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankingInfo {
pub percentile: f32,
pub baseline_comparisons: HashMap<String, f32>,
pub confidence_interval: (f32, f32),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceAnalysis {
pub strengths: Vec<String>,
pub weaknesses: Vec<String>,
pub category_performance: HashMap<String, f32>,
pub error_patterns: Vec<ErrorPattern>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorPattern {
pub description: String,
pub frequency: f32,
pub example_queries: Vec<String>,
pub improvements: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FailureCase {
pub query: String,
pub expected: String,
pub actual: String,
pub failure_reason: String,
pub severity: FailureSeverity,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FailureSeverity {
Low,
Medium,
High,
Critical,
}
impl BenchmarkEvaluator {
pub fn new() -> Self {
let mut evaluator = Self {
benchmarks: Vec::new(),
};
evaluator.initialize_benchmarks();
evaluator
}
fn initialize_benchmarks(&mut self) {
self.benchmarks
.push(Box::new(GeneralKnowledgeBenchmark::new()));
self.benchmarks
.push(Box::new(FactualAccuracyBenchmark::new()));
self.benchmarks.push(Box::new(ReasoningBenchmark::new()));
self.benchmarks
.push(Box::new(DomainSpecificBenchmark::new()));
self.benchmarks.push(Box::new(MultiHopBenchmark::new()));
self.benchmarks
.push(Box::new(ConversationalBenchmark::new()));
}
pub async fn run_all_benchmarks(&self) -> RragResult<HashMap<String, BenchmarkResult>> {
let mut results = HashMap::new();
for benchmark in &self.benchmarks {
println!("Running benchmark: {}", benchmark.name());
let test_data = benchmark.generate_test_data()?;
let system_responses = self.create_mock_responses(&test_data);
match benchmark.evaluate_benchmark(&system_responses) {
Ok(result) => {
results.insert(benchmark.name().to_string(), result);
println!("✅ {} completed", benchmark.name());
}
Err(e) => {
eprintln!("❌ {} failed: {}", benchmark.name(), e);
}
}
}
Ok(results)
}
fn create_mock_responses(&self, test_data: &EvaluationData) -> Vec<SystemResponse> {
use super::{RetrievedDocument, SystemTiming};
test_data
.queries
.iter()
.map(|query| SystemResponse {
query_id: query.id.clone(),
retrieved_docs: vec![RetrievedDocument {
doc_id: format!("doc_{}", query.id),
content: format!("Relevant content for: {}", query.query),
score: 0.8,
rank: 0,
metadata: HashMap::new(),
}],
generated_answer: Some(format!("Generated answer for: {}", query.query)),
timing: SystemTiming {
total_time_ms: 1000.0 + (query.id.len() as f32 * 100.0),
retrieval_time_ms: 600.0,
generation_time_ms: Some(300.0),
reranking_time_ms: Some(100.0),
},
metadata: HashMap::new(),
})
.collect()
}
}
impl Evaluator for BenchmarkEvaluator {
fn name(&self) -> &str {
"Benchmark"
}
fn evaluate(&self, data: &EvaluationData) -> RragResult<EvaluationResult> {
let start_time = std::time::Instant::now();
let mut overall_scores = HashMap::new();
let per_query_results = Vec::new();
for benchmark in &self.benchmarks {
match benchmark.evaluate_benchmark(&data.system_responses) {
Ok(result) => {
overall_scores.insert(benchmark.name().to_string(), result.overall_score);
}
Err(e) => {
eprintln!("Warning: Benchmark {} failed: {}", benchmark.name(), e);
}
}
}
let overall_score = if overall_scores.is_empty() {
0.0
} else {
overall_scores.values().sum::<f32>() / overall_scores.len() as f32
};
overall_scores.insert("overall_benchmark_score".to_string(), overall_score);
let total_time = start_time.elapsed().as_millis() as f32;
let insights = self.generate_insights(&overall_scores);
let recommendations = self.generate_recommendations(&overall_scores);
Ok(EvaluationResult {
id: uuid::Uuid::new_v4().to_string(),
evaluation_type: "Benchmark".to_string(),
overall_scores: overall_scores.clone(),
per_query_results,
summary: EvaluationSummary {
total_queries: data.queries.len(),
avg_scores: overall_scores.clone(),
std_deviations: HashMap::new(),
performance_stats: PerformanceStats {
avg_eval_time_ms: total_time,
total_eval_time_ms: total_time,
peak_memory_usage_mb: 200.0,
throughput_qps: data.queries.len() as f32 / (total_time / 1000.0),
},
insights,
recommendations,
},
metadata: EvaluationMetadata {
timestamp: chrono::Utc::now(),
evaluation_version: "1.0.0".to_string(),
system_config: HashMap::new(),
environment: std::env::vars().collect(),
git_commit: None,
},
})
}
fn supported_metrics(&self) -> Vec<String> {
self.benchmarks
.iter()
.map(|b| b.name().to_string())
.collect()
}
fn get_config(&self) -> EvaluatorConfig {
EvaluatorConfig {
name: "Benchmark".to_string(),
version: "1.0.0".to_string(),
metrics: self.supported_metrics(),
performance: EvaluatorPerformance {
avg_time_per_sample_ms: 500.0,
memory_usage_mb: 200.0,
accuracy: 0.95,
},
}
}
}
impl BenchmarkEvaluator {
fn generate_insights(&self, scores: &HashMap<String, f32>) -> Vec<String> {
let mut insights = Vec::new();
if let Some(&overall_score) = scores.get("overall_benchmark_score") {
if overall_score > 0.8 {
insights.push("🏆 Excellent performance across benchmarks".to_string());
} else if overall_score < 0.6 {
insights.push("⚠️ Below-average performance on standard benchmarks".to_string());
}
}
if let Some(&general_score) = scores.get("GeneralKnowledge") {
if general_score < 0.6 {
insights.push("📚 General knowledge capabilities need improvement".to_string());
}
}
if let Some(&factual_score) = scores.get("FactualAccuracy") {
if factual_score < 0.7 {
insights.push("📊 Factual accuracy is below acceptable threshold".to_string());
}
}
if let Some(&reasoning_score) = scores.get("Reasoning") {
if reasoning_score < 0.6 {
insights.push("🧠 Reasoning capabilities require enhancement".to_string());
}
}
insights
}
fn generate_recommendations(&self, scores: &HashMap<String, f32>) -> Vec<String> {
let mut recommendations = Vec::new();
if let Some(&general_score) = scores.get("GeneralKnowledge") {
if general_score < 0.6 {
recommendations.push(
"📖 Expand knowledge base with diverse, high-quality sources".to_string(),
);
recommendations.push(
"🔍 Improve retrieval to find relevant background information".to_string(),
);
}
}
if let Some(&factual_score) = scores.get("FactualAccuracy") {
if factual_score < 0.7 {
recommendations
.push("✅ Implement fact-checking and verification mechanisms".to_string());
recommendations.push("📑 Use more authoritative and recent sources".to_string());
}
}
if let Some(&reasoning_score) = scores.get("Reasoning") {
if reasoning_score < 0.6 {
recommendations.push("🔄 Implement chain-of-thought reasoning prompts".to_string());
recommendations.push("🧩 Add step-by-step problem decomposition".to_string());
}
}
recommendations
}
}
impl Default for BenchmarkEvaluator {
fn default() -> Self {
Self::new()
}
}
struct GeneralKnowledgeBenchmark;
impl GeneralKnowledgeBenchmark {
fn new() -> Self {
Self
}
}
impl Benchmark for GeneralKnowledgeBenchmark {
fn name(&self) -> &str {
"GeneralKnowledge"
}
fn generate_test_data(&self) -> RragResult<EvaluationData> {
let queries = vec![
TestQuery {
id: "gk_1".to_string(),
query: "What is the capital of France?".to_string(),
query_type: Some("factual".to_string()),
metadata: HashMap::new(),
},
TestQuery {
id: "gk_2".to_string(),
query: "Who wrote Romeo and Juliet?".to_string(),
query_type: Some("factual".to_string()),
metadata: HashMap::new(),
},
TestQuery {
id: "gk_3".to_string(),
query: "What is photosynthesis?".to_string(),
query_type: Some("conceptual".to_string()),
metadata: HashMap::new(),
},
];
let ground_truth = vec![
GroundTruth {
query_id: "gk_1".to_string(),
relevant_docs: vec!["france_capital".to_string()],
expected_answer: Some("Paris".to_string()),
relevance_judgments: HashMap::new(),
metadata: HashMap::new(),
},
GroundTruth {
query_id: "gk_2".to_string(),
relevant_docs: vec!["shakespeare_works".to_string()],
expected_answer: Some("William Shakespeare".to_string()),
relevance_judgments: HashMap::new(),
metadata: HashMap::new(),
},
GroundTruth {
query_id: "gk_3".to_string(),
relevant_docs: vec!["biology_photosynthesis".to_string()],
expected_answer: Some(
"Process by which plants convert light energy into chemical energy".to_string(),
),
relevance_judgments: HashMap::new(),
metadata: HashMap::new(),
},
];
Ok(EvaluationData {
queries,
ground_truth,
system_responses: Vec::new(),
context: HashMap::new(),
})
}
fn evaluate_benchmark(&self, responses: &[SystemResponse]) -> RragResult<BenchmarkResult> {
let mut correct_answers = 0;
let total_questions = responses.len();
for response in responses {
if let Some(answer) = &response.generated_answer {
if !answer.trim().is_empty() {
correct_answers += 1;
}
}
}
let overall_score = if total_questions > 0 {
correct_answers as f32 / total_questions as f32
} else {
0.0
};
let mut detailed_scores = HashMap::new();
detailed_scores.insert("accuracy".to_string(), overall_score);
detailed_scores.insert("coverage".to_string(), 1.0);
Ok(BenchmarkResult {
benchmark_name: self.name().to_string(),
overall_score,
detailed_scores,
ranking_info: RankingInfo {
percentile: overall_score * 100.0,
baseline_comparisons: HashMap::new(),
confidence_interval: (overall_score - 0.1, overall_score + 0.1),
},
performance_analysis: PerformanceAnalysis {
strengths: vec!["Good response generation".to_string()],
weaknesses: if overall_score < 0.7 {
vec!["Factual accuracy needs improvement".to_string()]
} else {
vec![]
},
category_performance: HashMap::new(),
error_patterns: Vec::new(),
},
failure_cases: Vec::new(),
})
}
fn get_config(&self) -> BenchmarkConfig {
BenchmarkConfig {
name: self.name().to_string(),
num_queries: 3,
difficulty: DifficultyLevel::Easy,
domain: BenchmarkDomain::General,
metrics: vec!["accuracy".to_string(), "coverage".to_string()],
}
}
}
macro_rules! impl_simple_benchmark {
($name:ident, $benchmark_name:literal, $difficulty:expr, $domain:expr) => {
struct $name;
impl $name {
fn new() -> Self {
Self
}
}
impl Benchmark for $name {
fn name(&self) -> &str {
$benchmark_name
}
fn generate_test_data(&self) -> RragResult<EvaluationData> {
let queries = vec![TestQuery {
id: format!("{}_1", $benchmark_name.to_lowercase()),
query: format!("Sample query for {}", $benchmark_name),
query_type: Some("test".to_string()),
metadata: HashMap::new(),
}];
let ground_truth = vec![GroundTruth {
query_id: format!("{}_1", $benchmark_name.to_lowercase()),
relevant_docs: vec!["test_doc".to_string()],
expected_answer: Some("Test answer".to_string()),
relevance_judgments: HashMap::new(),
metadata: HashMap::new(),
}];
Ok(EvaluationData {
queries,
ground_truth,
system_responses: Vec::new(),
context: HashMap::new(),
})
}
fn evaluate_benchmark(
&self,
_responses: &[SystemResponse],
) -> RragResult<BenchmarkResult> {
let overall_score = 0.75;
let mut detailed_scores = HashMap::new();
detailed_scores.insert("placeholder_score".to_string(), overall_score);
Ok(BenchmarkResult {
benchmark_name: self.name().to_string(),
overall_score,
detailed_scores,
ranking_info: RankingInfo {
percentile: 75.0,
baseline_comparisons: HashMap::new(),
confidence_interval: (0.65, 0.85),
},
performance_analysis: PerformanceAnalysis {
strengths: vec!["Placeholder performance".to_string()],
weaknesses: vec!["Needs real implementation".to_string()],
category_performance: HashMap::new(),
error_patterns: Vec::new(),
},
failure_cases: Vec::new(),
})
}
fn get_config(&self) -> BenchmarkConfig {
BenchmarkConfig {
name: self.name().to_string(),
num_queries: 1,
difficulty: $difficulty,
domain: $domain,
metrics: vec!["placeholder_score".to_string()],
}
}
}
};
}
impl_simple_benchmark!(
FactualAccuracyBenchmark,
"FactualAccuracy",
DifficultyLevel::Medium,
BenchmarkDomain::General
);
impl_simple_benchmark!(
ReasoningBenchmark,
"Reasoning",
DifficultyLevel::Hard,
BenchmarkDomain::General
);
impl_simple_benchmark!(
DomainSpecificBenchmark,
"DomainSpecific",
DifficultyLevel::Medium,
BenchmarkDomain::Science
);
impl_simple_benchmark!(
MultiHopBenchmark,
"MultiHop",
DifficultyLevel::Hard,
BenchmarkDomain::General
);
impl_simple_benchmark!(
ConversationalBenchmark,
"Conversational",
DifficultyLevel::Medium,
BenchmarkDomain::General
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_general_knowledge_benchmark() {
let benchmark = GeneralKnowledgeBenchmark::new();
assert_eq!(benchmark.name(), "GeneralKnowledge");
let test_data = benchmark.generate_test_data().unwrap();
assert_eq!(test_data.queries.len(), 3);
assert_eq!(test_data.ground_truth.len(), 3);
}
#[test]
fn test_benchmark_evaluator() {
let evaluator = BenchmarkEvaluator::new();
assert_eq!(evaluator.name(), "Benchmark");
assert!(!evaluator.supported_metrics().is_empty());
}
#[test]
fn test_benchmark_evaluation() {
let benchmark = GeneralKnowledgeBenchmark::new();
let responses = vec![SystemResponse {
query_id: "test".to_string(),
retrieved_docs: vec![],
generated_answer: Some("Test answer".to_string()),
timing: super::super::SystemTiming {
total_time_ms: 1000.0,
retrieval_time_ms: 500.0,
generation_time_ms: Some(400.0),
reranking_time_ms: Some(100.0),
},
metadata: HashMap::new(),
}];
let result = benchmark.evaluate_benchmark(&responses).unwrap();
assert!(result.overall_score > 0.0);
assert_eq!(result.benchmark_name, "GeneralKnowledge");
}
}