use super::entry::CachedPeer;
use rand::Rng;
use std::collections::HashSet;
#[derive(Debug, Clone, Copy)]
pub enum SelectionStrategy {
BestFirst,
EpsilonGreedy {
epsilon: f64,
},
Random,
}
impl Default for SelectionStrategy {
fn default() -> Self {
Self::EpsilonGreedy { epsilon: 0.1 }
}
}
pub fn select_epsilon_greedy(peers: &[CachedPeer], count: usize, epsilon: f64) -> Vec<&CachedPeer> {
if peers.is_empty() || count == 0 {
return Vec::new();
}
let mut rng = rand::thread_rng();
let mut selected = Vec::with_capacity(count.min(peers.len()));
let mut used_indices = HashSet::new();
let mut sorted_indices: Vec<usize> = (0..peers.len()).collect();
sorted_indices.sort_by(|&a, &b| {
peers[b]
.quality_score
.partial_cmp(&peers[a].quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let target_count = count.min(peers.len());
let explore_count = ((target_count as f64) * epsilon).ceil() as usize;
let exploit_count = target_count.saturating_sub(explore_count);
for &idx in sorted_indices.iter().take(exploit_count) {
if used_indices.insert(idx) && selected.len() < target_count {
selected.push(&peers[idx]);
}
}
let remaining: Vec<usize> = (0..peers.len())
.filter(|idx| !used_indices.contains(idx))
.collect();
if !remaining.is_empty() && selected.len() < target_count {
let (untested, tested): (Vec<_>, Vec<_>) = remaining.iter().partition(|&&idx| {
peers[idx].stats.success_count + peers[idx].stats.failure_count == 0
});
let explore_pool = if !untested.is_empty() {
untested
} else {
tested
};
let mut explore_indices: Vec<usize> = explore_pool.into_iter().copied().collect();
for i in (1..explore_indices.len()).rev() {
let j = rng.gen_range(0..=i);
explore_indices.swap(i, j);
}
for &idx in explore_indices.iter() {
if selected.len() >= target_count {
break;
}
if used_indices.insert(idx) {
selected.push(&peers[idx]);
}
}
}
for &idx in &sorted_indices {
if selected.len() >= target_count {
break;
}
if used_indices.insert(idx) {
selected.push(&peers[idx]);
}
}
selected
}
#[allow(dead_code)]
pub fn select_with_capabilities(
peers: &[CachedPeer],
count: usize,
require_relay: bool,
require_coordination: bool,
) -> Vec<&CachedPeer> {
if peers.is_empty() || count == 0 {
return Vec::new();
}
fn preference_score(peer: &CachedPeer, require_relay: bool, require_coordination: bool) -> u8 {
let mut score = 0u8;
if require_relay && peer.capabilities.supports_relay {
score = score.saturating_add(1);
}
if require_coordination && peer.capabilities.supports_coordination {
score = score.saturating_add(1);
}
score
}
let mut candidates: Vec<&CachedPeer> = peers.iter().collect();
candidates.sort_by(|a, b| {
let a_pref = preference_score(a, require_relay, require_coordination);
let b_pref = preference_score(b, require_relay, require_coordination);
b_pref.cmp(&a_pref).then_with(|| {
b.quality_score
.partial_cmp(&a.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
candidates.into_iter().take(count).collect()
}
pub fn select_relays_for_target(
peers: &[CachedPeer],
count: usize,
target_is_ipv4: bool,
prefer_dual_stack: bool,
) -> Vec<&CachedPeer> {
if peers.is_empty() || count == 0 {
return Vec::new();
}
let mut candidates: Vec<&CachedPeer> = peers
.iter()
.filter(|p| {
if p.capabilities.supports_dual_stack() {
return true;
}
if p.capabilities.external_addresses.is_empty() {
return true; }
if target_is_ipv4 {
p.capabilities.has_ipv4()
} else {
p.capabilities.has_ipv6()
}
})
.collect();
if candidates.is_empty() {
return Vec::new();
}
let ip_match = |peer: &CachedPeer| {
if peer.capabilities.external_addresses.is_empty() {
0u8
} else if target_is_ipv4 {
u8::from(peer.capabilities.has_ipv4())
} else {
u8::from(peer.capabilities.has_ipv6())
}
};
candidates.sort_by(|a, b| {
if prefer_dual_stack {
let a_ds = a.capabilities.supports_dual_stack();
let b_ds = b.capabilities.supports_dual_stack();
if a_ds != b_ds {
return b_ds.cmp(&a_ds);
}
}
let a_pref = (u8::from(a.capabilities.supports_relay) * 2).saturating_add(ip_match(a));
let b_pref = (u8::from(b.capabilities.supports_relay) * 2).saturating_add(ip_match(b));
b_pref.cmp(&a_pref).then_with(|| {
b.quality_score
.partial_cmp(&a.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
candidates.into_iter().take(count).collect()
}
pub fn select_dual_stack_relays(peers: &[CachedPeer], count: usize) -> Vec<&CachedPeer> {
let mut filtered: Vec<&CachedPeer> = peers
.iter()
.filter(|p| p.capabilities.supports_dual_stack())
.collect();
if filtered.is_empty() {
return Vec::new();
}
filtered.sort_by(|a, b| {
let a_pref = u8::from(a.capabilities.supports_relay);
let b_pref = u8::from(b.capabilities.supports_relay);
b_pref.cmp(&a_pref).then_with(|| {
b.quality_score
.partial_cmp(&a.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
filtered.into_iter().take(count).collect()
}
#[allow(dead_code)]
pub fn select_by_strategy(
peers: &[CachedPeer],
count: usize,
strategy: SelectionStrategy,
) -> Vec<&CachedPeer> {
match strategy {
SelectionStrategy::BestFirst => {
let mut sorted: Vec<&CachedPeer> = peers.iter().collect();
sorted.sort_by(|a, b| {
b.quality_score
.partial_cmp(&a.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted.into_iter().take(count).collect()
}
SelectionStrategy::EpsilonGreedy { epsilon } => {
select_epsilon_greedy(peers, count, epsilon)
}
SelectionStrategy::Random => {
let mut rng = rand::thread_rng();
let mut indices: Vec<usize> = (0..peers.len()).collect();
for i in (1..indices.len()).rev() {
let j = rng.gen_range(0..=i);
indices.swap(i, j);
}
indices.into_iter().take(count).map(|i| &peers[i]).collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bootstrap_cache::entry::PeerSource;
use crate::nat_traversal_api::PeerId;
fn create_test_peers(count: usize) -> Vec<CachedPeer> {
(0..count)
.map(|i| {
let mut peer = CachedPeer::new(
PeerId([i as u8; 32]),
vec![format!("127.0.0.1:{}", 9000 + i).parse().unwrap()],
PeerSource::Seed,
);
peer.quality_score = i as f64 / count as f64;
peer
})
.collect()
}
#[test]
fn test_select_empty() {
let peers: Vec<CachedPeer> = vec![];
let selected = select_epsilon_greedy(&peers, 5, 0.1);
assert!(selected.is_empty());
}
#[test]
fn test_select_pure_exploitation() {
let peers = create_test_peers(10);
let selected = select_epsilon_greedy(&peers, 5, 0.0);
assert_eq!(selected.len(), 5);
for i in 0..4 {
assert!(selected[i].quality_score >= selected[i + 1].quality_score);
}
assert!((selected[0].quality_score - 0.9).abs() < 0.01);
}
#[test]
fn test_select_with_exploration() {
let peers = create_test_peers(20);
let mut has_variation = false;
let first_selection = select_epsilon_greedy(&peers, 10, 0.5);
for _ in 0..10 {
let selection = select_epsilon_greedy(&peers, 10, 0.5);
if selection.iter().map(|p| p.peer_id).collect::<Vec<_>>()
!= first_selection
.iter()
.map(|p| p.peer_id)
.collect::<Vec<_>>()
{
has_variation = true;
break;
}
}
assert!(has_variation, "Expected variation with epsilon=0.5");
}
#[test]
fn test_select_more_than_available() {
let peers = create_test_peers(3);
let selected = select_epsilon_greedy(&peers, 10, 0.1);
assert_eq!(selected.len(), 3); }
#[test]
fn test_select_with_capabilities() {
let mut peers = create_test_peers(10);
peers[0].capabilities.supports_relay = true;
peers[5].capabilities.supports_relay = true;
peers[9].capabilities.supports_relay = true;
let relays = select_with_capabilities(&peers, 5, true, false);
assert_eq!(relays.len(), 5);
let relay_count = relays
.iter()
.filter(|peer| peer.capabilities.supports_relay)
.count();
assert!(relay_count >= 3, "Expected relay peers to be prioritized");
}
#[test]
fn test_best_first_strategy() {
let peers = create_test_peers(10);
let selected = select_by_strategy(&peers, 5, SelectionStrategy::BestFirst);
assert_eq!(selected.len(), 5);
for i in 0..4 {
assert!(selected[i].quality_score >= selected[i + 1].quality_score);
}
}
#[test]
fn test_random_strategy() {
let peers = create_test_peers(20);
let mut has_variation = false;
let first_selection = select_by_strategy(&peers, 10, SelectionStrategy::Random);
for _ in 0..10 {
let selection = select_by_strategy(&peers, 10, SelectionStrategy::Random);
if selection.iter().map(|p| p.peer_id).collect::<Vec<_>>()
!= first_selection
.iter()
.map(|p| p.peer_id)
.collect::<Vec<_>>()
{
has_variation = true;
break;
}
}
assert!(has_variation, "Random selection should vary");
}
fn create_relay_peer_with_addresses(
id: u8,
quality: f64,
ipv4_addrs: Vec<&str>,
ipv6_addrs: Vec<&str>,
) -> CachedPeer {
let mut peer = CachedPeer::new(PeerId([id; 32]), vec![], PeerSource::Seed);
peer.quality_score = quality;
peer.capabilities.supports_relay = true;
for addr in ipv4_addrs {
peer.capabilities
.external_addresses
.push(addr.parse().unwrap());
}
for addr in ipv6_addrs {
peer.capabilities
.external_addresses
.push(addr.parse().unwrap());
}
peer
}
#[test]
fn test_select_relays_for_ipv4_target() {
let peers = vec![
create_relay_peer_with_addresses(1, 0.9, vec!["1.2.3.4:9000"], vec!["[::1]:9000"]),
create_relay_peer_with_addresses(2, 0.7, vec!["5.6.7.8:9001"], vec![]),
create_relay_peer_with_addresses(3, 0.95, vec![], vec!["[2001:db8::1]:9002"]),
];
let selected = select_relays_for_target(&peers, 10, true, false);
assert_eq!(selected.len(), 2);
let ids: Vec<u8> = selected.iter().map(|p| p.peer_id.0[0]).collect();
assert!(ids.contains(&1)); assert!(ids.contains(&2)); assert!(!ids.contains(&3)); }
#[test]
fn test_select_relays_for_ipv6_target() {
let peers = vec![
create_relay_peer_with_addresses(1, 0.9, vec!["1.2.3.4:9000"], vec!["[::1]:9000"]),
create_relay_peer_with_addresses(2, 0.95, vec!["5.6.7.8:9001"], vec![]),
create_relay_peer_with_addresses(3, 0.7, vec![], vec!["[2001:db8::1]:9002"]),
];
let selected = select_relays_for_target(&peers, 10, false, false);
assert_eq!(selected.len(), 2);
let ids: Vec<u8> = selected.iter().map(|p| p.peer_id.0[0]).collect();
assert!(ids.contains(&1)); assert!(!ids.contains(&2)); assert!(ids.contains(&3)); }
#[test]
fn test_select_relays_prefer_dual_stack() {
let peers = vec![
create_relay_peer_with_addresses(1, 0.5, vec!["1.2.3.4:9000"], vec!["[::1]:9000"]),
create_relay_peer_with_addresses(2, 0.9, vec!["5.6.7.8:9001"], vec![]),
];
let selected = select_relays_for_target(&peers, 10, true, false);
assert_eq!(selected[0].peer_id.0[0], 2);
let selected = select_relays_for_target(&peers, 10, true, true);
assert_eq!(selected[0].peer_id.0[0], 1); }
#[test]
fn test_select_dual_stack_relays() {
let peers = vec![
create_relay_peer_with_addresses(1, 0.9, vec!["1.2.3.4:9000"], vec!["[::1]:9000"]),
create_relay_peer_with_addresses(2, 0.8, vec!["5.6.7.8:9001"], vec![]),
create_relay_peer_with_addresses(3, 0.7, vec![], vec!["[2001:db8::1]:9002"]),
create_relay_peer_with_addresses(4, 0.6, vec!["10.0.0.1:9003"], vec!["[::2]:9003"]),
];
let selected = select_dual_stack_relays(&peers, 10);
assert_eq!(selected.len(), 2);
for peer in &selected {
assert!(peer.capabilities.supports_dual_stack());
}
assert!(selected[0].quality_score >= selected[1].quality_score);
}
#[test]
fn test_select_relays_excludes_non_relays() {
let mut peers = vec![create_relay_peer_with_addresses(
1,
0.9,
vec!["1.2.3.4:9000"],
vec![],
)];
let mut non_relay = CachedPeer::new(PeerId([2; 32]), vec![], PeerSource::Seed);
non_relay.quality_score = 0.99;
non_relay.capabilities.supports_relay = false;
non_relay
.capabilities
.external_addresses
.push("5.6.7.8:9001".parse().unwrap());
peers.push(non_relay);
let selected = select_relays_for_target(&peers, 10, true, false);
assert_eq!(selected.len(), 2);
assert_eq!(selected[0].peer_id.0[0], 1);
}
#[test]
fn test_select_relays_empty_when_no_match() {
let peers = vec![
create_relay_peer_with_addresses(1, 0.9, vec![], vec!["[::1]:9000"]),
];
let selected = select_relays_for_target(&peers, 10, true, false);
assert!(selected.is_empty());
}
}