use super::config::BootstrapCacheConfig;
use super::entry::{CachedPeer, ConnectionOutcome, PeerCapabilities, PeerSource};
use super::persistence::{CacheData, CachePersistence};
use super::selection::select_epsilon_greedy;
use crate::nat_traversal_api::PeerId;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Instant, SystemTime};
use tokio::sync::{RwLock, broadcast};
use tracing::{debug, info, warn};
#[derive(Debug, Clone)]
pub enum CacheEvent {
Updated {
peer_count: usize,
},
Saved,
Merged {
added: usize,
},
Cleaned {
removed: usize,
},
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub total_peers: usize,
pub relay_peers: usize,
pub coordinator_peers: usize,
pub dual_stack_relay_peers: usize,
pub average_quality: f64,
pub untested_peers: usize,
}
#[derive(Debug)]
pub struct BootstrapCache {
config: BootstrapCacheConfig,
data: Arc<RwLock<CacheData>>,
persistence: CachePersistence,
event_tx: broadcast::Sender<CacheEvent>,
last_save: Arc<RwLock<Instant>>,
last_cleanup: Arc<RwLock<Instant>>,
}
impl BootstrapCache {
pub async fn open(config: BootstrapCacheConfig) -> std::io::Result<Self> {
let persistence = CachePersistence::new(&config.cache_dir, config.enable_file_locking)?;
let data = persistence.load()?;
let (event_tx, _) = broadcast::channel(256);
let now = Instant::now();
info!("Opened bootstrap cache with {} peers", data.peers.len());
Ok(Self {
config,
data: Arc::new(RwLock::new(data)),
persistence,
event_tx,
last_save: Arc::new(RwLock::new(now)),
last_cleanup: Arc::new(RwLock::new(now)),
})
}
pub fn subscribe(&self) -> broadcast::Receiver<CacheEvent> {
self.event_tx.subscribe()
}
pub async fn peer_count(&self) -> usize {
self.data.read().await.peers.len()
}
pub async fn get_peer(&self, peer_id: &PeerId) -> Option<CachedPeer> {
let mut data = self.data.write().await;
let peer = data.peers.get_mut(&peer_id.0)?;
peer.capabilities
.refresh_direct_capabilities(self.config.reachability_ttl, SystemTime::now());
peer.calculate_quality(&self.config.weights);
Some(peer.clone())
}
fn refresh_cached_peer(&self, peer: &mut CachedPeer, now: SystemTime) {
peer.capabilities
.refresh_direct_capabilities(self.config.reachability_ttl, now);
peer.calculate_quality(&self.config.weights);
}
pub async fn select_peers(&self, count: usize) -> Vec<CachedPeer> {
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
let peers: Vec<CachedPeer> = data.peers.values().cloned().collect();
drop(data);
select_epsilon_greedy(&peers, count, self.config.epsilon)
.into_iter()
.cloned()
.collect()
}
pub async fn select_relay_peers(&self, count: usize) -> Vec<CachedPeer> {
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
let peers: Vec<CachedPeer> = data.peers.values().cloned().collect();
drop(data);
super::selection::select_with_capabilities(&peers, count, true, false)
.into_iter()
.cloned()
.collect()
}
pub async fn select_coordinators(&self, count: usize) -> Vec<CachedPeer> {
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
let peers: Vec<CachedPeer> = data.peers.values().cloned().collect();
drop(data);
super::selection::select_with_capabilities(&peers, count, false, true)
.into_iter()
.cloned()
.collect()
}
pub async fn select_relays_for_target(
&self,
count: usize,
target: &std::net::SocketAddr,
prefer_dual_stack: bool,
) -> Vec<CachedPeer> {
use super::selection::select_relays_for_target;
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
let peers: Vec<CachedPeer> = data.peers.values().cloned().collect();
drop(data);
select_relays_for_target(&peers, count, *target, prefer_dual_stack)
.into_iter()
.cloned()
.collect()
}
pub async fn select_dual_stack_relays(&self, count: usize) -> Vec<CachedPeer> {
use super::selection::select_dual_stack_relays;
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
let peers: Vec<CachedPeer> = data.peers.values().cloned().collect();
drop(data);
select_dual_stack_relays(&peers, count)
.into_iter()
.cloned()
.collect()
}
pub async fn upsert(&self, peer: CachedPeer) {
let mut data = self.data.write().await;
if data.peers.len() >= self.config.max_peers && !data.peers.contains_key(&peer.peer_id.0) {
self.evict_lowest_quality(&mut data);
}
data.peers.insert(peer.peer_id.0, peer);
let count = data.peers.len();
drop(data);
let _ = self
.event_tx
.send(CacheEvent::Updated { peer_count: count });
}
pub async fn add_seed(&self, peer_id: PeerId, addresses: Vec<SocketAddr>) {
let peer = CachedPeer::new(peer_id, addresses, PeerSource::Seed);
self.upsert(peer).await;
}
pub async fn add_from_connection(
&self,
peer_id: PeerId,
addresses: Vec<SocketAddr>,
caps: Option<PeerCapabilities>,
) {
let mut peer = CachedPeer::new(peer_id, addresses, PeerSource::Connection);
if let Some(caps) = caps {
peer.capabilities = caps;
}
self.upsert(peer).await;
}
pub async fn record_outcome(&self, peer_id: &PeerId, outcome: ConnectionOutcome) {
let mut data = self.data.write().await;
if let Some(peer) = data.peers.get_mut(&peer_id.0) {
if outcome.success {
peer.record_success(
outcome.rtt_ms.unwrap_or(100),
outcome.capabilities_discovered,
);
} else {
peer.record_failure();
}
peer.calculate_quality(&self.config.weights);
}
}
pub async fn record_success(&self, peer_id: &PeerId, rtt_ms: u32) {
self.record_outcome(
peer_id,
ConnectionOutcome {
success: true,
rtt_ms: Some(rtt_ms),
capabilities_discovered: None,
},
)
.await;
}
pub async fn record_failure(&self, peer_id: &PeerId) {
self.record_outcome(
peer_id,
ConnectionOutcome {
success: false,
rtt_ms: None,
capabilities_discovered: None,
},
)
.await;
}
pub async fn update_capabilities(&self, peer_id: &PeerId, caps: PeerCapabilities) {
let mut data = self.data.write().await;
if let Some(peer) = data.peers.get_mut(&peer_id.0) {
peer.capabilities = caps;
peer.calculate_quality(&self.config.weights);
}
}
pub async fn observe_direct_reachability(&self, peer_id: PeerId, address: SocketAddr) {
let mut data = self.data.write().await;
let now = SystemTime::now();
let peer = data
.peers
.entry(peer_id.0)
.or_insert_with(|| CachedPeer::new(peer_id, vec![address], PeerSource::Connection));
if !peer.addresses.contains(&address) {
peer.addresses.push(address);
}
peer.last_seen = now;
peer.last_attempt = Some(now);
peer.stats.success_count = peer.stats.success_count.saturating_add(1);
peer.capabilities.record_direct_observation(address, now);
self.refresh_cached_peer(peer, now);
let count = data.peers.len();
drop(data);
let _ = self
.event_tx
.send(CacheEvent::Updated { peer_count: count });
}
pub async fn get(&self, peer_id: &PeerId) -> Option<CachedPeer> {
let mut data = self.data.write().await;
let peer = data.peers.get_mut(&peer_id.0)?;
self.refresh_cached_peer(peer, SystemTime::now());
Some(peer.clone())
}
pub async fn update_token(&self, peer_id: PeerId, token: Vec<u8>) {
let mut data = self.data.write().await;
if let Some(peer) = data.peers.get_mut(&peer_id.0) {
peer.token = Some(token);
}
}
pub async fn get_all_tokens(&self) -> std::collections::HashMap<PeerId, Vec<u8>> {
self.data
.read()
.await
.peers
.values()
.filter_map(|p| p.token.clone().map(|t| (p.peer_id, t)))
.collect()
}
pub async fn contains(&self, peer_id: &PeerId) -> bool {
self.data.read().await.peers.contains_key(&peer_id.0)
}
pub async fn remove(&self, peer_id: &PeerId) -> Option<CachedPeer> {
self.data.write().await.peers.remove(&peer_id.0)
}
pub async fn save(&self) -> std::io::Result<()> {
let mut data = self.data.write().await;
if data.peers.len() < self.config.min_peers_to_save {
debug!(
"Skipping save: only {} peers (min: {})",
data.peers.len(),
self.config.min_peers_to_save
);
return Ok(());
}
self.persistence.save(&mut data)?;
drop(data);
*self.last_save.write().await = Instant::now();
let _ = self.event_tx.send(CacheEvent::Saved);
Ok(())
}
pub async fn cleanup_stale(&self) -> usize {
let mut data = self.data.write().await;
let initial_count = data.peers.len();
data.peers
.retain(|_, peer| !peer.is_stale(self.config.stale_threshold));
let removed = initial_count - data.peers.len();
if removed > 0 {
info!("Cleaned up {} stale peers", removed);
let _ = self.event_tx.send(CacheEvent::Cleaned { removed });
}
drop(data);
*self.last_cleanup.write().await = Instant::now();
removed
}
pub async fn recalculate_quality(&self) {
let mut data = self.data.write().await;
for peer in data.peers.values_mut() {
peer.calculate_quality(&self.config.weights);
}
let count = data.peers.len();
let _ = self
.event_tx
.send(CacheEvent::Updated { peer_count: count });
}
pub async fn stats(&self) -> CacheStats {
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
let relay_count = data
.peers
.values()
.filter(|p| p.capabilities.supports_relay)
.count();
let coord_count = data
.peers
.values()
.filter(|p| p.capabilities.supports_coordination)
.count();
let dual_stack_count = data
.peers
.values()
.filter(|p| p.capabilities.supports_relay && p.capabilities.supports_dual_stack())
.count();
let untested = data
.peers
.values()
.filter(|p| p.stats.success_count + p.stats.failure_count == 0)
.count();
let avg_quality = if data.peers.is_empty() {
0.0
} else {
data.peers.values().map(|p| p.quality_score).sum::<f64>() / data.peers.len() as f64
};
CacheStats {
total_peers: data.peers.len(),
relay_peers: relay_count,
coordinator_peers: coord_count,
dual_stack_relay_peers: dual_stack_count,
average_quality: avg_quality,
untested_peers: untested,
}
}
pub fn start_maintenance(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let cache = self;
tokio::spawn(async move {
let mut save_interval = tokio::time::interval(cache.config.save_interval);
let mut cleanup_interval = tokio::time::interval(cache.config.cleanup_interval);
let mut quality_interval = tokio::time::interval(cache.config.quality_update_interval);
loop {
tokio::select! {
_ = save_interval.tick() => {
if let Err(e) = cache.save().await {
warn!("Failed to save cache: {}", e);
}
}
_ = cleanup_interval.tick() => {
cache.cleanup_stale().await;
}
_ = quality_interval.tick() => {
cache.recalculate_quality().await;
}
}
}
})
}
pub async fn all_peers(&self) -> Vec<CachedPeer> {
let mut data = self.data.write().await;
let now = SystemTime::now();
for peer in data.peers.values_mut() {
self.refresh_cached_peer(peer, now);
}
data.peers.values().cloned().collect()
}
pub fn config(&self) -> &BootstrapCacheConfig {
&self.config
}
fn evict_lowest_quality(&self, data: &mut CacheData) {
let evict_count = (self.config.max_peers / 20).max(1);
let mut sorted: Vec<_> = data.peers.iter().collect();
sorted.sort_by(|a, b| {
a.1.quality_score
.partial_cmp(&b.1.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let to_remove: Vec<[u8; 32]> = sorted
.into_iter()
.take(evict_count)
.map(|(id, _)| *id)
.collect();
for id in to_remove {
data.peers.remove(&id);
}
debug!("Evicted {} lowest quality peers", evict_count);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
async fn create_test_cache(temp_dir: &TempDir) -> BootstrapCache {
let config = BootstrapCacheConfig::builder()
.cache_dir(temp_dir.path())
.max_peers(100)
.epsilon(0.0) .min_peers_to_save(1)
.build();
BootstrapCache::open(config).await.unwrap()
}
#[tokio::test]
async fn test_cache_creation() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
assert_eq!(cache.peer_count().await, 0);
}
#[tokio::test]
async fn test_add_and_get() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
let peer_id = PeerId([1u8; 32]);
cache
.add_seed(peer_id, vec!["127.0.0.1:9000".parse().unwrap()])
.await;
assert_eq!(cache.peer_count().await, 1);
assert!(cache.contains(&peer_id).await);
let peer = cache.get(&peer_id).await.unwrap();
assert_eq!(peer.addresses.len(), 1);
}
#[tokio::test]
async fn test_select_peers() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
for i in 0..10usize {
let peer_id = PeerId([i as u8; 32]);
let mut peer = CachedPeer::new(
peer_id,
vec![format!("127.0.0.1:{}", 9000 + i).parse().unwrap()],
PeerSource::Seed,
);
peer.quality_score = i as f64 / 10.0;
cache.upsert(peer).await;
}
let selected = cache.select_peers(5).await;
assert_eq!(selected.len(), 5);
assert!(selected[0].quality_score >= selected[4].quality_score);
}
#[tokio::test]
async fn test_persistence() {
let temp_dir = TempDir::new().unwrap();
{
let cache = create_test_cache(&temp_dir).await;
cache
.add_seed(PeerId([1; 32]), vec!["127.0.0.1:9000".parse().unwrap()])
.await;
cache.save().await.unwrap();
}
{
let cache = create_test_cache(&temp_dir).await;
assert_eq!(cache.peer_count().await, 1);
assert!(cache.contains(&PeerId([1; 32])).await);
}
}
#[tokio::test]
async fn test_persisted_explicit_assist_hints_survive_reopen() {
let temp_dir = TempDir::new().unwrap();
let peer_id = PeerId([9; 32]);
let peer_addr: SocketAddr = "198.51.100.9:9000".parse().unwrap();
{
let cache = create_test_cache(&temp_dir).await;
let mut peer = CachedPeer::new(peer_id, vec![peer_addr], PeerSource::Merge);
peer.capabilities.record_assist_hints(true, true);
cache.upsert(peer).await;
cache.save().await.unwrap();
}
{
let cache = create_test_cache(&temp_dir).await;
let peer = cache.get(&peer_id).await.expect("peer should reload");
assert!(peer.capabilities.hinted_supports_relay);
assert!(peer.capabilities.hinted_supports_coordination);
assert!(peer.capabilities.supports_relay);
assert!(peer.capabilities.supports_coordination);
assert!(peer.addresses.contains(&peer_addr));
}
}
#[tokio::test]
async fn test_quality_scoring() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
let peer_id = PeerId([1; 32]);
cache
.add_seed(peer_id, vec!["127.0.0.1:9000".parse().unwrap()])
.await;
let peer = cache.get(&peer_id).await.unwrap();
let initial_quality = peer.quality_score;
for _ in 0..5 {
cache.record_success(&peer_id, 50).await;
}
let peer = cache.get(&peer_id).await.unwrap();
assert!(peer.quality_score > initial_quality);
assert!(peer.success_rate() > 0.9);
}
#[tokio::test]
async fn test_eviction() {
let temp_dir = TempDir::new().unwrap();
let config = BootstrapCacheConfig::builder()
.cache_dir(temp_dir.path())
.max_peers(10)
.build();
let cache = BootstrapCache::open(config).await.unwrap();
for i in 0..15u8 {
let peer_id = PeerId([i; 32]);
let mut peer = CachedPeer::new(
peer_id,
vec![format!("127.0.0.1:{}", 9000 + i as u16).parse().unwrap()],
PeerSource::Seed,
);
peer.quality_score = i as f64 / 15.0;
cache.upsert(peer).await;
}
assert!(cache.peer_count().await <= 10);
}
#[tokio::test]
async fn test_stats() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
let mut peer1 = CachedPeer::new(
PeerId([1; 32]),
vec!["203.0.113.1:9001".parse().unwrap()],
PeerSource::Seed,
);
peer1
.capabilities
.record_direct_observation("203.0.113.1:9001".parse().unwrap(), SystemTime::now());
cache.upsert(peer1).await;
let mut peer2 = CachedPeer::new(
PeerId([2; 32]),
vec!["198.51.100.2:9002".parse().unwrap()],
PeerSource::Seed,
);
peer2
.capabilities
.record_direct_observation("198.51.100.2:9002".parse().unwrap(), SystemTime::now());
cache.upsert(peer2).await;
cache
.add_seed(PeerId([3; 32]), vec!["127.0.0.1:9003".parse().unwrap()])
.await;
let stats = cache.stats().await;
assert_eq!(stats.total_peers, 3);
assert_eq!(stats.relay_peers, 2);
assert_eq!(stats.coordinator_peers, 2);
assert_eq!(stats.untested_peers, 3);
}
#[tokio::test]
async fn test_select_relay_peers() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
for i in 0..10u8 {
let addr: SocketAddr = format!("127.0.0.1:{}", 9000 + i as u16).parse().unwrap();
let mut peer = CachedPeer::new(PeerId([i; 32]), vec![addr], PeerSource::Seed);
if i % 2 == 0 {
peer.capabilities
.record_direct_observation(addr, SystemTime::now());
}
peer.quality_score = i as f64 / 10.0;
cache.upsert(peer).await;
}
let relays = cache.select_relay_peers(10).await;
assert_eq!(relays.len(), 10);
let relay_capable = relays
.iter()
.take(5)
.filter(|p| p.capabilities.direct_reachability_scope.is_some())
.count();
assert_eq!(
relay_capable, 5,
"Scoped direct-evidence peers should be first"
);
}
#[tokio::test]
async fn test_observe_direct_reachability_preserves_local_scope_without_global_promotion() {
let temp_dir = TempDir::new().unwrap();
let cache = create_test_cache(&temp_dir).await;
let peer_id = PeerId([9; 32]);
let addr: SocketAddr = "192.168.1.50:9000".parse().unwrap();
cache.observe_direct_reachability(peer_id, addr).await;
let peer = cache.get(&peer_id).await.expect("peer inserted");
assert!(!peer.capabilities.supports_relay);
assert!(!peer.capabilities.supports_coordination);
assert_eq!(
peer.capabilities.direct_reachability_scope,
Some(crate::reachability::ReachabilityScope::LocalNetwork)
);
assert!(peer.addresses.contains(&addr));
assert!(
peer.capabilities
.reachable_addresses
.iter()
.any(|entry| entry.address == addr)
);
assert!(peer.success_rate() > 0.0);
}
}