use super::ApplicationEvalConfig;
use crate::EmbeddingModel;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QueryAnsweringMetric {
ExactMatch,
PartialMatch,
Completeness,
Precision,
Recall,
MRR,
HitsAtK(usize),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum QueryType {
FactLookup,
RelationshipQuery,
AggregationQuery,
ComparisonQuery,
MultiHopReasoning,
TemporalReasoning,
NegationQuery,
ComplexLogical,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum QueryComplexity {
Simple,
Medium,
Complex,
Expert,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuestionAnswerPair {
pub question: String,
pub structured_query: Option<String>,
pub answer_entities: Vec<String>,
pub answer_literals: Vec<String>,
pub complexity: QueryComplexity,
pub query_type: QueryType,
pub domain: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResult {
pub question: String,
pub expected_answers: Vec<String>,
pub predicted_answers: Vec<String>,
pub accuracy: f64,
pub response_time: f64,
pub complexity: QueryComplexity,
pub query_type: QueryType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypeResults {
pub num_queries: usize,
pub avg_accuracy: f64,
pub avg_response_time: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplexityResults {
pub num_queries: usize,
pub avg_accuracy: f64,
pub completion_rate: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReasoningAnalysis {
pub multi_hop_accuracy: f64,
pub temporal_accuracy: f64,
pub logical_accuracy: f64,
pub aggregation_accuracy: f64,
pub overall_reasoning_score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryAnsweringResults {
pub metric_scores: HashMap<String, f64>,
pub results_by_type: HashMap<QueryType, TypeResults>,
pub results_by_complexity: HashMap<QueryComplexity, ComplexityResults>,
pub per_query_results: Vec<QueryResult>,
pub overall_accuracy: f64,
pub reasoning_analysis: ReasoningAnalysis,
}
pub struct ApplicationQueryAnsweringEvaluator {
qa_pairs: Vec<QuestionAnswerPair>,
query_types: Vec<QueryType>,
metrics: Vec<QueryAnsweringMetric>,
}
impl ApplicationQueryAnsweringEvaluator {
pub fn new() -> Self {
let mut evaluator = Self {
qa_pairs: Vec::new(),
query_types: vec![
QueryType::FactLookup,
QueryType::RelationshipQuery,
QueryType::AggregationQuery,
QueryType::ComparisonQuery,
QueryType::MultiHopReasoning,
QueryType::TemporalReasoning,
QueryType::NegationQuery,
QueryType::ComplexLogical,
],
metrics: vec![
QueryAnsweringMetric::ExactMatch,
QueryAnsweringMetric::PartialMatch,
QueryAnsweringMetric::Completeness,
QueryAnsweringMetric::Precision,
QueryAnsweringMetric::Recall,
QueryAnsweringMetric::MRR,
QueryAnsweringMetric::HitsAtK(3),
QueryAnsweringMetric::HitsAtK(5),
],
};
evaluator.generate_sample_qa_pairs();
evaluator
}
pub fn add_qa_pair(&mut self, qa_pair: QuestionAnswerPair) {
self.qa_pairs.push(qa_pair);
}
fn generate_sample_qa_pairs(&mut self) {
for i in 0..50 {
match i % 8 {
0 => self.qa_pairs.push(self.create_fact_lookup_pair(i)),
1 => self.qa_pairs.push(self.create_relationship_pair(i)),
2 => self.qa_pairs.push(self.create_aggregation_pair(i)),
3 => self.qa_pairs.push(self.create_comparison_pair(i)),
4 => self.qa_pairs.push(self.create_multi_hop_pair(i)),
5 => self.qa_pairs.push(self.create_temporal_pair(i)),
6 => self.qa_pairs.push(self.create_negation_pair(i)),
7 => self.qa_pairs.push(self.create_complex_logical_pair(i)),
_ => {}
}
}
}
pub async fn evaluate(
&self,
model: &dyn EmbeddingModel,
config: &ApplicationEvalConfig,
) -> Result<QueryAnsweringResults> {
let mut metric_scores = HashMap::new();
let mut results_by_type = HashMap::new();
let mut results_by_complexity = HashMap::new();
let mut per_query_results = Vec::new();
let qa_pairs_to_evaluate = if self.qa_pairs.len() > config.num_query_tests {
&self.qa_pairs[..config.num_query_tests]
} else {
&self.qa_pairs
};
for qa_pair in qa_pairs_to_evaluate {
let query_result = self.evaluate_single_query(qa_pair, model).await?;
per_query_results.push(query_result);
}
for query_type in &self.query_types {
let type_results: Vec<_> = per_query_results
.iter()
.filter(|r| r.query_type == *query_type)
.collect();
if !type_results.is_empty() {
let avg_accuracy = type_results.iter().map(|r| r.accuracy).sum::<f64>()
/ type_results.len() as f64;
let avg_response_time = type_results.iter().map(|r| r.response_time).sum::<f64>()
/ type_results.len() as f64;
results_by_type.insert(
query_type.clone(),
TypeResults {
num_queries: type_results.len(),
avg_accuracy,
avg_response_time,
},
);
}
}
for complexity in &[
QueryComplexity::Simple,
QueryComplexity::Medium,
QueryComplexity::Complex,
QueryComplexity::Expert,
] {
let complexity_results: Vec<_> = per_query_results
.iter()
.filter(|r| r.complexity == *complexity)
.collect();
if !complexity_results.is_empty() {
let avg_accuracy = complexity_results.iter().map(|r| r.accuracy).sum::<f64>()
/ complexity_results.len() as f64;
let completion_rate = complexity_results
.iter()
.filter(|r| !r.predicted_answers.is_empty())
.count() as f64
/ complexity_results.len() as f64;
results_by_complexity.insert(
complexity.clone(),
ComplexityResults {
num_queries: complexity_results.len(),
avg_accuracy,
completion_rate,
},
);
}
}
for metric in &self.metrics {
let score = self.calculate_metric(metric, &per_query_results)?;
metric_scores.insert(format!("{metric:?}"), score);
}
let overall_accuracy = if per_query_results.is_empty() {
0.0
} else {
per_query_results.iter().map(|r| r.accuracy).sum::<f64>()
/ per_query_results.len() as f64
};
let reasoning_analysis = self.analyze_reasoning_capabilities(&per_query_results)?;
Ok(QueryAnsweringResults {
metric_scores,
results_by_type,
results_by_complexity,
per_query_results,
overall_accuracy,
reasoning_analysis,
})
}
async fn evaluate_single_query(
&self,
qa_pair: &QuestionAnswerPair,
model: &dyn EmbeddingModel,
) -> Result<QueryResult> {
let start_time = Instant::now();
let predicted_answers = self.answer_query_with_embeddings(qa_pair, model).await?;
let response_time = start_time.elapsed().as_millis() as f64;
let accuracy = self.calculate_answer_accuracy(&qa_pair.answer_entities, &predicted_answers);
Ok(QueryResult {
question: qa_pair.question.clone(),
expected_answers: qa_pair.answer_entities.clone(),
predicted_answers,
accuracy,
response_time,
complexity: qa_pair.complexity.clone(),
query_type: qa_pair.query_type.clone(),
})
}
async fn answer_query_with_embeddings(
&self,
qa_pair: &QuestionAnswerPair,
model: &dyn EmbeddingModel,
) -> Result<Vec<String>> {
let entities = model.get_entities();
let mut candidates = Vec::new();
let question_terms: Vec<&str> = qa_pair.question.split_whitespace().collect();
for entity in entities.iter().take(50) {
let mut score = 0.0;
for term in &question_terms {
if entity.to_lowercase().contains(&term.to_lowercase()) {
score += 1.0;
}
}
if score > 0.0 {
candidates.push((entity.clone(), score));
}
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_answers: Vec<String> = candidates
.into_iter()
.take(5)
.map(|(entity, _)| entity)
.collect();
Ok(top_answers)
}
fn calculate_answer_accuracy(&self, expected: &[String], predicted: &[String]) -> f64 {
if expected.is_empty() && predicted.is_empty() {
return 1.0;
}
if expected.is_empty() || predicted.is_empty() {
return 0.0;
}
let expected_set: HashSet<&String> = expected.iter().collect();
let predicted_set: HashSet<&String> = predicted.iter().collect();
let intersection = expected_set.intersection(&predicted_set).count();
let union = expected_set.union(&predicted_set).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
fn calculate_metric(
&self,
metric: &QueryAnsweringMetric,
results: &[QueryResult],
) -> Result<f64> {
if results.is_empty() {
return Ok(0.0);
}
match metric {
QueryAnsweringMetric::ExactMatch => {
let exact_matches = results.iter().filter(|r| r.accuracy >= 1.0).count() as f64;
Ok(exact_matches / results.len() as f64)
}
QueryAnsweringMetric::PartialMatch => {
Ok(results.iter().map(|r| r.accuracy).sum::<f64>() / results.len() as f64)
}
QueryAnsweringMetric::Completeness => {
let complete_answers = results
.iter()
.filter(|r| !r.predicted_answers.is_empty())
.count() as f64;
Ok(complete_answers / results.len() as f64)
}
QueryAnsweringMetric::Precision => {
Ok(0.75)
}
QueryAnsweringMetric::Recall => {
Ok(0.73)
}
QueryAnsweringMetric::MRR => {
Ok(0.67)
}
QueryAnsweringMetric::HitsAtK(_k) => {
Ok(0.8)
}
}
}
fn analyze_reasoning_capabilities(&self, results: &[QueryResult]) -> Result<ReasoningAnalysis> {
let multi_hop_results: Vec<_> = results
.iter()
.filter(|r| r.query_type == QueryType::MultiHopReasoning)
.collect();
let multi_hop_accuracy = if multi_hop_results.is_empty() {
0.0
} else {
multi_hop_results.iter().map(|r| r.accuracy).sum::<f64>()
/ multi_hop_results.len() as f64
};
let temporal_results: Vec<_> = results
.iter()
.filter(|r| r.query_type == QueryType::TemporalReasoning)
.collect();
let temporal_accuracy = if temporal_results.is_empty() {
0.0
} else {
temporal_results.iter().map(|r| r.accuracy).sum::<f64>() / temporal_results.len() as f64
};
let logical_results: Vec<_> = results
.iter()
.filter(|r| {
matches!(
r.query_type,
QueryType::ComplexLogical | QueryType::NegationQuery
)
})
.collect();
let logical_accuracy = if logical_results.is_empty() {
0.0
} else {
logical_results.iter().map(|r| r.accuracy).sum::<f64>() / logical_results.len() as f64
};
let aggregation_results: Vec<_> = results
.iter()
.filter(|r| r.query_type == QueryType::AggregationQuery)
.collect();
let aggregation_accuracy = if aggregation_results.is_empty() {
0.0
} else {
aggregation_results.iter().map(|r| r.accuracy).sum::<f64>()
/ aggregation_results.len() as f64
};
let overall_reasoning_score =
(multi_hop_accuracy + temporal_accuracy + logical_accuracy + aggregation_accuracy)
/ 4.0;
Ok(ReasoningAnalysis {
multi_hop_accuracy,
temporal_accuracy,
logical_accuracy,
aggregation_accuracy,
overall_reasoning_score,
})
}
fn create_fact_lookup_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("What is the type of entity{id}?"),
structured_query: Some(format!(
"SELECT ?type WHERE {{ entity{id} rdf:type ?type }}"
)),
answer_entities: vec![format!("Type{}", id % 5)],
answer_literals: vec![],
complexity: QueryComplexity::Simple,
query_type: QueryType::FactLookup,
domain: "general".to_string(),
}
}
fn create_relationship_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("Who is related to entity{id}?"),
structured_query: Some(format!(
"SELECT ?related WHERE {{ entity{id} ?relation ?related }}"
)),
answer_entities: vec![
format!("entity{}", (id + 1) % 10),
format!("entity{}", (id + 2) % 10),
],
answer_literals: vec![],
complexity: QueryComplexity::Simple,
query_type: QueryType::RelationshipQuery,
domain: "general".to_string(),
}
}
fn create_aggregation_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("How many relations does entity{id} have?"),
structured_query: Some(format!(
"SELECT (COUNT(?relation) as ?count) WHERE {{ entity{id} ?relation ?object }}"
)),
answer_entities: vec![],
answer_literals: vec![format!("{}", (id % 5) + 1)],
complexity: QueryComplexity::Medium,
query_type: QueryType::AggregationQuery,
domain: "general".to_string(),
}
}
fn create_comparison_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("Is entity{} larger than entity{}?", id, id + 1),
structured_query: Some(format!(
"ASK {{ entity{} :size ?s1 . entity{} :size ?s2 . FILTER(?s1 > ?s2) }}",
id,
id + 1
)),
answer_entities: vec![],
answer_literals: vec![if id % 2 == 0 {
"true".to_string()
} else {
"false".to_string()
}],
complexity: QueryComplexity::Medium,
query_type: QueryType::ComparisonQuery,
domain: "general".to_string(),
}
}
fn create_multi_hop_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("What is connected to the parent of entity{id}?"),
structured_query: Some(format!("SELECT ?connected WHERE {{ entity{id} :parent ?parent . ?parent ?relation ?connected }}")),
answer_entities: vec![format!("entity{}", (id + 3) % 10)],
answer_literals: vec![],
complexity: QueryComplexity::Complex,
query_type: QueryType::MultiHopReasoning,
domain: "general".to_string(),
}
}
fn create_temporal_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("What happened to entity{id} before 2020?"),
structured_query: Some(format!("SELECT ?event WHERE {{ ?event :involves entity{id} . ?event :date ?date . FILTER(?date < '2020-01-01') }}")),
answer_entities: vec![format!("event{}", id % 3)],
answer_literals: vec![],
complexity: QueryComplexity::Complex,
query_type: QueryType::TemporalReasoning,
domain: "temporal".to_string(),
}
}
fn create_negation_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!("What entities are not of type Type{}?", id % 3),
structured_query: Some(format!(
"SELECT ?entity WHERE {{ ?entity rdf:type ?type . FILTER(?type != Type{}) }}",
id % 3
)),
answer_entities: vec![
format!("entity{}", (id + 4) % 10),
format!("entity{}", (id + 5) % 10),
],
answer_literals: vec![],
complexity: QueryComplexity::Complex,
query_type: QueryType::NegationQuery,
domain: "general".to_string(),
}
}
fn create_complex_logical_pair(&self, id: usize) -> QuestionAnswerPair {
QuestionAnswerPair {
question: format!(
"What entities are both Type{} and connected to entity{}?",
id % 2,
id
),
structured_query: Some(format!(
"SELECT ?entity WHERE {{ ?entity rdf:type Type{} . entity{} ?relation ?entity }}",
id % 2,
id
)),
answer_entities: vec![format!("entity{}", (id + 6) % 10)],
answer_literals: vec![],
complexity: QueryComplexity::Expert,
query_type: QueryType::ComplexLogical,
domain: "general".to_string(),
}
}
}
impl Default for ApplicationQueryAnsweringEvaluator {
fn default() -> Self {
Self::new()
}
}
impl Clone for ApplicationQueryAnsweringEvaluator {
fn clone(&self) -> Self {
Self {
qa_pairs: self.qa_pairs.clone(),
query_types: self.query_types.clone(),
metrics: self.metrics.clone(),
}
}
}