use super::base::BaseTextAugmenter;
use super::utils::{Doc, TextAugmentParameters};
use crate::base::BaseAugmenter;
use crate::text::Token;
use rand::prelude::SliceRandom;
use std::collections::HashSet;
pub struct WordsRandomInsertAugmenter {
word_params: TextAugmentParameters,
vocabulary: Vec<String>,
#[allow(dead_code)]
stopwords: Option<HashSet<String>>,
}
impl WordsRandomInsertAugmenter {
pub fn new(
word_params: TextAugmentParameters,
vocabulary: Vec<String>,
stopwords: Option<HashSet<String>>,
) -> Self {
WordsRandomInsertAugmenter {
word_params,
vocabulary,
stopwords,
}
}
}
impl BaseTextAugmenter for WordsRandomInsertAugmenter {}
impl BaseAugmenter<String, Doc> for WordsRandomInsertAugmenter {
fn augment_inner(&self, mut input: Doc, rng: &mut dyn rand::RngCore) -> Doc {
let num_tokens_to_insert = self.word_params.num_elements(input.tokens.len());
let mut selected_places_to_insert_indexes =
self.select_random_element_indexes(rng, (0..=input.tokens.len()).collect(), num_tokens_to_insert);
selected_places_to_insert_indexes.sort();
let mut tokens_to_insert: Vec<Token> = (0..num_tokens_to_insert)
.map(|_| self.vocabulary.choose(rng).map(|word| Token::from_str(word)).unwrap())
.collect();
assert_eq!(selected_places_to_insert_indexes.len(), tokens_to_insert.len());
let mut new_tokens: Vec<Token> = Vec::with_capacity(input.tokens.len() + tokens_to_insert.len());
let mut current_doc_index = 0;
for place_to_insert_index in selected_places_to_insert_indexes {
while current_doc_index < place_to_insert_index {
new_tokens.push(input.tokens[current_doc_index].clone());
current_doc_index += 1;
}
if place_to_insert_index <= current_doc_index {
new_tokens.push(tokens_to_insert.pop().unwrap());
}
}
while current_doc_index < input.tokens.len() {
new_tokens.push(input.tokens[current_doc_index].clone());
current_doc_index += 1;
}
input.num_changes = num_tokens_to_insert;
input.tokens = new_tokens;
input
}
fn convert_to_inner(&self, input: String) -> Doc {
Doc::new(&input)
}
fn convert_to_outer(&self, input: Doc) -> String {
input.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_case::test_case;
#[test_case(vec!["A", "B", "C", "D", "E"], vec!["A", "T", "K"], 0.5, 8, 3 ; "round 2.5 as +3")]
#[test_case(vec!["A", "B", "C", "D", "E", "D"], vec!["A", "T", "K"], 0.5, 9, 3 ; "3 of 6")]
#[test_case(vec!["\t", "B", " ", "D", "!"], vec!["A", "T", "K"], 0.2, 6, 1 ; "1 word")]
#[test_case(vec!["\t", "B", " ", "D", "!"], vec!["A", "T", "K"], 0.1, 6, 1 ; "round 0.05 as 1 word")]
#[test_case(vec!["\t", "B", " ", "D", "!"], vec!["A", "T", "K"], 0.0, 5, 0 ; "insert probability=0")]
fn test_insert(
input_tokens: Vec<&str>,
vocab: Vec<&str>,
p: f32,
expected_len: usize,
expected_doc_changes: usize,
) {
let mut doc = Doc::from_tokens(input_tokens);
let params = TextAugmentParameters::new(p, None, None);
let vocab = vocab.into_iter().map(|s| s.to_string()).collect();
let aug = WordsRandomInsertAugmenter::new(params, vocab, None);
let doc_tokens_before = doc.tokens.clone();
doc = aug.augment_inner(doc, &mut rand::thread_rng());
let doc_tokens_after = doc.tokens.clone();
assert_eq!(doc_tokens_after.len(), expected_len);
if expected_doc_changes == 0 {
assert_eq!(doc_tokens_before, doc_tokens_after);
} else {
assert_ne!(doc_tokens_before, doc_tokens_after);
assert_eq!(doc.num_changes, expected_doc_changes);
}
}
}