1use rand::prelude::*;
14use std::collections::HashSet;
15
16#[derive(Debug, Clone)]
18pub struct CodeEDAConfig {
19 pub sr_prob: f32,
21 pub ri_prob: f32,
23 pub rs_prob: f32,
25 pub rd_prob: f32,
27 pub quality_threshold: f32,
29 pub seed: u64,
31}
32
33impl Default for CodeEDAConfig {
34 fn default() -> Self {
35 Self {
36 sr_prob: 0.1,
37 ri_prob: 0.1,
38 rs_prob: 0.1,
39 rd_prob: 0.05,
40 quality_threshold: 0.75,
41 seed: 42,
42 }
43 }
44}
45
46#[derive(Debug)]
48pub struct CodeEDA {
49 config: CodeEDAConfig,
50 rng: StdRng,
51}
52
53impl CodeEDA {
54 #[must_use]
56 pub fn new() -> Self {
57 Self::with_config(CodeEDAConfig::default())
58 }
59
60 #[must_use]
62 pub fn with_config(config: CodeEDAConfig) -> Self {
63 let rng = StdRng::seed_from_u64(config.seed);
64 Self { config, rng }
65 }
66
67 pub fn augment(&mut self, code: &str, n_augmentations: usize) -> Vec<String> {
71 let mut results = Vec::with_capacity(n_augmentations);
72
73 for _ in 0..n_augmentations * 2 {
74 let augmented = self.apply_augmentations(code);
76 let quality = self.quality_score(&augmented, code);
77
78 if quality >= self.config.quality_threshold {
79 results.push(augmented);
80 if results.len() >= n_augmentations {
81 break;
82 }
83 }
84 }
85
86 results
87 }
88
89 fn apply_augmentations(&mut self, code: &str) -> String {
91 let mut result = code.to_string();
92
93 if self.rng.random::<f32>() < self.config.sr_prob {
94 result = self.synonym_replacement(&result);
95 }
96 if self.rng.random::<f32>() < self.config.ri_prob {
97 result = self.random_insertion(&result);
98 }
99 if self.rng.random::<f32>() < self.config.rs_prob {
100 result = self.random_swap(&result);
101 }
102 if self.rng.random::<f32>() < self.config.rd_prob {
103 result = self.random_deletion(&result);
104 }
105
106 result
107 }
108
109 fn synonym_replacement(&mut self, code: &str) -> String {
111 let variables = self.extract_variables(code);
112 if variables.is_empty() {
113 return code.to_string();
114 }
115
116 let mut var_list: Vec<_> = variables.into_iter().collect();
118 var_list.sort();
119 let idx = self.rng.random_range(0..var_list.len());
120 let old_var = &var_list[idx];
121
122 let new_var = self.generate_variable_name(old_var);
124
125 self.replace_identifier(code, old_var, &new_var)
127 }
128
129 fn random_insertion(&mut self, code: &str) -> String {
131 let lines: Vec<&str> = code.lines().collect();
132 if lines.is_empty() {
133 return code.to_string();
134 }
135
136 let insert_idx = self.rng.random_range(0..=lines.len());
137 let insert_type = self.rng.random_range(0..3);
138
139 let insertion = match insert_type {
140 0 => " # augmented".to_string(),
141 1 => " pass # placeholder".to_string(),
142 _ => format!(" # line {}", insert_idx + 1),
143 };
144
145 let mut result_lines: Vec<String> = lines.iter().map(|s| (*s).to_string()).collect();
146 result_lines.insert(insert_idx, insertion);
147 result_lines.join("\n")
148 }
149
150 fn random_swap(&mut self, code: &str) -> String {
152 let lines: Vec<&str> = code.lines().collect();
153 if lines.len() < 2 {
154 return code.to_string();
155 }
156
157 let swappable = self.find_swappable_pairs(&lines);
159 if swappable.is_empty() {
160 return code.to_string();
161 }
162
163 let (i, j) = swappable[self.rng.random_range(0..swappable.len())];
164 let mut result_lines: Vec<String> = lines.iter().map(|s| (*s).to_string()).collect();
165 result_lines.swap(i, j);
166 result_lines.join("\n")
167 }
168
169 fn random_deletion(&mut self, code: &str) -> String {
171 let lines: Vec<&str> = code.lines().collect();
172 if lines.len() <= 2 {
173 return code.to_string();
174 }
175
176 let deletable: Vec<usize> = lines
178 .iter()
179 .enumerate()
180 .filter(|(_, line)| {
181 let trimmed = line.trim();
182 trimmed.starts_with('#') || trimmed == "pass"
183 })
184 .map(|(i, _)| i)
185 .collect();
186
187 if deletable.is_empty() {
188 return code.to_string();
189 }
190
191 let del_idx = deletable[self.rng.random_range(0..deletable.len())];
192 let result_lines: Vec<&str> = lines
193 .iter()
194 .enumerate()
195 .filter(|(i, _)| *i != del_idx)
196 .map(|(_, line)| *line)
197 .collect();
198 result_lines.join("\n")
199 }
200
201 fn extract_variables(&self, code: &str) -> HashSet<String> {
203 let mut vars = HashSet::new();
204
205 for line in code.lines() {
207 let trimmed = line.trim();
208 if let Some(eq_pos) = trimmed.find('=') {
209 if eq_pos > 0 && !trimmed[..eq_pos].contains('(') {
210 let lhs = trimmed[..eq_pos].trim();
211 if !lhs.ends_with(['!', '<', '>', '=']) {
213 for var in lhs.split(',') {
215 let var = var.trim();
216 if is_valid_identifier(var) && !is_keyword(var) {
217 vars.insert(var.to_string());
218 }
219 }
220 }
221 }
222 }
223 }
224
225 vars
226 }
227
228 fn generate_variable_name(&mut self, old: &str) -> String {
230 let suffixes = ["_new", "_v2", "_alt", "_mod", "2"];
231 let suffix = suffixes[self.rng.random_range(0..suffixes.len())];
232 format!("{old}{suffix}")
233 }
234
235 fn replace_identifier(&self, code: &str, old: &str, new: &str) -> String {
237 let mut result = String::with_capacity(code.len() + 32);
238 let chars: Vec<char> = code.chars().collect();
239 let old_chars: Vec<char> = old.chars().collect();
240 let mut i = 0;
241
242 while i < chars.len() {
243 if i + old_chars.len() <= chars.len() {
244 let matches = chars[i..i + old_chars.len()]
245 .iter()
246 .zip(old_chars.iter())
247 .all(|(a, b)| a == b);
248
249 if matches {
250 let before_ok =
252 i == 0 || !chars[i - 1].is_alphanumeric() && chars[i - 1] != '_';
253 let after_ok = i + old_chars.len() >= chars.len()
254 || !chars[i + old_chars.len()].is_alphanumeric()
255 && chars[i + old_chars.len()] != '_';
256
257 if before_ok && after_ok {
258 result.push_str(new);
259 i += old_chars.len();
260 continue;
261 }
262 }
263 }
264 result.push(chars[i]);
265 i += 1;
266 }
267
268 result
269 }
270
271 fn find_swappable_pairs(&self, lines: &[&str]) -> Vec<(usize, usize)> {
273 let mut pairs = Vec::new();
274
275 for i in 0..lines.len().saturating_sub(1) {
276 let indent_i = lines[i].len() - lines[i].trim_start().len();
277 let indent_j = lines[i + 1].len() - lines[i + 1].trim_start().len();
278
279 if indent_i == indent_j {
281 let line_i = lines[i].trim();
282 let line_j = lines[i + 1].trim();
283
284 let is_simple_i = !line_i.starts_with("if ")
286 && !line_i.starts_with("for ")
287 && !line_i.starts_with("while ")
288 && !line_i.starts_with("def ")
289 && !line_i.starts_with("class ")
290 && !line_i.starts_with("return")
291 && !line_i.is_empty();
292
293 let is_simple_j = !line_j.starts_with("if ")
294 && !line_j.starts_with("for ")
295 && !line_j.starts_with("while ")
296 && !line_j.starts_with("def ")
297 && !line_j.starts_with("class ")
298 && !line_j.starts_with("return")
299 && !line_j.is_empty();
300
301 if is_simple_i && is_simple_j {
302 pairs.push((i, i + 1));
303 }
304 }
305 }
306
307 pairs
308 }
309
310 #[must_use]
316 pub fn quality_score(&self, augmented: &str, original: &str) -> f32 {
317 if !self.basic_syntax_check(augmented) {
319 return 0.0;
320 }
321
322 let orig_tokens: HashSet<_> = tokenize(original).collect();
324 let aug_tokens: HashSet<_> = tokenize(augmented).collect();
325
326 if orig_tokens.is_empty() {
327 return 1.0;
328 }
329
330 let overlap = orig_tokens.intersection(&aug_tokens).count();
331 overlap as f32 / orig_tokens.len() as f32
332 }
333
334 #[must_use]
338 pub fn diversity_score(&self, batch: &[String]) -> f32 {
339 if batch.is_empty() {
340 return 0.0;
341 }
342
343 let unique: HashSet<_> = batch.iter().collect();
344 unique.len() as f32 / batch.len() as f32
345 }
346
347 fn basic_syntax_check(&self, code: &str) -> bool {
349 let mut paren_depth = 0i32;
350 let mut bracket_depth = 0i32;
351 let mut brace_depth = 0i32;
352 let mut in_string = false;
353 let mut string_char = ' ';
354
355 for c in code.chars() {
356 if in_string {
357 if c == string_char {
358 in_string = false;
359 }
360 continue;
361 }
362
363 match c {
364 '"' | '\'' => {
365 in_string = true;
366 string_char = c;
367 }
368 '(' => paren_depth += 1,
369 ')' => paren_depth -= 1,
370 '[' => bracket_depth += 1,
371 ']' => bracket_depth -= 1,
372 '{' => brace_depth += 1,
373 '}' => brace_depth -= 1,
374 _ => {}
375 }
376
377 if paren_depth < 0 || bracket_depth < 0 || brace_depth < 0 {
378 return false;
379 }
380 }
381
382 paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 && !in_string
383 }
384}
385
386impl Default for CodeEDA {
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392fn tokenize(code: &str) -> impl Iterator<Item = &str> {
394 code.split(|c: char| !c.is_alphanumeric() && c != '_')
395 .filter(|s| !s.is_empty())
396}
397
398fn is_valid_identifier(s: &str) -> bool {
400 if s.is_empty() {
401 return false;
402 }
403 let Some(first) = s.chars().next() else {
404 return false;
405 };
406 (first.is_alphabetic() || first == '_') && s.chars().all(|c| c.is_alphanumeric() || c == '_')
407}
408
409fn is_keyword(s: &str) -> bool {
411 matches!(
412 s,
413 "False"
414 | "None"
415 | "True"
416 | "and"
417 | "as"
418 | "assert"
419 | "async"
420 | "await"
421 | "break"
422 | "class"
423 | "continue"
424 | "def"
425 | "del"
426 | "elif"
427 | "else"
428 | "except"
429 | "finally"
430 | "for"
431 | "from"
432 | "global"
433 | "if"
434 | "import"
435 | "in"
436 | "is"
437 | "lambda"
438 | "nonlocal"
439 | "not"
440 | "or"
441 | "pass"
442 | "raise"
443 | "return"
444 | "try"
445 | "while"
446 | "with"
447 | "yield"
448 )
449}
450
451#[derive(Debug, Clone)]
453pub struct AugmentationResult {
454 pub original: String,
456 pub variants: Vec<String>,
458 pub quality_scores: Vec<f32>,
460 pub diversity_score: f32,
462}
463
464#[derive(Debug)]
466pub struct BatchAugmenter {
467 eda: CodeEDA,
468 pub factor: f32,
470}
471
472impl BatchAugmenter {
473 #[must_use]
475 pub fn new(config: CodeEDAConfig, factor: f32) -> Self {
476 Self {
477 eda: CodeEDA::with_config(config),
478 factor,
479 }
480 }
481
482 pub fn augment_batch(&mut self, samples: &[String]) -> Vec<AugmentationResult> {
484 #[allow(clippy::cast_sign_loss)]
485 let n_aug = (self.factor.max(0.0) as usize).max(1);
486
487 samples
488 .iter()
489 .map(|code| {
490 let variants = self.eda.augment(code, n_aug);
491 let quality_scores: Vec<f32> = variants
492 .iter()
493 .map(|v| self.eda.quality_score(v, code))
494 .collect();
495 let diversity_score = self.eda.diversity_score(&variants);
496
497 AugmentationResult {
498 original: code.clone(),
499 variants,
500 quality_scores,
501 diversity_score,
502 }
503 })
504 .collect()
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_code_eda_basic() {
514 let mut eda = CodeEDA::new();
515 let code = "x = 1\ny = 2\nz = x + y";
516 let augmented = eda.augment(code, 3);
517
518 assert!(!augmented.is_empty());
519 for aug in &augmented {
520 let quality = eda.quality_score(aug, code);
521 assert!(quality >= 0.75);
522 }
523 }
524
525 #[test]
526 fn test_synonym_replacement() {
527 let mut eda = CodeEDA::with_config(CodeEDAConfig {
528 sr_prob: 1.0,
529 ri_prob: 0.0,
530 rs_prob: 0.0,
531 rd_prob: 0.0,
532 ..Default::default()
533 });
534
535 let code = "foo = 1\nbar = foo + 2";
536 let augmented = eda.augment(code, 1);
537
538 assert!(!augmented.is_empty());
539 let aug = &augmented[0];
541 assert!(aug.contains("_new") || aug.contains("_v2") || aug.contains("2"));
542 }
543
544 #[test]
545 fn test_random_insertion() {
546 let mut eda = CodeEDA::with_config(CodeEDAConfig {
547 sr_prob: 0.0,
548 ri_prob: 1.0,
549 rs_prob: 0.0,
550 rd_prob: 0.0,
551 ..Default::default()
552 });
553
554 let code = "x = 1";
555 let augmented = eda.augment(code, 1);
556
557 assert!(!augmented.is_empty());
558 assert!(augmented[0].lines().count() > code.lines().count());
560 }
561
562 #[test]
563 fn test_quality_score() {
564 let eda = CodeEDA::new();
565
566 let score = eda.quality_score("x = 1\ny = 2", "x = 1\ny = 2");
568 assert!((score - 1.0).abs() < f32::EPSILON);
569
570 let score = eda.quality_score("x_new = 1\ny = 2", "x = 1\ny = 2");
572 assert!(score > 0.5);
573
574 let score = eda.quality_score("x = (1", "x = 1");
576 assert!(score < f32::EPSILON);
577 }
578
579 #[test]
580 fn test_diversity_score() {
581 let eda = CodeEDA::new();
582
583 let batch = vec!["a".to_string(), "b".to_string(), "c".to_string()];
585 assert!((eda.diversity_score(&batch) - 1.0).abs() < f32::EPSILON);
586
587 let batch = vec!["a".to_string(), "a".to_string(), "a".to_string()];
589 assert!((eda.diversity_score(&batch) - 1.0 / 3.0).abs() < f32::EPSILON);
590
591 let batch: Vec<String> = vec![];
593 assert!(eda.diversity_score(&batch) < f32::EPSILON);
594 }
595
596 #[test]
597 fn test_batch_augmenter() {
598 let config = CodeEDAConfig::default();
599 let mut augmenter = BatchAugmenter::new(config, 2.0);
600
601 let samples = vec!["x = 1".to_string(), "y = 2".to_string()];
602 let results = augmenter.augment_batch(&samples);
603
604 assert_eq!(results.len(), 2);
605 for result in &results {
606 assert!(result.diversity_score >= 0.0);
607 }
608 }
609
610 #[test]
611 fn test_extract_variables() {
612 let eda = CodeEDA::new();
613
614 let vars = eda.extract_variables("x = 1\ny = 2\nif x == y: pass");
615 assert!(vars.contains("x"));
616 assert!(vars.contains("y"));
617 assert!(!vars.contains("if")); }
619
620 #[test]
621 fn test_basic_syntax_check() {
622 let eda = CodeEDA::new();
623
624 assert!(eda.basic_syntax_check("x = (1 + 2)"));
625 assert!(eda.basic_syntax_check("x = [1, 2, 3]"));
626 assert!(eda.basic_syntax_check("x = {'a': 1}"));
627 assert!(eda.basic_syntax_check("x = \"hello\""));
628
629 assert!(!eda.basic_syntax_check("x = (1 + 2"));
630 assert!(!eda.basic_syntax_check("x = [1, 2"));
631 assert!(!eda.basic_syntax_check("x = \"hello"));
632 }
633
634 #[test]
635 fn test_is_valid_identifier() {
636 assert!(is_valid_identifier("foo"));
637 assert!(is_valid_identifier("_bar"));
638 assert!(is_valid_identifier("baz123"));
639 assert!(is_valid_identifier("__init__"));
640
641 assert!(!is_valid_identifier("123abc"));
642 assert!(!is_valid_identifier(""));
643 assert!(!is_valid_identifier("foo-bar"));
644 }
645
646 #[test]
647 fn test_is_keyword() {
648 assert!(is_keyword("if"));
649 assert!(is_keyword("for"));
650 assert!(is_keyword("return"));
651 assert!(is_keyword("True"));
652
653 assert!(!is_keyword("foo"));
654 assert!(!is_keyword("bar"));
655 }
656
657 #[test]
660 fn test_augment_empty_code() {
661 let mut eda = CodeEDA::new();
662 let augmented = eda.augment("", 3);
663 for aug in &augmented {
665 assert!(eda.basic_syntax_check(aug));
666 }
667 }
668
669 #[test]
670 fn test_augment_single_char() {
671 let mut eda = CodeEDA::new();
672 let augmented = eda.augment("x", 3);
673 assert!(augmented.is_empty() || augmented.iter().all(|a| eda.basic_syntax_check(a)));
674 }
675
676 #[test]
677 fn test_augment_whitespace_only() {
678 let mut eda = CodeEDA::new();
679 let augmented = eda.augment(" \n\t\n ", 3);
680 for aug in &augmented {
681 assert!(eda.basic_syntax_check(aug));
682 }
683 }
684
685 #[test]
686 fn test_extract_variables_tuple_unpacking() {
687 let eda = CodeEDA::new();
688 let vars = eda.extract_variables("a, b, c = 1, 2, 3");
689 assert!(vars.contains("a"));
690 assert!(vars.contains("b"));
691 assert!(vars.contains("c"));
692 }
693
694 #[test]
695 fn test_extract_variables_no_assignments() {
696 let eda = CodeEDA::new();
697 let vars = eda.extract_variables("print('hello')\nfoo()");
698 assert!(vars.is_empty());
699 }
700
701 #[test]
702 fn test_extract_variables_with_comparison() {
703 let eda = CodeEDA::new();
704 let vars = eda.extract_variables("if x == y:\n pass");
705 assert!(!vars.contains("x"));
707 }
708
709 #[test]
710 fn test_synonym_replacement_no_variables() {
711 let mut eda = CodeEDA::with_config(CodeEDAConfig {
712 sr_prob: 1.0,
713 ri_prob: 0.0,
714 rs_prob: 0.0,
715 rd_prob: 0.0,
716 ..Default::default()
717 });
718
719 let code = "print('hello')";
720 let augmented = eda.augment(code, 1);
721 assert!(augmented.is_empty() || eda.basic_syntax_check(&augmented[0]));
723 }
724
725 #[test]
726 fn test_random_swap_single_line() {
727 let mut eda = CodeEDA::with_config(CodeEDAConfig {
728 sr_prob: 0.0,
729 ri_prob: 0.0,
730 rs_prob: 1.0,
731 rd_prob: 0.0,
732 ..Default::default()
733 });
734
735 let code = "x = 1";
736 let augmented = eda.augment(code, 1);
737 assert!(augmented.is_empty() || augmented[0] == code);
739 }
740
741 #[test]
742 fn test_random_deletion_minimal_code() {
743 let mut eda = CodeEDA::with_config(CodeEDAConfig {
744 sr_prob: 0.0,
745 ri_prob: 0.0,
746 rs_prob: 0.0,
747 rd_prob: 1.0,
748 ..Default::default()
749 });
750
751 let code = "x = 1\ny = 2"; let augmented = eda.augment(code, 1);
753 assert!(augmented.is_empty() || augmented[0].lines().count() >= 2);
755 }
756
757 #[test]
758 fn test_random_deletion_removes_comment() {
759 let mut eda = CodeEDA::with_config(CodeEDAConfig {
760 sr_prob: 0.0,
761 ri_prob: 0.0,
762 rs_prob: 0.0,
763 rd_prob: 1.0,
764 quality_threshold: 0.0, ..Default::default()
766 });
767
768 let code = "x = 1\n# comment\ny = 2\nz = 3";
769 let augmented = eda.augment(code, 1);
770 if !augmented.is_empty() {
771 assert!(!augmented[0].contains("# comment") || augmented[0].lines().count() < 4);
773 }
774 }
775
776 #[test]
777 fn test_quality_score_empty_original() {
778 let eda = CodeEDA::new();
779 let score = eda.quality_score("x = 1", "");
780 assert!((score - 1.0).abs() < f32::EPSILON);
782 }
783
784 #[test]
785 fn test_quality_score_nested_brackets() {
786 let eda = CodeEDA::new();
787 let code = "x = [[1, 2], [3, 4]]";
788 let score = eda.quality_score(code, code);
789 assert!((score - 1.0).abs() < f32::EPSILON);
790 assert!(eda.basic_syntax_check(code));
791 }
792
793 #[test]
794 fn test_quality_score_unbalanced_nested() {
795 let eda = CodeEDA::new();
796 let score = eda.quality_score("x = [[1, 2]", "x = 1");
797 assert!(score < f32::EPSILON);
798 }
799
800 #[test]
801 fn test_replace_identifier_word_boundary() {
802 let eda = CodeEDA::new();
803
804 let result = eda.replace_identifier("max = x + max_value", "x", "y");
806 assert!(result.contains("max"));
807 assert!(result.contains("y"));
808 assert!(result.contains("max_value")); }
810
811 #[test]
812 fn test_replace_identifier_at_start() {
813 let eda = CodeEDA::new();
814 let result = eda.replace_identifier("foo = 1", "foo", "bar");
815 assert_eq!(result, "bar = 1");
816 }
817
818 #[test]
819 fn test_replace_identifier_at_end() {
820 let eda = CodeEDA::new();
821 let result = eda.replace_identifier("x = foo", "foo", "bar");
822 assert_eq!(result, "x = bar");
823 }
824
825 #[test]
826 fn test_find_swappable_pairs_control_flow() {
827 let eda = CodeEDA::new();
828 let lines: Vec<&str> = vec!["if x:", " y = 1"];
829 let pairs = eda.find_swappable_pairs(&lines);
830 assert!(pairs.is_empty());
832 }
833
834 #[test]
835 fn test_find_swappable_pairs_different_indent() {
836 let eda = CodeEDA::new();
837 let lines: Vec<&str> = vec!["x = 1", " y = 2"];
838 let pairs = eda.find_swappable_pairs(&lines);
839 assert!(pairs.is_empty());
841 }
842
843 #[test]
844 fn test_find_swappable_pairs_valid() {
845 let eda = CodeEDA::new();
846 let lines: Vec<&str> = vec!["x = 1", "y = 2", "z = 3"];
847 let pairs = eda.find_swappable_pairs(&lines);
848 assert!(!pairs.is_empty());
850 }
851
852 #[test]
853 fn test_basic_syntax_check_escaped_quotes() {
854 let eda = CodeEDA::new();
855 let result = eda.basic_syntax_check(r#"x = "hello""#);
857 assert!(result);
858 }
859
860 #[test]
861 fn test_basic_syntax_check_mixed_brackets() {
862 let eda = CodeEDA::new();
863 assert!(eda.basic_syntax_check("x = ([1, 2], {3: 4})"));
864 assert!(!eda.basic_syntax_check("x = ([1, 2}, {3: 4])"));
865 }
866
867 #[test]
868 fn test_config_probabilities_boundary() {
869 let config = CodeEDAConfig {
870 sr_prob: 0.0,
871 ri_prob: 0.0,
872 rs_prob: 0.0,
873 rd_prob: 0.0,
874 quality_threshold: 0.0,
875 seed: 42,
876 };
877 let mut eda = CodeEDA::with_config(config);
878 let code = "x = 1";
879 let augmented = eda.augment(code, 5);
880 for aug in &augmented {
882 assert_eq!(aug, code);
883 }
884 }
885
886 #[test]
887 fn test_config_all_ops_enabled() {
888 let config = CodeEDAConfig {
889 sr_prob: 1.0,
890 ri_prob: 1.0,
891 rs_prob: 1.0,
892 rd_prob: 1.0,
893 quality_threshold: 0.0, seed: 42,
895 };
896 let mut eda = CodeEDA::with_config(config);
897 let code = "x = 1\n# comment\ny = 2\nz = 3";
898 let augmented = eda.augment(code, 3);
899 assert!(!augmented.is_empty());
901 }
902
903 #[test]
904 fn test_batch_augmenter_empty_samples() {
905 let config = CodeEDAConfig::default();
906 let mut augmenter = BatchAugmenter::new(config, 2.0);
907 let samples: Vec<String> = vec![];
908 let results = augmenter.augment_batch(&samples);
909 assert!(results.is_empty());
910 }
911
912 #[test]
913 fn test_batch_augmenter_factor_zero() {
914 let config = CodeEDAConfig::default();
915 let mut augmenter = BatchAugmenter::new(config, 0.0);
916 let samples = vec!["x = 1".to_string()];
917 let results = augmenter.augment_batch(&samples);
918 assert_eq!(results.len(), 1);
920 }
921
922 #[test]
923 fn test_augmentation_result_fields() {
924 let result = AugmentationResult {
925 original: "x = 1".to_string(),
926 variants: vec!["x_new = 1".to_string()],
927 quality_scores: vec![0.8],
928 diversity_score: 1.0,
929 };
930 assert_eq!(result.original, "x = 1");
931 assert_eq!(result.variants.len(), 1);
932 assert_eq!(result.quality_scores.len(), 1);
933 assert!((result.diversity_score - 1.0).abs() < f32::EPSILON);
934 }
935
936 #[test]
937 fn test_tokenize_special_chars() {
938 let tokens: Vec<_> = tokenize("x = 1 + y * 2").collect();
939 assert!(tokens.contains(&"x"));
940 assert!(tokens.contains(&"1"));
941 assert!(tokens.contains(&"y"));
942 assert!(tokens.contains(&"2"));
943 assert!(!tokens.contains(&"+"));
945 assert!(!tokens.contains(&"*"));
946 }
947
948 #[test]
949 fn test_is_valid_identifier_unicode() {
950 assert!(is_valid_identifier("über")); assert!(is_valid_identifier("x123"));
953 assert!(!is_valid_identifier("123über")); }
955
956 #[test]
957 fn test_code_eda_deterministic_with_seed() {
958 let config = CodeEDAConfig {
959 seed: 12345,
960 ..Default::default()
961 };
962 let mut eda1 = CodeEDA::with_config(config.clone());
963 let mut eda2 = CodeEDA::with_config(config);
964
965 let code = "x = 1\ny = 2\nz = 3";
966 let aug1 = eda1.augment(code, 3);
967 let aug2 = eda2.augment(code, 3);
968
969 assert_eq!(aug1, aug2, "Same seed should produce same augmentations");
970 }
971
972 #[test]
973 fn test_code_eda_different_seeds() {
974 let mut eda1 = CodeEDA::with_config(CodeEDAConfig {
975 seed: 1,
976 sr_prob: 1.0,
977 ..Default::default()
978 });
979 let mut eda2 = CodeEDA::with_config(CodeEDAConfig {
980 seed: 2,
981 sr_prob: 1.0,
982 ..Default::default()
983 });
984
985 let code = "foo = 1\nbar = foo + 2";
986 let aug1 = eda1.augment(code, 1);
987 let aug2 = eda2.augment(code, 1);
988
989 assert!(!aug1.is_empty());
992 assert!(!aug2.is_empty());
993 }
994
995 #[test]
998 fn test_all_python_keywords() {
999 let keywords = [
1000 "False", "None", "True", "and", "as", "assert", "async", "await", "break", "class",
1001 "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global",
1002 "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise",
1003 "return", "try", "while", "with", "yield",
1004 ];
1005 for kw in keywords {
1006 assert!(is_keyword(kw), "{kw} should be a keyword");
1007 }
1008 }
1009
1010 #[test]
1011 fn test_non_keywords() {
1012 let non_keywords = ["foo", "bar", "baz", "x", "y", "z", "print", "len", "str"];
1013 for nk in non_keywords {
1014 assert!(!is_keyword(nk), "{nk} should not be a keyword");
1015 }
1016 }
1017}
1018
1019#[cfg(test)]
1021mod proptests {
1022 use super::*;
1023 use proptest::prelude::*;
1024
1025 fn python_code_strategy() -> impl Strategy<Value = String> {
1027 prop::collection::vec(
1028 prop_oneof![
1029 "[a-z][a-z0-9_]{0,10} = [0-9]{1,5}".prop_map(|s| s),
1031 "# [a-zA-Z0-9 ]{0,20}".prop_map(|s| s),
1033 "[a-z]+\\([0-9, ]*\\)".prop_map(|s| s),
1035 ],
1036 1..10,
1037 )
1038 .prop_map(|lines| lines.join("\n"))
1039 }
1040
1041 proptest! {
1042 #[test]
1044 fn prop_augmented_code_is_syntactically_valid(
1045 seed in 0u64..10000,
1046 n_aug in 1usize..5,
1047 ) {
1048 let config = CodeEDAConfig {
1049 seed,
1050 quality_threshold: 0.5,
1051 ..Default::default()
1052 };
1053 let mut eda = CodeEDA::with_config(config);
1054 let code = "x = 1\ny = 2\nz = 3";
1055 let augmented = eda.augment(code, n_aug);
1056
1057 for aug in &augmented {
1058 prop_assert!(eda.basic_syntax_check(aug));
1059 }
1060 }
1061
1062 #[test]
1064 fn prop_quality_score_bounded(
1065 code in "[a-z]+ = [0-9]+",
1066 aug in "[a-z]+ = [0-9]+",
1067 ) {
1068 let eda = CodeEDA::new();
1069 let score = eda.quality_score(&aug, &code);
1070 prop_assert!(score >= 0.0);
1071 prop_assert!(score <= 1.0);
1072 }
1073
1074 #[test]
1076 fn prop_diversity_score_bounded(
1077 batch in prop::collection::vec("[a-z]+", 1..10),
1078 ) {
1079 let eda = CodeEDA::new();
1080 let score = eda.diversity_score(&batch);
1081 prop_assert!(score >= 0.0);
1082 prop_assert!(score <= 1.0);
1083 }
1084
1085 #[test]
1087 fn prop_deterministic_with_seed(
1088 seed in 0u64..10000,
1089 code in "[a-z]+ = [0-9]+\n[a-z]+ = [0-9]+",
1090 ) {
1091 let config = CodeEDAConfig {
1092 seed,
1093 ..Default::default()
1094 };
1095 let mut eda1 = CodeEDA::with_config(config.clone());
1096 let mut eda2 = CodeEDA::with_config(config);
1097
1098 let aug1 = eda1.augment(&code, 3);
1099 let aug2 = eda2.augment(&code, 3);
1100
1101 prop_assert_eq!(aug1, aug2);
1102 }
1103
1104 #[test]
1106 fn prop_extracted_vars_are_valid_identifiers(
1107 var in "[a-z][a-z0-9_]{0,10}",
1108 ) {
1109 let eda = CodeEDA::new();
1110 let code = format!("{var} = 42");
1111 let vars = eda.extract_variables(&code);
1112
1113 for v in vars {
1114 prop_assert!(is_valid_identifier(&v));
1115 prop_assert!(!is_keyword(&v));
1116 }
1117 }
1118
1119 #[test]
1121 fn prop_replace_identifier_similar_length(
1122 old in "[a-z]{3,6}",
1123 new in "[a-z]{3,6}",
1124 ) {
1125 let eda = CodeEDA::new();
1126 let code = format!("{old} = 1\n{old} + 2");
1127 let result = eda.replace_identifier(&code, &old, &new);
1128
1129 let len_diff = (result.len() as i64 - code.len() as i64).unsigned_abs();
1131 let replacement_diff = (new.len() as i64 - old.len() as i64).unsigned_abs();
1132 prop_assert!(len_diff <= replacement_diff * 2 + 1);
1133 }
1134
1135 #[test]
1137 fn prop_unbalanced_scores_zero(
1138 n_open in 1usize..5,
1139 ) {
1140 let eda = CodeEDA::new();
1141 let unbalanced = "(".repeat(n_open);
1142 let score = eda.quality_score(&unbalanced, "x = 1");
1143 prop_assert!(score < f32::EPSILON);
1144 }
1145
1146 #[test]
1148 fn prop_swappable_pairs_same_indent(
1149 indent in 0usize..4,
1150 n_lines in 2usize..6,
1151 ) {
1152 let eda = CodeEDA::new();
1153 let space = " ".repeat(indent * 4);
1154 let lines: Vec<String> = (0..n_lines)
1155 .map(|i| format!("{space}x{i} = {i}"))
1156 .collect();
1157 let lines_ref: Vec<&str> = lines.iter().map(|s| s.as_str()).collect();
1158
1159 let pairs = eda.find_swappable_pairs(&lines_ref);
1160
1161 for (i, j) in pairs {
1162 let indent_i = lines_ref[i].len() - lines_ref[i].trim_start().len();
1163 let indent_j = lines_ref[j].len() - lines_ref[j].trim_start().len();
1164 prop_assert_eq!(indent_i, indent_j);
1165 }
1166 }
1167 }
1168}