1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#[cfg(test)]
mod tests;
use lentrait::Len;
pub fn pad_mask<T: std::cmp::PartialEq>(batch: &[Vec<T>], pad_value: T) -> Vec<Vec<bool>> {
let mut mask: Vec<Vec<bool>> = vec![Vec::with_capacity(batch[0].len()); batch.len()];
for (i, seq) in batch.iter().enumerate() {
for token in seq {
mask[i].push(*token == pad_value);
}
}
mask
}
pub fn pad_batch<T: std::clone::Clone>(batch: &mut [Vec<T>], pad_value: T) {
let mut longest = 0;
for example in batch.iter() {
if example.len() > longest {
longest = example.len();
}
}
for example in batch.iter_mut() {
while example.len() < longest {
example.push(pad_value.clone());
}
}
}
pub fn filter_by_length<T: Len>(
lists: &mut [Vec<T>],
min_length: Option<usize>,
max_length: Option<usize>,
) {
for i in (0..lists[0].len()).rev() {
for x in 0..lists.len() {
if lists[x][i].len() > max_length.unwrap_or(usize::MAX)
|| lists[x][i].len() < min_length.unwrap_or(0)
{
for list in lists.iter_mut() {
list.remove(i);
}
break;
}
}
}
}
pub fn shuffle_lists<T: std::clone::Clone>(lists: &mut [Vec<T>]) {
use rand::seq::SliceRandom;
use rand::thread_rng;
let mut zipped: Vec<Vec<T>> = vec![Vec::with_capacity(lists.len()); lists[0].len()];
for list in lists.iter() {
for (i, item) in list.iter().enumerate() {
zipped[i].push(item.clone());
}
}
zipped.shuffle(&mut thread_rng());
for (x, list) in lists.iter_mut().enumerate() {
for (i, item) in list.iter_mut().enumerate() {
*item = zipped[i][x].clone();
}
}
}
pub fn sort_lists_by_length<T: Len + std::clone::Clone>(
lists: &mut [Vec<T>],
longest_first: Option<bool>,
) {
for i in 1..lists.len() {
assert!(lists[i].len() == lists[0].len())
} let mut zipped: Vec<Vec<T>> = vec![Vec::with_capacity(lists.len()); lists[0].len()];
for list in lists.iter() {
for (i, item) in list.iter().enumerate() {
zipped[i].push(item.clone());
}
}
zipped.sort_unstable_by(|a, b| {
a[0].len()
.partial_cmp(&b[0].len())
.expect("NaN found in lengths!")
});
if longest_first.unwrap_or(false) {
zipped.reverse()
}
for (x, list) in lists.iter_mut().enumerate() {
for (i, item) in list.iter_mut().enumerate() {
*item = zipped[i][x].clone();
}
}
}