#[cfg(test)]
mod tests;
use lentrait::Len;
pub fn pad_mask<T: std::cmp::PartialEq>(batch: &[Vec<T>], pad_value: T) -> Vec<Vec<bool>> {
let mut mask: Vec<Vec<bool>> = vec![Vec::with_capacity(batch[0].len()); batch.len()];
for (i, seq) in batch.iter().enumerate() {
for token in seq {
mask[i].push(*token == pad_value);
}
}
mask
}
pub fn pad_batch<T: std::clone::Clone>(batch: &mut [Vec<T>], pad_value: T) {
let mut longest = 0;
for example in batch.iter() {
if example.len() > longest {
longest = example.len();
}
}
for example in batch.iter_mut() {
while example.len() < longest {
example.push(pad_value.clone());
}
}
}
pub fn filter_by_length<T: Len>(
lists: &mut [Vec<T>],
min_length: Option<usize>,
max_length: Option<usize>,
) {
for i in (0..lists[0].len()).rev() {
for x in 0..lists.len() {
if lists[x][i].len() > max_length.unwrap_or(usize::MAX)
|| lists[x][i].len() < min_length.unwrap_or(0)
{
for list in lists.iter_mut() {
list.remove(i);
}
break;
}
}
}
}
pub fn shuffle_lists<T: std::clone::Clone>(lists: &mut [Vec<T>]) {
use rand::seq::SliceRandom;
use rand::thread_rng;
let mut zipped: Vec<Vec<T>> = vec![Vec::with_capacity(lists.len()); lists[0].len()];
for list in lists.iter() {
for (i, item) in list.iter().enumerate() {
zipped[i].push(item.clone());
}
}
zipped.shuffle(&mut thread_rng());
for (x, list) in lists.iter_mut().enumerate() {
for (i, item) in list.iter_mut().enumerate() {
*item = zipped[i][x].clone();
}
}
}
pub fn sort_lists_by_length<T: Len + std::clone::Clone>(
lists: &mut [Vec<T>],
longest_first: Option<bool>,
) {
for i in 1..lists.len() {
assert!(lists[i].len() == lists[0].len())
}
let mut zipped: Vec<Vec<T>> = vec![Vec::with_capacity(lists.len()); lists[0].len()];
for list in lists.iter() {
for (i, item) in list.iter().enumerate() {
zipped[i].push(item.clone());
}
}
zipped.sort_unstable_by(|a, b| {
a[0].len()
.partial_cmp(&b[0].len())
.expect("NaN found in lengths!")
});
if longest_first.unwrap_or(false) {
zipped.reverse()
}
for (x, list) in lists.iter_mut().enumerate() {
for (i, item) in list.iter_mut().enumerate() {
*item = zipped[i][x].clone();
}
}
}