ghostflow_data/
sampler.rs

1//! Samplers for data loading
2
3use rand::seq::SliceRandom;
4use rand::thread_rng;
5
6/// Trait for sampling indices
7pub trait Sampler: Iterator<Item = usize> {
8    fn len(&self) -> usize;
9}
10
11/// Sequential sampler - returns indices in order
12pub struct SequentialSampler {
13    current: usize,
14    len: usize,
15}
16
17impl SequentialSampler {
18    pub fn new(len: usize) -> Self {
19        SequentialSampler { current: 0, len }
20    }
21}
22
23impl Iterator for SequentialSampler {
24    type Item = usize;
25
26    fn next(&mut self) -> Option<Self::Item> {
27        if self.current < self.len {
28            let idx = self.current;
29            self.current += 1;
30            Some(idx)
31        } else {
32            None
33        }
34    }
35}
36
37impl Sampler for SequentialSampler {
38    fn len(&self) -> usize {
39        self.len
40    }
41}
42
43/// Random sampler - returns indices in random order
44pub struct RandomSampler {
45    indices: Vec<usize>,
46    current: usize,
47}
48
49impl RandomSampler {
50    pub fn new(len: usize) -> Self {
51        let mut indices: Vec<usize> = (0..len).collect();
52        indices.shuffle(&mut thread_rng());
53        RandomSampler { indices, current: 0 }
54    }
55}
56
57impl Iterator for RandomSampler {
58    type Item = usize;
59
60    fn next(&mut self) -> Option<Self::Item> {
61        if self.current < self.indices.len() {
62            let idx = self.indices[self.current];
63            self.current += 1;
64            Some(idx)
65        } else {
66            None
67        }
68    }
69}
70
71impl Sampler for RandomSampler {
72    fn len(&self) -> usize {
73        self.indices.len()
74    }
75}
76
77/// Weighted random sampler
78pub struct WeightedRandomSampler {
79    indices: Vec<usize>,
80    current: usize,
81}
82
83impl WeightedRandomSampler {
84    pub fn new(weights: &[f32], num_samples: usize, replacement: bool) -> Self {
85        let total_weight: f32 = weights.iter().sum();
86        let normalized: Vec<f32> = weights.iter().map(|w| w / total_weight).collect();
87        
88        let mut indices = Vec::with_capacity(num_samples);
89        let mut available: Vec<usize> = (0..weights.len()).collect();
90        
91        for _ in 0..num_samples {
92            // Simple weighted selection
93            let r: f32 = rand::random();
94            let mut cumsum = 0.0f32;
95            
96            for (i, &w) in normalized.iter().enumerate() {
97                cumsum += w;
98                if r < cumsum {
99                    if replacement {
100                        indices.push(i);
101                    } else if available.contains(&i) {
102                        indices.push(i);
103                        available.retain(|&x| x != i);
104                    }
105                    break;
106                }
107            }
108        }
109        
110        WeightedRandomSampler { indices, current: 0 }
111    }
112}
113
114impl Iterator for WeightedRandomSampler {
115    type Item = usize;
116
117    fn next(&mut self) -> Option<Self::Item> {
118        if self.current < self.indices.len() {
119            let idx = self.indices[self.current];
120            self.current += 1;
121            Some(idx)
122        } else {
123            None
124        }
125    }
126}
127
128impl Sampler for WeightedRandomSampler {
129    fn len(&self) -> usize {
130        self.indices.len()
131    }
132}
133
134/// Batch sampler - yields batches of indices
135pub struct BatchSampler<S: Sampler> {
136    sampler: S,
137    batch_size: usize,
138    drop_last: bool,
139}
140
141impl<S: Sampler> BatchSampler<S> {
142    pub fn new(sampler: S, batch_size: usize, drop_last: bool) -> Self {
143        BatchSampler {
144            sampler,
145            batch_size,
146            drop_last,
147        }
148    }
149}
150
151impl<S: Sampler> Iterator for BatchSampler<S> {
152    type Item = Vec<usize>;
153
154    fn next(&mut self) -> Option<Self::Item> {
155        let mut batch = Vec::with_capacity(self.batch_size);
156        
157        for _ in 0..self.batch_size {
158            if let Some(idx) = self.sampler.next() {
159                batch.push(idx);
160            } else {
161                break;
162            }
163        }
164        
165        if batch.is_empty() {
166            None
167        } else if self.drop_last && batch.len() < self.batch_size {
168            None
169        } else {
170            Some(batch)
171        }
172    }
173}