Skip to main content

alimentar/quality/
decontaminate.rs

1//! N-gram decontamination for benchmark safety.
2//!
3//! Detects overlap between training data and evaluation benchmarks
4//! using n-gram fingerprinting. Training samples that exceed the
5//! overlap threshold are flagged for removal.
6//!
7//! # Algorithm
8//!
9//! 1. Build n-gram set from each reference benchmark sample
10//! 2. For each training sample, compute n-gram overlap ratio
11//! 3. Flag samples exceeding threshold (default 50%)
12//!
13//! # References
14//!
15//! - Spec ยง12.1: Decontamination Protocol
16//! - AC-016: <1% n-gram overlap between training and eval data
17//! - GH-9: `apr validate --decontaminate`
18
19use std::collections::HashSet;
20
21/// Result of decontamination check on a single sample.
22#[derive(Debug, Clone)]
23pub struct ContaminationResult {
24    /// Index of the training sample
25    pub sample_index: usize,
26    /// Maximum overlap ratio with any reference sample (0.0 to 1.0)
27    pub max_overlap: f64,
28    /// Index of the reference sample with highest overlap
29    pub matched_reference: usize,
30    /// Whether this sample exceeds the contamination threshold
31    pub contaminated: bool,
32}
33
34/// Summary report of decontamination check.
35#[derive(Debug, Clone)]
36pub struct DecontaminationReport {
37    /// N-gram size used
38    pub ngram_size: usize,
39    /// Overlap threshold used
40    pub threshold: f64,
41    /// Total training samples checked
42    pub total_samples: usize,
43    /// Number of contaminated samples
44    pub contaminated_count: usize,
45    /// Contamination rate (0.0 to 1.0)
46    pub contamination_rate: f64,
47    /// Per-sample results (only contaminated samples included)
48    pub flagged: Vec<ContaminationResult>,
49}
50
51/// Extract character-level n-grams from text.
52fn extract_ngrams(text: &str, n: usize) -> HashSet<Vec<char>> {
53    let chars: Vec<char> = text
54        .chars()
55        .filter(|c| !c.is_whitespace())
56        .flat_map(|c| c.to_lowercase())
57        .collect();
58
59    if chars.len() < n {
60        return HashSet::new();
61    }
62
63    chars.windows(n).map(|w| w.to_vec()).collect()
64}
65
66/// Compute n-gram overlap ratio between two texts.
67///
68/// Returns the fraction of n-grams in `candidate` that also
69/// appear in `reference`. Range: 0.0 (no overlap) to 1.0 (complete).
70pub fn ngram_overlap(candidate: &str, reference: &str, n: usize) -> f64 {
71    let cand_ngrams = extract_ngrams(candidate, n);
72    if cand_ngrams.is_empty() {
73        return 0.0;
74    }
75
76    let ref_ngrams = extract_ngrams(reference, n);
77    let intersection = cand_ngrams.intersection(&ref_ngrams).count();
78
79    intersection as f64 / cand_ngrams.len() as f64
80}
81
82/// Check training data against reference benchmarks for contamination.
83///
84/// # Arguments
85///
86/// * `training_samples` - Training data texts
87/// * `reference_samples` - Benchmark/eval texts to check against
88/// * `ngram_size` - Size of n-grams (default: 10)
89/// * `threshold` - Overlap ratio above which a sample is contaminated
90///
91/// # Returns
92///
93/// `DecontaminationReport` with per-sample results and summary stats.
94pub fn check_contamination(
95    training_samples: &[&str],
96    reference_samples: &[&str],
97    ngram_size: usize,
98    threshold: f64,
99) -> DecontaminationReport {
100    // Pre-compute reference n-gram sets
101    let ref_ngram_sets: Vec<HashSet<Vec<char>>> = reference_samples
102        .iter()
103        .map(|s| extract_ngrams(s, ngram_size))
104        .collect();
105
106    let mut flagged = Vec::new();
107
108    for (i, sample) in training_samples.iter().enumerate() {
109        let cand_ngrams = extract_ngrams(sample, ngram_size);
110        if cand_ngrams.is_empty() {
111            continue;
112        }
113
114        let mut max_overlap = 0.0_f64;
115        let mut matched_ref = 0;
116
117        for (j, ref_set) in ref_ngram_sets.iter().enumerate() {
118            let intersection = cand_ngrams.intersection(ref_set).count();
119            let overlap = intersection as f64 / cand_ngrams.len() as f64;
120
121            if overlap > max_overlap {
122                max_overlap = overlap;
123                matched_ref = j;
124            }
125        }
126
127        if max_overlap > threshold {
128            flagged.push(ContaminationResult {
129                sample_index: i,
130                max_overlap,
131                matched_reference: matched_ref,
132                contaminated: true,
133            });
134        }
135    }
136
137    let contaminated_count = flagged.len();
138    let total = training_samples.len();
139    let rate = if total > 0 {
140        contaminated_count as f64 / total as f64
141    } else {
142        0.0
143    };
144
145    DecontaminationReport {
146        ngram_size,
147        threshold,
148        total_samples: total,
149        contaminated_count,
150        contamination_rate: rate,
151        flagged,
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_extract_ngrams() {
161        let ngrams = extract_ngrams("hello world", 3);
162        // "helloworld" -> "hel", "ell", "llo", "low", "owo", "wor", "orl", "rld"
163        assert_eq!(ngrams.len(), 8);
164    }
165
166    #[test]
167    fn test_extract_ngrams_short_text() {
168        let ngrams = extract_ngrams("hi", 10);
169        assert!(ngrams.is_empty());
170    }
171
172    #[test]
173    fn test_ngram_overlap_identical() {
174        let overlap = ngram_overlap("def fibonacci(n):", "def fibonacci(n):", 5);
175        assert!((overlap - 1.0).abs() < f64::EPSILON);
176    }
177
178    #[test]
179    fn test_ngram_overlap_no_match() {
180        let overlap = ngram_overlap(
181            "completely different text about cooking",
182            "def fibonacci(n): return n if n < 2",
183            10,
184        );
185        assert!(overlap < 0.1);
186    }
187
188    #[test]
189    fn test_ngram_overlap_partial() {
190        let overlap = ngram_overlap(
191            "def fibonacci(n): return n if n < 2 else fibonacci(n-1)",
192            "def fibonacci(n): return fib(n-1) + fib(n-2)",
193            5,
194        );
195        // Partial overlap from shared prefix
196        assert!(overlap > 0.0);
197        assert!(overlap < 1.0);
198    }
199
200    #[test]
201    fn test_check_contamination_clean() {
202        let training = vec![
203            "def sort_list(lst): return sorted(lst)",
204            "def reverse_string(s): return s[::-1]",
205        ];
206        let reference =
207            vec!["def fibonacci(n): return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)"];
208
209        let report = check_contamination(&training, &reference, 10, 0.5);
210        assert_eq!(report.contaminated_count, 0);
211        assert!(report.contamination_rate < 0.01);
212    }
213
214    #[test]
215    fn test_check_contamination_flagged() {
216        let reference_text =
217            "def fibonacci(n): return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)";
218        let training = vec![
219            "def sort_list(lst): return sorted(lst)",
220            reference_text, // exact copy
221        ];
222        let reference = vec![reference_text];
223
224        let report = check_contamination(&training, &reference, 10, 0.5);
225        assert_eq!(report.contaminated_count, 1);
226        assert_eq!(report.flagged[0].sample_index, 1);
227        assert!((report.flagged[0].max_overlap - 1.0).abs() < f64::EPSILON);
228    }
229
230    #[test]
231    fn test_check_contamination_threshold() {
232        let training =
233            vec!["def fibonacci(n): return n if n < 2 else fibonacci(n-1) + fibonacci(n-2)"];
234        let reference = vec!["def fibonacci(n): return n if n < 2 else fib(n-1) + fib(n-2)"];
235
236        // Strict threshold should catch partial overlap
237        let strict = check_contamination(&training, &reference, 5, 0.3);
238        // Lenient threshold should pass
239        let lenient = check_contamination(&training, &reference, 10, 0.9);
240
241        assert!(strict.contaminated_count >= lenient.contaminated_count);
242    }
243
244    #[test]
245    fn test_empty_inputs() {
246        let report = check_contamination(&[], &["some reference"], 10, 0.5);
247        assert_eq!(report.total_samples, 0);
248        assert_eq!(report.contaminated_count, 0);
249
250        let report2 = check_contamination(&["some training"], &[], 10, 0.5);
251        assert_eq!(report2.contaminated_count, 0);
252    }
253}