1use dashmap::DashMap;
10use parking_lot::RwLock;
11use std::collections::VecDeque;
12use std::net::SocketAddr;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use thiserror::Error;
16use tokio::sync::watch;
17use tracing::{debug, info, warn};
18
19#[derive(Error, Debug)]
21pub enum PartitionError {
22 #[error("Network partition detected")]
23 PartitionDetected,
24
25 #[error("Peer unreachable: {0}")]
26 PeerUnreachable(String),
27
28 #[error("Queue full: cannot accept more requests")]
29 QueueFull,
30
31 #[error("Recovery timeout")]
32 RecoveryTimeout,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
37pub enum PartitionState {
38 #[default]
40 Healthy,
41 Suspected,
43 Partitioned,
45 Recovering,
47}
48
49#[derive(Debug, Clone)]
51pub struct PartitionConfig {
52 pub failure_threshold: usize,
54 pub failure_window: Duration,
56 pub probe_interval: Duration,
58 pub max_queued_requests: usize,
60 pub recovery_probe_count: usize,
62 pub peer_timeout: Duration,
64}
65
66impl Default for PartitionConfig {
67 fn default() -> Self {
68 Self {
69 failure_threshold: 3,
70 failure_window: Duration::from_secs(10),
71 probe_interval: Duration::from_secs(5),
72 max_queued_requests: 1000,
73 recovery_probe_count: 3,
74 peer_timeout: Duration::from_secs(30),
75 }
76 }
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct PartitionStats {
82 pub partitions_detected: u64,
84 pub recoveries: u64,
86 pub queued_requests: usize,
88 pub dropped_requests: u64,
90 pub avg_partition_duration: Option<Duration>,
92 pub state: PartitionState,
94}
95
96#[derive(Debug, Clone)]
98#[allow(dead_code)]
99struct QueuedRequest {
100 peer: SocketAddr,
101 data: Vec<u8>,
102 queued_at: Instant,
103}
104
105#[derive(Debug, Clone)]
107struct PeerHealth {
108 last_success: Option<Instant>,
110 last_failure: Option<Instant>,
112 failure_count: usize,
114 failures: VecDeque<Instant>,
116}
117
118impl PeerHealth {
119 fn new() -> Self {
120 Self {
121 last_success: None,
122 last_failure: None,
123 failure_count: 0,
124 failures: VecDeque::new(),
125 }
126 }
127
128 fn record_success(&mut self) {
130 self.last_success = Some(Instant::now());
131 self.failure_count = 0;
132 self.failures.clear();
133 }
134
135 fn record_failure(&mut self, window: Duration) {
137 let now = Instant::now();
138 self.last_failure = Some(now);
139 self.failures.push_back(now);
140
141 while let Some(&first) = self.failures.front() {
143 if now.duration_since(first) > window {
144 self.failures.pop_front();
145 } else {
146 break;
147 }
148 }
149
150 self.failure_count = self.failures.len();
151 }
152
153 fn is_unhealthy(&self, threshold: usize, timeout: Duration) -> bool {
155 if self.failure_count >= threshold {
157 return true;
158 }
159
160 if let Some(last_success) = self.last_success {
162 if last_success.elapsed() > timeout {
163 return true;
164 }
165 } else if let Some(last_failure) = self.last_failure {
166 if last_failure.elapsed() > timeout {
167 return true;
168 }
169 }
170
171 false
172 }
173}
174
175pub struct PartitionDetector {
177 config: PartitionConfig,
178 state: Arc<RwLock<PartitionState>>,
179 peer_health: Arc<DashMap<SocketAddr, PeerHealth>>,
180 queued_requests: Arc<RwLock<VecDeque<QueuedRequest>>>,
181 stats: Arc<RwLock<PartitionStats>>,
182 state_tx: watch::Sender<PartitionState>,
183 state_rx: watch::Receiver<PartitionState>,
184 partition_start: Arc<RwLock<Option<Instant>>>,
185}
186
187impl PartitionDetector {
188 pub fn new(config: PartitionConfig) -> Self {
190 let (state_tx, state_rx) = watch::channel(PartitionState::Healthy);
191
192 Self {
193 config,
194 state: Arc::new(RwLock::new(PartitionState::Healthy)),
195 peer_health: Arc::new(DashMap::new()),
196 queued_requests: Arc::new(RwLock::new(VecDeque::new())),
197 stats: Arc::new(RwLock::new(PartitionStats::default())),
198 state_tx,
199 state_rx,
200 partition_start: Arc::new(RwLock::new(None)),
201 }
202 }
203
204 pub fn record_success(&self, peer: &SocketAddr) {
206 {
207 let mut health = self
208 .peer_health
209 .entry(*peer)
210 .or_insert_with(PeerHealth::new);
211 health.record_success();
212 } if *self.state.read() != PartitionState::Healthy {
216 self.check_recovery();
217 }
218 }
219
220 pub fn record_failure(&self, peer: &SocketAddr) {
222 {
223 let mut health = self
224 .peer_health
225 .entry(*peer)
226 .or_insert_with(PeerHealth::new);
227 health.record_failure(self.config.failure_window);
228
229 debug!("Peer {} failure count: {}", peer, health.failure_count);
230 } self.check_partition();
234 }
235
236 fn check_partition(&self) {
238 let unhealthy_count = self
239 .peer_health
240 .iter()
241 .filter(|entry| {
242 entry
243 .value()
244 .is_unhealthy(self.config.failure_threshold, self.config.peer_timeout)
245 })
246 .count();
247
248 let total_peers = self.peer_health.len();
249
250 if total_peers > 0 && unhealthy_count * 2 > total_peers {
252 let current_state = *self.state.read();
253
254 if current_state == PartitionState::Healthy {
255 self.transition_to_suspected();
256 } else if current_state == PartitionState::Suspected {
257 self.transition_to_partitioned();
258 }
259 }
260 }
261
262 fn check_recovery(&self) {
264 let healthy_count = self
265 .peer_health
266 .iter()
267 .filter(|entry| {
268 !entry
269 .value()
270 .is_unhealthy(self.config.failure_threshold, self.config.peer_timeout)
271 })
272 .count();
273
274 let total_peers = self.peer_health.len();
275
276 if total_peers > 0 && healthy_count * 2 > total_peers {
278 let current_state = *self.state.read();
279
280 if current_state == PartitionState::Partitioned {
281 self.transition_to_recovering();
282 } else if current_state == PartitionState::Recovering {
283 self.transition_to_healthy();
284 }
285 }
286 }
287
288 fn transition_to_suspected(&self) {
290 *self.state.write() = PartitionState::Suspected;
291 let _ = self.state_tx.send(PartitionState::Suspected);
292 warn!("Network partition suspected");
293 }
294
295 fn transition_to_partitioned(&self) {
297 *self.state.write() = PartitionState::Partitioned;
298 *self.partition_start.write() = Some(Instant::now());
299 let _ = self.state_tx.send(PartitionState::Partitioned);
300
301 let mut stats = self.stats.write();
302 stats.partitions_detected += 1;
303 stats.state = PartitionState::Partitioned;
304
305 warn!("Network partition detected - queueing requests");
306 }
307
308 fn transition_to_recovering(&self) {
310 *self.state.write() = PartitionState::Recovering;
311 let _ = self.state_tx.send(PartitionState::Recovering);
312 info!("Network partition recovering");
313 }
314
315 fn transition_to_healthy(&self) {
317 *self.state.write() = PartitionState::Healthy;
318 let _ = self.state_tx.send(PartitionState::Healthy);
319
320 if let Some(start) = *self.partition_start.read() {
322 let duration = start.elapsed();
323 let mut stats = self.stats.write();
324
325 stats.avg_partition_duration = Some(
326 stats
327 .avg_partition_duration
328 .map(|avg| (avg + duration) / 2)
329 .unwrap_or(duration),
330 );
331 stats.recoveries += 1;
332 stats.state = PartitionState::Healthy;
333 }
334
335 *self.partition_start.write() = None;
336
337 info!("Network partition recovered - processing queued requests");
338
339 self.flush_queue();
341 }
342
343 pub fn queue_request(&self, peer: SocketAddr, data: Vec<u8>) -> Result<(), PartitionError> {
345 let mut queue = self.queued_requests.write();
346
347 if queue.len() >= self.config.max_queued_requests {
348 self.stats.write().dropped_requests += 1;
349 return Err(PartitionError::QueueFull);
350 }
351
352 queue.push_back(QueuedRequest {
353 peer,
354 data,
355 queued_at: Instant::now(),
356 });
357
358 self.stats.write().queued_requests = queue.len();
359
360 Ok(())
361 }
362
363 fn flush_queue(&self) {
365 let requests: Vec<_> = {
366 let mut queue = self.queued_requests.write();
367 queue.drain(..).collect()
368 };
369
370 info!("Flushing {} queued requests", requests.len());
371
372 self.stats.write().queued_requests = 0;
374 }
375
376 pub fn drain_queue(&self) -> Vec<(SocketAddr, Vec<u8>)> {
378 let requests: Vec<_> = {
379 let mut queue = self.queued_requests.write();
380 queue.drain(..).collect()
381 };
382
383 self.stats.write().queued_requests = 0;
384
385 requests
386 .into_iter()
387 .map(|req| (req.peer, req.data))
388 .collect()
389 }
390
391 pub fn state(&self) -> PartitionState {
393 *self.state.read()
394 }
395
396 pub fn stats(&self) -> PartitionStats {
398 self.stats.read().clone()
399 }
400
401 pub async fn wait_state_change(&self) -> PartitionState {
403 let mut rx = self.state_rx.clone();
404 let _ = rx.changed().await;
405 let state = *rx.borrow();
406 state
407 }
408
409 pub fn is_peer_reachable(&self, peer: &SocketAddr) -> bool {
411 if let Some(health) = self.peer_health.get(peer) {
412 !health.is_unhealthy(self.config.failure_threshold, self.config.peer_timeout)
413 } else {
414 true }
416 }
417
418 pub fn unhealthy_peers(&self) -> Vec<SocketAddr> {
420 self.peer_health
421 .iter()
422 .filter(|entry| {
423 entry
424 .value()
425 .is_unhealthy(self.config.failure_threshold, self.config.peer_timeout)
426 })
427 .map(|entry| *entry.key())
428 .collect()
429 }
430
431 pub fn clear_peer_health(&self) {
433 self.peer_health.clear();
434 info!("Cleared peer health data");
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_peer_health() {
444 let mut health = PeerHealth::new();
445 let window = Duration::from_secs(10);
446
447 health.record_failure(window);
449 assert_eq!(health.failure_count, 1);
450
451 health.record_failure(window);
452 assert_eq!(health.failure_count, 2);
453
454 health.record_success();
456 assert_eq!(health.failure_count, 0);
457 }
458
459 #[test]
460 fn test_partition_detection() {
461 let config = PartitionConfig {
462 failure_threshold: 2,
463 ..Default::default()
464 };
465
466 let detector = PartitionDetector::new(config);
467 let peer: SocketAddr = "127.0.0.1:8080".parse().unwrap();
468
469 assert_eq!(detector.state(), PartitionState::Healthy);
470
471 detector.record_failure(&peer);
473 detector.record_failure(&peer);
474 detector.record_failure(&peer);
475
476 let state = detector.state();
478 assert!(state == PartitionState::Suspected || state == PartitionState::Partitioned);
479 }
480
481 #[test]
482 fn test_queue_request() {
483 let detector = PartitionDetector::new(PartitionConfig::default());
484 let peer: SocketAddr = "127.0.0.1:8080".parse().unwrap();
485
486 let result = detector.queue_request(peer, vec![1, 2, 3]);
487 assert!(result.is_ok());
488
489 let stats = detector.stats();
490 assert_eq!(stats.queued_requests, 1);
491 }
492
493 #[test]
494 fn test_queue_full() {
495 let config = PartitionConfig {
496 max_queued_requests: 2,
497 ..Default::default()
498 };
499
500 let detector = PartitionDetector::new(config);
501 let peer: SocketAddr = "127.0.0.1:8080".parse().unwrap();
502
503 detector.queue_request(peer, vec![1]).unwrap();
504 detector.queue_request(peer, vec![2]).unwrap();
505
506 let result = detector.queue_request(peer, vec![3]);
508 assert!(result.is_err());
509 }
510
511 #[test]
512 fn test_drain_queue() {
513 let detector = PartitionDetector::new(PartitionConfig::default());
514 let peer: SocketAddr = "127.0.0.1:8080".parse().unwrap();
515
516 detector.queue_request(peer, vec![1, 2, 3]).unwrap();
517 detector.queue_request(peer, vec![4, 5, 6]).unwrap();
518
519 let drained = detector.drain_queue();
520 assert_eq!(drained.len(), 2);
521 assert_eq!(detector.stats().queued_requests, 0);
522 }
523
524 #[tokio::test]
525 async fn test_state_transitions() {
526 let config = PartitionConfig {
527 failure_threshold: 1,
528 ..Default::default()
529 };
530
531 let detector = PartitionDetector::new(config);
532 let peer: SocketAddr = "127.0.0.1:8080".parse().unwrap();
533
534 assert_eq!(detector.state(), PartitionState::Healthy);
536
537 detector.record_failure(&peer);
539
540 assert!(detector.state() != PartitionState::Healthy);
542
543 detector.record_success(&peer);
545
546 let state = detector.state();
549 assert!(
550 state == PartitionState::Recovering
551 || state == PartitionState::Healthy
552 || state == PartitionState::Suspected
553 || state == PartitionState::Partitioned,
554 "Expected one of the valid states, got: {:?}",
555 state
556 );
557 }
558}