use std::net::SocketAddr;
use std::time::Duration;
use tracing::{debug, info};
use crate::rpc_codec::{PingRequest, RaftRpc};
use crate::transport::NexarTransport;
use super::config::ClusterConfig;
const MAX_PROBE_ATTEMPTS: u32 = 10;
const PROBE_INTERVAL: Duration = Duration::from_millis(300);
const PROBE_TIMEOUT: Duration = Duration::from_millis(200);
pub(super) async fn should_bootstrap(config: &ClusterConfig, transport: &NexarTransport) -> bool {
if config.force_bootstrap {
info!(
node_id = config.node_id,
listen_addr = %config.listen_addr,
"force_bootstrap flag set — bootstrapping unconditionally"
);
return true;
}
let designated = match designated_bootstrapper(&config.seed_nodes) {
Some(addr) => addr,
None => {
return true;
}
};
if designated == config.listen_addr {
info!(
node_id = config.node_id,
listen_addr = %config.listen_addr,
"this node is the designated bootstrapper"
);
return true;
}
info!(
node_id = config.node_id,
listen_addr = %config.listen_addr,
%designated,
"deferring to designated bootstrapper; probing for liveness"
);
for attempt in 0..MAX_PROBE_ATTEMPTS {
let probe_result = tokio::time::timeout(
PROBE_TIMEOUT,
ping_probe(designated, transport, config.node_id),
)
.await;
match probe_result {
Ok(Ok(())) => {
info!(
node_id = config.node_id,
%designated,
attempt,
"designated bootstrapper is up"
);
return false;
}
Ok(Err(e)) => {
debug!(
node_id = config.node_id,
%designated,
attempt,
error = %e,
"ping probe failed"
);
}
Err(_elapsed) => {
debug!(
node_id = config.node_id,
%designated,
attempt,
timeout_ms = PROBE_TIMEOUT.as_millis() as u64,
"ping probe timed out"
);
}
}
if attempt + 1 < MAX_PROBE_ATTEMPTS {
tokio::time::sleep(PROBE_INTERVAL).await;
}
}
info!(
node_id = config.node_id,
%designated,
"designated bootstrapper did not respond; proceeding to join loop"
);
false
}
pub(super) fn designated_bootstrapper(seed_nodes: &[SocketAddr]) -> Option<SocketAddr> {
seed_nodes.iter().min().copied()
}
async fn ping_probe(
addr: SocketAddr,
transport: &NexarTransport,
self_node_id: u64,
) -> crate::error::Result<()> {
let rpc = RaftRpc::Ping(PingRequest {
sender_id: self_node_id,
topology_version: 0,
});
match transport.send_rpc_to_addr(addr, rpc).await {
Ok(RaftRpc::Pong(_)) => Ok(()),
Ok(other) => Err(crate::error::ClusterError::Transport {
detail: format!("unexpected response to Ping from {addr}: {other:?}"),
}),
Err(e) => Err(e),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn addr(s: &str) -> SocketAddr {
s.parse().unwrap()
}
fn cfg_with_seeds(node_id: u64, listen: &str, seeds: &[&str]) -> ClusterConfig {
ClusterConfig {
node_id,
listen_addr: addr(listen),
seed_nodes: seeds.iter().map(|s| addr(s)).collect(),
num_groups: 2,
replication_factor: 1,
data_dir: std::env::temp_dir(),
force_bootstrap: false,
join_retry: Default::default(),
}
}
#[test]
fn designated_bootstrapper_picks_smallest() {
let seeds = vec![
addr("10.0.0.3:9400"),
addr("10.0.0.1:9400"),
addr("10.0.0.2:9400"),
];
assert_eq!(designated_bootstrapper(&seeds), Some(addr("10.0.0.1:9400")));
}
#[test]
fn designated_bootstrapper_empty_is_none() {
assert!(designated_bootstrapper(&[]).is_none());
}
#[test]
fn designated_bootstrapper_tie_break_by_port() {
let seeds = vec![addr("10.0.0.1:9401"), addr("10.0.0.1:9400")];
assert_eq!(designated_bootstrapper(&seeds), Some(addr("10.0.0.1:9400")));
}
#[tokio::test]
async fn should_bootstrap_when_self_is_lowest_seed() {
let cfg = cfg_with_seeds(
1,
"10.0.0.1:9400",
&["10.0.0.1:9400", "10.0.0.2:9400", "10.0.0.3:9400"],
);
let transport = Arc::new(NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap());
assert!(should_bootstrap(&cfg, &transport).await);
}
#[tokio::test]
async fn force_bootstrap_overrides_rule() {
let mut cfg = cfg_with_seeds(
3,
"10.0.0.3:9400",
&["10.0.0.1:9400", "10.0.0.2:9400", "10.0.0.3:9400"],
);
cfg.force_bootstrap = true;
let transport = Arc::new(NexarTransport::new(3, "127.0.0.1:0".parse().unwrap()).unwrap());
assert!(should_bootstrap(&cfg, &transport).await);
}
#[tokio::test]
async fn should_bootstrap_false_when_designated_unreachable() {
let cfg = cfg_with_seeds(2, "127.0.0.1:9400", &["127.0.0.1:1", "127.0.0.1:9400"]);
let transport = Arc::new(NexarTransport::new(2, "127.0.0.1:0".parse().unwrap()).unwrap());
let result =
tokio::time::timeout(Duration::from_secs(8), should_bootstrap(&cfg, &transport))
.await
.expect("should_bootstrap should not hang");
assert!(!result);
}
}