Skip to main content

verificar/ml/
augmentation.rs

1//! Synthetic data augmentation for code
2//!
3//! Implements Easy Data Augmentation (EDA) techniques adapted for code,
4//! based on Wei & Zou (2019).
5//!
6//! # Operations
7//!
8//! - **Synonym Replacement (SR)**: Rename variables consistently
9//! - **Random Insertion (RI)**: Insert comments or pass statements
10//! - **Random Swap (RS)**: Reorder independent statements
11//! - **Random Deletion (RD)**: Remove dead code or redundant statements
12
13use rand::prelude::*;
14use std::collections::HashSet;
15
16/// Configuration for code EDA augmentation
17#[derive(Debug, Clone)]
18pub struct CodeEDAConfig {
19    /// Probability of synonym replacement (variable renaming)
20    pub sr_prob: f32,
21    /// Probability of random insertion (comments/pass)
22    pub ri_prob: f32,
23    /// Probability of random swap (statement reorder)
24    pub rs_prob: f32,
25    /// Probability of random deletion
26    pub rd_prob: f32,
27    /// Minimum quality score threshold
28    pub quality_threshold: f32,
29    /// Random seed for reproducibility
30    pub seed: u64,
31}
32
33impl Default for CodeEDAConfig {
34    fn default() -> Self {
35        Self {
36            sr_prob: 0.1,
37            ri_prob: 0.1,
38            rs_prob: 0.1,
39            rd_prob: 0.05,
40            quality_threshold: 0.75,
41            seed: 42,
42        }
43    }
44}
45
46/// Easy Data Augmentation for code
47#[derive(Debug)]
48pub struct CodeEDA {
49    config: CodeEDAConfig,
50    rng: StdRng,
51}
52
53impl CodeEDA {
54    /// Create a new CodeEDA augmenter with default config
55    #[must_use]
56    pub fn new() -> Self {
57        Self::with_config(CodeEDAConfig::default())
58    }
59
60    /// Create a new CodeEDA augmenter with custom config
61    #[must_use]
62    pub fn with_config(config: CodeEDAConfig) -> Self {
63        let rng = StdRng::seed_from_u64(config.seed);
64        Self { config, rng }
65    }
66
67    /// Generate augmented versions of the input code
68    ///
69    /// Returns a vector of augmented code strings that pass quality threshold.
70    pub fn augment(&mut self, code: &str, n_augmentations: usize) -> Vec<String> {
71        let mut results = Vec::with_capacity(n_augmentations);
72
73        for _ in 0..n_augmentations * 2 {
74            // Generate more to account for quality filtering
75            let augmented = self.apply_augmentations(code);
76            let quality = self.quality_score(&augmented, code);
77
78            if quality >= self.config.quality_threshold {
79                results.push(augmented);
80                if results.len() >= n_augmentations {
81                    break;
82                }
83            }
84        }
85
86        results
87    }
88
89    /// Apply all augmentation operations probabilistically
90    fn apply_augmentations(&mut self, code: &str) -> String {
91        let mut result = code.to_string();
92
93        if self.rng.random::<f32>() < self.config.sr_prob {
94            result = self.synonym_replacement(&result);
95        }
96        if self.rng.random::<f32>() < self.config.ri_prob {
97            result = self.random_insertion(&result);
98        }
99        if self.rng.random::<f32>() < self.config.rs_prob {
100            result = self.random_swap(&result);
101        }
102        if self.rng.random::<f32>() < self.config.rd_prob {
103            result = self.random_deletion(&result);
104        }
105
106        result
107    }
108
109    /// Synonym Replacement: Rename variables consistently
110    fn synonym_replacement(&mut self, code: &str) -> String {
111        let variables = self.extract_variables(code);
112        if variables.is_empty() {
113            return code.to_string();
114        }
115
116        // Pick a random variable to rename (sorted for determinism)
117        let mut var_list: Vec<_> = variables.into_iter().collect();
118        var_list.sort();
119        let idx = self.rng.random_range(0..var_list.len());
120        let old_var = &var_list[idx];
121
122        // Generate new name
123        let new_var = self.generate_variable_name(old_var);
124
125        // Replace all occurrences (simple word boundary replacement)
126        self.replace_identifier(code, old_var, &new_var)
127    }
128
129    /// Random Insertion: Add comments or pass statements
130    fn random_insertion(&mut self, code: &str) -> String {
131        let lines: Vec<&str> = code.lines().collect();
132        if lines.is_empty() {
133            return code.to_string();
134        }
135
136        let insert_idx = self.rng.random_range(0..=lines.len());
137        let insert_type = self.rng.random_range(0..3);
138
139        let insertion = match insert_type {
140            0 => "    # augmented".to_string(),
141            1 => "    pass  # placeholder".to_string(),
142            _ => format!("    # line {}", insert_idx + 1),
143        };
144
145        let mut result_lines: Vec<String> = lines.iter().map(|s| (*s).to_string()).collect();
146        result_lines.insert(insert_idx, insertion);
147        result_lines.join("\n")
148    }
149
150    /// Random Swap: Reorder independent statements
151    fn random_swap(&mut self, code: &str) -> String {
152        let lines: Vec<&str> = code.lines().collect();
153        if lines.len() < 2 {
154            return code.to_string();
155        }
156
157        // Find swappable pairs (same indentation, no dependencies)
158        let swappable = self.find_swappable_pairs(&lines);
159        if swappable.is_empty() {
160            return code.to_string();
161        }
162
163        let (i, j) = swappable[self.rng.random_range(0..swappable.len())];
164        let mut result_lines: Vec<String> = lines.iter().map(|s| (*s).to_string()).collect();
165        result_lines.swap(i, j);
166        result_lines.join("\n")
167    }
168
169    /// Random Deletion: Remove redundant statements
170    fn random_deletion(&mut self, code: &str) -> String {
171        let lines: Vec<&str> = code.lines().collect();
172        if lines.len() <= 2 {
173            return code.to_string();
174        }
175
176        // Only delete comments or pass statements
177        let deletable: Vec<usize> = lines
178            .iter()
179            .enumerate()
180            .filter(|(_, line)| {
181                let trimmed = line.trim();
182                trimmed.starts_with('#') || trimmed == "pass"
183            })
184            .map(|(i, _)| i)
185            .collect();
186
187        if deletable.is_empty() {
188            return code.to_string();
189        }
190
191        let del_idx = deletable[self.rng.random_range(0..deletable.len())];
192        let result_lines: Vec<&str> = lines
193            .iter()
194            .enumerate()
195            .filter(|(i, _)| *i != del_idx)
196            .map(|(_, line)| *line)
197            .collect();
198        result_lines.join("\n")
199    }
200
201    /// Extract variable names from Python code (simple heuristic)
202    fn extract_variables(&self, code: &str) -> HashSet<String> {
203        let mut vars = HashSet::new();
204
205        // Match assignment patterns: var_name = ...
206        for line in code.lines() {
207            let trimmed = line.trim();
208            if let Some(eq_pos) = trimmed.find('=') {
209                if eq_pos > 0 && !trimmed[..eq_pos].contains('(') {
210                    let lhs = trimmed[..eq_pos].trim();
211                    // Skip if it's a comparison (==, !=, etc.)
212                    if !lhs.ends_with(['!', '<', '>', '=']) {
213                        // Handle tuple unpacking
214                        for var in lhs.split(',') {
215                            let var = var.trim();
216                            if is_valid_identifier(var) && !is_keyword(var) {
217                                vars.insert(var.to_string());
218                            }
219                        }
220                    }
221                }
222            }
223        }
224
225        vars
226    }
227
228    /// Generate a new variable name based on old one
229    fn generate_variable_name(&mut self, old: &str) -> String {
230        let suffixes = ["_new", "_v2", "_alt", "_mod", "2"];
231        let suffix = suffixes[self.rng.random_range(0..suffixes.len())];
232        format!("{old}{suffix}")
233    }
234
235    /// Replace identifier with word boundary awareness
236    fn replace_identifier(&self, code: &str, old: &str, new: &str) -> String {
237        let mut result = String::with_capacity(code.len() + 32);
238        let chars: Vec<char> = code.chars().collect();
239        let old_chars: Vec<char> = old.chars().collect();
240        let mut i = 0;
241
242        while i < chars.len() {
243            if i + old_chars.len() <= chars.len() {
244                let matches = chars[i..i + old_chars.len()]
245                    .iter()
246                    .zip(old_chars.iter())
247                    .all(|(a, b)| a == b);
248
249                if matches {
250                    // Check word boundaries
251                    let before_ok =
252                        i == 0 || !chars[i - 1].is_alphanumeric() && chars[i - 1] != '_';
253                    let after_ok = i + old_chars.len() >= chars.len()
254                        || !chars[i + old_chars.len()].is_alphanumeric()
255                            && chars[i + old_chars.len()] != '_';
256
257                    if before_ok && after_ok {
258                        result.push_str(new);
259                        i += old_chars.len();
260                        continue;
261                    }
262                }
263            }
264            result.push(chars[i]);
265            i += 1;
266        }
267
268        result
269    }
270
271    /// Find pairs of lines that can be safely swapped
272    fn find_swappable_pairs(&self, lines: &[&str]) -> Vec<(usize, usize)> {
273        let mut pairs = Vec::new();
274
275        for i in 0..lines.len().saturating_sub(1) {
276            let indent_i = lines[i].len() - lines[i].trim_start().len();
277            let indent_j = lines[i + 1].len() - lines[i + 1].trim_start().len();
278
279            // Same indentation, both are simple statements
280            if indent_i == indent_j {
281                let line_i = lines[i].trim();
282                let line_j = lines[i + 1].trim();
283
284                // Skip control flow, function defs, class defs
285                let is_simple_i = !line_i.starts_with("if ")
286                    && !line_i.starts_with("for ")
287                    && !line_i.starts_with("while ")
288                    && !line_i.starts_with("def ")
289                    && !line_i.starts_with("class ")
290                    && !line_i.starts_with("return")
291                    && !line_i.is_empty();
292
293                let is_simple_j = !line_j.starts_with("if ")
294                    && !line_j.starts_with("for ")
295                    && !line_j.starts_with("while ")
296                    && !line_j.starts_with("def ")
297                    && !line_j.starts_with("class ")
298                    && !line_j.starts_with("return")
299                    && !line_j.is_empty();
300
301                if is_simple_i && is_simple_j {
302                    pairs.push((i, i + 1));
303                }
304            }
305        }
306
307        pairs
308    }
309
310    /// Calculate quality score for augmented code
311    ///
312    /// Returns score in [0.0, 1.0] based on:
313    /// - Syntactic validity (must parse)
314    /// - Token overlap with original
315    #[must_use]
316    pub fn quality_score(&self, augmented: &str, original: &str) -> f32 {
317        // Basic syntactic check: balanced parentheses, quotes
318        if !self.basic_syntax_check(augmented) {
319            return 0.0;
320        }
321
322        // Token overlap score
323        let orig_tokens: HashSet<_> = tokenize(original).collect();
324        let aug_tokens: HashSet<_> = tokenize(augmented).collect();
325
326        if orig_tokens.is_empty() {
327            return 1.0;
328        }
329
330        let overlap = orig_tokens.intersection(&aug_tokens).count();
331        overlap as f32 / orig_tokens.len() as f32
332    }
333
334    /// Calculate diversity score for a batch of augmented code
335    ///
336    /// Returns score in [0.0, 1.0], higher means more diverse
337    #[must_use]
338    pub fn diversity_score(&self, batch: &[String]) -> f32 {
339        if batch.is_empty() {
340            return 0.0;
341        }
342
343        let unique: HashSet<_> = batch.iter().collect();
344        unique.len() as f32 / batch.len() as f32
345    }
346
347    /// Basic syntax validation
348    fn basic_syntax_check(&self, code: &str) -> bool {
349        let mut paren_depth = 0i32;
350        let mut bracket_depth = 0i32;
351        let mut brace_depth = 0i32;
352        let mut in_string = false;
353        let mut string_char = ' ';
354
355        for c in code.chars() {
356            if in_string {
357                if c == string_char {
358                    in_string = false;
359                }
360                continue;
361            }
362
363            match c {
364                '"' | '\'' => {
365                    in_string = true;
366                    string_char = c;
367                }
368                '(' => paren_depth += 1,
369                ')' => paren_depth -= 1,
370                '[' => bracket_depth += 1,
371                ']' => bracket_depth -= 1,
372                '{' => brace_depth += 1,
373                '}' => brace_depth -= 1,
374                _ => {}
375            }
376
377            if paren_depth < 0 || bracket_depth < 0 || brace_depth < 0 {
378                return false;
379            }
380        }
381
382        paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 && !in_string
383    }
384}
385
386impl Default for CodeEDA {
387    fn default() -> Self {
388        Self::new()
389    }
390}
391
392/// Simple tokenizer for code
393fn tokenize(code: &str) -> impl Iterator<Item = &str> {
394    code.split(|c: char| !c.is_alphanumeric() && c != '_')
395        .filter(|s| !s.is_empty())
396}
397
398/// Check if string is a valid Python identifier
399fn is_valid_identifier(s: &str) -> bool {
400    if s.is_empty() {
401        return false;
402    }
403    let Some(first) = s.chars().next() else {
404        return false;
405    };
406    (first.is_alphabetic() || first == '_') && s.chars().all(|c| c.is_alphanumeric() || c == '_')
407}
408
409/// Check if string is a Python keyword
410fn is_keyword(s: &str) -> bool {
411    matches!(
412        s,
413        "False"
414            | "None"
415            | "True"
416            | "and"
417            | "as"
418            | "assert"
419            | "async"
420            | "await"
421            | "break"
422            | "class"
423            | "continue"
424            | "def"
425            | "del"
426            | "elif"
427            | "else"
428            | "except"
429            | "finally"
430            | "for"
431            | "from"
432            | "global"
433            | "if"
434            | "import"
435            | "in"
436            | "is"
437            | "lambda"
438            | "nonlocal"
439            | "not"
440            | "or"
441            | "pass"
442            | "raise"
443            | "return"
444            | "try"
445            | "while"
446            | "with"
447            | "yield"
448    )
449}
450
451/// Batch augmentation result
452#[derive(Debug, Clone)]
453pub struct AugmentationResult {
454    /// Original code
455    pub original: String,
456    /// Augmented variants
457    pub variants: Vec<String>,
458    /// Quality scores for each variant
459    pub quality_scores: Vec<f32>,
460    /// Overall diversity score
461    pub diversity_score: f32,
462}
463
464/// Batch augmenter for processing multiple code samples
465#[derive(Debug)]
466pub struct BatchAugmenter {
467    eda: CodeEDA,
468    /// Augmentation factor (e.g., 5.0 = 5x more samples)
469    pub factor: f32,
470}
471
472impl BatchAugmenter {
473    /// Create a new batch augmenter
474    #[must_use]
475    pub fn new(config: CodeEDAConfig, factor: f32) -> Self {
476        Self {
477            eda: CodeEDA::with_config(config),
478            factor,
479        }
480    }
481
482    /// Augment a batch of code samples
483    pub fn augment_batch(&mut self, samples: &[String]) -> Vec<AugmentationResult> {
484        #[allow(clippy::cast_sign_loss)]
485        let n_aug = (self.factor.max(0.0) as usize).max(1);
486
487        samples
488            .iter()
489            .map(|code| {
490                let variants = self.eda.augment(code, n_aug);
491                let quality_scores: Vec<f32> = variants
492                    .iter()
493                    .map(|v| self.eda.quality_score(v, code))
494                    .collect();
495                let diversity_score = self.eda.diversity_score(&variants);
496
497                AugmentationResult {
498                    original: code.clone(),
499                    variants,
500                    quality_scores,
501                    diversity_score,
502                }
503            })
504            .collect()
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_code_eda_basic() {
514        let mut eda = CodeEDA::new();
515        let code = "x = 1\ny = 2\nz = x + y";
516        let augmented = eda.augment(code, 3);
517
518        assert!(!augmented.is_empty());
519        for aug in &augmented {
520            let quality = eda.quality_score(aug, code);
521            assert!(quality >= 0.75);
522        }
523    }
524
525    #[test]
526    fn test_synonym_replacement() {
527        let mut eda = CodeEDA::with_config(CodeEDAConfig {
528            sr_prob: 1.0,
529            ri_prob: 0.0,
530            rs_prob: 0.0,
531            rd_prob: 0.0,
532            ..Default::default()
533        });
534
535        let code = "foo = 1\nbar = foo + 2";
536        let augmented = eda.augment(code, 1);
537
538        assert!(!augmented.is_empty());
539        // Should have renamed a variable
540        let aug = &augmented[0];
541        assert!(aug.contains("_new") || aug.contains("_v2") || aug.contains("2"));
542    }
543
544    #[test]
545    fn test_random_insertion() {
546        let mut eda = CodeEDA::with_config(CodeEDAConfig {
547            sr_prob: 0.0,
548            ri_prob: 1.0,
549            rs_prob: 0.0,
550            rd_prob: 0.0,
551            ..Default::default()
552        });
553
554        let code = "x = 1";
555        let augmented = eda.augment(code, 1);
556
557        assert!(!augmented.is_empty());
558        // Should have added a line
559        assert!(augmented[0].lines().count() > code.lines().count());
560    }
561
562    #[test]
563    fn test_quality_score() {
564        let eda = CodeEDA::new();
565
566        // High quality: similar code
567        let score = eda.quality_score("x = 1\ny = 2", "x = 1\ny = 2");
568        assert!((score - 1.0).abs() < f32::EPSILON);
569
570        // Medium quality: some overlap
571        let score = eda.quality_score("x_new = 1\ny = 2", "x = 1\ny = 2");
572        assert!(score > 0.5);
573
574        // Zero quality: unbalanced parens
575        let score = eda.quality_score("x = (1", "x = 1");
576        assert!(score < f32::EPSILON);
577    }
578
579    #[test]
580    fn test_diversity_score() {
581        let eda = CodeEDA::new();
582
583        // All unique
584        let batch = vec!["a".to_string(), "b".to_string(), "c".to_string()];
585        assert!((eda.diversity_score(&batch) - 1.0).abs() < f32::EPSILON);
586
587        // All same
588        let batch = vec!["a".to_string(), "a".to_string(), "a".to_string()];
589        assert!((eda.diversity_score(&batch) - 1.0 / 3.0).abs() < f32::EPSILON);
590
591        // Empty
592        let batch: Vec<String> = vec![];
593        assert!(eda.diversity_score(&batch) < f32::EPSILON);
594    }
595
596    #[test]
597    fn test_batch_augmenter() {
598        let config = CodeEDAConfig::default();
599        let mut augmenter = BatchAugmenter::new(config, 2.0);
600
601        let samples = vec!["x = 1".to_string(), "y = 2".to_string()];
602        let results = augmenter.augment_batch(&samples);
603
604        assert_eq!(results.len(), 2);
605        for result in &results {
606            assert!(result.diversity_score >= 0.0);
607        }
608    }
609
610    #[test]
611    fn test_extract_variables() {
612        let eda = CodeEDA::new();
613
614        let vars = eda.extract_variables("x = 1\ny = 2\nif x == y: pass");
615        assert!(vars.contains("x"));
616        assert!(vars.contains("y"));
617        assert!(!vars.contains("if")); // Keyword
618    }
619
620    #[test]
621    fn test_basic_syntax_check() {
622        let eda = CodeEDA::new();
623
624        assert!(eda.basic_syntax_check("x = (1 + 2)"));
625        assert!(eda.basic_syntax_check("x = [1, 2, 3]"));
626        assert!(eda.basic_syntax_check("x = {'a': 1}"));
627        assert!(eda.basic_syntax_check("x = \"hello\""));
628
629        assert!(!eda.basic_syntax_check("x = (1 + 2"));
630        assert!(!eda.basic_syntax_check("x = [1, 2"));
631        assert!(!eda.basic_syntax_check("x = \"hello"));
632    }
633
634    #[test]
635    fn test_is_valid_identifier() {
636        assert!(is_valid_identifier("foo"));
637        assert!(is_valid_identifier("_bar"));
638        assert!(is_valid_identifier("baz123"));
639        assert!(is_valid_identifier("__init__"));
640
641        assert!(!is_valid_identifier("123abc"));
642        assert!(!is_valid_identifier(""));
643        assert!(!is_valid_identifier("foo-bar"));
644    }
645
646    #[test]
647    fn test_is_keyword() {
648        assert!(is_keyword("if"));
649        assert!(is_keyword("for"));
650        assert!(is_keyword("return"));
651        assert!(is_keyword("True"));
652
653        assert!(!is_keyword("foo"));
654        assert!(!is_keyword("bar"));
655    }
656
657    // ========== EDGE CASE TESTS (Extreme TDD) ==========
658
659    #[test]
660    fn test_augment_empty_code() {
661        let mut eda = CodeEDA::new();
662        let augmented = eda.augment("", 3);
663        // Empty code should still produce valid augmentations
664        for aug in &augmented {
665            assert!(eda.basic_syntax_check(aug));
666        }
667    }
668
669    #[test]
670    fn test_augment_single_char() {
671        let mut eda = CodeEDA::new();
672        let augmented = eda.augment("x", 3);
673        assert!(augmented.is_empty() || augmented.iter().all(|a| eda.basic_syntax_check(a)));
674    }
675
676    #[test]
677    fn test_augment_whitespace_only() {
678        let mut eda = CodeEDA::new();
679        let augmented = eda.augment("   \n\t\n   ", 3);
680        for aug in &augmented {
681            assert!(eda.basic_syntax_check(aug));
682        }
683    }
684
685    #[test]
686    fn test_extract_variables_tuple_unpacking() {
687        let eda = CodeEDA::new();
688        let vars = eda.extract_variables("a, b, c = 1, 2, 3");
689        assert!(vars.contains("a"));
690        assert!(vars.contains("b"));
691        assert!(vars.contains("c"));
692    }
693
694    #[test]
695    fn test_extract_variables_no_assignments() {
696        let eda = CodeEDA::new();
697        let vars = eda.extract_variables("print('hello')\nfoo()");
698        assert!(vars.is_empty());
699    }
700
701    #[test]
702    fn test_extract_variables_with_comparison() {
703        let eda = CodeEDA::new();
704        let vars = eda.extract_variables("if x == y:\n    pass");
705        // Should not extract from comparisons
706        assert!(!vars.contains("x"));
707    }
708
709    #[test]
710    fn test_synonym_replacement_no_variables() {
711        let mut eda = CodeEDA::with_config(CodeEDAConfig {
712            sr_prob: 1.0,
713            ri_prob: 0.0,
714            rs_prob: 0.0,
715            rd_prob: 0.0,
716            ..Default::default()
717        });
718
719        let code = "print('hello')";
720        let augmented = eda.augment(code, 1);
721        // Should not crash, just return original or valid augmentation
722        assert!(augmented.is_empty() || eda.basic_syntax_check(&augmented[0]));
723    }
724
725    #[test]
726    fn test_random_swap_single_line() {
727        let mut eda = CodeEDA::with_config(CodeEDAConfig {
728            sr_prob: 0.0,
729            ri_prob: 0.0,
730            rs_prob: 1.0,
731            rd_prob: 0.0,
732            ..Default::default()
733        });
734
735        let code = "x = 1";
736        let augmented = eda.augment(code, 1);
737        // Single line can't be swapped
738        assert!(augmented.is_empty() || augmented[0] == code);
739    }
740
741    #[test]
742    fn test_random_deletion_minimal_code() {
743        let mut eda = CodeEDA::with_config(CodeEDAConfig {
744            sr_prob: 0.0,
745            ri_prob: 0.0,
746            rs_prob: 0.0,
747            rd_prob: 1.0,
748            ..Default::default()
749        });
750
751        let code = "x = 1\ny = 2"; // Only 2 lines
752        let augmented = eda.augment(code, 1);
753        // Should not delete from minimal code
754        assert!(augmented.is_empty() || augmented[0].lines().count() >= 2);
755    }
756
757    #[test]
758    fn test_random_deletion_removes_comment() {
759        let mut eda = CodeEDA::with_config(CodeEDAConfig {
760            sr_prob: 0.0,
761            ri_prob: 0.0,
762            rs_prob: 0.0,
763            rd_prob: 1.0,
764            quality_threshold: 0.0, // Accept any quality
765            ..Default::default()
766        });
767
768        let code = "x = 1\n# comment\ny = 2\nz = 3";
769        let augmented = eda.augment(code, 1);
770        if !augmented.is_empty() {
771            // Should have removed the comment
772            assert!(!augmented[0].contains("# comment") || augmented[0].lines().count() < 4);
773        }
774    }
775
776    #[test]
777    fn test_quality_score_empty_original() {
778        let eda = CodeEDA::new();
779        let score = eda.quality_score("x = 1", "");
780        // Empty original means no tokens to compare
781        assert!((score - 1.0).abs() < f32::EPSILON);
782    }
783
784    #[test]
785    fn test_quality_score_nested_brackets() {
786        let eda = CodeEDA::new();
787        let code = "x = [[1, 2], [3, 4]]";
788        let score = eda.quality_score(code, code);
789        assert!((score - 1.0).abs() < f32::EPSILON);
790        assert!(eda.basic_syntax_check(code));
791    }
792
793    #[test]
794    fn test_quality_score_unbalanced_nested() {
795        let eda = CodeEDA::new();
796        let score = eda.quality_score("x = [[1, 2]", "x = 1");
797        assert!(score < f32::EPSILON);
798    }
799
800    #[test]
801    fn test_replace_identifier_word_boundary() {
802        let eda = CodeEDA::new();
803
804        // Should not replace 'x' in 'max'
805        let result = eda.replace_identifier("max = x + max_value", "x", "y");
806        assert!(result.contains("max"));
807        assert!(result.contains("y"));
808        assert!(result.contains("max_value")); // Should not become may_value
809    }
810
811    #[test]
812    fn test_replace_identifier_at_start() {
813        let eda = CodeEDA::new();
814        let result = eda.replace_identifier("foo = 1", "foo", "bar");
815        assert_eq!(result, "bar = 1");
816    }
817
818    #[test]
819    fn test_replace_identifier_at_end() {
820        let eda = CodeEDA::new();
821        let result = eda.replace_identifier("x = foo", "foo", "bar");
822        assert_eq!(result, "x = bar");
823    }
824
825    #[test]
826    fn test_find_swappable_pairs_control_flow() {
827        let eda = CodeEDA::new();
828        let lines: Vec<&str> = vec!["if x:", "    y = 1"];
829        let pairs = eda.find_swappable_pairs(&lines);
830        // Control flow should not be swappable
831        assert!(pairs.is_empty());
832    }
833
834    #[test]
835    fn test_find_swappable_pairs_different_indent() {
836        let eda = CodeEDA::new();
837        let lines: Vec<&str> = vec!["x = 1", "    y = 2"];
838        let pairs = eda.find_swappable_pairs(&lines);
839        // Different indentation should not be swappable
840        assert!(pairs.is_empty());
841    }
842
843    #[test]
844    fn test_find_swappable_pairs_valid() {
845        let eda = CodeEDA::new();
846        let lines: Vec<&str> = vec!["x = 1", "y = 2", "z = 3"];
847        let pairs = eda.find_swappable_pairs(&lines);
848        // Adjacent pairs with same indent should be swappable
849        assert!(!pairs.is_empty());
850    }
851
852    #[test]
853    fn test_basic_syntax_check_escaped_quotes() {
854        let eda = CodeEDA::new();
855        // Note: Our simple parser doesn't handle escapes, but shouldn't crash
856        let result = eda.basic_syntax_check(r#"x = "hello""#);
857        assert!(result);
858    }
859
860    #[test]
861    fn test_basic_syntax_check_mixed_brackets() {
862        let eda = CodeEDA::new();
863        assert!(eda.basic_syntax_check("x = ([1, 2], {3: 4})"));
864        assert!(!eda.basic_syntax_check("x = ([1, 2}, {3: 4])"));
865    }
866
867    #[test]
868    fn test_config_probabilities_boundary() {
869        let config = CodeEDAConfig {
870            sr_prob: 0.0,
871            ri_prob: 0.0,
872            rs_prob: 0.0,
873            rd_prob: 0.0,
874            quality_threshold: 0.0,
875            seed: 42,
876        };
877        let mut eda = CodeEDA::with_config(config);
878        let code = "x = 1";
879        let augmented = eda.augment(code, 5);
880        // With all probs at 0, should return original code
881        for aug in &augmented {
882            assert_eq!(aug, code);
883        }
884    }
885
886    #[test]
887    fn test_config_all_ops_enabled() {
888        let config = CodeEDAConfig {
889            sr_prob: 1.0,
890            ri_prob: 1.0,
891            rs_prob: 1.0,
892            rd_prob: 1.0,
893            quality_threshold: 0.0, // Accept any
894            seed: 42,
895        };
896        let mut eda = CodeEDA::with_config(config);
897        let code = "x = 1\n# comment\ny = 2\nz = 3";
898        let augmented = eda.augment(code, 3);
899        // Should produce varied augmentations
900        assert!(!augmented.is_empty());
901    }
902
903    #[test]
904    fn test_batch_augmenter_empty_samples() {
905        let config = CodeEDAConfig::default();
906        let mut augmenter = BatchAugmenter::new(config, 2.0);
907        let samples: Vec<String> = vec![];
908        let results = augmenter.augment_batch(&samples);
909        assert!(results.is_empty());
910    }
911
912    #[test]
913    fn test_batch_augmenter_factor_zero() {
914        let config = CodeEDAConfig::default();
915        let mut augmenter = BatchAugmenter::new(config, 0.0);
916        let samples = vec!["x = 1".to_string()];
917        let results = augmenter.augment_batch(&samples);
918        // Factor 0.0 should be treated as at least 1
919        assert_eq!(results.len(), 1);
920    }
921
922    #[test]
923    fn test_augmentation_result_fields() {
924        let result = AugmentationResult {
925            original: "x = 1".to_string(),
926            variants: vec!["x_new = 1".to_string()],
927            quality_scores: vec![0.8],
928            diversity_score: 1.0,
929        };
930        assert_eq!(result.original, "x = 1");
931        assert_eq!(result.variants.len(), 1);
932        assert_eq!(result.quality_scores.len(), 1);
933        assert!((result.diversity_score - 1.0).abs() < f32::EPSILON);
934    }
935
936    #[test]
937    fn test_tokenize_special_chars() {
938        let tokens: Vec<_> = tokenize("x = 1 + y * 2").collect();
939        assert!(tokens.contains(&"x"));
940        assert!(tokens.contains(&"1"));
941        assert!(tokens.contains(&"y"));
942        assert!(tokens.contains(&"2"));
943        // Should not contain operators
944        assert!(!tokens.contains(&"+"));
945        assert!(!tokens.contains(&"*"));
946    }
947
948    #[test]
949    fn test_is_valid_identifier_unicode() {
950        // Rust's is_alphabetic() includes unicode letters
951        assert!(is_valid_identifier("über")); // Valid: starts with alphabetic
952        assert!(is_valid_identifier("x123"));
953        assert!(!is_valid_identifier("123über")); // Invalid: starts with digit
954    }
955
956    #[test]
957    fn test_code_eda_deterministic_with_seed() {
958        let config = CodeEDAConfig {
959            seed: 12345,
960            ..Default::default()
961        };
962        let mut eda1 = CodeEDA::with_config(config.clone());
963        let mut eda2 = CodeEDA::with_config(config);
964
965        let code = "x = 1\ny = 2\nz = 3";
966        let aug1 = eda1.augment(code, 3);
967        let aug2 = eda2.augment(code, 3);
968
969        assert_eq!(aug1, aug2, "Same seed should produce same augmentations");
970    }
971
972    #[test]
973    fn test_code_eda_different_seeds() {
974        let mut eda1 = CodeEDA::with_config(CodeEDAConfig {
975            seed: 1,
976            sr_prob: 1.0,
977            ..Default::default()
978        });
979        let mut eda2 = CodeEDA::with_config(CodeEDAConfig {
980            seed: 2,
981            sr_prob: 1.0,
982            ..Default::default()
983        });
984
985        let code = "foo = 1\nbar = foo + 2";
986        let aug1 = eda1.augment(code, 1);
987        let aug2 = eda2.augment(code, 1);
988
989        // Different seeds may produce different results
990        // (Not guaranteed, but likely with SR renaming)
991        assert!(!aug1.is_empty());
992        assert!(!aug2.is_empty());
993    }
994
995    // ========== KEYWORD EXHAUSTIVE TEST ==========
996
997    #[test]
998    fn test_all_python_keywords() {
999        let keywords = [
1000            "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
1001            "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global",
1002            "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise",
1003            "return", "try", "while", "with", "yield",
1004        ];
1005        for kw in keywords {
1006            assert!(is_keyword(kw), "{kw} should be a keyword");
1007        }
1008    }
1009
1010    #[test]
1011    fn test_non_keywords() {
1012        let non_keywords = ["foo", "bar", "baz", "x", "y", "z", "print", "len", "str"];
1013        for nk in non_keywords {
1014            assert!(!is_keyword(nk), "{nk} should not be a keyword");
1015        }
1016    }
1017}
1018
1019/// Property-based tests for CodeEDA using proptest
1020#[cfg(test)]
1021mod proptests {
1022    use super::*;
1023    use proptest::prelude::*;
1024
1025    /// Generate valid Python-like code snippets
1026    fn python_code_strategy() -> impl Strategy<Value = String> {
1027        prop::collection::vec(
1028            prop_oneof![
1029                // Simple assignments
1030                "[a-z][a-z0-9_]{0,10} = [0-9]{1,5}".prop_map(|s| s),
1031                // Comments
1032                "# [a-zA-Z0-9 ]{0,20}".prop_map(|s| s),
1033                // Function calls
1034                "[a-z]+\\([0-9, ]*\\)".prop_map(|s| s),
1035            ],
1036            1..10,
1037        )
1038        .prop_map(|lines| lines.join("\n"))
1039    }
1040
1041    proptest! {
1042        /// Augmented code always passes basic syntax check
1043        #[test]
1044        fn prop_augmented_code_is_syntactically_valid(
1045            seed in 0u64..10000,
1046            n_aug in 1usize..5,
1047        ) {
1048            let config = CodeEDAConfig {
1049                seed,
1050                quality_threshold: 0.5,
1051                ..Default::default()
1052            };
1053            let mut eda = CodeEDA::with_config(config);
1054            let code = "x = 1\ny = 2\nz = 3";
1055            let augmented = eda.augment(code, n_aug);
1056
1057            for aug in &augmented {
1058                prop_assert!(eda.basic_syntax_check(aug));
1059            }
1060        }
1061
1062        /// Quality score is always in [0.0, 1.0]
1063        #[test]
1064        fn prop_quality_score_bounded(
1065            code in "[a-z]+ = [0-9]+",
1066            aug in "[a-z]+ = [0-9]+",
1067        ) {
1068            let eda = CodeEDA::new();
1069            let score = eda.quality_score(&aug, &code);
1070            prop_assert!(score >= 0.0);
1071            prop_assert!(score <= 1.0);
1072        }
1073
1074        /// Diversity score is always in [0.0, 1.0]
1075        #[test]
1076        fn prop_diversity_score_bounded(
1077            batch in prop::collection::vec("[a-z]+", 1..10),
1078        ) {
1079            let eda = CodeEDA::new();
1080            let score = eda.diversity_score(&batch);
1081            prop_assert!(score >= 0.0);
1082            prop_assert!(score <= 1.0);
1083        }
1084
1085        /// Deterministic: same seed + same input = same output
1086        #[test]
1087        fn prop_deterministic_with_seed(
1088            seed in 0u64..10000,
1089            code in "[a-z]+ = [0-9]+\n[a-z]+ = [0-9]+",
1090        ) {
1091            let config = CodeEDAConfig {
1092                seed,
1093                ..Default::default()
1094            };
1095            let mut eda1 = CodeEDA::with_config(config.clone());
1096            let mut eda2 = CodeEDA::with_config(config);
1097
1098            let aug1 = eda1.augment(&code, 3);
1099            let aug2 = eda2.augment(&code, 3);
1100
1101            prop_assert_eq!(aug1, aug2);
1102        }
1103
1104        /// Extracted variables are valid identifiers
1105        #[test]
1106        fn prop_extracted_vars_are_valid_identifiers(
1107            var in "[a-z][a-z0-9_]{0,10}",
1108        ) {
1109            let eda = CodeEDA::new();
1110            let code = format!("{var} = 42");
1111            let vars = eda.extract_variables(&code);
1112
1113            for v in vars {
1114                prop_assert!(is_valid_identifier(&v));
1115                prop_assert!(!is_keyword(&v));
1116            }
1117        }
1118
1119        /// Replace identifier preserves code length approximately
1120        #[test]
1121        fn prop_replace_identifier_similar_length(
1122            old in "[a-z]{3,6}",
1123            new in "[a-z]{3,6}",
1124        ) {
1125            let eda = CodeEDA::new();
1126            let code = format!("{old} = 1\n{old} + 2");
1127            let result = eda.replace_identifier(&code, &old, &new);
1128
1129            // Length difference should be bounded by replacement diff * occurrences
1130            let len_diff = (result.len() as i64 - code.len() as i64).unsigned_abs();
1131            let replacement_diff = (new.len() as i64 - old.len() as i64).unsigned_abs();
1132            prop_assert!(len_diff <= replacement_diff * 2 + 1);
1133        }
1134
1135        /// Balanced brackets: unbalanced code always scores 0
1136        #[test]
1137        fn prop_unbalanced_scores_zero(
1138            n_open in 1usize..5,
1139        ) {
1140            let eda = CodeEDA::new();
1141            let unbalanced = "(".repeat(n_open);
1142            let score = eda.quality_score(&unbalanced, "x = 1");
1143            prop_assert!(score < f32::EPSILON);
1144        }
1145
1146        /// Swappable pairs have same indentation
1147        #[test]
1148        fn prop_swappable_pairs_same_indent(
1149            indent in 0usize..4,
1150            n_lines in 2usize..6,
1151        ) {
1152            let eda = CodeEDA::new();
1153            let space = " ".repeat(indent * 4);
1154            let lines: Vec<String> = (0..n_lines)
1155                .map(|i| format!("{space}x{i} = {i}"))
1156                .collect();
1157            let lines_ref: Vec<&str> = lines.iter().map(|s| s.as_str()).collect();
1158
1159            let pairs = eda.find_swappable_pairs(&lines_ref);
1160
1161            for (i, j) in pairs {
1162                let indent_i = lines_ref[i].len() - lines_ref[i].trim_start().len();
1163                let indent_j = lines_ref[j].len() - lines_ref[j].trim_start().len();
1164                prop_assert_eq!(indent_i, indent_j);
1165            }
1166        }
1167    }
1168}