ant_node/replication/
bootstrap.rs1use std::collections::HashSet;
7use std::sync::Arc;
8use std::time::Duration;
9
10use crate::logging::{debug, info, warn};
11use tokio::sync::RwLock;
12use tokio_util::sync::CancellationToken;
13
14use saorsa_core::DhtNetworkEvent;
15
16use crate::ant_protocol::XorName;
17use crate::replication::scheduling::ReplicationQueues;
18use crate::replication::types::BootstrapState;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum BootstrapGateResult {
27 Received,
29 TimedOut,
31 Shutdown,
33}
34
35pub async fn wait_for_bootstrap_complete(
47 mut dht_events: tokio::sync::broadcast::Receiver<DhtNetworkEvent>,
48 timeout_secs: u64,
49 shutdown: &CancellationToken,
50) -> BootstrapGateResult {
51 let timeout = Duration::from_secs(timeout_secs);
52
53 let result = tokio::select! {
54 () = shutdown.cancelled() => {
55 debug!("Bootstrap sync: shutdown during BootstrapComplete wait");
56 BootstrapGateResult::Shutdown
57 }
58 () = tokio::time::sleep(timeout) => {
59 warn!(
60 "Bootstrap sync: timed out after {timeout_secs}s waiting for \
61 BootstrapComplete — proceeding (likely a bootstrap node with no peers)",
62 );
63 BootstrapGateResult::TimedOut
64 }
65 gate = async {
66 loop {
67 match dht_events.recv().await {
68 Ok(DhtNetworkEvent::BootstrapComplete { num_peers }) => {
69 info!(
70 "Bootstrap sync: DHT bootstrap complete \
71 with {num_peers} peers in routing table"
72 );
73 break BootstrapGateResult::Received;
74 }
75 Ok(_) => {}
76 Err(e) => {
77 warn!(
78 "Bootstrap sync: DHT event channel error: {e}, \
79 proceeding without gate"
80 );
81 break BootstrapGateResult::TimedOut;
82 }
83 }
84 }
85 } => gate,
86 };
87 drop(dht_events);
88 result
89}
90
91pub async fn mark_bootstrap_drained(bootstrap_state: &Arc<RwLock<BootstrapState>>) {
99 let mut state = bootstrap_state.write().await;
100 state.drained = true;
101 info!("Bootstrap explicitly marked as drained");
102}
103
104pub async fn check_bootstrap_drained(
113 bootstrap_state: &Arc<RwLock<BootstrapState>>,
114 queues: &ReplicationQueues,
115) -> bool {
116 let mut state = bootstrap_state.write().await;
117 if state.drained {
118 return true;
119 }
120
121 if state.pending_peer_requests > 0 {
122 return false;
123 }
124
125 if queues.is_bootstrap_work_empty(&state.pending_keys) {
126 state.drained = true;
127 info!("Bootstrap drained: all peer requests completed and work queues empty");
128 true
129 } else {
130 false
131 }
132}
133
134#[allow(clippy::implicit_hasher)]
136pub async fn track_discovered_keys(
137 bootstrap_state: &Arc<RwLock<BootstrapState>>,
138 keys: &HashSet<XorName>,
139) {
140 let mut state = bootstrap_state.write().await;
141 state.pending_keys.extend(keys);
142 debug!(
143 "Bootstrap tracking {} total discovered keys",
144 state.pending_keys.len()
145 );
146}
147
148pub async fn increment_pending_requests(
150 bootstrap_state: &Arc<RwLock<BootstrapState>>,
151 count: usize,
152) {
153 let mut state = bootstrap_state.write().await;
154 state.pending_peer_requests += count;
155}
156
157pub async fn decrement_pending_requests(
159 bootstrap_state: &Arc<RwLock<BootstrapState>>,
160 count: usize,
161) {
162 let mut state = bootstrap_state.write().await;
163 state.pending_peer_requests = state.pending_peer_requests.saturating_sub(count);
164}
165
166#[cfg(test)]
171#[allow(clippy::unwrap_used, clippy::expect_used)]
172mod tests {
173 use std::collections::HashSet;
174 use std::sync::Arc;
175
176 use tokio::sync::RwLock;
177
178 use std::time::Instant;
179
180 use super::*;
181 use crate::replication::scheduling::ReplicationQueues;
182 use crate::replication::types::{
183 BootstrapState, HintPipeline, VerificationEntry, VerificationState,
184 };
185
186 fn xor_name_from_byte(b: u8) -> XorName {
187 [b; 32]
188 }
189
190 #[tokio::test]
191 async fn check_drained_when_already_drained() {
192 let state = Arc::new(RwLock::new(BootstrapState {
193 drained: true,
194 pending_peer_requests: 5,
195 pending_keys: HashSet::new(),
196 }));
197 let queues = ReplicationQueues::new();
198
199 assert!(
200 check_bootstrap_drained(&state, &queues).await,
201 "should be drained when flag is already set"
202 );
203 }
204
205 #[tokio::test]
206 async fn check_drained_blocked_by_pending_requests() {
207 let state = Arc::new(RwLock::new(BootstrapState {
208 drained: false,
209 pending_peer_requests: 2,
210 pending_keys: HashSet::new(),
211 }));
212 let queues = ReplicationQueues::new();
213
214 assert!(
215 !check_bootstrap_drained(&state, &queues).await,
216 "should not drain with pending requests"
217 );
218 }
219
220 #[tokio::test]
221 async fn check_drained_transitions_when_all_work_done() {
222 let state = Arc::new(RwLock::new(BootstrapState {
223 drained: false,
224 pending_peer_requests: 0,
225 pending_keys: std::iter::once(xor_name_from_byte(0x01)).collect(),
226 }));
227 let queues = ReplicationQueues::new();
228
229 assert!(check_bootstrap_drained(&state, &queues).await);
231 assert!(state.read().await.drained, "drained flag should be set");
232 }
233
234 #[tokio::test]
235 async fn check_drained_blocked_by_queued_key() {
236 let state = Arc::new(RwLock::new(BootstrapState {
237 drained: false,
238 pending_peer_requests: 0,
239 pending_keys: std::iter::once(xor_name_from_byte(0x01)).collect(),
240 }));
241 let mut queues = ReplicationQueues::new();
242
243 let entry = VerificationEntry {
245 state: VerificationState::PendingVerify,
246 pipeline: HintPipeline::Replica,
247 verified_sources: Vec::new(),
248 tried_sources: HashSet::new(),
249 created_at: Instant::now(),
250 hint_sender: saorsa_core::identity::PeerId::from_bytes([0u8; 32]),
251 };
252 queues.add_pending_verify(xor_name_from_byte(0x01), entry);
253
254 assert!(
255 !check_bootstrap_drained(&state, &queues).await,
256 "should not drain while bootstrap key is still in pipeline"
257 );
258 }
259
260 #[tokio::test]
261 async fn mark_bootstrap_drained_sets_flag() {
262 let state = Arc::new(RwLock::new(BootstrapState::new()));
263 mark_bootstrap_drained(&state).await;
264 assert!(state.read().await.drained);
265 }
266
267 #[tokio::test]
268 async fn track_discovered_keys_accumulates() {
269 let state = Arc::new(RwLock::new(BootstrapState::new()));
270 let set_a: HashSet<XorName> = [xor_name_from_byte(0x01), xor_name_from_byte(0x02)]
271 .into_iter()
272 .collect();
273 let set_b: HashSet<XorName> = [xor_name_from_byte(0x02), xor_name_from_byte(0x03)]
274 .into_iter()
275 .collect();
276
277 track_discovered_keys(&state, &set_a).await;
278 track_discovered_keys(&state, &set_b).await;
279
280 let s = state.read().await;
281 assert_eq!(s.pending_keys.len(), 3, "should deduplicate across calls");
282 }
283
284 #[tokio::test]
285 async fn increment_and_decrement_pending_requests() {
286 let state = Arc::new(RwLock::new(BootstrapState::new()));
287
288 increment_pending_requests(&state, 5).await;
289 assert_eq!(state.read().await.pending_peer_requests, 5);
290
291 decrement_pending_requests(&state, 3).await;
292 assert_eq!(state.read().await.pending_peer_requests, 2);
293
294 decrement_pending_requests(&state, 10).await;
296 assert_eq!(
297 state.read().await.pending_peer_requests,
298 0,
299 "should saturate at zero"
300 );
301 }
302}