fast_aug/text/
chars_random_swap.rs

1use super::base::BaseTextAugmenter;
2use super::utils::{Doc, TextAugmentParameters};
3use crate::base::BaseAugmenter;
4use std::collections::HashSet;
5
6/// Augmenter that swaps random chars in random words in text
7///
8/// # Examples
9/// ```rust
10/// use fast_aug::base::BaseAugmenter;
11/// use fast_aug::text::{CharsRandomSwapAugmenter, TextAugmentParameters};
12///
13/// let rng = &mut rand::thread_rng();
14/// let augmenter = CharsRandomSwapAugmenter::new(
15///     TextAugmentParameters::new(0.5, None, None),
16///     TextAugmentParameters::new(0.5, None, None),
17///     None,
18/// );
19/// augmenter.augment("Some text!".to_string(), rng);
20/// augmenter.augment_batch(vec!["Some text!".to_string()], rng);
21/// ```
22pub struct CharsRandomSwapAugmenter {
23    /// Parameters to calculate number of words that will be augmented
24    word_params: TextAugmentParameters,
25    /// Parameters to calculate number of chars that will be augmented in each word
26    char_params: TextAugmentParameters,
27    /// Filter, Set of words that cannot be augmented
28    stopwords: Option<HashSet<String>>,
29}
30
31impl CharsRandomSwapAugmenter {
32    pub fn new(
33        word_params: TextAugmentParameters,
34        char_params: TextAugmentParameters,
35        stopwords: Option<HashSet<String>>,
36    ) -> Self {
37        CharsRandomSwapAugmenter {
38            word_params,
39            char_params,
40            stopwords,
41        }
42    }
43}
44
45impl BaseTextAugmenter for CharsRandomSwapAugmenter {}
46
47impl BaseAugmenter<String, Doc> for CharsRandomSwapAugmenter {
48    fn augment_inner(&self, mut input: Doc, rng: &mut dyn rand::RngCore) -> Doc {
49        // TODO: adjacent, middle, random swaps (now only random)
50        // Select random word tokens
51        let word_tokens_indexes = input.get_word_indexes(false, self.stopwords.as_ref());
52        let num_tokens_to_change = self.word_params.num_elements(word_tokens_indexes.len());
53        let selected_tokens_indexes =
54            self.select_random_element_indexes(rng, word_tokens_indexes, num_tokens_to_change);
55
56        // For all selected tokens select random chars and swap them
57        for token_index in selected_tokens_indexes {
58            let token = &mut input.tokens[token_index];
59            let num_chars_to_change = self.char_params.num_elements(token.utf8_len());
60
61            let selected_chars_indexes =
62                self.select_random_element_indexes(rng, (0..token.utf8_len()).collect(), num_chars_to_change);
63            let mut chars = token.token().chars().collect::<Vec<char>>();
64            selected_chars_indexes.chunks(2).for_each(|chunk| {
65                if chunk.len() == 2 {
66                    chars.swap(chunk[0], chunk[1]);
67                }
68            });
69            let new_token = chars.iter().collect::<String>();
70            token.change(&new_token, *token.kind());
71
72            input.num_changes += 1;
73        }
74
75        input
76    }
77
78    fn convert_to_inner(&self, input: String) -> Doc {
79        Doc::new(&input)
80    }
81
82    fn convert_to_outer(&self, input: Doc) -> String {
83        input.to_string()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use test_case::test_case;
91
92    #[test_case(vec!["ABCD", "EFGH", "IJKL", "MNOP", "QRST"], 0.5, 0.5, 3 ; "round 2.5 as 3 words, round 2.5 as 3 chars each")]
93    #[test_case(vec!["ABCD", "EFGH", "IJKL", "MNOP", "QRST"], 0.0, 0.5, 0 ; "swap chars in 0 words - no changes")]
94    #[test_case(vec!["ABCD", "EFGH", "IJKL", "MNOP", "QRST"], 0.5, 0.0, 0 ; "swap 0 chars - no changes")]
95    fn test_swap(input_tokens: Vec<&str>, words_p: f32, chars_p: f32, expected_doc_changes: usize) {
96        let mut doc = Doc::from_tokens(input_tokens);
97        let words_params = TextAugmentParameters::new(words_p, None, None);
98        let chars_params = TextAugmentParameters::new(chars_p, None, None);
99        let aug = CharsRandomSwapAugmenter::new(words_params, chars_params, None);
100
101        let doc_tokens_before = doc.tokens.clone();
102
103        doc = aug.augment_inner(doc, &mut rand::thread_rng());
104
105        let doc_tokens_after = doc.tokens.clone();
106
107        if expected_doc_changes == 0 {
108            assert_eq!(doc_tokens_before, doc_tokens_after);
109        } else {
110            assert_eq!(doc_tokens_before.len(), doc_tokens_after.len());
111            assert_ne!(doc_tokens_before, doc_tokens_after);
112            assert_eq!(doc.num_changes, expected_doc_changes);
113        }
114
115        let mut num_changed_words = 0;
116        for (token_before, token_after) in doc_tokens_before.iter().zip(doc_tokens_after.iter()) {
117            if token_before.token() != token_after.token() {
118                assert_eq!(token_before.token().len(), token_after.token().len());
119                assert_ne!(token_before.token(), token_after.token());
120                num_changed_words += 1;
121            }
122        }
123        assert_eq!(num_changed_words, expected_doc_changes);
124    }
125
126    #[test_case("It’s ” #NBAwards" ; "utf8 len not equal bytes len")]
127    fn test_swap_bugs(text: &str) {
128        let mut doc = Doc::new(text);
129        let words_params = TextAugmentParameters::new(1.0, None, None);
130        let chars_params = TextAugmentParameters::new(0.3, None, None);
131        let aug = CharsRandomSwapAugmenter::new(words_params, chars_params, None);
132
133        let doc_tokens_before = doc.tokens.clone();
134
135        doc = aug.augment_inner(doc, &mut rand::thread_rng());
136
137        let doc_tokens_after = doc.tokens.clone();
138
139        assert_eq!(doc_tokens_before.len(), doc_tokens_after.len());
140        assert_ne!(doc_tokens_before, doc_tokens_after);
141    }
142}