use rand::Rng;
use tch::{Kind, Tensor};
pub struct RandomRemoveVec<T, R> {
inner: Vec<T>,
rng: R,
}
impl<T, R> RandomRemoveVec<T, R>
where
R: Rng,
{
pub fn with_capacity(capacity: usize, rng: R) -> Self {
RandomRemoveVec {
inner: Vec::with_capacity(capacity + 1),
rng,
}
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn push(&mut self, value: T) {
self.inner.push(value);
}
pub fn len(&self) -> usize {
self.inner.len()
}
}
impl<T, R> RandomRemoveVec<T, R>
where
R: Rng,
{
pub fn remove_random(&mut self) -> Option<T> {
if self.inner.is_empty() {
None
} else {
Some(
self.inner
.swap_remove(self.rng.gen_range(0, self.inner.len())),
)
}
}
pub fn push_and_remove_random(&mut self, replacement: T) -> T {
self.inner.push(replacement);
self.inner
.swap_remove(self.rng.gen_range(0, self.inner.len()))
}
}
pub fn seq_len_to_mask(seq_lens: &Tensor, max_len: i64) -> Tensor {
let batch_size = seq_lens.size()[0];
Tensor::arange(max_len, (Kind::Int, seq_lens.device()))
.repeat(&[batch_size])
.view_(&[batch_size, max_len])
.lt_1(&seq_lens.unsqueeze(1))
}
#[cfg(test)]
mod tests {
use rand::{Rng, SeedableRng};
use rand_xorshift::XorShiftRng;
use super::RandomRemoveVec;
#[test]
fn random_remove_vec() {
let mut rng = XorShiftRng::seed_from_u64(42);
let mut elems = RandomRemoveVec::with_capacity(3, XorShiftRng::seed_from_u64(42));
elems.push(1);
elems.push(2);
elems.push(3);
assert_eq!(rng.gen_range(0, 4 as usize), 1);
assert_eq!(elems.push_and_remove_random(4), 2);
assert_eq!(rng.gen_range(0, 4 as usize), 2);
assert_eq!(elems.push_and_remove_random(5), 3);
assert_eq!(rng.gen_range(0, 4 as usize), 1);
assert_eq!(elems.push_and_remove_random(6), 4);
assert_eq!(rng.gen_range(0, 3 as usize), 1);
assert_eq!(elems.remove_random().unwrap(), 6);
assert_eq!(rng.gen_range(0, 2 as usize), 0);
assert_eq!(elems.remove_random().unwrap(), 1);
assert_eq!(rng.gen_range(0, 1 as usize), 0);
assert_eq!(elems.remove_random().unwrap(), 5);
assert_eq!(elems.remove_random(), None);
assert_eq!(elems.push_and_remove_random(7), 7);
assert_eq!(elems.push_and_remove_random(8), 8);
}
}