alimentar/quality/
decontaminate.rs1use std::collections::HashSet;
20
21#[derive(Debug, Clone)]
23pub struct ContaminationResult {
24 pub sample_index: usize,
26 pub max_overlap: f64,
28 pub matched_reference: usize,
30 pub contaminated: bool,
32}
33
34#[derive(Debug, Clone)]
36pub struct DecontaminationReport {
37 pub ngram_size: usize,
39 pub threshold: f64,
41 pub total_samples: usize,
43 pub contaminated_count: usize,
45 pub contamination_rate: f64,
47 pub flagged: Vec<ContaminationResult>,
49}
50
51fn extract_ngrams(text: &str, n: usize) -> HashSet<Vec<char>> {
53 let chars: Vec<char> = text
54 .chars()
55 .filter(|c| !c.is_whitespace())
56 .flat_map(|c| c.to_lowercase())
57 .collect();
58
59 if chars.len() < n {
60 return HashSet::new();
61 }
62
63 chars.windows(n).map(|w| w.to_vec()).collect()
64}
65
66pub fn ngram_overlap(candidate: &str, reference: &str, n: usize) -> f64 {
71 let cand_ngrams = extract_ngrams(candidate, n);
72 if cand_ngrams.is_empty() {
73 return 0.0;
74 }
75
76 let ref_ngrams = extract_ngrams(reference, n);
77 let intersection = cand_ngrams.intersection(&ref_ngrams).count();
78
79 intersection as f64 / cand_ngrams.len() as f64
80}
81
82pub fn check_contamination(
95 training_samples: &[&str],
96 reference_samples: &[&str],
97 ngram_size: usize,
98 threshold: f64,
99) -> DecontaminationReport {
100 let ref_ngram_sets: Vec<HashSet<Vec<char>>> = reference_samples
102 .iter()
103 .map(|s| extract_ngrams(s, ngram_size))
104 .collect();
105
106 let mut flagged = Vec::new();
107
108 for (i, sample) in training_samples.iter().enumerate() {
109 let cand_ngrams = extract_ngrams(sample, ngram_size);
110 if cand_ngrams.is_empty() {
111 continue;
112 }
113
114 let mut max_overlap = 0.0_f64;
115 let mut matched_ref = 0;
116
117 for (j, ref_set) in ref_ngram_sets.iter().enumerate() {
118 let intersection = cand_ngrams.intersection(ref_set).count();
119 let overlap = intersection as f64 / cand_ngrams.len() as f64;
120
121 if overlap > max_overlap {
122 max_overlap = overlap;
123 matched_ref = j;
124 }
125 }
126
127 if max_overlap > threshold {
128 flagged.push(ContaminationResult {
129 sample_index: i,
130 max_overlap,
131 matched_reference: matched_ref,
132 contaminated: true,
133 });
134 }
135 }
136
137 let contaminated_count = flagged.len();
138 let total = training_samples.len();
139 let rate = if total > 0 {
140 contaminated_count as f64 / total as f64
141 } else {
142 0.0
143 };
144
145 DecontaminationReport {
146 ngram_size,
147 threshold,
148 total_samples: total,
149 contaminated_count,
150 contamination_rate: rate,
151 flagged,
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_extract_ngrams() {
161 let ngrams = extract_ngrams("hello world", 3);
162 assert_eq!(ngrams.len(), 8);
164 }
165
166 #[test]
167 fn test_extract_ngrams_short_text() {
168 let ngrams = extract_ngrams("hi", 10);
169 assert!(ngrams.is_empty());
170 }
171
172 #[test]
173 fn test_ngram_overlap_identical() {
174 let overlap = ngram_overlap("def fibonacci(n):", "def fibonacci(n):", 5);
175 assert!((overlap - 1.0).abs() < f64::EPSILON);
176 }
177
178 #[test]
179 fn test_ngram_overlap_no_match() {
180 let overlap = ngram_overlap(
181 "completely different text about cooking",
182 "def fibonacci(n): return n if n < 2",
183 10,
184 );
185 assert!(overlap < 0.1);
186 }
187
188 #[test]
189 fn test_ngram_overlap_partial() {
190 let overlap = ngram_overlap(
191 "def fibonacci(n): return n if n < 2 else fibonacci(n-1)",
192 "def fibonacci(n): return fib(n-1) + fib(n-2)",
193 5,
194 );
195 assert!(overlap > 0.0);
197 assert!(overlap < 1.0);
198 }
199
200 #[test]
201 fn test_check_contamination_clean() {
202 let training = vec![
203 "def sort_list(lst): return sorted(lst)",
204 "def reverse_string(s): return s[::-1]",
205 ];
206 let reference =
207 vec!["def fibonacci(n): return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)"];
208
209 let report = check_contamination(&training, &reference, 10, 0.5);
210 assert_eq!(report.contaminated_count, 0);
211 assert!(report.contamination_rate < 0.01);
212 }
213
214 #[test]
215 fn test_check_contamination_flagged() {
216 let reference_text =
217 "def fibonacci(n): return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)";
218 let training = vec![
219 "def sort_list(lst): return sorted(lst)",
220 reference_text, ];
222 let reference = vec![reference_text];
223
224 let report = check_contamination(&training, &reference, 10, 0.5);
225 assert_eq!(report.contaminated_count, 1);
226 assert_eq!(report.flagged[0].sample_index, 1);
227 assert!((report.flagged[0].max_overlap - 1.0).abs() < f64::EPSILON);
228 }
229
230 #[test]
231 fn test_check_contamination_threshold() {
232 let training =
233 vec!["def fibonacci(n): return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)"];
234 let reference = vec!["def fibonacci(n): return n if n < 2 else fib(n-1) + fib(n-2)"];
235
236 let strict = check_contamination(&training, &reference, 5, 0.3);
238 let lenient = check_contamination(&training, &reference, 10, 0.9);
240
241 assert!(strict.contaminated_count >= lenient.contaminated_count);
242 }
243
244 #[test]
245 fn test_empty_inputs() {
246 let report = check_contamination(&[], &["some reference"], 10, 0.5);
247 assert_eq!(report.total_samples, 0);
248 assert_eq!(report.contaminated_count, 0);
249
250 let report2 = check_contamination(&["some training"], &[], 10, 0.5);
251 assert_eq!(report2.contaminated_count, 0);
252 }
253}