use crate::error::BootstrapError;
use crate::network::DHTConfig;
use crate::rate_limit::{JoinRateLimiter, JoinRateLimiterConfig};
use crate::security::{BootstrapIpLimiter, IPDiversityConfig};
use crate::{P2PError, Result};
use parking_lot::Mutex;
use saorsa_transport::bootstrap_cache::{
BootstrapCache as AntBootstrapCache, BootstrapCacheConfig, CachedPeer, PeerCapabilities,
};
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::task::JoinHandle;
use tracing::{info, warn};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BootstrapConfig {
pub cache_dir: PathBuf,
pub max_peers: usize,
pub epsilon: f64,
pub rate_limit: JoinRateLimiterConfig,
pub diversity: IPDiversityConfig,
}
impl Default for BootstrapConfig {
fn default() -> Self {
Self {
cache_dir: default_cache_dir(),
max_peers: 20_000,
epsilon: 0.1,
rate_limit: JoinRateLimiterConfig::default(),
diversity: IPDiversityConfig::default(),
}
}
}
pub struct BootstrapManager {
cache: Arc<AntBootstrapCache>,
rate_limiter: JoinRateLimiter,
ip_limiter: Mutex<BootstrapIpLimiter>,
diversity_config: IPDiversityConfig,
maintenance_handle: Option<JoinHandle<()>>,
}
impl BootstrapManager {
async fn with_config_loopback_and_k(
config: BootstrapConfig,
allow_loopback: bool,
k_value: usize,
) -> Result<Self> {
let ant_config = BootstrapCacheConfig::builder()
.cache_dir(&config.cache_dir)
.max_peers(config.max_peers)
.epsilon(config.epsilon)
.build();
let cache = AntBootstrapCache::open(ant_config).await.map_err(|e| {
P2PError::Bootstrap(BootstrapError::CacheError(
format!("Failed to open bootstrap cache: {e}").into(),
))
})?;
Ok(Self {
cache: Arc::new(cache),
rate_limiter: JoinRateLimiter::new(config.rate_limit),
ip_limiter: Mutex::new(BootstrapIpLimiter::with_loopback_and_k(
config.diversity.clone(),
allow_loopback,
k_value,
)),
diversity_config: config.diversity,
maintenance_handle: None,
})
}
pub async fn new() -> Result<Self> {
Self::with_config(BootstrapConfig::default()).await
}
pub async fn with_config(config: BootstrapConfig) -> Result<Self> {
Self::with_config_loopback_and_k(config, false, DHTConfig::DEFAULT_K_VALUE).await
}
pub async fn with_node_config(
mut config: BootstrapConfig,
node_config: &crate::network::NodeConfig,
) -> Result<Self> {
if let Some(ref diversity) = node_config.diversity_config {
config.diversity = diversity.clone();
}
Self::with_config_loopback_and_k(
config,
node_config.allow_loopback,
node_config.dht_config.k_value,
)
.await
}
pub fn start_maintenance(&mut self) -> Result<()> {
if self.maintenance_handle.is_some() {
return Ok(()); }
let handle = self.cache.clone().start_maintenance();
self.maintenance_handle = Some(handle);
info!("Started bootstrap cache maintenance tasks");
Ok(())
}
pub async fn add_peer(&self, addr: &SocketAddr, addresses: Vec<SocketAddr>) -> Result<()> {
if addresses.is_empty() {
return Err(P2PError::Bootstrap(BootstrapError::InvalidData(
"No addresses provided".to_string().into(),
)));
}
let ip = addr.ip();
self.rate_limiter.check_join_allowed(&ip).map_err(|e| {
warn!("Rate limit exceeded for {}: {}", ip, e);
P2PError::Bootstrap(BootstrapError::RateLimited(e.to_string().into()))
})?;
{
let mut diversity = self.ip_limiter.lock();
if !diversity.can_accept(ip) {
warn!("IP diversity limit exceeded for {}", ip);
return Err(P2PError::Bootstrap(BootstrapError::RateLimited(
"IP diversity limits exceeded".to_string().into(),
)));
}
if let Err(e) = diversity.track(ip) {
warn!("Failed to track IP diversity for {}: {}", ip, e);
}
}
self.cache.add_seed(*addr, addresses).await;
Ok(())
}
pub async fn add_peer_trusted(&self, addr: &SocketAddr, addresses: Vec<SocketAddr>) {
self.cache.add_seed(*addr, addresses).await;
}
pub async fn record_success(&self, addr: &SocketAddr, rtt_ms: u32) {
self.cache.record_success(addr, rtt_ms).await;
}
pub async fn record_failure(&self, addr: &SocketAddr) {
self.cache.record_failure(addr).await;
}
pub async fn select_peers(&self, count: usize) -> Vec<CachedPeer> {
self.cache.select_peers(count).await
}
pub async fn select_relay_peers(&self, count: usize) -> Vec<CachedPeer> {
self.cache.select_relay_peers(count).await
}
pub async fn select_coordinators(&self, count: usize) -> Vec<CachedPeer> {
self.cache.select_coordinators(count).await
}
pub async fn stats(&self) -> BootstrapStats {
let ant_stats = self.cache.stats().await;
BootstrapStats {
total_peers: ant_stats.total_peers,
relay_peers: ant_stats.relay_peers,
coordinator_peers: ant_stats.coordinator_peers,
average_quality: ant_stats.average_quality,
untested_peers: ant_stats.untested_peers,
}
}
pub async fn peer_count(&self) -> usize {
self.cache.peer_count().await
}
pub async fn save(&self) -> Result<()> {
self.cache.save().await.map_err(|e| {
P2PError::Bootstrap(BootstrapError::CacheError(
format!("Failed to save cache: {e}").into(),
))
})
}
pub async fn update_capabilities(&self, addr: &SocketAddr, capabilities: PeerCapabilities) {
self.cache.update_capabilities(addr, capabilities).await;
}
pub async fn contains(&self, addr: &SocketAddr) -> bool {
self.cache.contains(addr).await
}
pub async fn get_peer(&self, addr: &SocketAddr) -> Option<CachedPeer> {
self.cache.get(addr).await
}
pub fn diversity_config(&self) -> &IPDiversityConfig {
&self.diversity_config
}
}
#[derive(Debug, Clone, Default)]
pub struct BootstrapStats {
pub total_peers: usize,
pub relay_peers: usize,
pub coordinator_peers: usize,
pub average_quality: f64,
pub untested_peers: usize,
}
fn default_cache_dir() -> PathBuf {
if let Some(cache_dir) = dirs::cache_dir() {
cache_dir.join("saorsa").join("bootstrap")
} else if let Some(home) = dirs::home_dir() {
home.join(".cache").join("saorsa").join("bootstrap")
} else {
PathBuf::from(".saorsa-bootstrap-cache")
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_config(temp_dir: &TempDir) -> BootstrapConfig {
BootstrapConfig {
cache_dir: temp_dir.path().to_path_buf(),
max_peers: 100,
epsilon: 0.0, rate_limit: JoinRateLimiterConfig::default(),
diversity: IPDiversityConfig::default(),
}
}
#[tokio::test]
async fn test_manager_creation() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await;
assert!(manager.is_ok());
let manager = manager.unwrap();
assert_eq!(manager.peer_count().await, 0);
}
#[tokio::test]
async fn test_add_and_get_peer() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
let addr: SocketAddr = "10.0.0.1:9000".parse().unwrap();
let result = manager.add_peer(&addr, vec![addr]).await;
assert!(result.is_ok());
assert_eq!(manager.peer_count().await, 1);
assert!(manager.contains(&addr).await);
}
#[tokio::test]
async fn test_add_peer_no_addresses_fails() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
let addr: SocketAddr = "10.0.0.1:9000".parse().unwrap();
let result = manager.add_peer(&addr, vec![]).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
P2PError::Bootstrap(BootstrapError::InvalidData(_))
));
}
#[tokio::test]
async fn test_add_trusted_peer_bypasses_checks() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
manager.add_peer_trusted(&addr, vec![addr]).await;
assert_eq!(manager.peer_count().await, 1);
assert!(manager.contains(&addr).await);
}
#[tokio::test]
async fn test_record_success_updates_quality() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
manager.add_peer_trusted(&addr, vec![addr]).await;
let initial_peer = manager.get_peer(&addr).await.unwrap();
let initial_quality = initial_peer.quality_score;
for _ in 0..5 {
manager.record_success(&addr, 50).await;
}
let updated_peer = manager.get_peer(&addr).await.unwrap();
assert!(
updated_peer.quality_score >= initial_quality,
"Quality should improve after successes"
);
}
#[tokio::test]
async fn test_record_failure_decreases_quality() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
manager.add_peer_trusted(&addr, vec![addr]).await;
for _ in 0..3 {
manager.record_success(&addr, 50).await;
}
let good_peer = manager.get_peer(&addr).await.unwrap();
let good_quality = good_peer.quality_score;
for _ in 0..5 {
manager.record_failure(&addr).await;
}
let bad_peer = manager.get_peer(&addr).await.unwrap();
assert!(
bad_peer.quality_score < good_quality,
"Quality should decrease after failures"
);
}
#[tokio::test]
async fn test_select_peers_returns_best() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
for i in 0..10 {
let addr: SocketAddr = format!("127.0.0.1:{}", 9000 + i).parse().unwrap();
manager.add_peer_trusted(&addr, vec![addr]).await;
for _ in 0..i {
manager.record_success(&addr, 50).await;
}
}
let selected = manager.select_peers(5).await;
assert_eq!(selected.len(), 5);
for i in 0..4 {
assert!(
selected[i].quality_score >= selected[i + 1].quality_score,
"Peers should be sorted by quality"
);
}
}
#[tokio::test]
async fn test_stats() {
let temp_dir = TempDir::new().unwrap();
let config = test_config(&temp_dir);
let manager = BootstrapManager::with_config(config).await.unwrap();
for i in 0..5 {
let addr: SocketAddr = format!("127.0.0.1:{}", 9000 + i).parse().unwrap();
manager.add_peer_trusted(&addr, vec![addr]).await;
}
let stats = manager.stats().await;
assert_eq!(stats.total_peers, 5);
assert_eq!(stats.untested_peers, 5); }
#[tokio::test]
async fn test_persistence() {
let temp_dir = TempDir::new().unwrap();
let cache_path = temp_dir.path().to_path_buf();
{
let config = BootstrapConfig {
cache_dir: cache_path.clone(),
max_peers: 100,
epsilon: 0.0,
rate_limit: JoinRateLimiterConfig::default(),
diversity: IPDiversityConfig::default(),
};
let manager = BootstrapManager::with_config(config).await.unwrap();
let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
manager.add_peer_trusted(&addr, vec![addr]).await;
let count_before = manager.peer_count().await;
assert_eq!(count_before, 1, "Peer should be in cache before save");
manager.save().await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
{
let config = BootstrapConfig {
cache_dir: cache_path,
max_peers: 100,
epsilon: 0.0,
rate_limit: JoinRateLimiterConfig::default(),
diversity: IPDiversityConfig::default(),
};
let manager = BootstrapManager::with_config(config).await.unwrap();
let count = manager.peer_count().await;
if count == 0 {
eprintln!(
"Note: saorsa-transport BootstrapCache may have different persistence behavior"
);
}
}
}
#[tokio::test]
async fn test_rate_limiting() {
let temp_dir = TempDir::new().unwrap();
let diversity_config = IPDiversityConfig {
max_per_ip: Some(usize::MAX),
max_per_subnet: Some(usize::MAX),
};
let config = BootstrapConfig {
cache_dir: temp_dir.path().to_path_buf(),
max_peers: 100,
epsilon: 0.0,
rate_limit: JoinRateLimiterConfig {
max_joins_per_64_per_hour: 100, max_joins_per_48_per_hour: 100, max_joins_per_24_per_hour: 2, max_global_joins_per_minute: 100,
global_burst_size: 10,
},
diversity: diversity_config,
};
let manager = BootstrapManager::with_config(config).await.unwrap();
for i in 0..2 {
let addr: SocketAddr = format!("192.168.1.{}:{}", 10 + i, 9000 + i)
.parse()
.unwrap();
let result = manager.add_peer(&addr, vec![addr]).await;
assert!(
result.is_ok(),
"First 2 peers should be allowed: {:?}",
result
);
}
let addr: SocketAddr = "192.168.1.100:9100".parse().unwrap();
let result = manager.add_peer(&addr, vec![addr]).await;
assert!(result.is_err(), "Third peer should be rate limited");
assert!(matches!(
result.unwrap_err(),
P2PError::Bootstrap(BootstrapError::RateLimited(_))
));
}
}