fast_aug/text/
chars_random_swap.rs1use super::base::BaseTextAugmenter;
2use super::utils::{Doc, TextAugmentParameters};
3use crate::base::BaseAugmenter;
4use std::collections::HashSet;
5
6pub struct CharsRandomSwapAugmenter {
23 word_params: TextAugmentParameters,
25 char_params: TextAugmentParameters,
27 stopwords: Option<HashSet<String>>,
29}
30
31impl CharsRandomSwapAugmenter {
32 pub fn new(
33 word_params: TextAugmentParameters,
34 char_params: TextAugmentParameters,
35 stopwords: Option<HashSet<String>>,
36 ) -> Self {
37 CharsRandomSwapAugmenter {
38 word_params,
39 char_params,
40 stopwords,
41 }
42 }
43}
44
45impl BaseTextAugmenter for CharsRandomSwapAugmenter {}
46
47impl BaseAugmenter<String, Doc> for CharsRandomSwapAugmenter {
48 fn augment_inner(&self, mut input: Doc, rng: &mut dyn rand::RngCore) -> Doc {
49 let word_tokens_indexes = input.get_word_indexes(false, self.stopwords.as_ref());
52 let num_tokens_to_change = self.word_params.num_elements(word_tokens_indexes.len());
53 let selected_tokens_indexes =
54 self.select_random_element_indexes(rng, word_tokens_indexes, num_tokens_to_change);
55
56 for token_index in selected_tokens_indexes {
58 let token = &mut input.tokens[token_index];
59 let num_chars_to_change = self.char_params.num_elements(token.utf8_len());
60
61 let selected_chars_indexes =
62 self.select_random_element_indexes(rng, (0..token.utf8_len()).collect(), num_chars_to_change);
63 let mut chars = token.token().chars().collect::<Vec<char>>();
64 selected_chars_indexes.chunks(2).for_each(|chunk| {
65 if chunk.len() == 2 {
66 chars.swap(chunk[0], chunk[1]);
67 }
68 });
69 let new_token = chars.iter().collect::<String>();
70 token.change(&new_token, *token.kind());
71
72 input.num_changes += 1;
73 }
74
75 input
76 }
77
78 fn convert_to_inner(&self, input: String) -> Doc {
79 Doc::new(&input)
80 }
81
82 fn convert_to_outer(&self, input: Doc) -> String {
83 input.to_string()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use test_case::test_case;
91
92 #[test_case(vec!["ABCD", "EFGH", "IJKL", "MNOP", "QRST"], 0.5, 0.5, 3 ; "round 2.5 as 3 words, round 2.5 as 3 chars each")]
93 #[test_case(vec!["ABCD", "EFGH", "IJKL", "MNOP", "QRST"], 0.0, 0.5, 0 ; "swap chars in 0 words - no changes")]
94 #[test_case(vec!["ABCD", "EFGH", "IJKL", "MNOP", "QRST"], 0.5, 0.0, 0 ; "swap 0 chars - no changes")]
95 fn test_swap(input_tokens: Vec<&str>, words_p: f32, chars_p: f32, expected_doc_changes: usize) {
96 let mut doc = Doc::from_tokens(input_tokens);
97 let words_params = TextAugmentParameters::new(words_p, None, None);
98 let chars_params = TextAugmentParameters::new(chars_p, None, None);
99 let aug = CharsRandomSwapAugmenter::new(words_params, chars_params, None);
100
101 let doc_tokens_before = doc.tokens.clone();
102
103 doc = aug.augment_inner(doc, &mut rand::thread_rng());
104
105 let doc_tokens_after = doc.tokens.clone();
106
107 if expected_doc_changes == 0 {
108 assert_eq!(doc_tokens_before, doc_tokens_after);
109 } else {
110 assert_eq!(doc_tokens_before.len(), doc_tokens_after.len());
111 assert_ne!(doc_tokens_before, doc_tokens_after);
112 assert_eq!(doc.num_changes, expected_doc_changes);
113 }
114
115 let mut num_changed_words = 0;
116 for (token_before, token_after) in doc_tokens_before.iter().zip(doc_tokens_after.iter()) {
117 if token_before.token() != token_after.token() {
118 assert_eq!(token_before.token().len(), token_after.token().len());
119 assert_ne!(token_before.token(), token_after.token());
120 num_changed_words += 1;
121 }
122 }
123 assert_eq!(num_changed_words, expected_doc_changes);
124 }
125
126 #[test_case("It’s ” #NBAwards" ; "utf8 len not equal bytes len")]
127 fn test_swap_bugs(text: &str) {
128 let mut doc = Doc::new(text);
129 let words_params = TextAugmentParameters::new(1.0, None, None);
130 let chars_params = TextAugmentParameters::new(0.3, None, None);
131 let aug = CharsRandomSwapAugmenter::new(words_params, chars_params, None);
132
133 let doc_tokens_before = doc.tokens.clone();
134
135 doc = aug.augment_inner(doc, &mut rand::thread_rng());
136
137 let doc_tokens_after = doc.tokens.clone();
138
139 assert_eq!(doc_tokens_before.len(), doc_tokens_after.len());
140 assert_ne!(doc_tokens_before, doc_tokens_after);
141 }
142}