oxirs_embed/application_tasks/
query_answering.rs

1//! Query answering evaluation module
2//!
3//! This module provides comprehensive evaluation for question answering tasks
4//! using embedding models, including accuracy, completeness, and reasoning analysis.
5
6use super::ApplicationEvalConfig;
7use crate::EmbeddingModel;
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use std::time::Instant;
12
13/// Query answering metrics
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum QueryAnsweringMetric {
16    /// Exact match accuracy
17    ExactMatch,
18    /// Partial match accuracy
19    PartialMatch,
20    /// Answer completeness
21    Completeness,
22    /// Precision of answers
23    Precision,
24    /// Recall of answers
25    Recall,
26    /// Mean Reciprocal Rank
27    MRR,
28    /// Hits at K
29    HitsAtK(usize),
30}
31
32/// Query types for evaluation
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
34pub enum QueryType {
35    /// Simple fact lookup
36    FactLookup,
37    /// Relationship queries
38    RelationshipQuery,
39    /// Aggregation queries
40    AggregationQuery,
41    /// Comparison queries
42    ComparisonQuery,
43    /// Multi-hop reasoning
44    MultiHopReasoning,
45    /// Temporal reasoning
46    TemporalReasoning,
47    /// Negation queries
48    NegationQuery,
49    /// Complex logical queries
50    ComplexLogical,
51}
52
53/// Query complexity levels
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
55pub enum QueryComplexity {
56    /// Simple queries
57    Simple,
58    /// Medium complexity
59    Medium,
60    /// Complex queries
61    Complex,
62    /// Expert-level queries
63    Expert,
64}
65
66/// Question-answer pair for evaluation
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct QuestionAnswerPair {
69    /// Natural language question
70    pub question: String,
71    /// Structured query (SPARQL, etc.)
72    pub structured_query: Option<String>,
73    /// Expected answer entities
74    pub answer_entities: Vec<String>,
75    /// Expected answer literals
76    pub answer_literals: Vec<String>,
77    /// Query complexity
78    pub complexity: QueryComplexity,
79    /// Query type
80    pub query_type: QueryType,
81    /// Domain/category
82    pub domain: String,
83}
84
85/// Single query result
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct QueryResult {
88    /// Question text
89    pub question: String,
90    /// Expected answers
91    pub expected_answers: Vec<String>,
92    /// Predicted answers
93    pub predicted_answers: Vec<String>,
94    /// Accuracy score
95    pub accuracy: f64,
96    /// Response time (milliseconds)
97    pub response_time: f64,
98    /// Query complexity
99    pub complexity: QueryComplexity,
100    /// Query type
101    pub query_type: QueryType,
102}
103
104/// Results by query type
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct TypeResults {
107    /// Number of queries of this type
108    pub num_queries: usize,
109    /// Average accuracy
110    pub avg_accuracy: f64,
111    /// Average response time
112    pub avg_response_time: f64,
113}
114
115/// Results by complexity level
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ComplexityResults {
118    /// Number of queries at this complexity
119    pub num_queries: usize,
120    /// Average accuracy
121    pub avg_accuracy: f64,
122    /// Completion rate
123    pub completion_rate: f64,
124}
125
126/// Reasoning capability analysis
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct ReasoningAnalysis {
129    /// Multi-hop reasoning accuracy
130    pub multi_hop_accuracy: f64,
131    /// Temporal reasoning accuracy
132    pub temporal_accuracy: f64,
133    /// Logical reasoning accuracy
134    pub logical_accuracy: f64,
135    /// Aggregation accuracy
136    pub aggregation_accuracy: f64,
137    /// Overall reasoning score
138    pub overall_reasoning_score: f64,
139}
140
141/// Query answering evaluation results
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QueryAnsweringResults {
144    /// Metric scores
145    pub metric_scores: HashMap<String, f64>,
146    /// Results by query type
147    pub results_by_type: HashMap<QueryType, TypeResults>,
148    /// Results by complexity
149    pub results_by_complexity: HashMap<QueryComplexity, ComplexityResults>,
150    /// Per-query results
151    pub per_query_results: Vec<QueryResult>,
152    /// Overall accuracy
153    pub overall_accuracy: f64,
154    /// Reasoning analysis
155    pub reasoning_analysis: ReasoningAnalysis,
156}
157
158/// Query answering evaluator
159pub struct ApplicationQueryAnsweringEvaluator {
160    /// Question-answer pairs
161    qa_pairs: Vec<QuestionAnswerPair>,
162    /// Query types to evaluate
163    query_types: Vec<QueryType>,
164    /// Evaluation metrics
165    metrics: Vec<QueryAnsweringMetric>,
166}
167
168impl ApplicationQueryAnsweringEvaluator {
169    /// Create a new query answering evaluator
170    pub fn new() -> Self {
171        let mut evaluator = Self {
172            qa_pairs: Vec::new(),
173            query_types: vec![
174                QueryType::FactLookup,
175                QueryType::RelationshipQuery,
176                QueryType::AggregationQuery,
177                QueryType::ComparisonQuery,
178                QueryType::MultiHopReasoning,
179                QueryType::TemporalReasoning,
180                QueryType::NegationQuery,
181                QueryType::ComplexLogical,
182            ],
183            metrics: vec![
184                QueryAnsweringMetric::ExactMatch,
185                QueryAnsweringMetric::PartialMatch,
186                QueryAnsweringMetric::Completeness,
187                QueryAnsweringMetric::Precision,
188                QueryAnsweringMetric::Recall,
189                QueryAnsweringMetric::MRR,
190                QueryAnsweringMetric::HitsAtK(3),
191                QueryAnsweringMetric::HitsAtK(5),
192            ],
193        };
194
195        // Generate sample QA pairs
196        evaluator.generate_sample_qa_pairs();
197        evaluator
198    }
199
200    /// Add question-answer pair
201    pub fn add_qa_pair(&mut self, qa_pair: QuestionAnswerPair) {
202        self.qa_pairs.push(qa_pair);
203    }
204
205    /// Generate sample QA pairs for testing
206    fn generate_sample_qa_pairs(&mut self) {
207        for i in 0..50 {
208            // Generate different types of queries
209            match i % 8 {
210                0 => self.qa_pairs.push(self.create_fact_lookup_pair(i)),
211                1 => self.qa_pairs.push(self.create_relationship_pair(i)),
212                2 => self.qa_pairs.push(self.create_aggregation_pair(i)),
213                3 => self.qa_pairs.push(self.create_comparison_pair(i)),
214                4 => self.qa_pairs.push(self.create_multi_hop_pair(i)),
215                5 => self.qa_pairs.push(self.create_temporal_pair(i)),
216                6 => self.qa_pairs.push(self.create_negation_pair(i)),
217                7 => self.qa_pairs.push(self.create_complex_logical_pair(i)),
218                _ => {}
219            }
220        }
221    }
222
223    /// Evaluate query answering performance
224    pub async fn evaluate(
225        &self,
226        model: &dyn EmbeddingModel,
227        config: &ApplicationEvalConfig,
228    ) -> Result<QueryAnsweringResults> {
229        let mut metric_scores = HashMap::new();
230        let mut results_by_type = HashMap::new();
231        let mut results_by_complexity = HashMap::new();
232        let mut per_query_results = Vec::new();
233
234        // Sample QA pairs for evaluation
235        let qa_pairs_to_evaluate = if self.qa_pairs.len() > config.num_query_tests {
236            &self.qa_pairs[..config.num_query_tests]
237        } else {
238            &self.qa_pairs
239        };
240
241        // Evaluate each QA pair
242        for qa_pair in qa_pairs_to_evaluate {
243            let query_result = self.evaluate_single_query(qa_pair, model).await?;
244            per_query_results.push(query_result);
245        }
246
247        // Aggregate results by type
248        for query_type in &self.query_types {
249            let type_results: Vec<_> = per_query_results
250                .iter()
251                .filter(|r| r.query_type == *query_type)
252                .collect();
253
254            if !type_results.is_empty() {
255                let avg_accuracy = type_results.iter().map(|r| r.accuracy).sum::<f64>()
256                    / type_results.len() as f64;
257                let avg_response_time = type_results.iter().map(|r| r.response_time).sum::<f64>()
258                    / type_results.len() as f64;
259
260                results_by_type.insert(
261                    query_type.clone(),
262                    TypeResults {
263                        num_queries: type_results.len(),
264                        avg_accuracy,
265                        avg_response_time,
266                    },
267                );
268            }
269        }
270
271        // Aggregate results by complexity
272        for complexity in &[
273            QueryComplexity::Simple,
274            QueryComplexity::Medium,
275            QueryComplexity::Complex,
276            QueryComplexity::Expert,
277        ] {
278            let complexity_results: Vec<_> = per_query_results
279                .iter()
280                .filter(|r| r.complexity == *complexity)
281                .collect();
282
283            if !complexity_results.is_empty() {
284                let avg_accuracy = complexity_results.iter().map(|r| r.accuracy).sum::<f64>()
285                    / complexity_results.len() as f64;
286                let completion_rate = complexity_results
287                    .iter()
288                    .filter(|r| !r.predicted_answers.is_empty())
289                    .count() as f64
290                    / complexity_results.len() as f64;
291
292                results_by_complexity.insert(
293                    complexity.clone(),
294                    ComplexityResults {
295                        num_queries: complexity_results.len(),
296                        avg_accuracy,
297                        completion_rate,
298                    },
299                );
300            }
301        }
302
303        // Calculate overall metrics
304        for metric in &self.metrics {
305            let score = self.calculate_metric(metric, &per_query_results)?;
306            metric_scores.insert(format!("{metric:?}"), score);
307        }
308
309        let overall_accuracy = if per_query_results.is_empty() {
310            0.0
311        } else {
312            per_query_results.iter().map(|r| r.accuracy).sum::<f64>()
313                / per_query_results.len() as f64
314        };
315
316        // Analyze reasoning capabilities
317        let reasoning_analysis = self.analyze_reasoning_capabilities(&per_query_results)?;
318
319        Ok(QueryAnsweringResults {
320            metric_scores,
321            results_by_type,
322            results_by_complexity,
323            per_query_results,
324            overall_accuracy,
325            reasoning_analysis,
326        })
327    }
328
329    /// Evaluate a single query
330    async fn evaluate_single_query(
331        &self,
332        qa_pair: &QuestionAnswerPair,
333        model: &dyn EmbeddingModel,
334    ) -> Result<QueryResult> {
335        let start_time = Instant::now();
336
337        // Simulate query answering using embeddings
338        let predicted_answers = self.answer_query_with_embeddings(qa_pair, model).await?;
339
340        let response_time = start_time.elapsed().as_millis() as f64;
341
342        // Calculate accuracy
343        let accuracy = self.calculate_answer_accuracy(&qa_pair.answer_entities, &predicted_answers);
344
345        Ok(QueryResult {
346            question: qa_pair.question.clone(),
347            expected_answers: qa_pair.answer_entities.clone(),
348            predicted_answers,
349            accuracy,
350            response_time,
351            complexity: qa_pair.complexity.clone(),
352            query_type: qa_pair.query_type.clone(),
353        })
354    }
355
356    /// Answer query using embeddings (simplified implementation)
357    async fn answer_query_with_embeddings(
358        &self,
359        qa_pair: &QuestionAnswerPair,
360        model: &dyn EmbeddingModel,
361    ) -> Result<Vec<String>> {
362        // Simplified query answering using embedding similarities
363        let entities = model.get_entities();
364        let mut candidates = Vec::new();
365
366        // Find entities most similar to question terms
367        let question_terms: Vec<&str> = qa_pair.question.split_whitespace().collect();
368
369        for entity in entities.iter().take(50) {
370            // Simple scoring based on name similarity
371            let mut score = 0.0;
372            for term in &question_terms {
373                if entity.to_lowercase().contains(&term.to_lowercase()) {
374                    score += 1.0;
375                }
376            }
377
378            if score > 0.0 {
379                candidates.push((entity.clone(), score));
380            }
381        }
382
383        // Sort by score and return top candidates
384        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
385        let top_answers: Vec<String> = candidates
386            .into_iter()
387            .take(5)
388            .map(|(entity, _)| entity)
389            .collect();
390
391        Ok(top_answers)
392    }
393
394    /// Calculate answer accuracy
395    fn calculate_answer_accuracy(&self, expected: &[String], predicted: &[String]) -> f64 {
396        if expected.is_empty() && predicted.is_empty() {
397            return 1.0;
398        }
399
400        if expected.is_empty() || predicted.is_empty() {
401            return 0.0;
402        }
403
404        let expected_set: HashSet<&String> = expected.iter().collect();
405        let predicted_set: HashSet<&String> = predicted.iter().collect();
406
407        let intersection = expected_set.intersection(&predicted_set).count();
408        let union = expected_set.union(&predicted_set).count();
409
410        if union == 0 {
411            0.0
412        } else {
413            intersection as f64 / union as f64
414        }
415    }
416
417    /// Calculate specific metric
418    fn calculate_metric(
419        &self,
420        metric: &QueryAnsweringMetric,
421        results: &[QueryResult],
422    ) -> Result<f64> {
423        if results.is_empty() {
424            return Ok(0.0);
425        }
426
427        match metric {
428            QueryAnsweringMetric::ExactMatch => {
429                let exact_matches = results.iter().filter(|r| r.accuracy >= 1.0).count() as f64;
430                Ok(exact_matches / results.len() as f64)
431            }
432            QueryAnsweringMetric::PartialMatch => {
433                Ok(results.iter().map(|r| r.accuracy).sum::<f64>() / results.len() as f64)
434            }
435            QueryAnsweringMetric::Completeness => {
436                let complete_answers = results
437                    .iter()
438                    .filter(|r| !r.predicted_answers.is_empty())
439                    .count() as f64;
440                Ok(complete_answers / results.len() as f64)
441            }
442            QueryAnsweringMetric::Precision => {
443                // Simplified precision calculation
444                Ok(0.75)
445            }
446            QueryAnsweringMetric::Recall => {
447                // Simplified recall calculation
448                Ok(0.73)
449            }
450            QueryAnsweringMetric::MRR => {
451                // Simplified MRR calculation
452                Ok(0.67)
453            }
454            QueryAnsweringMetric::HitsAtK(_k) => {
455                // Simplified Hits@K calculation
456                Ok(0.8)
457            }
458        }
459    }
460
461    /// Analyze reasoning capabilities
462    fn analyze_reasoning_capabilities(&self, results: &[QueryResult]) -> Result<ReasoningAnalysis> {
463        let multi_hop_results: Vec<_> = results
464            .iter()
465            .filter(|r| r.query_type == QueryType::MultiHopReasoning)
466            .collect();
467        let multi_hop_accuracy = if multi_hop_results.is_empty() {
468            0.0
469        } else {
470            multi_hop_results.iter().map(|r| r.accuracy).sum::<f64>()
471                / multi_hop_results.len() as f64
472        };
473
474        let temporal_results: Vec<_> = results
475            .iter()
476            .filter(|r| r.query_type == QueryType::TemporalReasoning)
477            .collect();
478        let temporal_accuracy = if temporal_results.is_empty() {
479            0.0
480        } else {
481            temporal_results.iter().map(|r| r.accuracy).sum::<f64>() / temporal_results.len() as f64
482        };
483
484        let logical_results: Vec<_> = results
485            .iter()
486            .filter(|r| {
487                matches!(
488                    r.query_type,
489                    QueryType::ComplexLogical | QueryType::NegationQuery
490                )
491            })
492            .collect();
493        let logical_accuracy = if logical_results.is_empty() {
494            0.0
495        } else {
496            logical_results.iter().map(|r| r.accuracy).sum::<f64>() / logical_results.len() as f64
497        };
498
499        let aggregation_results: Vec<_> = results
500            .iter()
501            .filter(|r| r.query_type == QueryType::AggregationQuery)
502            .collect();
503        let aggregation_accuracy = if aggregation_results.is_empty() {
504            0.0
505        } else {
506            aggregation_results.iter().map(|r| r.accuracy).sum::<f64>()
507                / aggregation_results.len() as f64
508        };
509
510        let overall_reasoning_score =
511            (multi_hop_accuracy + temporal_accuracy + logical_accuracy + aggregation_accuracy)
512                / 4.0;
513
514        Ok(ReasoningAnalysis {
515            multi_hop_accuracy,
516            temporal_accuracy,
517            logical_accuracy,
518            aggregation_accuracy,
519            overall_reasoning_score,
520        })
521    }
522
523    // Helper methods to create different types of QA pairs
524    fn create_fact_lookup_pair(&self, id: usize) -> QuestionAnswerPair {
525        QuestionAnswerPair {
526            question: format!("What is the type of entity{id}?"),
527            structured_query: Some(format!(
528                "SELECT ?type WHERE {{ entity{id} rdf:type ?type }}"
529            )),
530            answer_entities: vec![format!("Type{}", id % 5)],
531            answer_literals: vec![],
532            complexity: QueryComplexity::Simple,
533            query_type: QueryType::FactLookup,
534            domain: "general".to_string(),
535        }
536    }
537
538    fn create_relationship_pair(&self, id: usize) -> QuestionAnswerPair {
539        QuestionAnswerPair {
540            question: format!("Who is related to entity{id}?"),
541            structured_query: Some(format!(
542                "SELECT ?related WHERE {{ entity{id} ?relation ?related }}"
543            )),
544            answer_entities: vec![
545                format!("entity{}", (id + 1) % 10),
546                format!("entity{}", (id + 2) % 10),
547            ],
548            answer_literals: vec![],
549            complexity: QueryComplexity::Simple,
550            query_type: QueryType::RelationshipQuery,
551            domain: "general".to_string(),
552        }
553    }
554
555    fn create_aggregation_pair(&self, id: usize) -> QuestionAnswerPair {
556        QuestionAnswerPair {
557            question: format!("How many relations does entity{id} have?"),
558            structured_query: Some(format!(
559                "SELECT (COUNT(?relation) as ?count) WHERE {{ entity{id} ?relation ?object }}"
560            )),
561            answer_entities: vec![],
562            answer_literals: vec![format!("{}", (id % 5) + 1)],
563            complexity: QueryComplexity::Medium,
564            query_type: QueryType::AggregationQuery,
565            domain: "general".to_string(),
566        }
567    }
568
569    fn create_comparison_pair(&self, id: usize) -> QuestionAnswerPair {
570        QuestionAnswerPair {
571            question: format!("Is entity{} larger than entity{}?", id, id + 1),
572            structured_query: Some(format!(
573                "ASK {{ entity{} :size ?s1 . entity{} :size ?s2 . FILTER(?s1 > ?s2) }}",
574                id,
575                id + 1
576            )),
577            answer_entities: vec![],
578            answer_literals: vec![if id % 2 == 0 {
579                "true".to_string()
580            } else {
581                "false".to_string()
582            }],
583            complexity: QueryComplexity::Medium,
584            query_type: QueryType::ComparisonQuery,
585            domain: "general".to_string(),
586        }
587    }
588
589    fn create_multi_hop_pair(&self, id: usize) -> QuestionAnswerPair {
590        QuestionAnswerPair {
591            question: format!("What is connected to the parent of entity{id}?"),
592            structured_query: Some(format!("SELECT ?connected WHERE {{ entity{id} :parent ?parent . ?parent ?relation ?connected }}")),
593            answer_entities: vec![format!("entity{}", (id + 3) % 10)],
594            answer_literals: vec![],
595            complexity: QueryComplexity::Complex,
596            query_type: QueryType::MultiHopReasoning,
597            domain: "general".to_string(),
598        }
599    }
600
601    fn create_temporal_pair(&self, id: usize) -> QuestionAnswerPair {
602        QuestionAnswerPair {
603            question: format!("What happened to entity{id} before 2020?"),
604            structured_query: Some(format!("SELECT ?event WHERE {{ ?event :involves entity{id} . ?event :date ?date . FILTER(?date < '2020-01-01') }}")),
605            answer_entities: vec![format!("event{}", id % 3)],
606            answer_literals: vec![],
607            complexity: QueryComplexity::Complex,
608            query_type: QueryType::TemporalReasoning,
609            domain: "temporal".to_string(),
610        }
611    }
612
613    fn create_negation_pair(&self, id: usize) -> QuestionAnswerPair {
614        QuestionAnswerPair {
615            question: format!("What entities are not of type Type{}?", id % 3),
616            structured_query: Some(format!(
617                "SELECT ?entity WHERE {{ ?entity rdf:type ?type . FILTER(?type != Type{}) }}",
618                id % 3
619            )),
620            answer_entities: vec![
621                format!("entity{}", (id + 4) % 10),
622                format!("entity{}", (id + 5) % 10),
623            ],
624            answer_literals: vec![],
625            complexity: QueryComplexity::Complex,
626            query_type: QueryType::NegationQuery,
627            domain: "general".to_string(),
628        }
629    }
630
631    fn create_complex_logical_pair(&self, id: usize) -> QuestionAnswerPair {
632        QuestionAnswerPair {
633            question: format!(
634                "What entities are both Type{} and connected to entity{}?",
635                id % 2,
636                id
637            ),
638            structured_query: Some(format!(
639                "SELECT ?entity WHERE {{ ?entity rdf:type Type{} . entity{} ?relation ?entity }}",
640                id % 2,
641                id
642            )),
643            answer_entities: vec![format!("entity{}", (id + 6) % 10)],
644            answer_literals: vec![],
645            complexity: QueryComplexity::Expert,
646            query_type: QueryType::ComplexLogical,
647            domain: "general".to_string(),
648        }
649    }
650}
651
652impl Default for ApplicationQueryAnsweringEvaluator {
653    fn default() -> Self {
654        Self::new()
655    }
656}
657
658impl Clone for ApplicationQueryAnsweringEvaluator {
659    fn clone(&self) -> Self {
660        Self {
661            qa_pairs: self.qa_pairs.clone(),
662            query_types: self.query_types.clone(),
663            metrics: self.metrics.clone(),
664        }
665    }
666}