1use std::collections::HashSet;
7use std::sync::Arc;
8use std::time::Duration;
9
10use crate::logging::{debug, info, warn};
11use tokio::sync::RwLock;
12use tokio_util::sync::CancellationToken;
13
14use saorsa_core::DhtNetworkEvent;
15
16use crate::ant_protocol::XorName;
17use crate::replication::scheduling::ReplicationQueues;
18use crate::replication::types::BootstrapState;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum BootstrapGateResult {
27 Received,
29 TimedOut,
31 Shutdown,
33}
34
35pub async fn wait_for_bootstrap_complete(
47 mut dht_events: tokio::sync::broadcast::Receiver<DhtNetworkEvent>,
48 timeout_secs: u64,
49 shutdown: &CancellationToken,
50) -> BootstrapGateResult {
51 let timeout = Duration::from_secs(timeout_secs);
52
53 let result = tokio::select! {
54 () = shutdown.cancelled() => {
55 debug!("Bootstrap sync: shutdown during BootstrapComplete wait");
56 BootstrapGateResult::Shutdown
57 }
58 () = tokio::time::sleep(timeout) => {
59 warn!(
60 "Bootstrap sync: timed out after {timeout_secs}s waiting for \
61 BootstrapComplete — proceeding (likely a bootstrap node with no peers)",
62 );
63 BootstrapGateResult::TimedOut
64 }
65 gate = async {
66 loop {
67 match dht_events.recv().await {
68 Ok(DhtNetworkEvent::BootstrapComplete { num_peers }) => {
69 info!(
70 "Bootstrap sync: DHT bootstrap complete \
71 with {num_peers} peers in routing table"
72 );
73 break BootstrapGateResult::Received;
74 }
75 Ok(_) => {}
76 Err(e) => {
77 warn!(
78 "Bootstrap sync: DHT event channel error: {e}, \
79 proceeding without gate"
80 );
81 break BootstrapGateResult::TimedOut;
82 }
83 }
84 }
85 } => gate,
86 };
87 drop(dht_events);
88 result
89}
90
91pub async fn mark_bootstrap_drained(bootstrap_state: &Arc<RwLock<BootstrapState>>) {
99 let mut state = bootstrap_state.write().await;
100 state.drained = true;
101 info!("Bootstrap explicitly marked as drained");
102}
103
104pub async fn check_bootstrap_drained(
113 bootstrap_state: &Arc<RwLock<BootstrapState>>,
114 queues: &ReplicationQueues,
115) -> bool {
116 let mut state = bootstrap_state.write().await;
117 if state.drained {
118 return true;
119 }
120
121 if state.pending_peer_requests > 0 {
122 return false;
123 }
124
125 if !state.capacity_rejected_sources.is_empty() {
132 let n = state.capacity_rejected_sources.len();
133 debug!("Bootstrap NOT drained: {n} source(s) have outstanding capacity-rejected hints");
134 return false;
135 }
136
137 if queues.is_bootstrap_work_empty(&state.pending_keys) {
138 state.drained = true;
139 info!("Bootstrap drained: all peer requests completed and work queues empty");
140 true
141 } else {
142 false
143 }
144}
145
146pub async fn note_capacity_rejected(
154 bootstrap_state: &Arc<RwLock<BootstrapState>>,
155 source: saorsa_core::identity::PeerId,
156) {
157 let mut state = bootstrap_state.write().await;
158 if state.capacity_rejected_sources.insert(source) {
159 let n = state.capacity_rejected_sources.len();
160 debug!(
161 "Bootstrap: source {source} now has outstanding capacity-rejected hints \
162 ({n} sources outstanding)"
163 );
164 }
165}
166
167pub async fn clear_capacity_rejected(
174 bootstrap_state: &Arc<RwLock<BootstrapState>>,
175 source: &saorsa_core::identity::PeerId,
176) {
177 let mut state = bootstrap_state.write().await;
178 if state.capacity_rejected_sources.remove(source) {
179 let n = state.capacity_rejected_sources.len();
180 debug!(
181 "Bootstrap: cleared outstanding capacity rejections for {source} \
182 ({n} sources still outstanding)"
183 );
184 }
185}
186
187#[allow(clippy::implicit_hasher)]
189pub async fn track_discovered_keys(
190 bootstrap_state: &Arc<RwLock<BootstrapState>>,
191 keys: &HashSet<XorName>,
192) {
193 let mut state = bootstrap_state.write().await;
194 state.pending_keys.extend(keys);
195 debug!(
196 "Bootstrap tracking {} total discovered keys",
197 state.pending_keys.len()
198 );
199}
200
201pub async fn increment_pending_requests(
203 bootstrap_state: &Arc<RwLock<BootstrapState>>,
204 count: usize,
205) {
206 let mut state = bootstrap_state.write().await;
207 state.pending_peer_requests += count;
208}
209
210pub async fn decrement_pending_requests(
212 bootstrap_state: &Arc<RwLock<BootstrapState>>,
213 count: usize,
214) {
215 let mut state = bootstrap_state.write().await;
216 state.pending_peer_requests = state.pending_peer_requests.saturating_sub(count);
217}
218
219#[cfg(test)]
224#[allow(clippy::unwrap_used, clippy::expect_used)]
225mod tests {
226 use std::collections::HashSet;
227 use std::sync::Arc;
228
229 use tokio::sync::RwLock;
230
231 use std::time::Instant;
232
233 use super::*;
234 use crate::replication::scheduling::ReplicationQueues;
235 use crate::replication::types::{
236 BootstrapState, HintPipeline, VerificationEntry, VerificationState,
237 };
238
239 fn xor_name_from_byte(b: u8) -> XorName {
240 [b; 32]
241 }
242
243 #[tokio::test]
244 async fn check_drained_when_already_drained() {
245 let state = Arc::new(RwLock::new(BootstrapState {
246 drained: true,
247 pending_peer_requests: 5,
248 pending_keys: HashSet::new(),
249 capacity_rejected_sources: HashSet::new(),
250 }));
251 let queues = ReplicationQueues::new();
252
253 assert!(
254 check_bootstrap_drained(&state, &queues).await,
255 "should be drained when flag is already set"
256 );
257 }
258
259 #[tokio::test]
260 async fn check_drained_blocked_by_pending_requests() {
261 let state = Arc::new(RwLock::new(BootstrapState {
262 drained: false,
263 pending_peer_requests: 2,
264 pending_keys: HashSet::new(),
265 capacity_rejected_sources: HashSet::new(),
266 }));
267 let queues = ReplicationQueues::new();
268
269 assert!(
270 !check_bootstrap_drained(&state, &queues).await,
271 "should not drain with pending requests"
272 );
273 }
274
275 #[tokio::test]
276 async fn check_drained_transitions_when_all_work_done() {
277 let state = Arc::new(RwLock::new(BootstrapState {
278 drained: false,
279 pending_peer_requests: 0,
280 pending_keys: std::iter::once(xor_name_from_byte(0x01)).collect(),
281 capacity_rejected_sources: HashSet::new(),
282 }));
283 let queues = ReplicationQueues::new();
284
285 assert!(check_bootstrap_drained(&state, &queues).await);
287 assert!(state.read().await.drained, "drained flag should be set");
288 }
289
290 #[tokio::test]
291 async fn check_drained_blocked_by_queued_key() {
292 let state = Arc::new(RwLock::new(BootstrapState {
293 drained: false,
294 pending_peer_requests: 0,
295 pending_keys: std::iter::once(xor_name_from_byte(0x01)).collect(),
296 capacity_rejected_sources: HashSet::new(),
297 }));
298 let mut queues = ReplicationQueues::new();
299
300 let entry = VerificationEntry {
302 state: VerificationState::PendingVerify,
303 pipeline: HintPipeline::Replica,
304 verified_sources: Vec::new(),
305 tried_sources: HashSet::new(),
306 created_at: Instant::now(),
307 hint_sender: saorsa_core::identity::PeerId::from_bytes([0u8; 32]),
308 };
309 queues.add_pending_verify(xor_name_from_byte(0x01), entry);
310
311 assert!(
312 !check_bootstrap_drained(&state, &queues).await,
313 "should not drain while bootstrap key is still in pipeline"
314 );
315 }
316
317 #[tokio::test]
318 async fn mark_bootstrap_drained_sets_flag() {
319 let state = Arc::new(RwLock::new(BootstrapState::new()));
320 mark_bootstrap_drained(&state).await;
321 assert!(state.read().await.drained);
322 }
323
324 #[tokio::test]
325 async fn track_discovered_keys_accumulates() {
326 let state = Arc::new(RwLock::new(BootstrapState::new()));
327 let set_a: HashSet<XorName> = [xor_name_from_byte(0x01), xor_name_from_byte(0x02)]
328 .into_iter()
329 .collect();
330 let set_b: HashSet<XorName> = [xor_name_from_byte(0x02), xor_name_from_byte(0x03)]
331 .into_iter()
332 .collect();
333
334 track_discovered_keys(&state, &set_a).await;
335 track_discovered_keys(&state, &set_b).await;
336
337 let s = state.read().await;
338 assert_eq!(s.pending_keys.len(), 3, "should deduplicate across calls");
339 }
340
341 #[tokio::test]
342 async fn increment_and_decrement_pending_requests() {
343 let state = Arc::new(RwLock::new(BootstrapState::new()));
344
345 increment_pending_requests(&state, 5).await;
346 assert_eq!(state.read().await.pending_peer_requests, 5);
347
348 decrement_pending_requests(&state, 3).await;
349 assert_eq!(state.read().await.pending_peer_requests, 2);
350
351 decrement_pending_requests(&state, 10).await;
353 assert_eq!(
354 state.read().await.pending_peer_requests,
355 0,
356 "should saturate at zero"
357 );
358 }
359
360 #[tokio::test]
366 async fn capacity_rejected_clears_on_clean_cycle() {
367 let state = Arc::new(RwLock::new(BootstrapState::new()));
368 let queues = ReplicationQueues::new();
369 let source = saorsa_core::identity::PeerId::from_bytes([7u8; 32]);
370
371 note_capacity_rejected(&state, source).await;
373 assert!(
374 !check_bootstrap_drained(&state, &queues).await,
375 "drain must be blocked while a source has outstanding capacity rejections"
376 );
377
378 clear_capacity_rejected(&state, &source).await;
380 assert!(
381 check_bootstrap_drained(&state, &queues).await,
382 "drain must complete once the source's outstanding rejections are cleared"
383 );
384 }
385
386 #[tokio::test]
389 async fn capacity_rejected_is_per_source() {
390 let state = Arc::new(RwLock::new(BootstrapState::new()));
391 let queues = ReplicationQueues::new();
392 let source_a = saorsa_core::identity::PeerId::from_bytes([0xAA; 32]);
393 let source_b = saorsa_core::identity::PeerId::from_bytes([0xBB; 32]);
394
395 note_capacity_rejected(&state, source_a).await;
396 note_capacity_rejected(&state, source_b).await;
397 assert!(!check_bootstrap_drained(&state, &queues).await);
398
399 clear_capacity_rejected(&state, &source_a).await;
401 assert!(
402 !check_bootstrap_drained(&state, &queues).await,
403 "B's outstanding rejections must keep drain blocked"
404 );
405
406 clear_capacity_rejected(&state, &source_b).await;
407 assert!(check_bootstrap_drained(&state, &queues).await);
408 }
409}