Skip to main content

nodedb_cluster/
bootstrap.rs

1//! Cluster bootstrap and join protocol.
2//!
3//! Three startup paths:
4//!
5//! 1. **Bootstrap**: First seed node — creates topology, routing table, Raft groups,
6//!    persists to catalog. The cluster is born.
7//!
8//! 2. **Join**: New node contacts a seed, receives full cluster state via
9//!    `JoinResponse`, persists, and registers peers.
10//!
11//! 3. **Restart**: Node loads topology + routing from catalog, reconnects to
12//!    known peers.
13
14use std::net::SocketAddr;
15use tracing::{info, warn};
16
17use crate::catalog::ClusterCatalog;
18use crate::error::{ClusterError, Result};
19use crate::multi_raft::MultiRaft;
20use crate::routing::{GroupInfo, RoutingTable};
21use crate::rpc_codec::{JoinGroupInfo, JoinNodeInfo, JoinRequest, JoinResponse, RaftRpc};
22use crate::topology::{ClusterTopology, NodeInfo, NodeState};
23use crate::transport::NexarTransport;
24
25/// Configuration for cluster formation.
26#[derive(Debug, Clone)]
27pub struct ClusterConfig {
28    /// This node's unique ID.
29    pub node_id: u64,
30    /// Address to listen on for Raft RPCs.
31    pub listen_addr: SocketAddr,
32    /// Seed node addresses for bootstrap/join.
33    pub seed_nodes: Vec<SocketAddr>,
34    /// Number of Raft groups to create on bootstrap.
35    pub num_groups: u64,
36    /// Replication factor (number of replicas per group).
37    pub replication_factor: usize,
38    /// Data directory for persistent Raft log storage.
39    pub data_dir: std::path::PathBuf,
40}
41
42/// Result of cluster startup — everything needed to run the Raft loop.
43pub struct ClusterState {
44    pub topology: ClusterTopology,
45    pub routing: RoutingTable,
46    pub multi_raft: MultiRaft,
47}
48
49/// Start the cluster — bootstrap, join, or restart depending on state.
50///
51/// Returns the initialized cluster state ready for the Raft loop.
52pub async fn start_cluster(
53    config: &ClusterConfig,
54    catalog: &ClusterCatalog,
55    transport: &NexarTransport,
56) -> Result<ClusterState> {
57    // Check if we have existing state.
58    if catalog.is_bootstrapped()? {
59        return restart(config, catalog, transport);
60    }
61
62    // No existing state — try bootstrap or join.
63    let is_seed = config.seed_nodes.contains(&config.listen_addr);
64
65    if is_seed && should_bootstrap(config, transport).await {
66        bootstrap(config, catalog)
67    } else {
68        join(config, catalog, transport).await
69    }
70}
71
72/// Check if this seed should bootstrap a new cluster.
73///
74/// A seed bootstraps if no other seed is already running.
75async fn should_bootstrap(config: &ClusterConfig, transport: &NexarTransport) -> bool {
76    for addr in &config.seed_nodes {
77        if *addr == config.listen_addr {
78            continue;
79        }
80        // Try to contact another seed.
81        let probe = RaftRpc::JoinRequest(JoinRequest {
82            node_id: config.node_id,
83            listen_addr: config.listen_addr.to_string(),
84        });
85        match transport.send_rpc_to_addr(*addr, probe).await {
86            Ok(_) => return false, // Another seed is alive — join instead.
87            Err(_) => continue,    // Seed not reachable — keep checking.
88        }
89    }
90    // No other seed responded — we bootstrap.
91    true
92}
93
94// ── Bootstrap ───────────────────────────────────────────────────────
95
96/// Bootstrap a new cluster: this node is the founding member.
97fn bootstrap(config: &ClusterConfig, catalog: &ClusterCatalog) -> Result<ClusterState> {
98    info!(
99        node_id = config.node_id,
100        addr = %config.listen_addr,
101        groups = config.num_groups,
102        "bootstrapping new cluster"
103    );
104
105    // Create topology with this node.
106    let mut topology = ClusterTopology::new();
107    topology.add_node(NodeInfo::new(
108        config.node_id,
109        config.listen_addr,
110        NodeState::Active,
111    ));
112
113    // Create routing table: all groups on this single node.
114    let routing = RoutingTable::uniform(
115        config.num_groups,
116        &[config.node_id],
117        config.replication_factor.min(1), // Single node → RF=1.
118    );
119
120    // Create MultiRaft with all groups (single-node, no peers).
121    let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
122    for group_id in routing.group_ids() {
123        multi_raft.add_group(group_id, vec![])?;
124    }
125
126    // Generate cluster ID and persist everything.
127    let cluster_id = generate_cluster_id();
128    catalog.save_cluster_id(cluster_id)?;
129    catalog.save_topology(&topology)?;
130    catalog.save_routing(&routing)?;
131
132    info!(
133        node_id = config.node_id,
134        cluster_id,
135        groups = config.num_groups,
136        "cluster bootstrapped"
137    );
138
139    Ok(ClusterState {
140        topology,
141        routing,
142        multi_raft,
143    })
144}
145
146// ── Join ────────────────────────────────────────────────────────────
147
148/// Join an existing cluster by contacting seed nodes.
149async fn join(
150    config: &ClusterConfig,
151    catalog: &ClusterCatalog,
152    transport: &NexarTransport,
153) -> Result<ClusterState> {
154    info!(
155        node_id = config.node_id,
156        seeds = ?config.seed_nodes,
157        "joining existing cluster"
158    );
159
160    let req = RaftRpc::JoinRequest(JoinRequest {
161        node_id: config.node_id,
162        listen_addr: config.listen_addr.to_string(),
163    });
164
165    // Try each seed until one accepts.
166    let mut last_err = None;
167    for addr in &config.seed_nodes {
168        match transport.send_rpc_to_addr(*addr, req.clone()).await {
169            Ok(RaftRpc::JoinResponse(resp)) => {
170                if !resp.success {
171                    last_err = Some(ClusterError::Transport {
172                        detail: format!("join rejected by {addr}: {}", resp.error),
173                    });
174                    continue;
175                }
176                return apply_join_response(config, catalog, transport, &resp);
177            }
178            Ok(other) => {
179                last_err = Some(ClusterError::Transport {
180                    detail: format!("unexpected response from {addr}: {other:?}"),
181                });
182            }
183            Err(e) => {
184                warn!(%addr, error = %e, "seed unreachable");
185                last_err = Some(e);
186            }
187        }
188    }
189
190    Err(last_err.unwrap_or_else(|| ClusterError::Transport {
191        detail: "no seed nodes configured".into(),
192    }))
193}
194
195/// Apply a JoinResponse: reconstruct topology, routing, and MultiRaft from wire data.
196fn apply_join_response(
197    config: &ClusterConfig,
198    catalog: &ClusterCatalog,
199    transport: &NexarTransport,
200    resp: &JoinResponse,
201) -> Result<ClusterState> {
202    // Reconstruct topology.
203    let mut topology = ClusterTopology::new();
204    for node in &resp.nodes {
205        let state = NodeState::from_u8(node.state).unwrap_or(NodeState::Active);
206        let mut info = NodeInfo {
207            node_id: node.node_id,
208            addr: node.addr.clone(),
209            state,
210            raft_groups: node.raft_groups.clone(),
211        };
212        // If this is us, mark as Active.
213        if node.node_id == config.node_id {
214            info.state = NodeState::Active;
215        }
216        topology.add_node(info);
217    }
218
219    // Reconstruct routing table.
220    let mut group_members = std::collections::HashMap::new();
221    for g in &resp.groups {
222        group_members.insert(
223            g.group_id,
224            GroupInfo {
225                leader: g.leader,
226                members: g.members.clone(),
227            },
228        );
229    }
230    let routing = RoutingTable::from_parts(resp.vshard_to_group.clone(), group_members);
231
232    // Create MultiRaft — join the groups that include this node.
233    let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
234    for g in &resp.groups {
235        if g.members.contains(&config.node_id) {
236            let peers: Vec<u64> = g
237                .members
238                .iter()
239                .copied()
240                .filter(|&id| id != config.node_id)
241                .collect();
242            multi_raft.add_group(g.group_id, peers)?;
243        }
244    }
245
246    // Register all peers in the transport.
247    for node in &resp.nodes {
248        if node.node_id != config.node_id
249            && let Ok(addr) = node.addr.parse::<SocketAddr>()
250        {
251            transport.register_peer(node.node_id, addr);
252        }
253    }
254
255    // Persist.
256    catalog.save_topology(&topology)?;
257    catalog.save_routing(&routing)?;
258
259    info!(
260        node_id = config.node_id,
261        nodes = topology.node_count(),
262        groups = routing.num_groups(),
263        "joined cluster"
264    );
265
266    Ok(ClusterState {
267        topology,
268        routing,
269        multi_raft,
270    })
271}
272
273// ── Restart ─────────────────────────────────────────────────────────
274
275/// Restart from persisted state — load topology and routing from catalog.
276fn restart(
277    config: &ClusterConfig,
278    catalog: &ClusterCatalog,
279    transport: &NexarTransport,
280) -> Result<ClusterState> {
281    let topology = catalog
282        .load_topology()?
283        .ok_or_else(|| ClusterError::Transport {
284            detail: "catalog is bootstrapped but topology is missing".into(),
285        })?;
286
287    let routing = catalog
288        .load_routing()?
289        .ok_or_else(|| ClusterError::Transport {
290            detail: "catalog is bootstrapped but routing table is missing".into(),
291        })?;
292
293    // Reconstruct MultiRaft from routing table.
294    let mut multi_raft = MultiRaft::new(config.node_id, routing.clone(), config.data_dir.clone());
295    for (group_id, info) in routing.group_members() {
296        if info.members.contains(&config.node_id) {
297            let peers: Vec<u64> = info
298                .members
299                .iter()
300                .copied()
301                .filter(|&id| id != config.node_id)
302                .collect();
303            multi_raft.add_group(*group_id, peers)?;
304        }
305    }
306
307    // Register all known peers in the transport.
308    for node in topology.all_nodes() {
309        if node.node_id != config.node_id
310            && let Some(addr) = node.socket_addr()
311        {
312            transport.register_peer(node.node_id, addr);
313        }
314    }
315
316    info!(
317        node_id = config.node_id,
318        nodes = topology.node_count(),
319        groups = multi_raft.group_count(),
320        "restarted from catalog"
321    );
322
323    Ok(ClusterState {
324        topology,
325        routing,
326        multi_raft,
327    })
328}
329
330// ── Join request handler ────────────────────────────────────────────
331
332/// Build a JoinResponse from current cluster state.
333///
334/// Called by the RPC handler on the seed/leader node when a JoinRequest arrives.
335pub fn handle_join_request(
336    req: &JoinRequest,
337    topology: &mut ClusterTopology,
338    routing: &RoutingTable,
339) -> JoinResponse {
340    // Add the new node to topology.
341    let addr: SocketAddr = match req.listen_addr.parse() {
342        Ok(a) => a,
343        Err(e) => {
344            return JoinResponse {
345                success: false,
346                error: format!("invalid listen_addr '{}': {e}", req.listen_addr),
347                nodes: vec![],
348                vshard_to_group: vec![],
349                groups: vec![],
350            };
351        }
352    };
353
354    if topology.contains(req.node_id) {
355        // Node already known — update address if changed.
356        if let Some(existing) = topology.get_node_mut(req.node_id) {
357            existing.addr = req.listen_addr.clone();
358            existing.state = NodeState::Active;
359        }
360    } else {
361        topology.add_node(NodeInfo::new(req.node_id, addr, NodeState::Active));
362    }
363
364    // Build wire response.
365    let nodes: Vec<JoinNodeInfo> = topology
366        .all_nodes()
367        .map(|n| JoinNodeInfo {
368            node_id: n.node_id,
369            addr: n.addr.clone(),
370            state: n.state.as_u8(),
371            raft_groups: n.raft_groups.clone(),
372        })
373        .collect();
374
375    let groups: Vec<JoinGroupInfo> = routing
376        .group_members()
377        .iter()
378        .map(|(&gid, info)| JoinGroupInfo {
379            group_id: gid,
380            leader: info.leader,
381            members: info.members.clone(),
382        })
383        .collect();
384
385    JoinResponse {
386        success: true,
387        error: String::new(),
388        nodes,
389        vshard_to_group: routing.vshard_to_group().to_vec(),
390        groups,
391    }
392}
393
394/// Generate a unique cluster ID (random u64).
395fn generate_cluster_id() -> u64 {
396    use rand::Rng;
397    rand::rng().random::<u64>()
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::catalog::ClusterCatalog;
404
405    fn temp_catalog() -> (tempfile::TempDir, ClusterCatalog) {
406        let dir = tempfile::tempdir().unwrap();
407        let path = dir.path().join("cluster.redb");
408        let catalog = ClusterCatalog::open(&path).unwrap();
409        (dir, catalog)
410    }
411
412    #[test]
413    fn bootstrap_creates_cluster() {
414        let (_dir, catalog) = temp_catalog();
415        let config = ClusterConfig {
416            node_id: 1,
417            listen_addr: "127.0.0.1:9400".parse().unwrap(),
418            seed_nodes: vec!["127.0.0.1:9400".parse().unwrap()],
419            num_groups: 4,
420            replication_factor: 1,
421            data_dir: _dir.path().to_path_buf(),
422        };
423
424        let state = bootstrap(&config, &catalog).unwrap();
425
426        assert_eq!(state.topology.node_count(), 1);
427        assert_eq!(state.topology.active_nodes().len(), 1);
428        assert_eq!(state.routing.num_groups(), 4);
429        assert_eq!(state.multi_raft.group_count(), 4);
430
431        // Verify persistence.
432        assert!(catalog.is_bootstrapped().unwrap());
433        let loaded_topo = catalog.load_topology().unwrap().unwrap();
434        assert_eq!(loaded_topo.node_count(), 1);
435        let loaded_rt = catalog.load_routing().unwrap().unwrap();
436        assert_eq!(loaded_rt.num_groups(), 4);
437    }
438
439    #[tokio::test]
440    async fn restart_from_catalog() {
441        let (_dir, catalog) = temp_catalog();
442        let config = ClusterConfig {
443            node_id: 1,
444            listen_addr: "127.0.0.1:9400".parse().unwrap(),
445            seed_nodes: vec![],
446            num_groups: 4,
447            replication_factor: 1,
448            data_dir: _dir.path().to_path_buf(),
449        };
450
451        // Bootstrap first.
452        let _ = bootstrap(&config, &catalog).unwrap();
453
454        // Create transport for restart.
455        let transport = NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap();
456
457        // Restart — should load from catalog.
458        let state = restart(&config, &catalog, &transport).unwrap();
459
460        assert_eq!(state.topology.node_count(), 1);
461        assert_eq!(state.routing.num_groups(), 4);
462        assert_eq!(state.multi_raft.group_count(), 4);
463    }
464
465    #[test]
466    fn handle_join_request_adds_node() {
467        let mut topology = ClusterTopology::new();
468        topology.add_node(NodeInfo::new(
469            1,
470            "10.0.0.1:9400".parse().unwrap(),
471            NodeState::Active,
472        ));
473
474        let routing = RoutingTable::uniform(2, &[1], 1);
475
476        let req = JoinRequest {
477            node_id: 2,
478            listen_addr: "10.0.0.2:9400".into(),
479        };
480
481        let resp = handle_join_request(&req, &mut topology, &routing);
482
483        assert!(resp.success);
484        assert_eq!(resp.nodes.len(), 2);
485        assert_eq!(resp.vshard_to_group.len(), 1024);
486        assert_eq!(resp.groups.len(), 2);
487
488        // Topology should now contain node 2.
489        assert!(topology.contains(2));
490        assert_eq!(topology.node_count(), 2);
491    }
492
493    #[test]
494    fn handle_join_request_idempotent() {
495        let mut topology = ClusterTopology::new();
496        topology.add_node(NodeInfo::new(
497            1,
498            "10.0.0.1:9400".parse().unwrap(),
499            NodeState::Active,
500        ));
501
502        let routing = RoutingTable::uniform(1, &[1], 1);
503
504        let req = JoinRequest {
505            node_id: 2,
506            listen_addr: "10.0.0.2:9400".into(),
507        };
508
509        // Join twice — should be idempotent.
510        let _ = handle_join_request(&req, &mut topology, &routing);
511        let resp = handle_join_request(&req, &mut topology, &routing);
512
513        assert!(resp.success);
514        assert_eq!(resp.nodes.len(), 2); // Still 2, not 3.
515    }
516
517    #[test]
518    fn handle_join_invalid_addr() {
519        let mut topology = ClusterTopology::new();
520        let routing = RoutingTable::uniform(1, &[1], 1);
521
522        let req = JoinRequest {
523            node_id: 2,
524            listen_addr: "not-a-valid-address".into(),
525        };
526
527        let resp = handle_join_request(&req, &mut topology, &routing);
528        assert!(!resp.success);
529        assert!(!resp.error.is_empty());
530    }
531
532    #[tokio::test]
533    async fn full_bootstrap_join_flow() {
534        // Node 1 bootstraps, Node 2 joins via QUIC.
535        use std::sync::{Arc, Mutex};
536        use std::time::Duration;
537
538        let t1 = Arc::new(NexarTransport::new(1, "127.0.0.1:0".parse().unwrap()).unwrap());
539        let t2 = Arc::new(NexarTransport::new(2, "127.0.0.1:0".parse().unwrap()).unwrap());
540
541        let (_dir1, catalog1) = temp_catalog();
542        let (_dir2, catalog2) = temp_catalog();
543
544        let addr1 = t1.local_addr();
545        let addr2 = t2.local_addr();
546
547        // Bootstrap node 1.
548        let config1 = ClusterConfig {
549            node_id: 1,
550            listen_addr: addr1,
551            seed_nodes: vec![addr1],
552            num_groups: 2,
553            replication_factor: 1,
554            data_dir: _dir1.path().to_path_buf(),
555        };
556        let state1 = bootstrap(&config1, &catalog1).unwrap();
557
558        // Set up a handler for node 1 that handles JoinRequests.
559        let topology1 = Arc::new(Mutex::new(state1.topology));
560        let routing1 = Arc::new(state1.routing);
561
562        struct JoinHandler {
563            topology: Arc<Mutex<ClusterTopology>>,
564            routing: Arc<RoutingTable>,
565        }
566
567        impl crate::transport::RaftRpcHandler for JoinHandler {
568            async fn handle_rpc(&self, rpc: RaftRpc) -> Result<RaftRpc> {
569                match rpc {
570                    RaftRpc::JoinRequest(req) => {
571                        let mut topo = self.topology.lock().unwrap();
572                        let resp = handle_join_request(&req, &mut topo, &self.routing);
573                        Ok(RaftRpc::JoinResponse(resp))
574                    }
575                    other => Err(ClusterError::Transport {
576                        detail: format!("unexpected: {other:?}"),
577                    }),
578                }
579            }
580        }
581
582        let handler = Arc::new(JoinHandler {
583            topology: topology1.clone(),
584            routing: routing1.clone(),
585        });
586
587        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
588        let t1c = t1.clone();
589        tokio::spawn(async move {
590            t1c.serve(handler, shutdown_rx).await.unwrap();
591        });
592
593        tokio::time::sleep(Duration::from_millis(30)).await;
594
595        // Node 2 joins.
596        let config2 = ClusterConfig {
597            node_id: 2,
598            listen_addr: addr2,
599            seed_nodes: vec![addr1],
600            num_groups: 2,
601            replication_factor: 1,
602            data_dir: _dir2.path().to_path_buf(),
603        };
604
605        let state2 = join(&config2, &catalog2, &t2).await.unwrap();
606
607        assert_eq!(state2.topology.node_count(), 2);
608        assert_eq!(state2.routing.num_groups(), 2);
609
610        // Verify node 2's state was persisted.
611        assert!(catalog2.load_topology().unwrap().is_some());
612        assert!(catalog2.load_routing().unwrap().is_some());
613
614        // Verify node 1's topology was updated.
615        let topo1 = topology1.lock().unwrap();
616        assert_eq!(topo1.node_count(), 2);
617        assert!(topo1.contains(2));
618
619        shutdown_tx.send(true).unwrap();
620    }
621}