1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use rand::seq::SliceRandom;
use rand::thread_rng;

use crate::{sampler::Sampler, Len};

/// Sampler that returns random index between zero and `data_source_len`.
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct RandomSampler {
    /// The length of the data source.
    data_source_len: usize,
    /// Whether the sample is replaced or not.
    /// If it's replaced, we can have 2 times the same sample.
    replacement: bool,
}

impl Sampler for RandomSampler {
    fn new(data_source_len: usize) -> Self {
        Self {
            data_source_len,
            replacement: false,
        }
    }
}
impl Len for RandomSampler {
    fn len(&self) -> usize {
        self.data_source_len
    }
}
impl IntoIterator for RandomSampler {
    type IntoIter = RandomSamplerIter;
    type Item = usize;
    fn into_iter(self) -> Self::IntoIter {
        RandomSamplerIter::new(self.data_source_len, self.replacement)
    }
}
/// Iterator that returns random index between zero and `data_source_len`.
pub struct RandomSamplerIter {
    /// The length of the data source.
    data_source_len: usize,
    /// A permutation over the datasets indexes.
    indexes: Vec<usize>,
    /// The current index.
    idx: usize,
}

impl RandomSamplerIter {
    /// Create a new `RandomSamplerIter`.
    ///
    /// # Arguments
    ///
    /// * `data_source_len` - The length of the dataset.
    /// * `replacement` - Wether we can have the same sample twice over one iteration or not.
    fn new(data_source_len: usize, replacement: bool) -> Self {
        if replacement {
            todo!()
        } else {
            let mut vec: Vec<usize> = (0..data_source_len).collect();
            vec.shuffle(&mut thread_rng());
            Self {
                data_source_len,
                indexes: vec,
                idx: 0,
            }
        }
    }
}
impl Iterator for RandomSamplerIter {
    type Item = usize;
    fn next(&mut self) -> Option<Self::Item> {
        if self.idx < self.data_source_len {
            self.idx += 1;
            Some(self.indexes[self.idx - 1])
        } else {
            None
        }
    }
}

#[test]
fn random_sampler() {
    let random_sampler = RandomSampler {
        data_source_len: 10,
        replacement: false,
    };
    for idx in random_sampler {
        println!("{idx}");
    }
}