use crate::error::{NetworkError, NetworkResult};
use crate::network::NetworkNode;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct BootstrapConfig {
pub max_retries: u32,
pub backoff_multiplier: f64,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl Default for BootstrapConfig {
fn default() -> Self {
Self {
max_retries: 3,
backoff_multiplier: 2.0,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
}
}
}
pub struct BootstrapConnector {
config: BootstrapConfig,
}
impl BootstrapConnector {
pub fn new() -> Self {
Self {
config: BootstrapConfig::default(),
}
}
pub fn with_config(config: BootstrapConfig) -> Self {
Self { config }
}
pub async fn connect_with_retry(
&self,
node: &NetworkNode,
addr: SocketAddr,
) -> NetworkResult<()> {
let mut backoff = self.config.initial_backoff;
let mut attempt = 0;
loop {
match node.connect_addr(addr).await {
Ok(_peer_id) => {
return Ok(());
}
Err(e) => {
attempt += 1;
if attempt >= self.config.max_retries {
return Err(NetworkError::ConnectionFailed(format!(
"Bootstrap connection failed after {} attempts: {}",
attempt, e
)));
}
sleep(backoff).await;
backoff = std::cmp::min(
Duration::from_secs_f64(
backoff.as_secs_f64() * self.config.backoff_multiplier,
),
self.config.max_backoff,
);
}
}
}
}
pub async fn connect_multiple(&self, node: &NetworkNode, addrs: &[SocketAddr]) -> usize {
let handles: Vec<_> = addrs
.iter()
.map(|&addr| {
let node_clone = node.clone();
let config = self.config.clone();
tokio::spawn(async move {
let connector = BootstrapConnector::with_config(config);
connector
.connect_with_retry(&node_clone, addr)
.await
.is_ok()
})
})
.collect();
futures::future::join_all(handles)
.await
.into_iter()
.filter(|r| matches!(r, Ok(true)))
.count()
}
}
impl Default for BootstrapConnector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bootstrap_config_default() {
let config = BootstrapConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.backoff_multiplier, 2.0);
assert_eq!(config.initial_backoff, Duration::from_millis(100));
assert_eq!(config.max_backoff, Duration::from_secs(5));
}
#[test]
fn test_bootstrap_config_custom() {
let config = BootstrapConfig {
max_retries: 5,
backoff_multiplier: 1.5,
initial_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(10),
};
assert_eq!(config.max_retries, 5);
assert_eq!(config.backoff_multiplier, 1.5);
}
#[test]
fn test_bootstrap_connector_new() {
let connector = BootstrapConnector::new();
assert_eq!(connector.config.max_retries, 3);
}
#[test]
fn test_bootstrap_connector_with_config() {
let config = BootstrapConfig {
max_retries: 2,
backoff_multiplier: 2.0,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
};
let connector = BootstrapConnector::with_config(config.clone());
assert_eq!(connector.config.max_retries, 2);
}
#[test]
fn test_bootstrap_connector_default() {
let connector = BootstrapConnector::default();
assert_eq!(connector.config.max_retries, 3);
}
#[test]
fn test_exponential_backoff_calculation() {
let config = BootstrapConfig {
max_retries: 3,
backoff_multiplier: 2.0,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
};
let mut backoff = config.initial_backoff;
assert_eq!(backoff, Duration::from_millis(100));
backoff = Duration::from_secs_f64(backoff.as_secs_f64() * config.backoff_multiplier);
assert_eq!(backoff, Duration::from_millis(200));
backoff = Duration::from_secs_f64(backoff.as_secs_f64() * config.backoff_multiplier);
assert_eq!(backoff, Duration::from_millis(400));
backoff = Duration::from_secs_f64(backoff.as_secs_f64() * config.backoff_multiplier);
assert_eq!(backoff, Duration::from_millis(800));
}
#[test]
fn test_max_backoff_clamping() {
let config = BootstrapConfig {
max_retries: 5,
backoff_multiplier: 2.0,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_secs(5),
};
let mut backoff = config.initial_backoff;
for _ in 0..5 {
backoff = std::cmp::min(
Duration::from_secs_f64(backoff.as_secs_f64() * config.backoff_multiplier),
config.max_backoff,
);
}
assert!(backoff <= config.max_backoff);
}
}