rexis_rag/evaluation/
metrics.rs

1//! # Common Evaluation Metrics
2//!
3//! Shared metrics and utilities used across different evaluation modules.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Common statistical metrics
9pub struct StatisticalMetrics;
10
11impl StatisticalMetrics {
12    /// Calculate mean of a vector
13    pub fn mean(values: &[f32]) -> f32 {
14        if values.is_empty() {
15            0.0
16        } else {
17            values.iter().sum::<f32>() / values.len() as f32
18        }
19    }
20
21    /// Calculate standard deviation
22    pub fn std_dev(values: &[f32]) -> f32 {
23        if values.len() < 2 {
24            return 0.0;
25        }
26
27        let mean = Self::mean(values);
28        let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f32>() / values.len() as f32;
29
30        variance.sqrt()
31    }
32
33    /// Calculate median
34    pub fn median(values: &[f32]) -> f32 {
35        if values.is_empty() {
36            return 0.0;
37        }
38
39        let mut sorted = values.to_vec();
40        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
41
42        let mid = sorted.len() / 2;
43        if sorted.len() % 2 == 0 {
44            (sorted[mid - 1] + sorted[mid]) / 2.0
45        } else {
46            sorted[mid]
47        }
48    }
49
50    /// Calculate percentile
51    pub fn percentile(values: &[f32], p: f32) -> f32 {
52        if values.is_empty() {
53            return 0.0;
54        }
55
56        let mut sorted = values.to_vec();
57        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
58
59        let index = ((p / 100.0) * (sorted.len() - 1) as f32).round() as usize;
60        sorted[index.min(sorted.len() - 1)]
61    }
62
63    /// Calculate correlation coefficient
64    pub fn correlation(x: &[f32], y: &[f32]) -> f32 {
65        if x.len() != y.len() || x.len() < 2 {
66            return 0.0;
67        }
68
69        let mean_x = Self::mean(x);
70        let mean_y = Self::mean(y);
71
72        let numerator: f32 = x
73            .iter()
74            .zip(y.iter())
75            .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y))
76            .sum();
77
78        let sum_sq_x: f32 = x.iter().map(|&xi| (xi - mean_x).powi(2)).sum();
79        let sum_sq_y: f32 = y.iter().map(|&yi| (yi - mean_y).powi(2)).sum();
80
81        let denominator = (sum_sq_x * sum_sq_y).sqrt();
82
83        if denominator == 0.0 {
84            0.0
85        } else {
86            numerator / denominator
87        }
88    }
89}
90
91/// Text similarity metrics
92pub struct TextSimilarityMetrics;
93
94impl TextSimilarityMetrics {
95    /// Calculate Jaccard similarity
96    pub fn jaccard_similarity(text1: &str, text2: &str) -> f32 {
97        let words1: std::collections::HashSet<&str> = text1.split_whitespace().collect();
98        let words2: std::collections::HashSet<&str> = text2.split_whitespace().collect();
99
100        let intersection = words1.intersection(&words2).count();
101        let union = words1.union(&words2).count();
102
103        if union == 0 {
104            0.0
105        } else {
106            intersection as f32 / union as f32
107        }
108    }
109
110    /// Calculate cosine similarity (simplified word-based)
111    pub fn cosine_similarity(text1: &str, text2: &str) -> f32 {
112        let words1: Vec<&str> = text1.split_whitespace().collect();
113        let words2: Vec<&str> = text2.split_whitespace().collect();
114
115        if words1.is_empty() || words2.is_empty() {
116            return 0.0;
117        }
118
119        // Create term frequency maps
120        let mut tf1: HashMap<&str, f32> = HashMap::new();
121        let mut tf2: HashMap<&str, f32> = HashMap::new();
122
123        for word in &words1 {
124            *tf1.entry(word).or_insert(0.0) += 1.0;
125        }
126
127        for word in &words2 {
128            *tf2.entry(word).or_insert(0.0) += 1.0;
129        }
130
131        // Calculate dot product
132        let mut dot_product = 0.0;
133        for (word, freq1) in &tf1 {
134            if let Some(freq2) = tf2.get(word) {
135                dot_product += freq1 * freq2;
136            }
137        }
138
139        // Calculate magnitudes
140        let magnitude1: f32 = tf1.values().map(|f| f * f).sum::<f32>().sqrt();
141        let magnitude2: f32 = tf2.values().map(|f| f * f).sum::<f32>().sqrt();
142
143        if magnitude1 == 0.0 || magnitude2 == 0.0 {
144            0.0
145        } else {
146            dot_product / (magnitude1 * magnitude2)
147        }
148    }
149
150    /// Calculate BLEU score (simplified)
151    pub fn bleu_score(candidate: &str, reference: &str, n: usize) -> f32 {
152        let candidate_words: Vec<&str> = candidate.split_whitespace().collect();
153        let reference_words: Vec<&str> = reference.split_whitespace().collect();
154
155        if candidate_words.len() < n || reference_words.len() < n {
156            return 0.0;
157        }
158
159        // Generate n-grams
160        let candidate_ngrams: Vec<Vec<&str>> = (0..=candidate_words.len() - n)
161            .map(|i| candidate_words[i..i + n].to_vec())
162            .collect();
163
164        let reference_ngrams: Vec<Vec<&str>> = (0..=reference_words.len() - n)
165            .map(|i| reference_words[i..i + n].to_vec())
166            .collect();
167
168        // Count matches
169        let mut matches = 0;
170        for candidate_ngram in &candidate_ngrams {
171            if reference_ngrams.contains(candidate_ngram) {
172                matches += 1;
173            }
174        }
175
176        if candidate_ngrams.is_empty() {
177            0.0
178        } else {
179            matches as f32 / candidate_ngrams.len() as f32
180        }
181    }
182
183    /// Calculate ROUGE-L score (simplified)
184    pub fn rouge_l_score(candidate: &str, reference: &str) -> f32 {
185        let candidate_words: Vec<&str> = candidate.split_whitespace().collect();
186        let reference_words: Vec<&str> = reference.split_whitespace().collect();
187
188        // Find longest common subsequence
189        let lcs_length = Self::lcs_length(&candidate_words, &reference_words);
190
191        if candidate_words.is_empty() && reference_words.is_empty() {
192            1.0
193        } else if candidate_words.is_empty() || reference_words.is_empty() {
194            0.0
195        } else {
196            let recall = lcs_length as f32 / reference_words.len() as f32;
197            let precision = lcs_length as f32 / candidate_words.len() as f32;
198
199            if recall + precision == 0.0 {
200                0.0
201            } else {
202                2.0 * recall * precision / (recall + precision)
203            }
204        }
205    }
206
207    /// Calculate longest common subsequence length
208    fn lcs_length(x: &[&str], y: &[&str]) -> usize {
209        let m = x.len();
210        let n = y.len();
211
212        if m == 0 || n == 0 {
213            return 0;
214        }
215
216        let mut dp = vec![vec![0; n + 1]; m + 1];
217
218        for i in 1..=m {
219            for j in 1..=n {
220                if x[i - 1] == y[j - 1] {
221                    dp[i][j] = dp[i - 1][j - 1] + 1;
222                } else {
223                    dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
224                }
225            }
226        }
227
228        dp[m][n]
229    }
230
231    /// Calculate edit distance (Levenshtein)
232    pub fn edit_distance(s1: &str, s2: &str) -> usize {
233        let chars1: Vec<char> = s1.chars().collect();
234        let chars2: Vec<char> = s2.chars().collect();
235
236        let m = chars1.len();
237        let n = chars2.len();
238
239        if m == 0 {
240            return n;
241        }
242        if n == 0 {
243            return m;
244        }
245
246        let mut dp = vec![vec![0; n + 1]; m + 1];
247
248        // Initialize base cases
249        for i in 0..=m {
250            dp[i][0] = i;
251        }
252        for j in 0..=n {
253            dp[0][j] = j;
254        }
255
256        // Fill the DP table
257        for i in 1..=m {
258            for j in 1..=n {
259                let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
260
261                dp[i][j] = (dp[i - 1][j] + 1) // deletion
262                    .min(dp[i][j - 1] + 1) // insertion
263                    .min(dp[i - 1][j - 1] + cost); // substitution
264            }
265        }
266
267        dp[m][n]
268    }
269
270    /// Calculate normalized edit distance
271    pub fn normalized_edit_distance(s1: &str, s2: &str) -> f32 {
272        let distance = Self::edit_distance(s1, s2);
273        let max_len = s1.len().max(s2.len());
274
275        if max_len == 0 {
276            0.0
277        } else {
278            distance as f32 / max_len as f32
279        }
280    }
281}
282
283/// Information Retrieval metrics
284pub struct IRMetrics;
285
286impl IRMetrics {
287    /// Calculate precision at K
288    pub fn precision_at_k(relevant_docs: &[bool], k: usize) -> f32 {
289        if k == 0 {
290            return 0.0;
291        }
292
293        let top_k = &relevant_docs[..k.min(relevant_docs.len())];
294        let relevant_count = top_k.iter().filter(|&&r| r).count();
295
296        relevant_count as f32 / top_k.len() as f32
297    }
298
299    /// Calculate recall at K
300    pub fn recall_at_k(relevant_docs: &[bool], k: usize, total_relevant: usize) -> f32 {
301        if total_relevant == 0 {
302            return 1.0;
303        }
304
305        let top_k = &relevant_docs[..k.min(relevant_docs.len())];
306        let retrieved_relevant = top_k.iter().filter(|&&r| r).count();
307
308        retrieved_relevant as f32 / total_relevant as f32
309    }
310
311    /// Calculate F1 score at K
312    pub fn f1_at_k(relevant_docs: &[bool], k: usize, total_relevant: usize) -> f32 {
313        let precision = Self::precision_at_k(relevant_docs, k);
314        let recall = Self::recall_at_k(relevant_docs, k, total_relevant);
315
316        if precision + recall == 0.0 {
317            0.0
318        } else {
319            2.0 * precision * recall / (precision + recall)
320        }
321    }
322
323    /// Calculate Average Precision
324    pub fn average_precision(relevant_docs: &[bool]) -> f32 {
325        let total_relevant = relevant_docs.iter().filter(|&&r| r).count();
326
327        if total_relevant == 0 {
328            return 0.0;
329        }
330
331        let mut sum_precision = 0.0;
332        let mut relevant_count = 0;
333
334        for (i, &is_relevant) in relevant_docs.iter().enumerate() {
335            if is_relevant {
336                relevant_count += 1;
337                sum_precision += relevant_count as f32 / (i + 1) as f32;
338            }
339        }
340
341        sum_precision / total_relevant as f32
342    }
343
344    /// Calculate Reciprocal Rank
345    pub fn reciprocal_rank(relevant_docs: &[bool]) -> f32 {
346        for (i, &is_relevant) in relevant_docs.iter().enumerate() {
347            if is_relevant {
348                return 1.0 / (i + 1) as f32;
349            }
350        }
351        0.0
352    }
353
354    /// Calculate NDCG at K
355    pub fn ndcg_at_k(relevance_scores: &[f32], k: usize) -> f32 {
356        if k == 0 || relevance_scores.is_empty() {
357            return 0.0;
358        }
359
360        let k = k.min(relevance_scores.len());
361
362        // Calculate DCG
363        let dcg = Self::dcg(&relevance_scores[..k]);
364
365        // Calculate IDCG (Ideal DCG)
366        let mut ideal_scores = relevance_scores.to_vec();
367        ideal_scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
368        let idcg = Self::dcg(&ideal_scores[..k]);
369
370        if idcg == 0.0 {
371            0.0
372        } else {
373            dcg / idcg
374        }
375    }
376
377    /// Calculate Discounted Cumulative Gain
378    fn dcg(relevance_scores: &[f32]) -> f32 {
379        relevance_scores
380            .iter()
381            .enumerate()
382            .map(|(i, &score)| score / (i as f32 + 2.0).log2())
383            .sum()
384    }
385}
386
387/// Quality assessment metrics
388pub struct QualityMetrics;
389
390impl QualityMetrics {
391    /// Calculate perplexity (simplified)
392    pub fn perplexity(text: &str) -> f32 {
393        let words: Vec<&str> = text.split_whitespace().collect();
394        if words.is_empty() {
395            return f32::INFINITY;
396        }
397
398        // Simplified perplexity calculation based on word frequency
399        let mut word_counts: HashMap<&str, usize> = HashMap::new();
400        for word in &words {
401            *word_counts.entry(word).or_insert(0) += 1;
402        }
403
404        let mut log_prob_sum = 0.0;
405        let vocab_size = word_counts.len() as f32;
406
407        for count in word_counts.values() {
408            let prob = *count as f32 / words.len() as f32;
409            log_prob_sum += prob * prob.ln();
410        }
411
412        // Add smoothing
413        let avg_log_prob = log_prob_sum / vocab_size;
414        (-avg_log_prob).exp()
415    }
416
417    /// Calculate readability score (simplified Flesch)
418    pub fn readability_score(text: &str) -> f32 {
419        let sentences: Vec<&str> = text.split(&['.', '!', '?'][..]).collect();
420        let words: Vec<&str> = text.split_whitespace().collect();
421        let syllables = Self::count_syllables(text);
422
423        if sentences.is_empty() || words.is_empty() {
424            return 0.0;
425        }
426
427        let avg_sentence_length = words.len() as f32 / sentences.len() as f32;
428        let avg_syllables_per_word = syllables as f32 / words.len() as f32;
429
430        // Simplified Flesch Reading Ease
431        let score = 206.835 - 1.015 * avg_sentence_length - 84.6 * avg_syllables_per_word;
432        score.max(0.0).min(100.0)
433    }
434
435    /// Count syllables in text (simplified)
436    fn count_syllables(text: &str) -> usize {
437        let vowels = ['a', 'e', 'i', 'o', 'u', 'y'];
438        let mut syllable_count = 0;
439
440        for word in text.split_whitespace() {
441            let mut word_syllables = 0;
442            let mut previous_was_vowel = false;
443
444            for ch in word.to_lowercase().chars() {
445                if vowels.contains(&ch) {
446                    if !previous_was_vowel {
447                        word_syllables += 1;
448                    }
449                    previous_was_vowel = true;
450                } else {
451                    previous_was_vowel = false;
452                }
453            }
454
455            // Every word has at least one syllable
456            if word_syllables == 0 {
457                word_syllables = 1;
458            }
459
460            syllable_count += word_syllables;
461        }
462
463        syllable_count
464    }
465
466    /// Calculate lexical diversity (Type-Token Ratio)
467    pub fn lexical_diversity(text: &str) -> f32 {
468        let words: Vec<&str> = text.split_whitespace().collect();
469        if words.is_empty() {
470            return 0.0;
471        }
472
473        let unique_words: std::collections::HashSet<&str> = words.iter().cloned().collect();
474        unique_words.len() as f32 / words.len() as f32
475    }
476
477    /// Calculate semantic coherence (simplified)
478    pub fn semantic_coherence(sentences: &[&str]) -> f32 {
479        if sentences.len() < 2 {
480            return 1.0;
481        }
482
483        let mut coherence_scores = Vec::new();
484
485        for i in 0..sentences.len() - 1 {
486            let similarity =
487                TextSimilarityMetrics::jaccard_similarity(sentences[i], sentences[i + 1]);
488            coherence_scores.push(similarity);
489        }
490
491        StatisticalMetrics::mean(&coherence_scores)
492    }
493}
494
495/// Specialized evaluation metrics for different domains
496pub struct DomainMetrics;
497
498impl DomainMetrics {
499    /// Calculate factual accuracy (simplified)
500    pub fn factual_accuracy(generated_text: &str, reference_facts: &[&str]) -> f32 {
501        if reference_facts.is_empty() {
502            return 1.0; // No facts to check against
503        }
504
505        let generated_lower = generated_text.to_lowercase();
506        let mut supported_facts = 0;
507
508        for fact in reference_facts {
509            // Very simplified fact checking - look for key terms
510            let fact_words: Vec<&str> = fact.split_whitespace().collect();
511            let fact_words_len = fact_words.len();
512            let mut word_matches = 0;
513
514            for word in &fact_words {
515                if generated_lower.contains(&word.to_lowercase()) {
516                    word_matches += 1;
517                }
518            }
519
520            // Consider fact supported if most words are present
521            if word_matches as f32 / fact_words_len as f32 > 0.7 {
522                supported_facts += 1;
523            }
524        }
525
526        supported_facts as f32 / reference_facts.len() as f32
527    }
528
529    /// Calculate bias score (simplified)
530    pub fn bias_score(text: &str) -> f32 {
531        // Simplified bias detection based on certain terms
532        let biased_terms = [
533            "always",
534            "never",
535            "all",
536            "none",
537            "everyone",
538            "nobody",
539            "obviously",
540            "clearly",
541            "definitely",
542            "certainly",
543        ];
544
545        let text_lower = text.to_lowercase();
546        let words: Vec<&str> = text_lower.split_whitespace().collect();
547
548        if words.is_empty() {
549            return 0.0;
550        }
551
552        let biased_count = words
553            .iter()
554            .filter(|word| biased_terms.iter().any(|term| word.contains(term)))
555            .count();
556
557        biased_count as f32 / words.len() as f32
558    }
559
560    /// Calculate toxicity score (simplified)
561    pub fn toxicity_score(text: &str) -> f32 {
562        // Very simplified toxicity detection
563        let toxic_patterns = ["hate", "stupid", "idiot", "kill", "die"];
564
565        let text_lower = text.to_lowercase();
566        let words: Vec<&str> = text_lower.split_whitespace().collect();
567
568        if words.is_empty() {
569            return 0.0;
570        }
571
572        let toxic_count = words
573            .iter()
574            .filter(|word| toxic_patterns.iter().any(|pattern| word.contains(pattern)))
575            .count();
576
577        toxic_count as f32 / words.len() as f32
578    }
579}
580
581/// Metric aggregation utilities
582#[derive(Debug, Clone, Serialize, Deserialize)]
583pub struct MetricAggregator {
584    /// Collected metrics
585    pub metrics: HashMap<String, Vec<f32>>,
586}
587
588impl MetricAggregator {
589    /// Create new metric aggregator
590    pub fn new() -> Self {
591        Self {
592            metrics: HashMap::new(),
593        }
594    }
595
596    /// Add a metric value
597    pub fn add_metric(&mut self, name: &str, value: f32) {
598        self.metrics
599            .entry(name.to_string())
600            .or_insert_with(Vec::new)
601            .push(value);
602    }
603
604    /// Get summary statistics for all metrics
605    pub fn get_summary(&self) -> HashMap<String, MetricSummary> {
606        let mut summaries = HashMap::new();
607
608        for (name, values) in &self.metrics {
609            let summary = MetricSummary {
610                count: values.len(),
611                mean: StatisticalMetrics::mean(values),
612                std_dev: StatisticalMetrics::std_dev(values),
613                median: StatisticalMetrics::median(values),
614                min: values.iter().fold(f32::INFINITY, |a, &b| a.min(b)),
615                max: values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)),
616                percentile_25: StatisticalMetrics::percentile(values, 25.0),
617                percentile_75: StatisticalMetrics::percentile(values, 75.0),
618                percentile_95: StatisticalMetrics::percentile(values, 95.0),
619            };
620
621            summaries.insert(name.clone(), summary);
622        }
623
624        summaries
625    }
626
627    /// Calculate confidence interval
628    pub fn confidence_interval(
629        &self,
630        metric_name: &str,
631        confidence_level: f32,
632    ) -> Option<(f32, f32)> {
633        if let Some(values) = self.metrics.get(metric_name) {
634            if values.len() < 2 {
635                return None;
636            }
637
638            let mean = StatisticalMetrics::mean(values);
639            let std_dev = StatisticalMetrics::std_dev(values);
640            let n = values.len() as f32;
641
642            // Simplified confidence interval (assuming normal distribution)
643            let z_score = match confidence_level {
644                0.90 => 1.645,
645                0.95 => 1.96,
646                0.99 => 2.576,
647                _ => 1.96, // Default to 95%
648            };
649
650            let margin_of_error = z_score * std_dev / n.sqrt();
651            Some((mean - margin_of_error, mean + margin_of_error))
652        } else {
653            None
654        }
655    }
656}
657
658impl Default for MetricAggregator {
659    fn default() -> Self {
660        Self::new()
661    }
662}
663
664/// Summary statistics for a metric
665#[derive(Debug, Clone, Serialize, Deserialize)]
666pub struct MetricSummary {
667    /// Number of observations
668    pub count: usize,
669    /// Mean value
670    pub mean: f32,
671    /// Standard deviation
672    pub std_dev: f32,
673    /// Median value
674    pub median: f32,
675    /// Minimum value
676    pub min: f32,
677    /// Maximum value
678    pub max: f32,
679    /// 25th percentile
680    pub percentile_25: f32,
681    /// 75th percentile
682    pub percentile_75: f32,
683    /// 95th percentile
684    pub percentile_95: f32,
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690
691    #[test]
692    fn test_statistical_metrics() {
693        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
694
695        assert_eq!(StatisticalMetrics::mean(&values), 3.0);
696        assert_eq!(StatisticalMetrics::median(&values), 3.0);
697        assert!((StatisticalMetrics::std_dev(&values) - 1.5811).abs() < 0.01);
698    }
699
700    #[test]
701    fn test_text_similarity() {
702        let text1 = "the quick brown fox";
703        let text2 = "the quick brown dog";
704
705        let jaccard = TextSimilarityMetrics::jaccard_similarity(text1, text2);
706        assert!(jaccard > 0.5); // Should be similar
707
708        let cosine = TextSimilarityMetrics::cosine_similarity(text1, text2);
709        assert!(cosine > 0.5); // Should be similar
710    }
711
712    #[test]
713    fn test_ir_metrics() {
714        let relevant_docs = vec![true, false, true, false, true];
715
716        assert_eq!(IRMetrics::precision_at_k(&relevant_docs, 3), 2.0 / 3.0);
717        assert_eq!(IRMetrics::recall_at_k(&relevant_docs, 3, 3), 2.0 / 3.0);
718        assert_eq!(IRMetrics::reciprocal_rank(&relevant_docs), 1.0);
719    }
720
721    #[test]
722    fn test_metric_aggregator() {
723        let mut aggregator = MetricAggregator::new();
724
725        aggregator.add_metric("precision", 0.8);
726        aggregator.add_metric("precision", 0.9);
727        aggregator.add_metric("recall", 0.7);
728
729        let summary = aggregator.get_summary();
730
731        assert_eq!(summary["precision"].count, 2);
732        assert_eq!(summary["precision"].mean, 0.85);
733        assert_eq!(summary["recall"].count, 1);
734    }
735
736    #[test]
737    fn test_quality_metrics() {
738        let text = "This is a simple test sentence.";
739
740        let diversity = QualityMetrics::lexical_diversity(text);
741        assert!(diversity > 0.8); // Most words are unique
742
743        let readability = QualityMetrics::readability_score(text);
744        assert!(readability > 0.0); // Should have some readability score
745    }
746}