use std::net::SocketAddr;
use tracing::warn;
use crate::routing::RoutingTable;
use crate::rpc_codec::{JoinGroupInfo, JoinNodeInfo, JoinRequest, JoinResponse};
use crate::topology::{CLUSTER_WIRE_FORMAT_VERSION, ClusterTopology, NodeInfo, NodeState};
pub fn handle_join_request(
req: &JoinRequest,
topology: &mut ClusterTopology,
routing: &RoutingTable,
cluster_id: u64,
) -> JoinResponse {
if req.wire_version != CLUSTER_WIRE_FORMAT_VERSION {
warn!(
node_id = req.node_id,
joiner_wire_version = req.wire_version,
expected_wire_version = CLUSTER_WIRE_FORMAT_VERSION,
"join request rejected: joiner cluster wire_version mismatch"
);
return reject(format!(
"joiner wire_version {} does not match this cluster's wire_version {} — \
rolling upgrade is required before this node can join",
req.wire_version, CLUSTER_WIRE_FORMAT_VERSION
));
}
let addr: SocketAddr = match req.listen_addr.parse() {
Ok(a) => a,
Err(e) => {
return reject(format!("invalid listen_addr '{}': {e}", req.listen_addr));
}
};
if let Some(existing) = topology.get_node(req.node_id) {
let existing_addr = existing.addr.clone();
if existing_addr != req.listen_addr {
return reject(format!(
"node_id {} already registered with different address {} (request: {})",
req.node_id, existing_addr, req.listen_addr
));
}
if existing.state != NodeState::Active
&& let Some(entry) = topology.get_node_mut(req.node_id)
{
entry.state = NodeState::Active;
}
return build_response(topology, routing, cluster_id);
}
let spki_pin: Option<[u8; 32]> = req.spki_pin.as_deref().and_then(|b| {
if b.len() == 32 {
let mut arr = [0u8; 32];
arr.copy_from_slice(b);
Some(arr)
} else {
None
}
});
topology.add_node(
NodeInfo::new(req.node_id, addr, NodeState::Active)
.with_wire_version(req.wire_version)
.with_spiffe_id(req.spiffe_id.clone())
.with_spki_pin(spki_pin),
);
build_response(topology, routing, cluster_id)
}
fn build_response(
topology: &ClusterTopology,
routing: &RoutingTable,
cluster_id: u64,
) -> JoinResponse {
let nodes: Vec<JoinNodeInfo> = topology
.all_nodes()
.map(|n| JoinNodeInfo {
node_id: n.node_id,
addr: n.addr.clone(),
state: n.state.as_u8(),
raft_groups: n.raft_groups.clone(),
wire_version: n.wire_version,
spiffe_id: n.spiffe_id.clone(),
spki_pin: n.spki_pin.map(|arr| arr.to_vec()),
})
.collect();
let groups: Vec<JoinGroupInfo> = routing
.group_members()
.iter()
.map(|(&gid, info)| JoinGroupInfo {
group_id: gid,
leader: info.leader,
members: info.members.clone(),
learners: info.learners.clone(),
})
.collect();
JoinResponse {
success: true,
error: String::new(),
cluster_id,
nodes,
vshard_to_group: routing.vshard_to_group().to_vec(),
groups,
}
}
fn reject(error: String) -> JoinResponse {
JoinResponse {
success: false,
error,
cluster_id: 0,
nodes: vec![],
vshard_to_group: vec![],
groups: vec![],
}
}
#[cfg(test)]
mod tests {
use super::*;
fn topo_with_one_node() -> ClusterTopology {
let mut topology = ClusterTopology::new();
topology.add_node(NodeInfo::new(
1,
"10.0.0.1:9400".parse().unwrap(),
NodeState::Active,
));
topology
}
#[test]
fn handle_join_request_adds_node() {
let mut topology = topo_with_one_node();
let routing = RoutingTable::uniform(2, &[1], 1);
let req = JoinRequest {
node_id: 2,
listen_addr: "10.0.0.2:9400".into(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
spiffe_id: None,
spki_pin: None,
};
let resp = handle_join_request(&req, &mut topology, &routing, 42);
assert!(resp.success);
assert_eq!(resp.nodes.len(), 2);
assert_eq!(resp.vshard_to_group.len(), 1024);
assert_eq!(resp.groups.len(), 3);
assert!(topology.contains(2));
assert_eq!(topology.node_count(), 2);
}
#[test]
fn handle_join_request_idempotent() {
let mut topology = topo_with_one_node();
let routing = RoutingTable::uniform(1, &[1], 1);
let req = JoinRequest {
node_id: 2,
listen_addr: "10.0.0.2:9400".into(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
spiffe_id: None,
spki_pin: None,
};
let _ = handle_join_request(&req, &mut topology, &routing, 42);
let resp = handle_join_request(&req, &mut topology, &routing, 42);
assert!(resp.success);
assert_eq!(resp.nodes.len(), 2); assert_eq!(topology.node_count(), 2);
}
#[test]
fn handle_join_request_idempotent_no_mutation() {
let mut topology = topo_with_one_node();
let routing = RoutingTable::uniform(1, &[1], 1);
let req = JoinRequest {
node_id: 2,
listen_addr: "10.0.0.2:9400".into(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
spiffe_id: None,
spki_pin: None,
};
let resp1 = handle_join_request(&req, &mut topology, &routing, 7);
let ids_before: Vec<u64> = topology.all_nodes().map(|n| n.node_id).collect();
let count_before = topology.node_count();
let resp2 = handle_join_request(&req, &mut topology, &routing, 7);
assert_eq!(resp1.cluster_id, 7);
assert_eq!(resp2.cluster_id, 7);
let ids_after: Vec<u64> = topology.all_nodes().map(|n| n.node_id).collect();
assert!(resp1.success && resp2.success);
assert_eq!(count_before, topology.node_count());
assert_eq!(ids_before, ids_after);
assert_eq!(resp2.nodes.len(), 2);
let n2 = topology.get_node(2).unwrap();
assert_eq!(n2.state, NodeState::Active);
}
#[test]
fn handle_join_request_rejects_id_collision() {
let mut topology = topo_with_one_node();
let routing = RoutingTable::uniform(1, &[1], 1);
let req1 = JoinRequest {
node_id: 2,
listen_addr: "10.0.0.2:9400".into(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
spiffe_id: None,
spki_pin: None,
};
let resp1 = handle_join_request(&req1, &mut topology, &routing, 11);
assert!(resp1.success);
let req2 = JoinRequest {
node_id: 2,
listen_addr: "10.0.0.99:9400".into(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
spiffe_id: None,
spki_pin: None,
};
let resp2 = handle_join_request(&req2, &mut topology, &routing, 11);
assert!(!resp2.success);
assert!(
resp2.error.contains("already registered"),
"error should mention collision: {}",
resp2.error
);
assert_eq!(topology.node_count(), 2);
let n2 = topology.get_node(2).unwrap();
assert_eq!(n2.addr, "10.0.0.2:9400");
}
#[test]
fn handle_join_invalid_addr() {
let mut topology = ClusterTopology::new();
let routing = RoutingTable::uniform(1, &[1], 1);
let req = JoinRequest {
node_id: 2,
listen_addr: "not-a-valid-address".into(),
wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
spiffe_id: None,
spki_pin: None,
};
let resp = handle_join_request(&req, &mut topology, &routing, 42);
assert!(!resp.success);
assert!(!resp.error.is_empty());
}
}