1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#[cfg(test)]
mod tests;
use lentrait::Len;
pub fn pad_mask<T: std::cmp::PartialEq>(batch: &[Vec<T>], pad_value: T) -> Vec<Vec<bool>> {
let mut mask: Vec<Vec<bool>> = vec![Vec::with_capacity(batch[0].len()); batch.len()];
for (i, seq)in batch.iter().enumerate() {
for token in seq {
mask[i].push(*token == pad_value);
}
}
mask
}
pub fn pad_batch<T: std::clone::Clone>(batch: &mut Vec<Vec<T>>, pad_value: T) {
let mut longest = 0;
for example in batch.iter() {
if example.len() > longest {longest = example.len();}
}
for example in batch.iter_mut() {
while example.len() < longest {
example.push(pad_value.clone());
}
}
}
pub fn filter_by_length<T: Len> (lists: &mut Vec<Vec<T>>, min_length: Option<usize>, max_length: Option<usize>) {
for i in (0..lists[0].len()).rev() {
for x in 0..lists.len() {
if lists[x][i].len() > max_length.unwrap_or(usize::MAX) || lists[x][i].len() < min_length.unwrap_or(0) {
for list in lists.iter_mut() {
list.remove(i);
}
break;
}
}
}
}
pub fn shuffle_lists<T: std::clone::Clone>(lists: &mut Vec<Vec<T>>) {
use rand::thread_rng;
use rand::seq::SliceRandom;
let mut zipped: Vec<Vec<T>> = vec![Vec::with_capacity(lists.len()); lists[0].len()];
for list in lists.iter() {
for (i, item) in list.iter().enumerate() {
zipped[i].push(item.clone());
}
}
zipped.shuffle(&mut thread_rng());
for (x, list) in lists.iter_mut().enumerate() {
for (i, item) in list.iter_mut().enumerate() {
*item = zipped[i][x].clone();
}
}
}
pub fn sort_lists_by_length<T: Len + std::clone::Clone>(lists: &mut Vec<Vec<T>>, longest_first:Option<bool>) {
for i in 1..lists.len() {assert!(lists[i].len() == lists[0].len())}
let mut zipped: Vec<Vec<T>> = vec![Vec::with_capacity(lists.len()); lists[0].len()];
for list in lists.iter() {
for (i, item) in list.iter().enumerate() {
zipped[i].push(item.clone());
}
}
zipped.sort_unstable_by(|a, b| {
a[0].len().partial_cmp(&b[0].len()).expect("NaN found in lengths!")
});
if longest_first.unwrap_or(false) {zipped.reverse()}
for (x, list) in lists.iter_mut().enumerate() {
for (i, item) in list.iter_mut().enumerate() {
*item = zipped[i][x].clone();
}
}
}