nodedb_cluster/bootstrap/
handle_join.rs1use std::net::SocketAddr;
21
22use tracing::warn;
23
24use crate::routing::RoutingTable;
25use crate::rpc_codec::{JoinGroupInfo, JoinNodeInfo, JoinRequest, JoinResponse};
26use crate::topology::{CLUSTER_WIRE_FORMAT_VERSION, ClusterTopology, NodeInfo, NodeState};
27
28pub fn handle_join_request(
40 req: &JoinRequest,
41 topology: &mut ClusterTopology,
42 routing: &RoutingTable,
43 cluster_id: u64,
44) -> JoinResponse {
45 if req.wire_version != CLUSTER_WIRE_FORMAT_VERSION {
52 warn!(
53 node_id = req.node_id,
54 joiner_wire_version = req.wire_version,
55 expected_wire_version = CLUSTER_WIRE_FORMAT_VERSION,
56 "join request rejected: joiner cluster wire_version mismatch"
57 );
58 return reject(format!(
59 "joiner wire_version {} does not match this cluster's wire_version {} — \
60 rolling upgrade is required before this node can join",
61 req.wire_version, CLUSTER_WIRE_FORMAT_VERSION
62 ));
63 }
64
65 let addr: SocketAddr = match req.listen_addr.parse() {
67 Ok(a) => a,
68 Err(e) => {
69 return reject(format!("invalid listen_addr '{}': {e}", req.listen_addr));
70 }
71 };
72
73 if let Some(existing) = topology.get_node(req.node_id) {
75 let existing_addr = existing.addr.clone();
76 if existing_addr != req.listen_addr {
77 return reject(format!(
79 "node_id {} already registered with different address {} (request: {})",
80 req.node_id, existing_addr, req.listen_addr
81 ));
82 }
83 if existing.state != NodeState::Active
88 && let Some(entry) = topology.get_node_mut(req.node_id)
89 {
90 entry.state = NodeState::Active;
91 }
92 return build_response(topology, routing, cluster_id);
93 }
94
95 let spki_pin: Option<[u8; 32]> = req.spki_pin.as_deref().and_then(|b| {
100 if b.len() == 32 {
101 let mut arr = [0u8; 32];
102 arr.copy_from_slice(b);
103 Some(arr)
104 } else {
105 None
106 }
107 });
108 topology.add_node(
109 NodeInfo::new(req.node_id, addr, NodeState::Active)
110 .with_wire_version(req.wire_version)
111 .with_spiffe_id(req.spiffe_id.clone())
112 .with_spki_pin(spki_pin),
113 );
114 build_response(topology, routing, cluster_id)
115}
116
117fn build_response(
119 topology: &ClusterTopology,
120 routing: &RoutingTable,
121 cluster_id: u64,
122) -> JoinResponse {
123 let nodes: Vec<JoinNodeInfo> = topology
124 .all_nodes()
125 .map(|n| JoinNodeInfo {
126 node_id: n.node_id,
127 addr: n.addr.clone(),
128 state: n.state.as_u8(),
129 raft_groups: n.raft_groups.clone(),
130 wire_version: n.wire_version,
131 spiffe_id: n.spiffe_id.clone(),
132 spki_pin: n.spki_pin.map(|arr| arr.to_vec()),
133 })
134 .collect();
135
136 let groups: Vec<JoinGroupInfo> = routing
137 .group_members()
138 .iter()
139 .map(|(&gid, info)| JoinGroupInfo {
140 group_id: gid,
141 leader: info.leader,
142 members: info.members.clone(),
143 learners: info.learners.clone(),
144 })
145 .collect();
146
147 JoinResponse {
148 success: true,
149 error: String::new(),
150 cluster_id,
151 nodes,
152 vshard_to_group: routing.vshard_to_group().to_vec(),
153 groups,
154 }
155}
156
157fn reject(error: String) -> JoinResponse {
159 JoinResponse {
160 success: false,
161 error,
162 cluster_id: 0,
163 nodes: vec![],
164 vshard_to_group: vec![],
165 groups: vec![],
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 fn topo_with_one_node() -> ClusterTopology {
174 let mut topology = ClusterTopology::new();
175 topology.add_node(NodeInfo::new(
176 1,
177 "10.0.0.1:9400".parse().unwrap(),
178 NodeState::Active,
179 ));
180 topology
181 }
182
183 #[test]
184 fn handle_join_request_adds_node() {
185 let mut topology = topo_with_one_node();
186 let routing = RoutingTable::uniform(2, &[1], 1);
187
188 let req = JoinRequest {
189 node_id: 2,
190 listen_addr: "10.0.0.2:9400".into(),
191 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
192 spiffe_id: None,
193 spki_pin: None,
194 };
195
196 let resp = handle_join_request(&req, &mut topology, &routing, 42);
197
198 assert!(resp.success);
199 assert_eq!(resp.nodes.len(), 2);
200 assert_eq!(resp.vshard_to_group.len(), 1024);
201 assert_eq!(resp.groups.len(), 3);
203
204 assert!(topology.contains(2));
205 assert_eq!(topology.node_count(), 2);
206 }
207
208 #[test]
209 fn handle_join_request_idempotent() {
210 let mut topology = topo_with_one_node();
211 let routing = RoutingTable::uniform(1, &[1], 1);
212
213 let req = JoinRequest {
214 node_id: 2,
215 listen_addr: "10.0.0.2:9400".into(),
216 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
217 spiffe_id: None,
218 spki_pin: None,
219 };
220
221 let _ = handle_join_request(&req, &mut topology, &routing, 42);
222 let resp = handle_join_request(&req, &mut topology, &routing, 42);
223
224 assert!(resp.success);
225 assert_eq!(resp.nodes.len(), 2); assert_eq!(topology.node_count(), 2);
227 }
228
229 #[test]
233 fn handle_join_request_idempotent_no_mutation() {
234 let mut topology = topo_with_one_node();
235 let routing = RoutingTable::uniform(1, &[1], 1);
236
237 let req = JoinRequest {
238 node_id: 2,
239 listen_addr: "10.0.0.2:9400".into(),
240 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
241 spiffe_id: None,
242 spki_pin: None,
243 };
244
245 let resp1 = handle_join_request(&req, &mut topology, &routing, 7);
246 let ids_before: Vec<u64> = topology.all_nodes().map(|n| n.node_id).collect();
247 let count_before = topology.node_count();
248
249 let resp2 = handle_join_request(&req, &mut topology, &routing, 7);
250 assert_eq!(resp1.cluster_id, 7);
251 assert_eq!(resp2.cluster_id, 7);
252 let ids_after: Vec<u64> = topology.all_nodes().map(|n| n.node_id).collect();
253
254 assert!(resp1.success && resp2.success);
255 assert_eq!(count_before, topology.node_count());
256 assert_eq!(ids_before, ids_after);
257 assert_eq!(resp2.nodes.len(), 2);
258 let n2 = topology.get_node(2).unwrap();
260 assert_eq!(n2.state, NodeState::Active);
261 }
262
263 #[test]
265 fn handle_join_request_rejects_id_collision() {
266 let mut topology = topo_with_one_node();
267 let routing = RoutingTable::uniform(1, &[1], 1);
268
269 let req1 = JoinRequest {
271 node_id: 2,
272 listen_addr: "10.0.0.2:9400".into(),
273 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
274 spiffe_id: None,
275 spki_pin: None,
276 };
277 let resp1 = handle_join_request(&req1, &mut topology, &routing, 11);
278 assert!(resp1.success);
279
280 let req2 = JoinRequest {
282 node_id: 2,
283 listen_addr: "10.0.0.99:9400".into(),
284 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
285 spiffe_id: None,
286 spki_pin: None,
287 };
288 let resp2 = handle_join_request(&req2, &mut topology, &routing, 11);
289
290 assert!(!resp2.success);
291 assert!(
292 resp2.error.contains("already registered"),
293 "error should mention collision: {}",
294 resp2.error
295 );
296 assert_eq!(topology.node_count(), 2);
298 let n2 = topology.get_node(2).unwrap();
299 assert_eq!(n2.addr, "10.0.0.2:9400");
300 }
301
302 #[test]
303 fn handle_join_invalid_addr() {
304 let mut topology = ClusterTopology::new();
305 let routing = RoutingTable::uniform(1, &[1], 1);
306
307 let req = JoinRequest {
308 node_id: 2,
309 listen_addr: "not-a-valid-address".into(),
310 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
311 spiffe_id: None,
312 spki_pin: None,
313 };
314
315 let resp = handle_join_request(&req, &mut topology, &routing, 42);
316 assert!(!resp.success);
317 assert!(!resp.error.is_empty());
318 }
319}