1use std::collections::HashSet;
24use std::net::SocketAddr;
25use std::sync::{Arc, Mutex, RwLock};
26
27use tracing::{debug, info, warn};
28
29use crate::catalog::ClusterCatalog;
30use crate::error::{ClusterError, Result};
31use crate::lifecycle_state::ClusterLifecycleTracker;
32use crate::multi_raft::MultiRaft;
33use crate::routing::{GroupInfo, RoutingTable};
34use crate::rpc_codec::{JoinRequest, JoinResponse, LEADER_REDIRECT_PREFIX, RaftRpc};
35use crate::topology::{ClusterTopology, NodeInfo, NodeState};
36use crate::transport::NexarTransport;
37
38use super::config::{ClusterConfig, ClusterState};
39
40const MAX_REDIRECTS_PER_ATTEMPT: u32 = 3;
44
45pub(crate) fn parse_leader_hint(error: &str) -> Option<SocketAddr> {
58 error
59 .strip_prefix(LEADER_REDIRECT_PREFIX)
60 .and_then(|s| s.trim().parse().ok())
61}
62
63pub(super) async fn join(
75 config: &ClusterConfig,
76 catalog: &ClusterCatalog,
77 transport: &NexarTransport,
78 lifecycle: &ClusterLifecycleTracker,
79) -> Result<ClusterState> {
80 info!(
81 node_id = config.node_id,
82 seeds = ?config.seed_nodes,
83 "joining existing cluster"
84 );
85
86 if config.seed_nodes.is_empty() {
87 let err = ClusterError::Transport {
88 detail: "no seed nodes configured".into(),
89 };
90 lifecycle.to_failed(err.to_string());
91 return Err(err);
92 }
93
94 let req_template = JoinRequest {
95 node_id: config.node_id,
96 listen_addr: config.listen_addr.to_string(),
97 wire_version: crate::topology::CLUSTER_WIRE_FORMAT_VERSION,
98 spiffe_id: None,
99 spki_pin: transport.local_spki_pin().map(|arr| arr.to_vec()),
100 };
101
102 let policy = config.join_retry;
103 let mut last_err: Option<ClusterError> = None;
104
105 for attempt in 0..policy.max_attempts {
106 lifecycle.to_joining(attempt);
107
108 let delay = policy.backoff_for(attempt);
109 if !delay.is_zero() {
110 debug!(
111 node_id = config.node_id,
112 attempt,
113 delay_ms = delay.as_millis() as u64,
114 "backing off before next join attempt"
115 );
116 tokio::time::sleep(delay).await;
117 }
118
119 match try_join_once(config, catalog, transport, &req_template).await {
120 Ok(state) => return Ok(state),
121 Err(e) => {
122 warn!(
123 node_id = config.node_id,
124 attempt,
125 error = %e,
126 "join attempt failed; will retry"
127 );
128 last_err = Some(e);
129 }
130 }
131 }
132
133 let max_attempts = policy.max_attempts;
134 let err = last_err.unwrap_or_else(|| ClusterError::Transport {
135 detail: format!("join exhausted {max_attempts} attempts with no concrete error"),
136 });
137 lifecycle.to_failed(err.to_string());
138 Err(err)
139}
140
141async fn try_join_once(
146 config: &ClusterConfig,
147 catalog: &ClusterCatalog,
148 transport: &NexarTransport,
149 req_template: &JoinRequest,
150) -> Result<ClusterState> {
151 let mut work: std::collections::VecDeque<SocketAddr> =
161 config.seed_nodes.iter().copied().collect();
162 {
163 let mut sorted: Vec<SocketAddr> = work.drain(..).collect();
167 sorted.sort();
168 work.extend(sorted);
169 }
170 let mut visited: HashSet<SocketAddr> = HashSet::new();
171 let mut redirects: u32 = 0;
172 let mut last_err: Option<ClusterError> = None;
173
174 while let Some(addr) = work.pop_front() {
175 if !visited.insert(addr) {
176 continue;
177 }
178
179 let rpc = RaftRpc::JoinRequest(req_template.clone());
180 match transport.send_rpc_to_addr(addr, rpc).await {
181 Ok(RaftRpc::JoinResponse(resp)) => {
182 if resp.success {
183 return apply_join_response(config, catalog, transport, &resp);
184 }
185 if let Some(leader) = parse_leader_hint(&resp.error) {
187 if redirects < MAX_REDIRECTS_PER_ATTEMPT && !visited.contains(&leader) {
188 info!(
189 node_id = config.node_id,
190 from = %addr,
191 to = %leader,
192 "following leader redirect"
193 );
194 redirects += 1;
195 work.push_front(leader);
196 continue;
197 }
198 debug!(
199 node_id = config.node_id,
200 from = %addr,
201 leader = %leader,
202 redirects,
203 "redirect cap reached or loop detected; falling through"
204 );
205 }
206 last_err = Some(ClusterError::Transport {
207 detail: format!("join rejected by {addr}: {}", resp.error),
208 });
209 }
210 Ok(other) => {
211 last_err = Some(ClusterError::Transport {
212 detail: format!("unexpected response from {addr}: {other:?}"),
213 });
214 }
215 Err(e) => {
216 debug!(%addr, error = %e, "seed unreachable");
217 last_err = Some(e);
218 }
219 }
220 }
221
222 Err(last_err.unwrap_or_else(|| ClusterError::Transport {
223 detail: "no seed nodes produced a response".into(),
224 }))
225}
226
227fn apply_join_response(
243 config: &ClusterConfig,
244 catalog: &ClusterCatalog,
245 transport: &NexarTransport,
246 resp: &JoinResponse,
247) -> Result<ClusterState> {
248 let mut topology = ClusterTopology::new();
250 for node in &resp.nodes {
251 let state = NodeState::from_u8(node.state).unwrap_or(NodeState::Active);
252 let spki_pin: Option<[u8; 32]> = node.spki_pin.as_deref().and_then(|b| {
253 if b.len() == 32 {
254 let mut arr = [0u8; 32];
255 arr.copy_from_slice(b);
256 Some(arr)
257 } else {
258 None
259 }
260 });
261 let mut info = NodeInfo::new(
262 node.node_id,
263 node.addr.parse().unwrap_or_else(|_| {
264 "0.0.0.0:0"
265 .parse()
266 .expect("invariant: \"0.0.0.0:0\" is a valid SocketAddr literal")
267 }),
268 state,
269 )
270 .with_wire_version(node.wire_version)
271 .with_spiffe_id(node.spiffe_id.clone())
272 .with_spki_pin(spki_pin);
273 info.raft_groups = node.raft_groups.clone();
275 if node.node_id == config.node_id {
276 info.state = NodeState::Active;
277 }
278 topology.add_node(info);
279 }
280
281 let mut group_members = std::collections::HashMap::new();
283 for g in &resp.groups {
284 group_members.insert(
285 g.group_id,
286 GroupInfo {
287 leader: g.leader,
288 members: g.members.clone(),
289 learners: g.learners.clone(),
290 },
291 );
292 }
293 let routing = RoutingTable::from_parts(resp.vshard_to_group.clone(), group_members);
294
295 catalog.save_cluster_id(resp.cluster_id)?;
304 catalog.save_topology(&topology)?;
305 catalog.save_routing(&routing)?;
306
307 let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone())
313 .with_election_timeout(config.election_timeout_min, config.election_timeout_max);
314 for g in &resp.groups {
315 let is_voter = g.members.contains(&config.node_id);
316 let is_learner = g.learners.contains(&config.node_id);
317
318 if is_voter {
319 let peers: Vec<u64> = g
320 .members
321 .iter()
322 .copied()
323 .filter(|&id| id != config.node_id)
324 .collect();
325 multi_raft.add_group(g.group_id, peers)?;
326 } else if is_learner {
327 let voters = g.members.clone();
328 let other_learners: Vec<u64> = g
329 .learners
330 .iter()
331 .copied()
332 .filter(|&id| id != config.node_id)
333 .collect();
334 multi_raft.add_group_as_learner(g.group_id, voters, other_learners)?;
335 }
336 }
337
338 for node in &resp.nodes {
340 if node.node_id != config.node_id
341 && let Ok(addr) = node.addr.parse::<SocketAddr>()
342 {
343 transport.register_peer(node.node_id, addr);
344 }
345 }
346
347 info!(
348 node_id = config.node_id,
349 nodes = topology.node_count(),
350 groups = routing.num_groups(),
351 "joined cluster"
352 );
353
354 Ok(ClusterState {
355 topology: Arc::new(RwLock::new(topology)),
356 routing: Arc::new(RwLock::new(routing)),
357 multi_raft: Arc::new(Mutex::new(multi_raft)),
358 })
359}
360
361#[cfg(test)]
362mod tests {
363 use super::super::bootstrap_fn::bootstrap;
364 use super::super::config::JoinRetryPolicy;
365 use super::super::handle_join::handle_join_request;
366 use super::*;
367 use std::sync::Arc;
368 use std::time::Duration;
369
370 fn temp_catalog() -> (tempfile::TempDir, ClusterCatalog) {
371 let dir = tempfile::tempdir().unwrap();
372 let path = dir.path().join("cluster.redb");
373 let catalog = ClusterCatalog::open(&path).unwrap();
374 (dir, catalog)
375 }
376
377 #[test]
380 fn parse_leader_hint_extracts_valid_addr() {
381 assert_eq!(
382 parse_leader_hint("not leader; retry at 10.0.0.1:9400"),
383 Some("10.0.0.1:9400".parse().unwrap())
384 );
385 assert_eq!(
386 parse_leader_hint("not leader; retry at 127.0.0.1:65535"),
387 Some("127.0.0.1:65535".parse().unwrap())
388 );
389 }
390
391 #[test]
392 fn parse_leader_hint_rejects_unrelated_error() {
393 assert_eq!(
394 parse_leader_hint("node_id 2 already registered with different address 10.0.0.2:9400"),
395 None
396 );
397 assert_eq!(parse_leader_hint(""), None);
398 assert_eq!(
399 parse_leader_hint("conf change commit timeout on group 0"),
400 None
401 );
402 }
403
404 #[test]
405 fn parse_leader_hint_rejects_malformed_addr() {
406 assert_eq!(parse_leader_hint("not leader; retry at notanaddress"), None);
407 assert_eq!(parse_leader_hint("not leader; retry at "), None);
408 assert_eq!(parse_leader_hint("not leader; retry at 10.0.0.1"), None);
409 }
410
411 #[test]
412 fn join_retry_policy_default_schedule() {
413 let policy = JoinRetryPolicy::default();
417 assert_eq!(policy.backoff_for(0), Duration::ZERO);
418 assert_eq!(policy.backoff_for(1), Duration::from_millis(250));
419 assert_eq!(policy.backoff_for(2), Duration::from_millis(500));
420 assert_eq!(policy.backoff_for(3), Duration::from_secs(1));
421 assert_eq!(policy.backoff_for(4), Duration::from_secs(2));
422 assert_eq!(policy.backoff_for(5), Duration::from_secs(4));
423 assert_eq!(policy.backoff_for(6), Duration::from_secs(8));
424 assert_eq!(policy.backoff_for(7), Duration::from_secs(16));
425 assert_eq!(policy.backoff_for(8), Duration::from_secs(32));
426 assert_eq!(policy.backoff_for(9), Duration::ZERO);
428 }
429
430 #[test]
431 fn join_retry_policy_test_schedule_is_subsecond() {
432 let policy = JoinRetryPolicy {
435 max_attempts: 8,
436 max_backoff_secs: 2,
437 };
438 let total: Duration = (0..=policy.max_attempts)
441 .map(|a| policy.backoff_for(a))
442 .sum();
443 assert!(
444 total < Duration::from_secs(5),
445 "test schedule too slow: {total:?}"
446 );
447 assert_eq!(policy.backoff_for(8), Duration::from_secs(2));
449 }
450
451 #[tokio::test]
454 async fn full_bootstrap_join_flow() {
455 use crate::transport::credentials::TransportCredentials;
457 let t1 = Arc::new(
458 NexarTransport::new(
459 1,
460 "127.0.0.1:0".parse().unwrap(),
461 TransportCredentials::Insecure,
462 )
463 .unwrap(),
464 );
465 let t2 = Arc::new(
466 NexarTransport::new(
467 2,
468 "127.0.0.1:0".parse().unwrap(),
469 TransportCredentials::Insecure,
470 )
471 .unwrap(),
472 );
473
474 let (_dir1, catalog1) = temp_catalog();
475 let (_dir2, catalog2) = temp_catalog();
476
477 let addr1 = t1.local_addr();
478 let addr2 = t2.local_addr();
479
480 let config1 = ClusterConfig {
481 node_id: 1,
482 listen_addr: addr1,
483 seed_nodes: vec![addr1],
484 num_groups: 2,
485 replication_factor: 1,
486 data_dir: _dir1.path().to_path_buf(),
487 force_bootstrap: false,
488 join_retry: Default::default(),
489 swim_udp_addr: None,
490 election_timeout_min: Duration::from_millis(150),
491 election_timeout_max: Duration::from_millis(300),
492 install_snapshot_chunk_bytes: 4 * 1024 * 1024,
493 orphan_partial_max_age_secs: 300,
494 };
495 let state1 = bootstrap(&config1, &catalog1, None).unwrap();
496
497 let topology1 = state1.topology.clone();
500 let routing1 = state1.routing.clone();
501
502 struct JoinHandler {
503 topology: std::sync::Arc<std::sync::RwLock<ClusterTopology>>,
504 routing: std::sync::Arc<std::sync::RwLock<RoutingTable>>,
505 }
506
507 impl crate::transport::RaftRpcHandler for JoinHandler {
508 async fn handle_rpc(&self, rpc: RaftRpc) -> Result<RaftRpc> {
509 match rpc {
510 RaftRpc::JoinRequest(req) => {
511 let mut topo = self.topology.write().unwrap();
512 let routing = self.routing.read().unwrap();
513 let resp = handle_join_request(&req, &mut topo, &routing, 99);
514 Ok(RaftRpc::JoinResponse(resp))
515 }
516 other => Err(ClusterError::Transport {
517 detail: format!("unexpected: {other:?}"),
518 }),
519 }
520 }
521 }
522
523 let handler = Arc::new(JoinHandler {
524 topology: topology1.clone(),
525 routing: routing1.clone(),
526 });
527
528 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
529 let t1c = t1.clone();
530 tokio::spawn(async move {
531 t1c.serve(handler, shutdown_rx).await.unwrap();
532 });
533
534 tokio::time::sleep(Duration::from_millis(30)).await;
535
536 let config2 = ClusterConfig {
537 node_id: 2,
538 listen_addr: addr2,
539 seed_nodes: vec![addr1],
540 num_groups: 2,
541 replication_factor: 1,
542 data_dir: _dir2.path().to_path_buf(),
543 force_bootstrap: false,
544 join_retry: Default::default(),
545 swim_udp_addr: None,
546 election_timeout_min: Duration::from_millis(150),
547 election_timeout_max: Duration::from_millis(300),
548 install_snapshot_chunk_bytes: 4 * 1024 * 1024,
549 orphan_partial_max_age_secs: 300,
550 };
551
552 let lifecycle = ClusterLifecycleTracker::new();
553 let state2 = join(&config2, &catalog2, &t2, &lifecycle).await.unwrap();
554 assert!(matches!(
557 lifecycle.current(),
558 crate::lifecycle_state::ClusterLifecycleState::Joining { .. }
559 ));
560
561 assert_eq!(state2.topology.read().unwrap().node_count(), 2);
562 assert_eq!(state2.routing.read().unwrap().num_groups(), 3);
567
568 assert!(catalog2.load_topology().unwrap().is_some());
571 assert!(catalog2.load_routing().unwrap().is_some());
572
573 let topo1 = topology1.read().unwrap();
575 assert_eq!(topo1.node_count(), 2);
576 assert!(topo1.contains(2));
577
578 shutdown_tx.send(true).unwrap();
579 }
580}