Skip to main content

ant_node/replication/
bootstrap.rs

1//! New-node bootstrap logic (Section 16).
2//!
3//! A joining node performs active sync to discover and verify keys it should
4//! hold, then transitions to normal operation once all bootstrap work drains.
5
6use std::collections::HashSet;
7use std::sync::Arc;
8use std::time::Duration;
9
10use crate::logging::{debug, info, warn};
11use tokio::sync::RwLock;
12use tokio_util::sync::CancellationToken;
13
14use saorsa_core::DhtNetworkEvent;
15
16use crate::ant_protocol::XorName;
17use crate::replication::scheduling::ReplicationQueues;
18use crate::replication::types::BootstrapState;
19
20// ---------------------------------------------------------------------------
21// DHT bootstrap gate
22// ---------------------------------------------------------------------------
23
24/// Outcome of waiting for the `DhtNetworkEvent::BootstrapComplete` event.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum BootstrapGateResult {
27    /// The event was received — routing table is populated.
28    Received,
29    /// Timed out or channel error — proceed anyway (bootstrap node scenario).
30    TimedOut,
31    /// Shutdown was requested while waiting.
32    Shutdown,
33}
34
35/// Wait for saorsa-core's `DhtNetworkEvent::BootstrapComplete` before
36/// returning.
37///
38/// The caller must supply a pre-subscribed `dht_events` receiver. This is
39/// critical: the subscription must be created **before**
40/// `P2PNode::start()` so the `BootstrapComplete` event is not missed.
41///
42/// Returns [`BootstrapGateResult::Received`] on success,
43/// [`BootstrapGateResult::TimedOut`] if the timeout elapses (e.g. a
44/// bootstrap node with no peers), or [`BootstrapGateResult::Shutdown`] if
45/// cancellation is signalled.
46pub async fn wait_for_bootstrap_complete(
47    mut dht_events: tokio::sync::broadcast::Receiver<DhtNetworkEvent>,
48    timeout_secs: u64,
49    shutdown: &CancellationToken,
50) -> BootstrapGateResult {
51    let timeout = Duration::from_secs(timeout_secs);
52
53    let result = tokio::select! {
54        () = shutdown.cancelled() => {
55            debug!("Bootstrap sync: shutdown during BootstrapComplete wait");
56            BootstrapGateResult::Shutdown
57        }
58        () = tokio::time::sleep(timeout) => {
59            warn!(
60                "Bootstrap sync: timed out after {timeout_secs}s waiting for \
61                 BootstrapComplete — proceeding (likely a bootstrap node with no peers)",
62            );
63            BootstrapGateResult::TimedOut
64        }
65        gate = async {
66            loop {
67                match dht_events.recv().await {
68                    Ok(DhtNetworkEvent::BootstrapComplete { num_peers }) => {
69                        info!(
70                            "Bootstrap sync: DHT bootstrap complete \
71                             with {num_peers} peers in routing table"
72                        );
73                        break BootstrapGateResult::Received;
74                    }
75                    Ok(_) => {}
76                    Err(e) => {
77                        warn!(
78                            "Bootstrap sync: DHT event channel error: {e}, \
79                             proceeding without gate"
80                        );
81                        break BootstrapGateResult::TimedOut;
82                    }
83                }
84            }
85        } => gate,
86    };
87    drop(dht_events);
88    result
89}
90
91// ---------------------------------------------------------------------------
92// Bootstrap sync
93// ---------------------------------------------------------------------------
94
95// `snapshot_close_neighbors` is defined in `neighbor_sync` and re-used here.
96
97/// Mark bootstrap as complete, updating the shared state.
98pub async fn mark_bootstrap_drained(bootstrap_state: &Arc<RwLock<BootstrapState>>) {
99    let mut state = bootstrap_state.write().await;
100    state.drained = true;
101    info!("Bootstrap explicitly marked as drained");
102}
103
104/// Check if bootstrap is drained and update state if so.
105///
106/// Bootstrap is drained when:
107/// 1. All bootstrap peer requests have completed.
108/// 2. All bootstrap-discovered keys have left the pipeline (no longer in
109///    `PendingVerify`, `FetchQueue`, or `InFlightFetch`).
110///
111/// Returns `true` if bootstrap is (now) drained.
112pub async fn check_bootstrap_drained(
113    bootstrap_state: &Arc<RwLock<BootstrapState>>,
114    queues: &ReplicationQueues,
115) -> bool {
116    let mut state = bootstrap_state.write().await;
117    if state.drained {
118        return true;
119    }
120
121    if state.pending_peer_requests > 0 {
122        return false;
123    }
124
125    if queues.is_bootstrap_work_empty(&state.pending_keys) {
126        state.drained = true;
127        info!("Bootstrap drained: all peer requests completed and work queues empty");
128        true
129    } else {
130        false
131    }
132}
133
134/// Record a set of discovered keys into the bootstrap state for drain tracking.
135#[allow(clippy::implicit_hasher)]
136pub async fn track_discovered_keys(
137    bootstrap_state: &Arc<RwLock<BootstrapState>>,
138    keys: &HashSet<XorName>,
139) {
140    let mut state = bootstrap_state.write().await;
141    state.pending_keys.extend(keys);
142    debug!(
143        "Bootstrap tracking {} total discovered keys",
144        state.pending_keys.len()
145    );
146}
147
148/// Increment the pending peer request counter.
149pub async fn increment_pending_requests(
150    bootstrap_state: &Arc<RwLock<BootstrapState>>,
151    count: usize,
152) {
153    let mut state = bootstrap_state.write().await;
154    state.pending_peer_requests += count;
155}
156
157/// Decrement the pending peer request counter (saturating).
158pub async fn decrement_pending_requests(
159    bootstrap_state: &Arc<RwLock<BootstrapState>>,
160    count: usize,
161) {
162    let mut state = bootstrap_state.write().await;
163    state.pending_peer_requests = state.pending_peer_requests.saturating_sub(count);
164}
165
166// ---------------------------------------------------------------------------
167// Tests
168// ---------------------------------------------------------------------------
169
170#[cfg(test)]
171#[allow(clippy::unwrap_used, clippy::expect_used)]
172mod tests {
173    use std::collections::HashSet;
174    use std::sync::Arc;
175
176    use tokio::sync::RwLock;
177
178    use std::time::Instant;
179
180    use super::*;
181    use crate::replication::scheduling::ReplicationQueues;
182    use crate::replication::types::{
183        BootstrapState, HintPipeline, VerificationEntry, VerificationState,
184    };
185
186    fn xor_name_from_byte(b: u8) -> XorName {
187        [b; 32]
188    }
189
190    #[tokio::test]
191    async fn check_drained_when_already_drained() {
192        let state = Arc::new(RwLock::new(BootstrapState {
193            drained: true,
194            pending_peer_requests: 5,
195            pending_keys: HashSet::new(),
196        }));
197        let queues = ReplicationQueues::new();
198
199        assert!(
200            check_bootstrap_drained(&state, &queues).await,
201            "should be drained when flag is already set"
202        );
203    }
204
205    #[tokio::test]
206    async fn check_drained_blocked_by_pending_requests() {
207        let state = Arc::new(RwLock::new(BootstrapState {
208            drained: false,
209            pending_peer_requests: 2,
210            pending_keys: HashSet::new(),
211        }));
212        let queues = ReplicationQueues::new();
213
214        assert!(
215            !check_bootstrap_drained(&state, &queues).await,
216            "should not drain with pending requests"
217        );
218    }
219
220    #[tokio::test]
221    async fn check_drained_transitions_when_all_work_done() {
222        let state = Arc::new(RwLock::new(BootstrapState {
223            drained: false,
224            pending_peer_requests: 0,
225            pending_keys: std::iter::once(xor_name_from_byte(0x01)).collect(),
226        }));
227        let queues = ReplicationQueues::new();
228
229        // Key 0x01 is not in any queue, so bootstrap should drain.
230        assert!(check_bootstrap_drained(&state, &queues).await);
231        assert!(state.read().await.drained, "drained flag should be set");
232    }
233
234    #[tokio::test]
235    async fn check_drained_blocked_by_queued_key() {
236        let state = Arc::new(RwLock::new(BootstrapState {
237            drained: false,
238            pending_peer_requests: 0,
239            pending_keys: std::iter::once(xor_name_from_byte(0x01)).collect(),
240        }));
241        let mut queues = ReplicationQueues::new();
242
243        // Put the bootstrap key into the pending-verify queue.
244        let entry = VerificationEntry {
245            state: VerificationState::PendingVerify,
246            pipeline: HintPipeline::Replica,
247            verified_sources: Vec::new(),
248            tried_sources: HashSet::new(),
249            created_at: Instant::now(),
250            hint_sender: saorsa_core::identity::PeerId::from_bytes([0u8; 32]),
251        };
252        queues.add_pending_verify(xor_name_from_byte(0x01), entry);
253
254        assert!(
255            !check_bootstrap_drained(&state, &queues).await,
256            "should not drain while bootstrap key is still in pipeline"
257        );
258    }
259
260    #[tokio::test]
261    async fn mark_bootstrap_drained_sets_flag() {
262        let state = Arc::new(RwLock::new(BootstrapState::new()));
263        mark_bootstrap_drained(&state).await;
264        assert!(state.read().await.drained);
265    }
266
267    #[tokio::test]
268    async fn track_discovered_keys_accumulates() {
269        let state = Arc::new(RwLock::new(BootstrapState::new()));
270        let set_a: HashSet<XorName> = [xor_name_from_byte(0x01), xor_name_from_byte(0x02)]
271            .into_iter()
272            .collect();
273        let set_b: HashSet<XorName> = [xor_name_from_byte(0x02), xor_name_from_byte(0x03)]
274            .into_iter()
275            .collect();
276
277        track_discovered_keys(&state, &set_a).await;
278        track_discovered_keys(&state, &set_b).await;
279
280        let s = state.read().await;
281        assert_eq!(s.pending_keys.len(), 3, "should deduplicate across calls");
282    }
283
284    #[tokio::test]
285    async fn increment_and_decrement_pending_requests() {
286        let state = Arc::new(RwLock::new(BootstrapState::new()));
287
288        increment_pending_requests(&state, 5).await;
289        assert_eq!(state.read().await.pending_peer_requests, 5);
290
291        decrement_pending_requests(&state, 3).await;
292        assert_eq!(state.read().await.pending_peer_requests, 2);
293
294        // Saturating subtraction.
295        decrement_pending_requests(&state, 10).await;
296        assert_eq!(
297            state.read().await.pending_peer_requests,
298            0,
299            "should saturate at zero"
300        );
301    }
302}