use std::collections::HashSet;
use std::net::SocketAddr;
use tracing::{debug, info, warn};
use crate::catalog::ClusterCatalog;
use crate::error::{ClusterError, Result};
use crate::lifecycle_state::ClusterLifecycleTracker;
use crate::multi_raft::MultiRaft;
use crate::routing::{GroupInfo, RoutingTable};
use crate::rpc_codec::{JoinRequest, JoinResponse, LEADER_REDIRECT_PREFIX, RaftRpc};
use crate::topology::{ClusterTopology, NodeInfo, NodeState};
use crate::transport::NexarTransport;
use super::config::{ClusterConfig, ClusterState};
const MAX_REDIRECTS_PER_ATTEMPT: u32 = 3;
pub(crate) fn parse_leader_hint(error: &str) -> Option<SocketAddr> {
error
.strip_prefix(LEADER_REDIRECT_PREFIX)
.and_then(|s| s.trim().parse().ok())
}
pub(super) async fn join(
config: &ClusterConfig,
catalog: &ClusterCatalog,
transport: &NexarTransport,
lifecycle: &ClusterLifecycleTracker,
) -> Result<ClusterState> {
info!(
node_id = config.node_id,
seeds = ?config.seed_nodes,
"joining existing cluster"
);
if config.seed_nodes.is_empty() {
let err = ClusterError::Transport {
detail: "no seed nodes configured".into(),
};
lifecycle.to_failed(err.to_string());
return Err(err);
}
let req_template = JoinRequest {
node_id: config.node_id,
listen_addr: config.listen_addr.to_string(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
};
let policy = config.join_retry;
let mut last_err: Option<ClusterError> = None;
for attempt in 0..policy.max_attempts {
lifecycle.to_joining(attempt);
let delay = policy.backoff_for(attempt);
if !delay.is_zero() {
debug!(
node_id = config.node_id,
attempt,
delay_ms = delay.as_millis() as u64,
"backing off before next join attempt"
);
tokio::time::sleep(delay).await;
}
match try_join_once(config, catalog, transport, &req_template).await {
Ok(state) => return Ok(state),
Err(e) => {
warn!(
node_id = config.node_id,
attempt,
error = %e,
"join attempt failed; will retry"
);
last_err = Some(e);
}
}
}
let max_attempts = policy.max_attempts;
let err = last_err.unwrap_or_else(|| ClusterError::Transport {
detail: format!("join exhausted {max_attempts} attempts with no concrete error"),
});
lifecycle.to_failed(err.to_string());
Err(err)
}
async fn try_join_once(
config: &ClusterConfig,
catalog: &ClusterCatalog,
transport: &NexarTransport,
req_template: &JoinRequest,
) -> Result<ClusterState> {
let mut work: std::collections::VecDeque<SocketAddr> =
config.seed_nodes.iter().copied().collect();
{
let mut sorted: Vec<SocketAddr> = work.drain(..).collect();
sorted.sort();
work.extend(sorted);
}
let mut visited: HashSet<SocketAddr> = HashSet::new();
let mut redirects: u32 = 0;
let mut last_err: Option<ClusterError> = None;
while let Some(addr) = work.pop_front() {
if !visited.insert(addr) {
continue;
}
let rpc = RaftRpc::JoinRequest(req_template.clone());
match transport.send_rpc_to_addr(addr, rpc).await {
Ok(RaftRpc::JoinResponse(resp)) => {
if resp.success {
return apply_join_response(config, catalog, transport, &resp);
}
if let Some(leader) = parse_leader_hint(&resp.error) {
if redirects < MAX_REDIRECTS_PER_ATTEMPT && !visited.contains(&leader) {
info!(
node_id = config.node_id,
from = %addr,
to = %leader,
"following leader redirect"
);
redirects += 1;
work.push_front(leader);
continue;
}
debug!(
node_id = config.node_id,
from = %addr,
leader = %leader,
redirects,
"redirect cap reached or loop detected; falling through"
);
}
last_err = Some(ClusterError::Transport {
detail: format!("join rejected by {addr}: {}", resp.error),
});
}
Ok(other) => {
last_err = Some(ClusterError::Transport {
detail: format!("unexpected response from {addr}: {other:?}"),
});
}
Err(e) => {
debug!(%addr, error = %e, "seed unreachable");
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| ClusterError::Transport {
detail: "no seed nodes produced a response".into(),
}))
}
fn apply_join_response(
config: &ClusterConfig,
catalog: &ClusterCatalog,
transport: &NexarTransport,
resp: &JoinResponse,
) -> Result<ClusterState> {
let mut topology = ClusterTopology::new();
for node in &resp.nodes {
let state = NodeState::from_u8(node.state).unwrap_or(NodeState::Active);
let mut info = NodeInfo {
node_id: node.node_id,
addr: node.addr.clone(),
state,
raft_groups: node.raft_groups.clone(),
wire_version: node.wire_version,
};
if node.node_id == config.node_id {
info.state = NodeState::Active;
}
topology.add_node(info);
}
let mut group_members = std::collections::HashMap::new();
for g in &resp.groups {
group_members.insert(
g.group_id,
GroupInfo {
leader: g.leader,
members: g.members.clone(),
learners: g.learners.clone(),
},
);
}
let routing = RoutingTable::from_parts(resp.vshard_to_group.clone(), group_members);
catalog.save_cluster_id(resp.cluster_id)?;
catalog.save_topology(&topology)?;
catalog.save_routing(&routing)?;
let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
for g in &resp.groups {
let is_voter = g.members.contains(&config.node_id);
let is_learner = g.learners.contains(&config.node_id);
if is_voter {
let peers: Vec<u64> = g
.members
.iter()
.copied()
.filter(|&id| id != config.node_id)
.collect();
multi_raft.add_group(g.group_id, peers)?;
} else if is_learner {
let voters = g.members.clone();
let other_learners: Vec<u64> = g
.learners
.iter()
.copied()
.filter(|&id| id != config.node_id)
.collect();
multi_raft.add_group_as_learner(g.group_id, voters, other_learners)?;
}
}
for node in &resp.nodes {
if node.node_id != config.node_id
&& let Ok(addr) = node.addr.parse::<SocketAddr>()
{
transport.register_peer(node.node_id, addr);
}
}
info!(
node_id = config.node_id,
nodes = topology.node_count(),
groups = routing.num_groups(),
"joined cluster"
);
Ok(ClusterState {
topology,
routing,
multi_raft,
})
}
#[cfg(test)]
mod tests {
use super::super::bootstrap_fn::bootstrap;
use super::super::config::JoinRetryPolicy;
use super::super::handle_join::handle_join_request;
use super::*;
use std::sync::{Arc, Mutex};
use std::time::Duration;
fn temp_catalog() -> (tempfile::TempDir, ClusterCatalog) {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("cluster.redb");
let catalog = ClusterCatalog::open(&path).unwrap();
(dir, catalog)
}
#[test]
fn parse_leader_hint_extracts_valid_addr() {
assert_eq!(
parse_leader_hint("not leader; retry at 10.0.0.1:9400"),
Some("10.0.0.1:9400".parse().unwrap())
);
assert_eq!(
parse_leader_hint("not leader; retry at 127.0.0.1:65535"),
Some("127.0.0.1:65535".parse().unwrap())
);
}
#[test]
fn parse_leader_hint_rejects_unrelated_error() {
assert_eq!(
parse_leader_hint("node_id 2 already registered with different address 10.0.0.2:9400"),
None
);
assert_eq!(parse_leader_hint(""), None);
assert_eq!(
parse_leader_hint("conf change commit timeout on group 0"),
None
);
}
#[test]
fn parse_leader_hint_rejects_malformed_addr() {
assert_eq!(parse_leader_hint("not leader; retry at notanaddress"), None);
assert_eq!(parse_leader_hint("not leader; retry at "), None);
assert_eq!(parse_leader_hint("not leader; retry at 10.0.0.1"), None);
}
#[test]
fn join_retry_policy_default_schedule() {
let policy = JoinRetryPolicy::default();
assert_eq!(policy.backoff_for(0), Duration::ZERO);
assert_eq!(policy.backoff_for(1), Duration::from_millis(250));
assert_eq!(policy.backoff_for(2), Duration::from_millis(500));
assert_eq!(policy.backoff_for(3), Duration::from_secs(1));
assert_eq!(policy.backoff_for(4), Duration::from_secs(2));
assert_eq!(policy.backoff_for(5), Duration::from_secs(4));
assert_eq!(policy.backoff_for(6), Duration::from_secs(8));
assert_eq!(policy.backoff_for(7), Duration::from_secs(16));
assert_eq!(policy.backoff_for(8), Duration::from_secs(32));
assert_eq!(policy.backoff_for(9), Duration::ZERO);
}
#[test]
fn join_retry_policy_test_schedule_is_subsecond() {
let policy = JoinRetryPolicy {
max_attempts: 8,
max_backoff_secs: 2,
};
let total: Duration = (0..=policy.max_attempts)
.map(|a| policy.backoff_for(a))
.sum();
assert!(
total < Duration::from_secs(5),
"test schedule too slow: {total:?}"
);
assert_eq!(policy.backoff_for(8), Duration::from_secs(2));
}
#[tokio::test]
async fn full_bootstrap_join_flow() {
let t1 = Arc::new(NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap());
let t2 = Arc::new(NexarTransport::new(2, "127.0.0.1:0".parse().unwrap()).unwrap());
let (_dir1, catalog1) = temp_catalog();
let (_dir2, catalog2) = temp_catalog();
let addr1 = t1.local_addr();
let addr2 = t2.local_addr();
let config1 = ClusterConfig {
node_id: 1,
listen_addr: addr1,
seed_nodes: vec![addr1],
num_groups: 2,
replication_factor: 1,
data_dir: _dir1.path().to_path_buf(),
force_bootstrap: false,
join_retry: Default::default(),
swim_udp_addr: None,
};
let state1 = bootstrap(&config1, &catalog1).unwrap();
let topology1 = Arc::new(Mutex::new(state1.topology));
let routing1 = Arc::new(state1.routing);
struct JoinHandler {
topology: Arc<Mutex<ClusterTopology>>,
routing: Arc<RoutingTable>,
}
impl crate::transport::RaftRpcHandler for JoinHandler {
async fn handle_rpc(&self, rpc: RaftRpc) -> Result<RaftRpc> {
match rpc {
RaftRpc::JoinRequest(req) => {
let mut topo = self.topology.lock().unwrap();
let resp = handle_join_request(&req, &mut topo, &self.routing, 99);
Ok(RaftRpc::JoinResponse(resp))
}
other => Err(ClusterError::Transport {
detail: format!("unexpected: {other:?}"),
}),
}
}
}
let handler = Arc::new(JoinHandler {
topology: topology1.clone(),
routing: routing1.clone(),
});
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
let t1c = t1.clone();
tokio::spawn(async move {
t1c.serve(handler, shutdown_rx).await.unwrap();
});
tokio::time::sleep(Duration::from_millis(30)).await;
let config2 = ClusterConfig {
node_id: 2,
listen_addr: addr2,
seed_nodes: vec![addr1],
num_groups: 2,
replication_factor: 1,
data_dir: _dir2.path().to_path_buf(),
force_bootstrap: false,
join_retry: Default::default(),
swim_udp_addr: None,
};
let lifecycle = ClusterLifecycleTracker::new();
let state2 = join(&config2, &catalog2, &t2, &lifecycle).await.unwrap();
assert!(matches!(
lifecycle.current(),
crate::lifecycle_state::ClusterLifecycleState::Joining { .. }
));
assert_eq!(state2.topology.node_count(), 2);
assert_eq!(state2.routing.num_groups(), 2);
assert!(catalog2.load_topology().unwrap().is_some());
assert!(catalog2.load_routing().unwrap().is_some());
let topo1 = topology1.lock().unwrap();
assert_eq!(topo1.node_count(), 2);
assert!(topo1.contains(2));
shutdown_tx.send(true).unwrap();
}
}