use crate::batching;
#[test]
fn pad_mask_test() {
let batch = vec![vec!["d", "hello", "how"], vec!["hi", "yo", "PAD"]];
let expected_mask = vec![vec![false, false, false], vec![false, false, true]];
let pad_mask = batching::pad_mask(&batch, "PAD");
assert_eq!(expected_mask, pad_mask);
}
#[test]
fn pad_batch_test() {
let mut seqs = vec![vec![1, 2, 3, 1], vec![1, 4, 6, 2, 3, 5, 67]];
let expected_padded_batch = vec![vec![1, 2, 3, 1, 0, 0, 0], vec![1, 4, 6, 2, 3, 5, 67]];
batching::pad_batch(&mut seqs, 0);
assert_eq!(seqs, expected_padded_batch);
}
#[test]
fn filter_by_length_test() {
let mut seqs = vec![
vec![vec![1, 2, 3, 1], vec![1, 4, 6, 2, 3, 5, 67], vec![1, 2, 3]],
vec![vec![1, 1], vec![1, 67], vec![1, 2, 3]],
];
let expected_seqs = vec![
vec![vec![1, 2, 3, 1], vec![1, 2, 3]],
vec![vec![1, 1], vec![1, 2, 3]],
];
batching::filter_by_length(&mut seqs, None, Some(6));
assert_eq!(seqs, expected_seqs);
}
#[test]
fn shuffle_lists_test() {
let mut seqs = vec![
vec![vec![1, 2, 3, 1], vec![1, 4, 6, 2, 3, 5, 67], vec![1, 2, 3]],
vec![vec![1, 1], vec![1, 67], vec![1, 2, 3]],
];
let orig_seqs = seqs.clone();
for _ in 0..10 {
batching::shuffle_lists(&mut seqs);
if seqs != orig_seqs {
break;
}
}
assert_ne!(seqs, orig_seqs);
}
#[test]
fn sort_lists_by_length_test() {
let mut seqs = vec![
vec![
"hello".to_string(),
"how are you".to_string(),
"yo".to_string(),
],
vec!["hey".to_string(), "wow".to_string(), "who".to_string()],
];
let sorted_seqs = vec![
vec![
"yo".to_string(),
"hello".to_string(),
"how are you".to_string(),
],
vec!["who".to_string(), "hey".to_string(), "wow".to_string()],
];
let reverse_sorted_seqs = vec![
vec![
"how are you".to_string(),
"hello".to_string(),
"yo".to_string(),
],
vec!["wow".to_string(), "hey".to_string(), "who".to_string()],
];
batching::sort_lists_by_length(&mut seqs, Some(false));
assert_eq!(seqs, sorted_seqs);
batching::sort_lists_by_length(&mut seqs, Some(true));
assert_eq!(seqs, reverse_sorted_seqs);
}