entrenar/eval/generative/
text_gen.rs1use std::collections::HashMap;
7
8pub fn bleu_score(references: &[&str], hypothesis: &str, max_n: usize) -> f64 {
18 if references.is_empty() || hypothesis.is_empty() {
19 return 0.0;
20 }
21
22 let hyp_tokens: Vec<&str> = hypothesis.split_whitespace().collect();
23 if hyp_tokens.is_empty() {
24 return 0.0;
25 }
26
27 let ref_token_lists: Vec<Vec<&str>> =
28 references.iter().map(|r| r.split_whitespace().collect()).collect();
29
30 let mut log_precisions = Vec::new();
32 for n in 1..=max_n {
33 let (clipped, total) = modified_precision(&ref_token_lists, &hyp_tokens, n);
34 if total == 0 {
35 return 0.0;
36 }
37 let precision = clipped as f64 / total as f64;
38 if precision == 0.0 {
39 return 0.0;
40 }
41 log_precisions.push(precision.max(f64::MIN_POSITIVE).ln());
42 }
43
44 let avg_log_precision: f64 =
46 log_precisions.iter().sum::<f64>() / log_precisions.len().max(1) as f64;
47
48 let hyp_len = hyp_tokens.len();
50 let closest_ref_len = ref_token_lists
51 .iter()
52 .map(Vec::len)
53 .min_by_key(|&len| (len as isize - hyp_len as isize).unsigned_abs())
54 .unwrap_or(0);
55
56 let bp = if hyp_len >= closest_ref_len {
57 1.0
58 } else if closest_ref_len == 0 {
59 0.0
60 } else {
61 (1.0 - closest_ref_len as f64 / hyp_len as f64).exp()
62 };
63
64 bp * avg_log_precision.exp()
65}
66
67fn modified_precision(references: &[Vec<&str>], hypothesis: &[&str], n: usize) -> (usize, usize) {
69 let hyp_ngrams = extract_ngrams(hypothesis, n);
70 let total: usize = hyp_ngrams.values().sum();
71
72 let mut clipped = 0usize;
73 for (ngram, &hyp_count) in &hyp_ngrams {
74 let max_ref_count = references
75 .iter()
76 .map(|r| {
77 let ref_ngrams = extract_ngrams(r, n);
78 ref_ngrams.get(ngram).copied().unwrap_or(0)
79 })
80 .max()
81 .unwrap_or(0);
82 clipped += hyp_count.min(max_ref_count);
83 }
84
85 (clipped, total)
86}
87
88fn extract_ngrams<'a>(tokens: &[&'a str], n: usize) -> HashMap<Vec<&'a str>, usize> {
90 let mut counts = HashMap::new();
91 if tokens.len() >= n {
92 for window in tokens.windows(n) {
93 *counts.entry(window.to_vec()).or_insert(0) += 1;
94 }
95 }
96 counts
97}
98
99pub fn rouge_n(reference: &str, hypothesis: &str, n: usize) -> f64 {
103 let ref_tokens: Vec<&str> = reference.split_whitespace().collect();
104 let hyp_tokens: Vec<&str> = hypothesis.split_whitespace().collect();
105
106 if ref_tokens.len() < n || hyp_tokens.len() < n {
107 return 0.0;
108 }
109
110 let ref_ngrams = extract_ngrams(&ref_tokens, n);
111 let hyp_ngrams = extract_ngrams(&hyp_tokens, n);
112
113 let mut overlap = 0usize;
114 for (ngram, &hyp_count) in &hyp_ngrams {
115 let ref_count = ref_ngrams.get(ngram).copied().unwrap_or(0);
116 overlap += hyp_count.min(ref_count);
117 }
118
119 let ref_total: usize = ref_ngrams.values().sum();
120 let hyp_total: usize = hyp_ngrams.values().sum();
121
122 if ref_total == 0 || hyp_total == 0 {
123 return 0.0;
124 }
125
126 let precision = overlap as f64 / hyp_total as f64;
127 let recall = overlap as f64 / ref_total as f64;
128
129 if precision + recall == 0.0 {
130 return 0.0;
131 }
132
133 2.0 * precision * recall / (precision + recall)
134}
135
136pub fn rouge_l(reference: &str, hypothesis: &str) -> f64 {
140 let ref_tokens: Vec<&str> = reference.split_whitespace().collect();
141 let hyp_tokens: Vec<&str> = hypothesis.split_whitespace().collect();
142
143 if ref_tokens.is_empty() || hyp_tokens.is_empty() {
144 return 0.0;
145 }
146
147 let lcs_len = lcs_length(&ref_tokens, &hyp_tokens);
148
149 let precision = lcs_len as f64 / hyp_tokens.len() as f64;
150 let recall = lcs_len as f64 / ref_tokens.len() as f64;
151
152 if precision + recall == 0.0 {
153 return 0.0;
154 }
155
156 2.0 * precision * recall / (precision + recall)
157}
158
159fn lcs_length(a: &[&str], b: &[&str]) -> usize {
161 let n = a.len();
162 let m = b.len();
163 let mut dp = vec![vec![0usize; m + 1]; n + 1];
164
165 for i in 1..=n {
166 for j in 1..=m {
167 if a[i - 1] == b[j - 1] {
168 dp[i][j] = dp[i - 1][j - 1] + 1;
169 } else {
170 dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
171 }
172 }
173 }
174
175 dp[n][m]
176}
177
178pub fn perplexity(log_probs: &[f64]) -> f64 {
185 if log_probs.is_empty() {
186 return f64::INFINITY;
187 }
188
189 let avg_neg_log_prob = -log_probs.iter().sum::<f64>() / log_probs.len() as f64;
190 avg_neg_log_prob.exp()
191}