ant_quic/bootstrap_cache/
selection.rs1use super::entry::CachedPeer;
4use rand::Rng;
5use std::collections::HashSet;
6
7#[derive(Debug, Clone, Copy)]
9pub enum SelectionStrategy {
10 BestFirst,
12 EpsilonGreedy {
14 epsilon: f64,
16 },
17 Random,
19}
20
21impl Default for SelectionStrategy {
22 fn default() -> Self {
23 Self::EpsilonGreedy { epsilon: 0.1 }
24 }
25}
26
27pub fn select_epsilon_greedy(peers: &[CachedPeer], count: usize, epsilon: f64) -> Vec<&CachedPeer> {
40 if peers.is_empty() || count == 0 {
41 return Vec::new();
42 }
43
44 let mut rng = rand::thread_rng();
45 let mut selected = Vec::with_capacity(count.min(peers.len()));
46 let mut used_indices = HashSet::new();
47
48 let mut sorted_indices: Vec<usize> = (0..peers.len()).collect();
50 sorted_indices.sort_by(|&a, &b| {
51 peers[b]
52 .quality_score
53 .partial_cmp(&peers[a].quality_score)
54 .unwrap_or(std::cmp::Ordering::Equal)
55 });
56
57 let target_count = count.min(peers.len());
59 let explore_count = ((target_count as f64) * epsilon).ceil() as usize;
60 let exploit_count = target_count.saturating_sub(explore_count);
61
62 for &idx in sorted_indices.iter().take(exploit_count) {
64 if used_indices.insert(idx) && selected.len() < target_count {
65 selected.push(&peers[idx]);
66 }
67 }
68
69 let remaining: Vec<usize> = (0..peers.len())
72 .filter(|idx| !used_indices.contains(idx))
73 .collect();
74
75 if !remaining.is_empty() && selected.len() < target_count {
76 let (untested, tested): (Vec<_>, Vec<_>) = remaining.iter().partition(|&&idx| {
78 peers[idx].stats.success_count + peers[idx].stats.failure_count == 0
79 });
80
81 let explore_pool = if !untested.is_empty() {
83 untested
84 } else {
85 tested
86 };
87
88 let mut explore_indices: Vec<usize> = explore_pool.into_iter().copied().collect();
90 for i in (1..explore_indices.len()).rev() {
92 let j = rng.gen_range(0..=i);
93 explore_indices.swap(i, j);
94 }
95
96 for &idx in explore_indices.iter() {
97 if selected.len() >= target_count {
98 break;
99 }
100 if used_indices.insert(idx) {
101 selected.push(&peers[idx]);
102 }
103 }
104 }
105
106 for &idx in &sorted_indices {
108 if selected.len() >= target_count {
109 break;
110 }
111 if used_indices.insert(idx) {
112 selected.push(&peers[idx]);
113 }
114 }
115
116 selected
117}
118
119#[allow(dead_code)]
123pub fn select_with_capabilities(
124 peers: &[CachedPeer],
125 count: usize,
126 require_relay: bool,
127 require_coordination: bool,
128) -> Vec<&CachedPeer> {
129 let mut filtered: Vec<&CachedPeer> = peers
130 .iter()
131 .filter(|p| {
132 (!require_relay || p.capabilities.supports_relay)
133 && (!require_coordination || p.capabilities.supports_coordination)
134 })
135 .collect();
136
137 if filtered.is_empty() {
138 return Vec::new();
139 }
140
141 filtered.sort_by(|a, b| {
143 b.quality_score
144 .partial_cmp(&a.quality_score)
145 .unwrap_or(std::cmp::Ordering::Equal)
146 });
147
148 filtered.into_iter().take(count).collect()
149}
150
151#[allow(dead_code)]
153pub fn select_by_strategy(
154 peers: &[CachedPeer],
155 count: usize,
156 strategy: SelectionStrategy,
157) -> Vec<&CachedPeer> {
158 match strategy {
159 SelectionStrategy::BestFirst => {
160 let mut sorted: Vec<&CachedPeer> = peers.iter().collect();
161 sorted.sort_by(|a, b| {
162 b.quality_score
163 .partial_cmp(&a.quality_score)
164 .unwrap_or(std::cmp::Ordering::Equal)
165 });
166 sorted.into_iter().take(count).collect()
167 }
168 SelectionStrategy::EpsilonGreedy { epsilon } => {
169 select_epsilon_greedy(peers, count, epsilon)
170 }
171 SelectionStrategy::Random => {
172 let mut rng = rand::thread_rng();
173 let mut indices: Vec<usize> = (0..peers.len()).collect();
174 for i in (1..indices.len()).rev() {
176 let j = rng.gen_range(0..=i);
177 indices.swap(i, j);
178 }
179 indices.into_iter().take(count).map(|i| &peers[i]).collect()
180 }
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::bootstrap_cache::entry::PeerSource;
188 use crate::nat_traversal_api::PeerId;
189
190 fn create_test_peers(count: usize) -> Vec<CachedPeer> {
191 (0..count)
192 .map(|i| {
193 let mut peer = CachedPeer::new(
194 PeerId([i as u8; 32]),
195 vec![format!("127.0.0.1:{}", 9000 + i).parse().unwrap()],
196 PeerSource::Seed,
197 );
198 peer.quality_score = i as f64 / count as f64;
200 peer
201 })
202 .collect()
203 }
204
205 #[test]
206 fn test_select_empty() {
207 let peers: Vec<CachedPeer> = vec![];
208 let selected = select_epsilon_greedy(&peers, 5, 0.1);
209 assert!(selected.is_empty());
210 }
211
212 #[test]
213 fn test_select_pure_exploitation() {
214 let peers = create_test_peers(10);
215 let selected = select_epsilon_greedy(&peers, 5, 0.0);
217
218 assert_eq!(selected.len(), 5);
219 for i in 0..4 {
221 assert!(selected[i].quality_score >= selected[i + 1].quality_score);
222 }
223 assert!((selected[0].quality_score - 0.9).abs() < 0.01);
225 }
226
227 #[test]
228 fn test_select_with_exploration() {
229 let peers = create_test_peers(20);
230 let mut has_variation = false;
233 let first_selection = select_epsilon_greedy(&peers, 10, 0.5);
234
235 for _ in 0..10 {
236 let selection = select_epsilon_greedy(&peers, 10, 0.5);
237 if selection.iter().map(|p| p.peer_id).collect::<Vec<_>>()
238 != first_selection
239 .iter()
240 .map(|p| p.peer_id)
241 .collect::<Vec<_>>()
242 {
243 has_variation = true;
244 break;
245 }
246 }
247 assert!(has_variation, "Expected variation with epsilon=0.5");
249 }
250
251 #[test]
252 fn test_select_more_than_available() {
253 let peers = create_test_peers(3);
254 let selected = select_epsilon_greedy(&peers, 10, 0.1);
255 assert_eq!(selected.len(), 3); }
257
258 #[test]
259 fn test_select_with_capabilities() {
260 let mut peers = create_test_peers(10);
261
262 peers[0].capabilities.supports_relay = true;
264 peers[5].capabilities.supports_relay = true;
265 peers[9].capabilities.supports_relay = true;
266
267 let relays = select_with_capabilities(&peers, 10, true, false);
268 assert_eq!(relays.len(), 3);
269
270 for peer in &relays {
272 assert!(peer.capabilities.supports_relay);
273 }
274 }
275
276 #[test]
277 fn test_best_first_strategy() {
278 let peers = create_test_peers(10);
279 let selected = select_by_strategy(&peers, 5, SelectionStrategy::BestFirst);
280
281 assert_eq!(selected.len(), 5);
282 for i in 0..4 {
284 assert!(selected[i].quality_score >= selected[i + 1].quality_score);
285 }
286 }
287
288 #[test]
289 fn test_random_strategy() {
290 let peers = create_test_peers(20);
291 let mut has_variation = false;
293 let first_selection = select_by_strategy(&peers, 10, SelectionStrategy::Random);
294
295 for _ in 0..10 {
296 let selection = select_by_strategy(&peers, 10, SelectionStrategy::Random);
297 if selection.iter().map(|p| p.peer_id).collect::<Vec<_>>()
298 != first_selection
299 .iter()
300 .map(|p| p.peer_id)
301 .collect::<Vec<_>>()
302 {
303 has_variation = true;
304 break;
305 }
306 }
307 assert!(has_variation, "Random selection should vary");
308 }
309}