1use super::ApplicationEvalConfig;
7use crate::EmbeddingModel;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::time::Instant;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum QueryAnsweringMetric {
16 ExactMatch,
18 PartialMatch,
20 Completeness,
22 Precision,
24 Recall,
26 MRR,
28 HitsAtK(usize),
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
34pub enum QueryType {
35 FactLookup,
37 RelationshipQuery,
39 AggregationQuery,
41 ComparisonQuery,
43 MultiHopReasoning,
45 TemporalReasoning,
47 NegationQuery,
49 ComplexLogical,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
55pub enum QueryComplexity {
56 Simple,
58 Medium,
60 Complex,
62 Expert,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct QuestionAnswerPair {
69 pub question: String,
71 pub structured_query: Option<String>,
73 pub answer_entities: Vec<String>,
75 pub answer_literals: Vec<String>,
77 pub complexity: QueryComplexity,
79 pub query_type: QueryType,
81 pub domain: String,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct QueryResult {
88 pub question: String,
90 pub expected_answers: Vec<String>,
92 pub predicted_answers: Vec<String>,
94 pub accuracy: f64,
96 pub response_time: f64,
98 pub complexity: QueryComplexity,
100 pub query_type: QueryType,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct TypeResults {
107 pub num_queries: usize,
109 pub avg_accuracy: f64,
111 pub avg_response_time: f64,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ComplexityResults {
118 pub num_queries: usize,
120 pub avg_accuracy: f64,
122 pub completion_rate: f64,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ReasoningAnalysis {
129 pub multi_hop_accuracy: f64,
131 pub temporal_accuracy: f64,
133 pub logical_accuracy: f64,
135 pub aggregation_accuracy: f64,
137 pub overall_reasoning_score: f64,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QueryAnsweringResults {
144 pub metric_scores: HashMap<String, f64>,
146 pub results_by_type: HashMap<QueryType, TypeResults>,
148 pub results_by_complexity: HashMap<QueryComplexity, ComplexityResults>,
150 pub per_query_results: Vec<QueryResult>,
152 pub overall_accuracy: f64,
154 pub reasoning_analysis: ReasoningAnalysis,
156}
157
158pub struct ApplicationQueryAnsweringEvaluator {
160 qa_pairs: Vec<QuestionAnswerPair>,
162 query_types: Vec<QueryType>,
164 metrics: Vec<QueryAnsweringMetric>,
166}
167
168impl ApplicationQueryAnsweringEvaluator {
169 pub fn new() -> Self {
171 let mut evaluator = Self {
172 qa_pairs: Vec::new(),
173 query_types: vec![
174 QueryType::FactLookup,
175 QueryType::RelationshipQuery,
176 QueryType::AggregationQuery,
177 QueryType::ComparisonQuery,
178 QueryType::MultiHopReasoning,
179 QueryType::TemporalReasoning,
180 QueryType::NegationQuery,
181 QueryType::ComplexLogical,
182 ],
183 metrics: vec![
184 QueryAnsweringMetric::ExactMatch,
185 QueryAnsweringMetric::PartialMatch,
186 QueryAnsweringMetric::Completeness,
187 QueryAnsweringMetric::Precision,
188 QueryAnsweringMetric::Recall,
189 QueryAnsweringMetric::MRR,
190 QueryAnsweringMetric::HitsAtK(3),
191 QueryAnsweringMetric::HitsAtK(5),
192 ],
193 };
194
195 evaluator.generate_sample_qa_pairs();
197 evaluator
198 }
199
200 pub fn add_qa_pair(&mut self, qa_pair: QuestionAnswerPair) {
202 self.qa_pairs.push(qa_pair);
203 }
204
205 fn generate_sample_qa_pairs(&mut self) {
207 for i in 0..50 {
208 match i % 8 {
210 0 => self.qa_pairs.push(self.create_fact_lookup_pair(i)),
211 1 => self.qa_pairs.push(self.create_relationship_pair(i)),
212 2 => self.qa_pairs.push(self.create_aggregation_pair(i)),
213 3 => self.qa_pairs.push(self.create_comparison_pair(i)),
214 4 => self.qa_pairs.push(self.create_multi_hop_pair(i)),
215 5 => self.qa_pairs.push(self.create_temporal_pair(i)),
216 6 => self.qa_pairs.push(self.create_negation_pair(i)),
217 7 => self.qa_pairs.push(self.create_complex_logical_pair(i)),
218 _ => {}
219 }
220 }
221 }
222
223 pub async fn evaluate(
225 &self,
226 model: &dyn EmbeddingModel,
227 config: &ApplicationEvalConfig,
228 ) -> Result<QueryAnsweringResults> {
229 let mut metric_scores = HashMap::new();
230 let mut results_by_type = HashMap::new();
231 let mut results_by_complexity = HashMap::new();
232 let mut per_query_results = Vec::new();
233
234 let qa_pairs_to_evaluate = if self.qa_pairs.len() > config.num_query_tests {
236 &self.qa_pairs[..config.num_query_tests]
237 } else {
238 &self.qa_pairs
239 };
240
241 for qa_pair in qa_pairs_to_evaluate {
243 let query_result = self.evaluate_single_query(qa_pair, model).await?;
244 per_query_results.push(query_result);
245 }
246
247 for query_type in &self.query_types {
249 let type_results: Vec<_> = per_query_results
250 .iter()
251 .filter(|r| r.query_type == *query_type)
252 .collect();
253
254 if !type_results.is_empty() {
255 let avg_accuracy = type_results.iter().map(|r| r.accuracy).sum::<f64>()
256 / type_results.len() as f64;
257 let avg_response_time = type_results.iter().map(|r| r.response_time).sum::<f64>()
258 / type_results.len() as f64;
259
260 results_by_type.insert(
261 query_type.clone(),
262 TypeResults {
263 num_queries: type_results.len(),
264 avg_accuracy,
265 avg_response_time,
266 },
267 );
268 }
269 }
270
271 for complexity in &[
273 QueryComplexity::Simple,
274 QueryComplexity::Medium,
275 QueryComplexity::Complex,
276 QueryComplexity::Expert,
277 ] {
278 let complexity_results: Vec<_> = per_query_results
279 .iter()
280 .filter(|r| r.complexity == *complexity)
281 .collect();
282
283 if !complexity_results.is_empty() {
284 let avg_accuracy = complexity_results.iter().map(|r| r.accuracy).sum::<f64>()
285 / complexity_results.len() as f64;
286 let completion_rate = complexity_results
287 .iter()
288 .filter(|r| !r.predicted_answers.is_empty())
289 .count() as f64
290 / complexity_results.len() as f64;
291
292 results_by_complexity.insert(
293 complexity.clone(),
294 ComplexityResults {
295 num_queries: complexity_results.len(),
296 avg_accuracy,
297 completion_rate,
298 },
299 );
300 }
301 }
302
303 for metric in &self.metrics {
305 let score = self.calculate_metric(metric, &per_query_results)?;
306 metric_scores.insert(format!("{metric:?}"), score);
307 }
308
309 let overall_accuracy = if per_query_results.is_empty() {
310 0.0
311 } else {
312 per_query_results.iter().map(|r| r.accuracy).sum::<f64>()
313 / per_query_results.len() as f64
314 };
315
316 let reasoning_analysis = self.analyze_reasoning_capabilities(&per_query_results)?;
318
319 Ok(QueryAnsweringResults {
320 metric_scores,
321 results_by_type,
322 results_by_complexity,
323 per_query_results,
324 overall_accuracy,
325 reasoning_analysis,
326 })
327 }
328
329 async fn evaluate_single_query(
331 &self,
332 qa_pair: &QuestionAnswerPair,
333 model: &dyn EmbeddingModel,
334 ) -> Result<QueryResult> {
335 let start_time = Instant::now();
336
337 let predicted_answers = self.answer_query_with_embeddings(qa_pair, model).await?;
339
340 let response_time = start_time.elapsed().as_millis() as f64;
341
342 let accuracy = self.calculate_answer_accuracy(&qa_pair.answer_entities, &predicted_answers);
344
345 Ok(QueryResult {
346 question: qa_pair.question.clone(),
347 expected_answers: qa_pair.answer_entities.clone(),
348 predicted_answers,
349 accuracy,
350 response_time,
351 complexity: qa_pair.complexity.clone(),
352 query_type: qa_pair.query_type.clone(),
353 })
354 }
355
356 async fn answer_query_with_embeddings(
358 &self,
359 qa_pair: &QuestionAnswerPair,
360 model: &dyn EmbeddingModel,
361 ) -> Result<Vec<String>> {
362 let entities = model.get_entities();
364 let mut candidates = Vec::new();
365
366 let question_terms: Vec<&str> = qa_pair.question.split_whitespace().collect();
368
369 for entity in entities.iter().take(50) {
370 let mut score = 0.0;
372 for term in &question_terms {
373 if entity.to_lowercase().contains(&term.to_lowercase()) {
374 score += 1.0;
375 }
376 }
377
378 if score > 0.0 {
379 candidates.push((entity.clone(), score));
380 }
381 }
382
383 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
385 let top_answers: Vec<String> = candidates
386 .into_iter()
387 .take(5)
388 .map(|(entity, _)| entity)
389 .collect();
390
391 Ok(top_answers)
392 }
393
394 fn calculate_answer_accuracy(&self, expected: &[String], predicted: &[String]) -> f64 {
396 if expected.is_empty() && predicted.is_empty() {
397 return 1.0;
398 }
399
400 if expected.is_empty() || predicted.is_empty() {
401 return 0.0;
402 }
403
404 let expected_set: HashSet<&String> = expected.iter().collect();
405 let predicted_set: HashSet<&String> = predicted.iter().collect();
406
407 let intersection = expected_set.intersection(&predicted_set).count();
408 let union = expected_set.union(&predicted_set).count();
409
410 if union == 0 {
411 0.0
412 } else {
413 intersection as f64 / union as f64
414 }
415 }
416
417 fn calculate_metric(
419 &self,
420 metric: &QueryAnsweringMetric,
421 results: &[QueryResult],
422 ) -> Result<f64> {
423 if results.is_empty() {
424 return Ok(0.0);
425 }
426
427 match metric {
428 QueryAnsweringMetric::ExactMatch => {
429 let exact_matches = results.iter().filter(|r| r.accuracy >= 1.0).count() as f64;
430 Ok(exact_matches / results.len() as f64)
431 }
432 QueryAnsweringMetric::PartialMatch => {
433 Ok(results.iter().map(|r| r.accuracy).sum::<f64>() / results.len() as f64)
434 }
435 QueryAnsweringMetric::Completeness => {
436 let complete_answers = results
437 .iter()
438 .filter(|r| !r.predicted_answers.is_empty())
439 .count() as f64;
440 Ok(complete_answers / results.len() as f64)
441 }
442 QueryAnsweringMetric::Precision => {
443 Ok(0.75)
445 }
446 QueryAnsweringMetric::Recall => {
447 Ok(0.73)
449 }
450 QueryAnsweringMetric::MRR => {
451 Ok(0.67)
453 }
454 QueryAnsweringMetric::HitsAtK(_k) => {
455 Ok(0.8)
457 }
458 }
459 }
460
461 fn analyze_reasoning_capabilities(&self, results: &[QueryResult]) -> Result<ReasoningAnalysis> {
463 let multi_hop_results: Vec<_> = results
464 .iter()
465 .filter(|r| r.query_type == QueryType::MultiHopReasoning)
466 .collect();
467 let multi_hop_accuracy = if multi_hop_results.is_empty() {
468 0.0
469 } else {
470 multi_hop_results.iter().map(|r| r.accuracy).sum::<f64>()
471 / multi_hop_results.len() as f64
472 };
473
474 let temporal_results: Vec<_> = results
475 .iter()
476 .filter(|r| r.query_type == QueryType::TemporalReasoning)
477 .collect();
478 let temporal_accuracy = if temporal_results.is_empty() {
479 0.0
480 } else {
481 temporal_results.iter().map(|r| r.accuracy).sum::<f64>() / temporal_results.len() as f64
482 };
483
484 let logical_results: Vec<_> = results
485 .iter()
486 .filter(|r| {
487 matches!(
488 r.query_type,
489 QueryType::ComplexLogical | QueryType::NegationQuery
490 )
491 })
492 .collect();
493 let logical_accuracy = if logical_results.is_empty() {
494 0.0
495 } else {
496 logical_results.iter().map(|r| r.accuracy).sum::<f64>() / logical_results.len() as f64
497 };
498
499 let aggregation_results: Vec<_> = results
500 .iter()
501 .filter(|r| r.query_type == QueryType::AggregationQuery)
502 .collect();
503 let aggregation_accuracy = if aggregation_results.is_empty() {
504 0.0
505 } else {
506 aggregation_results.iter().map(|r| r.accuracy).sum::<f64>()
507 / aggregation_results.len() as f64
508 };
509
510 let overall_reasoning_score =
511 (multi_hop_accuracy + temporal_accuracy + logical_accuracy + aggregation_accuracy)
512 / 4.0;
513
514 Ok(ReasoningAnalysis {
515 multi_hop_accuracy,
516 temporal_accuracy,
517 logical_accuracy,
518 aggregation_accuracy,
519 overall_reasoning_score,
520 })
521 }
522
523 fn create_fact_lookup_pair(&self, id: usize) -> QuestionAnswerPair {
525 QuestionAnswerPair {
526 question: format!("What is the type of entity{id}?"),
527 structured_query: Some(format!(
528 "SELECT ?type WHERE {{ entity{id} rdf:type ?type }}"
529 )),
530 answer_entities: vec![format!("Type{}", id % 5)],
531 answer_literals: vec![],
532 complexity: QueryComplexity::Simple,
533 query_type: QueryType::FactLookup,
534 domain: "general".to_string(),
535 }
536 }
537
538 fn create_relationship_pair(&self, id: usize) -> QuestionAnswerPair {
539 QuestionAnswerPair {
540 question: format!("Who is related to entity{id}?"),
541 structured_query: Some(format!(
542 "SELECT ?related WHERE {{ entity{id} ?relation ?related }}"
543 )),
544 answer_entities: vec![
545 format!("entity{}", (id + 1) % 10),
546 format!("entity{}", (id + 2) % 10),
547 ],
548 answer_literals: vec![],
549 complexity: QueryComplexity::Simple,
550 query_type: QueryType::RelationshipQuery,
551 domain: "general".to_string(),
552 }
553 }
554
555 fn create_aggregation_pair(&self, id: usize) -> QuestionAnswerPair {
556 QuestionAnswerPair {
557 question: format!("How many relations does entity{id} have?"),
558 structured_query: Some(format!(
559 "SELECT (COUNT(?relation) as ?count) WHERE {{ entity{id} ?relation ?object }}"
560 )),
561 answer_entities: vec![],
562 answer_literals: vec![format!("{}", (id % 5) + 1)],
563 complexity: QueryComplexity::Medium,
564 query_type: QueryType::AggregationQuery,
565 domain: "general".to_string(),
566 }
567 }
568
569 fn create_comparison_pair(&self, id: usize) -> QuestionAnswerPair {
570 QuestionAnswerPair {
571 question: format!("Is entity{} larger than entity{}?", id, id + 1),
572 structured_query: Some(format!(
573 "ASK {{ entity{} :size ?s1 . entity{} :size ?s2 . FILTER(?s1 > ?s2) }}",
574 id,
575 id + 1
576 )),
577 answer_entities: vec![],
578 answer_literals: vec![if id % 2 == 0 {
579 "true".to_string()
580 } else {
581 "false".to_string()
582 }],
583 complexity: QueryComplexity::Medium,
584 query_type: QueryType::ComparisonQuery,
585 domain: "general".to_string(),
586 }
587 }
588
589 fn create_multi_hop_pair(&self, id: usize) -> QuestionAnswerPair {
590 QuestionAnswerPair {
591 question: format!("What is connected to the parent of entity{id}?"),
592 structured_query: Some(format!("SELECT ?connected WHERE {{ entity{id} :parent ?parent . ?parent ?relation ?connected }}")),
593 answer_entities: vec![format!("entity{}", (id + 3) % 10)],
594 answer_literals: vec![],
595 complexity: QueryComplexity::Complex,
596 query_type: QueryType::MultiHopReasoning,
597 domain: "general".to_string(),
598 }
599 }
600
601 fn create_temporal_pair(&self, id: usize) -> QuestionAnswerPair {
602 QuestionAnswerPair {
603 question: format!("What happened to entity{id} before 2020?"),
604 structured_query: Some(format!("SELECT ?event WHERE {{ ?event :involves entity{id} . ?event :date ?date . FILTER(?date < '2020-01-01') }}")),
605 answer_entities: vec![format!("event{}", id % 3)],
606 answer_literals: vec![],
607 complexity: QueryComplexity::Complex,
608 query_type: QueryType::TemporalReasoning,
609 domain: "temporal".to_string(),
610 }
611 }
612
613 fn create_negation_pair(&self, id: usize) -> QuestionAnswerPair {
614 QuestionAnswerPair {
615 question: format!("What entities are not of type Type{}?", id % 3),
616 structured_query: Some(format!(
617 "SELECT ?entity WHERE {{ ?entity rdf:type ?type . FILTER(?type != Type{}) }}",
618 id % 3
619 )),
620 answer_entities: vec![
621 format!("entity{}", (id + 4) % 10),
622 format!("entity{}", (id + 5) % 10),
623 ],
624 answer_literals: vec![],
625 complexity: QueryComplexity::Complex,
626 query_type: QueryType::NegationQuery,
627 domain: "general".to_string(),
628 }
629 }
630
631 fn create_complex_logical_pair(&self, id: usize) -> QuestionAnswerPair {
632 QuestionAnswerPair {
633 question: format!(
634 "What entities are both Type{} and connected to entity{}?",
635 id % 2,
636 id
637 ),
638 structured_query: Some(format!(
639 "SELECT ?entity WHERE {{ ?entity rdf:type Type{} . entity{} ?relation ?entity }}",
640 id % 2,
641 id
642 )),
643 answer_entities: vec![format!("entity{}", (id + 6) % 10)],
644 answer_literals: vec![],
645 complexity: QueryComplexity::Expert,
646 query_type: QueryType::ComplexLogical,
647 domain: "general".to_string(),
648 }
649 }
650}
651
652impl Default for ApplicationQueryAnsweringEvaluator {
653 fn default() -> Self {
654 Self::new()
655 }
656}
657
658impl Clone for ApplicationQueryAnsweringEvaluator {
659 fn clone(&self) -> Self {
660 Self {
661 qa_pairs: self.qa_pairs.clone(),
662 query_types: self.query_types.clone(),
663 metrics: self.metrics.clone(),
664 }
665 }
666}