ghostflow_data/
sampler.rs

1//! Samplers for data loading
2
3use rand::seq::SliceRandom;
4use rand::thread_rng;
5
6/// Trait for sampling indices
7pub trait Sampler: Iterator<Item = usize> {
8    fn len(&self) -> usize;
9    fn is_empty(&self) -> bool {
10        self.len() == 0
11    }
12}
13
14/// Sequential sampler - returns indices in order
15pub struct SequentialSampler {
16    current: usize,
17    len: usize,
18}
19
20impl SequentialSampler {
21    pub fn new(len: usize) -> Self {
22        SequentialSampler { current: 0, len }
23    }
24}
25
26impl Iterator for SequentialSampler {
27    type Item = usize;
28
29    fn next(&mut self) -> Option<Self::Item> {
30        if self.current < self.len {
31            let idx = self.current;
32            self.current += 1;
33            Some(idx)
34        } else {
35            None
36        }
37    }
38}
39
40impl Sampler for SequentialSampler {
41    fn len(&self) -> usize {
42        self.len
43    }
44}
45
46/// Random sampler - returns indices in random order
47pub struct RandomSampler {
48    indices: Vec<usize>,
49    current: usize,
50}
51
52impl RandomSampler {
53    pub fn new(len: usize) -> Self {
54        let mut indices: Vec<usize> = (0..len).collect();
55        indices.shuffle(&mut thread_rng());
56        RandomSampler { indices, current: 0 }
57    }
58}
59
60impl Iterator for RandomSampler {
61    type Item = usize;
62
63    fn next(&mut self) -> Option<Self::Item> {
64        if self.current < self.indices.len() {
65            let idx = self.indices[self.current];
66            self.current += 1;
67            Some(idx)
68        } else {
69            None
70        }
71    }
72}
73
74impl Sampler for RandomSampler {
75    fn len(&self) -> usize {
76        self.indices.len()
77    }
78}
79
80/// Weighted random sampler
81pub struct WeightedRandomSampler {
82    indices: Vec<usize>,
83    current: usize,
84}
85
86impl WeightedRandomSampler {
87    pub fn new(weights: &[f32], num_samples: usize, replacement: bool) -> Self {
88        let total_weight: f32 = weights.iter().sum();
89        let normalized: Vec<f32> = weights.iter().map(|w| w / total_weight).collect();
90        
91        let mut indices = Vec::with_capacity(num_samples);
92        let mut available: Vec<usize> = (0..weights.len()).collect();
93        
94        for _ in 0..num_samples {
95            // Simple weighted selection
96            let r: f32 = rand::random();
97            let mut cumsum = 0.0f32;
98            
99            for (i, &w) in normalized.iter().enumerate() {
100                cumsum += w;
101                if r < cumsum {
102                    if replacement {
103                        indices.push(i);
104                    } else if available.contains(&i) {
105                        indices.push(i);
106                        available.retain(|&x| x != i);
107                    }
108                    break;
109                }
110            }
111        }
112        
113        WeightedRandomSampler { indices, current: 0 }
114    }
115}
116
117impl Iterator for WeightedRandomSampler {
118    type Item = usize;
119
120    fn next(&mut self) -> Option<Self::Item> {
121        if self.current < self.indices.len() {
122            let idx = self.indices[self.current];
123            self.current += 1;
124            Some(idx)
125        } else {
126            None
127        }
128    }
129}
130
131impl Sampler for WeightedRandomSampler {
132    fn len(&self) -> usize {
133        self.indices.len()
134    }
135}
136
137/// Batch sampler - yields batches of indices
138pub struct BatchSampler<S: Sampler> {
139    sampler: S,
140    batch_size: usize,
141    drop_last: bool,
142}
143
144impl<S: Sampler> BatchSampler<S> {
145    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
146        BatchSampler {
147            sampler,
148            batch_size,
149            drop_last,
150        }
151    }
152}
153
154impl<S: Sampler> Iterator for BatchSampler<S> {
155    type Item = Vec<usize>;
156
157    fn next(&mut self) -> Option<Self::Item> {
158        let mut batch = Vec::with_capacity(self.batch_size);
159        
160        for _ in 0..self.batch_size {
161            if let Some(idx) = self.sampler.next() {
162                batch.push(idx);
163            } else {
164                break;
165            }
166        }
167        
168        if batch.is_empty() || (self.drop_last && batch.len() < self.batch_size) {
169            None
170        } else {
171            Some(batch)
172        }
173    }
174}