1use std::net::SocketAddr;
15use tracing::{info, warn};
16
17use crate::catalog::ClusterCatalog;
18use crate::error::{ClusterError, Result};
19use crate::multi_raft::MultiRaft;
20use crate::routing::{GroupInfo, RoutingTable};
21use crate::rpc_codec::{JoinGroupInfo, JoinNodeInfo, JoinRequest, JoinResponse, RaftRpc};
22use crate::topology::{ClusterTopology, NodeInfo, NodeState};
23use crate::transport::NexarTransport;
24
25#[derive(Debug, Clone)]
27pub struct ClusterConfig {
28 pub node_id: u64,
30 pub listen_addr: SocketAddr,
32 pub seed_nodes: Vec<SocketAddr>,
34 pub num_groups: u64,
36 pub replication_factor: usize,
38 pub data_dir: std::path::PathBuf,
40}
41
42pub struct ClusterState {
44 pub topology: ClusterTopology,
45 pub routing: RoutingTable,
46 pub multi_raft: MultiRaft,
47}
48
49pub async fn start_cluster(
53 config: &ClusterConfig,
54 catalog: &ClusterCatalog,
55 transport: &NexarTransport,
56) -> Result<ClusterState> {
57 if catalog.is_bootstrapped()? {
59 return restart(config, catalog, transport);
60 }
61
62 let is_seed = config.seed_nodes.contains(&config.listen_addr);
64
65 if is_seed && should_bootstrap(config, transport).await {
66 bootstrap(config, catalog)
67 } else {
68 join(config, catalog, transport).await
69 }
70}
71
72async fn should_bootstrap(config: &ClusterConfig, transport: &NexarTransport) -> bool {
76 for addr in &config.seed_nodes {
77 if *addr == config.listen_addr {
78 continue;
79 }
80 let probe = RaftRpc::JoinRequest(JoinRequest {
82 node_id: config.node_id,
83 listen_addr: config.listen_addr.to_string(),
84 });
85 match transport.send_rpc_to_addr(*addr, probe).await {
86 Ok(_) => return false, Err(_) => continue, }
89 }
90 true
92}
93
94fn bootstrap(config: &ClusterConfig, catalog: &ClusterCatalog) -> Result<ClusterState> {
98 info!(
99 node_id = config.node_id,
100 addr = %config.listen_addr,
101 groups = config.num_groups,
102 "bootstrapping new cluster"
103 );
104
105 let mut topology = ClusterTopology::new();
107 topology.add_node(NodeInfo::new(
108 config.node_id,
109 config.listen_addr,
110 NodeState::Active,
111 ));
112
113 let routing = RoutingTable::uniform(
115 config.num_groups,
116 &[config.node_id],
117 config.replication_factor.min(1), );
119
120 let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
122 for group_id in routing.group_ids() {
123 multi_raft.add_group(group_id, vec![])?;
124 }
125
126 let cluster_id = generate_cluster_id();
128 catalog.save_cluster_id(cluster_id)?;
129 catalog.save_topology(&topology)?;
130 catalog.save_routing(&routing)?;
131
132 info!(
133 node_id = config.node_id,
134 cluster_id,
135 groups = config.num_groups,
136 "cluster bootstrapped"
137 );
138
139 Ok(ClusterState {
140 topology,
141 routing,
142 multi_raft,
143 })
144}
145
146async fn join(
150 config: &ClusterConfig,
151 catalog: &ClusterCatalog,
152 transport: &NexarTransport,
153) -> Result<ClusterState> {
154 info!(
155 node_id = config.node_id,
156 seeds = ?config.seed_nodes,
157 "joining existing cluster"
158 );
159
160 let req = RaftRpc::JoinRequest(JoinRequest {
161 node_id: config.node_id,
162 listen_addr: config.listen_addr.to_string(),
163 });
164
165 let mut last_err = None;
167 for addr in &config.seed_nodes {
168 match transport.send_rpc_to_addr(*addr, req.clone()).await {
169 Ok(RaftRpc::JoinResponse(resp)) => {
170 if !resp.success {
171 last_err = Some(ClusterError::Transport {
172 detail: format!("join rejected by {addr}: {}", resp.error),
173 });
174 continue;
175 }
176 return apply_join_response(config, catalog, transport, &resp);
177 }
178 Ok(other) => {
179 last_err = Some(ClusterError::Transport {
180 detail: format!("unexpected response from {addr}: {other:?}"),
181 });
182 }
183 Err(e) => {
184 warn!(%addr, error = %e, "seed unreachable");
185 last_err = Some(e);
186 }
187 }
188 }
189
190 Err(last_err.unwrap_or_else(|| ClusterError::Transport {
191 detail: "no seed nodes configured".into(),
192 }))
193}
194
195fn apply_join_response(
197 config: &ClusterConfig,
198 catalog: &ClusterCatalog,
199 transport: &NexarTransport,
200 resp: &JoinResponse,
201) -> Result<ClusterState> {
202 let mut topology = ClusterTopology::new();
204 for node in &resp.nodes {
205 let state = NodeState::from_u8(node.state).unwrap_or(NodeState::Active);
206 let mut info = NodeInfo {
207 node_id: node.node_id,
208 addr: node.addr.clone(),
209 state,
210 raft_groups: node.raft_groups.clone(),
211 };
212 if node.node_id == config.node_id {
214 info.state = NodeState::Active;
215 }
216 topology.add_node(info);
217 }
218
219 let mut group_members = std::collections::HashMap::new();
221 for g in &resp.groups {
222 group_members.insert(
223 g.group_id,
224 GroupInfo {
225 leader: g.leader,
226 members: g.members.clone(),
227 },
228 );
229 }
230 let routing = RoutingTable::from_parts(resp.vshard_to_group.clone(), group_members);
231
232 let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
234 for g in &resp.groups {
235 if g.members.contains(&config.node_id) {
236 let peers: Vec<u64> = g
237 .members
238 .iter()
239 .copied()
240 .filter(|&id| id != config.node_id)
241 .collect();
242 multi_raft.add_group(g.group_id, peers)?;
243 }
244 }
245
246 for node in &resp.nodes {
248 if node.node_id != config.node_id
249 && let Ok(addr) = node.addr.parse::<SocketAddr>()
250 {
251 transport.register_peer(node.node_id, addr);
252 }
253 }
254
255 catalog.save_topology(&topology)?;
257 catalog.save_routing(&routing)?;
258
259 info!(
260 node_id = config.node_id,
261 nodes = topology.node_count(),
262 groups = routing.num_groups(),
263 "joined cluster"
264 );
265
266 Ok(ClusterState {
267 topology,
268 routing,
269 multi_raft,
270 })
271}
272
273fn restart(
277 config: &ClusterConfig,
278 catalog: &ClusterCatalog,
279 transport: &NexarTransport,
280) -> Result<ClusterState> {
281 let topology = catalog
282 .load_topology()?
283 .ok_or_else(|| ClusterError::Transport {
284 detail: "catalog is bootstrapped but topology is missing".into(),
285 })?;
286
287 let routing = catalog
288 .load_routing()?
289 .ok_or_else(|| ClusterError::Transport {
290 detail: "catalog is bootstrapped but routing table is missing".into(),
291 })?;
292
293 let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
295 for (group_id, info) in routing.group_members() {
296 if info.members.contains(&config.node_id) {
297 let peers: Vec<u64> = info
298 .members
299 .iter()
300 .copied()
301 .filter(|&id| id != config.node_id)
302 .collect();
303 multi_raft.add_group(*group_id, peers)?;
304 }
305 }
306
307 for node in topology.all_nodes() {
309 if node.node_id != config.node_id
310 && let Some(addr) = node.socket_addr()
311 {
312 transport.register_peer(node.node_id, addr);
313 }
314 }
315
316 info!(
317 node_id = config.node_id,
318 nodes = topology.node_count(),
319 groups = multi_raft.group_count(),
320 "restarted from catalog"
321 );
322
323 Ok(ClusterState {
324 topology,
325 routing,
326 multi_raft,
327 })
328}
329
330pub fn handle_join_request(
336 req: &JoinRequest,
337 topology: &mut ClusterTopology,
338 routing: &RoutingTable,
339) -> JoinResponse {
340 let addr: SocketAddr = match req.listen_addr.parse() {
342 Ok(a) => a,
343 Err(e) => {
344 return JoinResponse {
345 success: false,
346 error: format!("invalid listen_addr '{}': {e}", req.listen_addr),
347 nodes: vec![],
348 vshard_to_group: vec![],
349 groups: vec![],
350 };
351 }
352 };
353
354 if topology.contains(req.node_id) {
355 if let Some(existing) = topology.get_node_mut(req.node_id) {
357 existing.addr = req.listen_addr.clone();
358 existing.state = NodeState::Active;
359 }
360 } else {
361 topology.add_node(NodeInfo::new(req.node_id, addr, NodeState::Active));
362 }
363
364 let nodes: Vec<JoinNodeInfo> = topology
366 .all_nodes()
367 .map(|n| JoinNodeInfo {
368 node_id: n.node_id,
369 addr: n.addr.clone(),
370 state: n.state.as_u8(),
371 raft_groups: n.raft_groups.clone(),
372 })
373 .collect();
374
375 let groups: Vec<JoinGroupInfo> = routing
376 .group_members()
377 .iter()
378 .map(|(&gid, info)| JoinGroupInfo {
379 group_id: gid,
380 leader: info.leader,
381 members: info.members.clone(),
382 })
383 .collect();
384
385 JoinResponse {
386 success: true,
387 error: String::new(),
388 nodes,
389 vshard_to_group: routing.vshard_to_group().to_vec(),
390 groups,
391 }
392}
393
394fn generate_cluster_id() -> u64 {
396 use rand::Rng;
397 rand::rng().random::<u64>()
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::catalog::ClusterCatalog;
404
405 fn temp_catalog() -> (tempfile::TempDir, ClusterCatalog) {
406 let dir = tempfile::tempdir().unwrap();
407 let path = dir.path().join("cluster.redb");
408 let catalog = ClusterCatalog::open(&path).unwrap();
409 (dir, catalog)
410 }
411
412 #[test]
413 fn bootstrap_creates_cluster() {
414 let (_dir, catalog) = temp_catalog();
415 let config = ClusterConfig {
416 node_id: 1,
417 listen_addr: "127.0.0.1:9400".parse().unwrap(),
418 seed_nodes: vec!["127.0.0.1:9400".parse().unwrap()],
419 num_groups: 4,
420 replication_factor: 1,
421 data_dir: _dir.path().to_path_buf(),
422 };
423
424 let state = bootstrap(&config, &catalog).unwrap();
425
426 assert_eq!(state.topology.node_count(), 1);
427 assert_eq!(state.topology.active_nodes().len(), 1);
428 assert_eq!(state.routing.num_groups(), 4);
429 assert_eq!(state.multi_raft.group_count(), 4);
430
431 assert!(catalog.is_bootstrapped().unwrap());
433 let loaded_topo = catalog.load_topology().unwrap().unwrap();
434 assert_eq!(loaded_topo.node_count(), 1);
435 let loaded_rt = catalog.load_routing().unwrap().unwrap();
436 assert_eq!(loaded_rt.num_groups(), 4);
437 }
438
439 #[tokio::test]
440 async fn restart_from_catalog() {
441 let (_dir, catalog) = temp_catalog();
442 let config = ClusterConfig {
443 node_id: 1,
444 listen_addr: "127.0.0.1:9400".parse().unwrap(),
445 seed_nodes: vec![],
446 num_groups: 4,
447 replication_factor: 1,
448 data_dir: _dir.path().to_path_buf(),
449 };
450
451 let _ = bootstrap(&config, &catalog).unwrap();
453
454 let transport = NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap();
456
457 let state = restart(&config, &catalog, &transport).unwrap();
459
460 assert_eq!(state.topology.node_count(), 1);
461 assert_eq!(state.routing.num_groups(), 4);
462 assert_eq!(state.multi_raft.group_count(), 4);
463 }
464
465 #[test]
466 fn handle_join_request_adds_node() {
467 let mut topology = ClusterTopology::new();
468 topology.add_node(NodeInfo::new(
469 1,
470 "10.0.0.1:9400".parse().unwrap(),
471 NodeState::Active,
472 ));
473
474 let routing = RoutingTable::uniform(2, &[1], 1);
475
476 let req = JoinRequest {
477 node_id: 2,
478 listen_addr: "10.0.0.2:9400".into(),
479 };
480
481 let resp = handle_join_request(&req, &mut topology, &routing);
482
483 assert!(resp.success);
484 assert_eq!(resp.nodes.len(), 2);
485 assert_eq!(resp.vshard_to_group.len(), 1024);
486 assert_eq!(resp.groups.len(), 2);
487
488 assert!(topology.contains(2));
490 assert_eq!(topology.node_count(), 2);
491 }
492
493 #[test]
494 fn handle_join_request_idempotent() {
495 let mut topology = ClusterTopology::new();
496 topology.add_node(NodeInfo::new(
497 1,
498 "10.0.0.1:9400".parse().unwrap(),
499 NodeState::Active,
500 ));
501
502 let routing = RoutingTable::uniform(1, &[1], 1);
503
504 let req = JoinRequest {
505 node_id: 2,
506 listen_addr: "10.0.0.2:9400".into(),
507 };
508
509 let _ = handle_join_request(&req, &mut topology, &routing);
511 let resp = handle_join_request(&req, &mut topology, &routing);
512
513 assert!(resp.success);
514 assert_eq!(resp.nodes.len(), 2); }
516
517 #[test]
518 fn handle_join_invalid_addr() {
519 let mut topology = ClusterTopology::new();
520 let routing = RoutingTable::uniform(1, &[1], 1);
521
522 let req = JoinRequest {
523 node_id: 2,
524 listen_addr: "not-a-valid-address".into(),
525 };
526
527 let resp = handle_join_request(&req, &mut topology, &routing);
528 assert!(!resp.success);
529 assert!(!resp.error.is_empty());
530 }
531
532 #[tokio::test]
533 async fn full_bootstrap_join_flow() {
534 use std::sync::{Arc, Mutex};
536 use std::time::Duration;
537
538 let t1 = Arc::new(NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap());
539 let t2 = Arc::new(NexarTransport::new(2, "127.0.0.1:0".parse().unwrap()).unwrap());
540
541 let (_dir1, catalog1) = temp_catalog();
542 let (_dir2, catalog2) = temp_catalog();
543
544 let addr1 = t1.local_addr();
545 let addr2 = t2.local_addr();
546
547 let config1 = ClusterConfig {
549 node_id: 1,
550 listen_addr: addr1,
551 seed_nodes: vec![addr1],
552 num_groups: 2,
553 replication_factor: 1,
554 data_dir: _dir1.path().to_path_buf(),
555 };
556 let state1 = bootstrap(&config1, &catalog1).unwrap();
557
558 let topology1 = Arc::new(Mutex::new(state1.topology));
560 let routing1 = Arc::new(state1.routing);
561
562 struct JoinHandler {
563 topology: Arc<Mutex<ClusterTopology>>,
564 routing: Arc<RoutingTable>,
565 }
566
567 impl crate::transport::RaftRpcHandler for JoinHandler {
568 async fn handle_rpc(&self, rpc: RaftRpc) -> Result<RaftRpc> {
569 match rpc {
570 RaftRpc::JoinRequest(req) => {
571 let mut topo = self.topology.lock().unwrap();
572 let resp = handle_join_request(&req, &mut topo, &self.routing);
573 Ok(RaftRpc::JoinResponse(resp))
574 }
575 other => Err(ClusterError::Transport {
576 detail: format!("unexpected: {other:?}"),
577 }),
578 }
579 }
580 }
581
582 let handler = Arc::new(JoinHandler {
583 topology: topology1.clone(),
584 routing: routing1.clone(),
585 });
586
587 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
588 let t1c = t1.clone();
589 tokio::spawn(async move {
590 t1c.serve(handler, shutdown_rx).await.unwrap();
591 });
592
593 tokio::time::sleep(Duration::from_millis(30)).await;
594
595 let config2 = ClusterConfig {
597 node_id: 2,
598 listen_addr: addr2,
599 seed_nodes: vec![addr1],
600 num_groups: 2,
601 replication_factor: 1,
602 data_dir: _dir2.path().to_path_buf(),
603 };
604
605 let state2 = join(&config2, &catalog2, &t2).await.unwrap();
606
607 assert_eq!(state2.topology.node_count(), 2);
608 assert_eq!(state2.routing.num_groups(), 2);
609
610 assert!(catalog2.load_topology().unwrap().is_some());
612 assert!(catalog2.load_routing().unwrap().is_some());
613
614 let topo1 = topology1.lock().unwrap();
616 assert_eq!(topo1.node_count(), 2);
617 assert!(topo1.contains(2));
618
619 shutdown_tx.send(true).unwrap();
620 }
621}