1use iroh::EndpointId;
19use std::collections::HashMap;
20use std::sync::{Arc, RwLock};
21use std::time::{Duration, SystemTime};
22use tracing::{debug, info, warn};
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum PartitionEvent {
27 PartitionDetected {
29 peer_id: EndpointId,
30 consecutive_failures: u64,
31 },
32 PartitionHealed {
34 peer_id: EndpointId,
35 partition_duration: Duration,
36 },
37 PeerRecovered { peer_id: EndpointId },
39 HeartbeatSuccess { peer_id: EndpointId },
41 HeartbeatFailure {
43 peer_id: EndpointId,
44 consecutive_failures: u64,
45 },
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum PeerPartitionState {
51 Connected,
53 Partitioned,
55 Recovering,
57}
58
59#[derive(Debug, Clone)]
61pub struct PeerHeartbeat {
62 pub state: PeerPartitionState,
64 pub last_heartbeat: SystemTime,
66 pub partition_detected_at: Option<SystemTime>,
68 pub consecutive_failures: u64,
70}
71
72impl PeerHeartbeat {
73 pub fn new() -> Self {
75 Self {
76 state: PeerPartitionState::Connected,
77 last_heartbeat: SystemTime::now(),
78 partition_detected_at: None,
79 consecutive_failures: 0,
80 }
81 }
82
83 pub fn record_success(&mut self, peer_id: EndpointId) -> Option<PartitionEvent> {
87 let now = SystemTime::now();
88 let was_partitioned = self.state == PeerPartitionState::Partitioned;
89
90 self.last_heartbeat = now;
91 self.consecutive_failures = 0;
92
93 if was_partitioned {
94 let partition_duration = self
96 .partition_detected_at
97 .and_then(|detected_at| now.duration_since(detected_at).ok())
98 .unwrap_or(Duration::from_secs(0));
99
100 self.state = PeerPartitionState::Recovering;
101 self.partition_detected_at = None;
102
103 info!(
104 peer_id = ?peer_id,
105 partition_duration_secs = partition_duration.as_secs(),
106 "Partition healed - peer recovering"
107 );
108
109 return Some(PartitionEvent::PartitionHealed {
110 peer_id,
111 partition_duration,
112 });
113 } else if self.state == PeerPartitionState::Recovering {
114 self.state = PeerPartitionState::Connected;
116
117 info!(peer_id = ?peer_id, "Peer fully recovered");
118
119 return Some(PartitionEvent::PeerRecovered { peer_id });
120 }
121
122 debug!(peer_id = ?peer_id, "Heartbeat success");
123 Some(PartitionEvent::HeartbeatSuccess { peer_id })
124 }
125
126 pub fn record_failure(
130 &mut self,
131 peer_id: EndpointId,
132 timeout_threshold: u64,
133 ) -> Option<PartitionEvent> {
134 self.consecutive_failures += 1;
135
136 if self.consecutive_failures >= timeout_threshold
138 && self.state != PeerPartitionState::Partitioned
139 {
140 self.state = PeerPartitionState::Partitioned;
141 self.partition_detected_at = Some(SystemTime::now());
142
143 warn!(
144 peer_id = ?peer_id,
145 consecutive_failures = self.consecutive_failures,
146 "Partition detected"
147 );
148
149 return Some(PartitionEvent::PartitionDetected {
150 peer_id,
151 consecutive_failures: self.consecutive_failures,
152 });
153 }
154
155 if self.state == PeerPartitionState::Connected {
157 debug!(
158 peer_id = ?peer_id,
159 consecutive_failures = self.consecutive_failures,
160 threshold = timeout_threshold,
161 "Heartbeat failure"
162 );
163
164 return Some(PartitionEvent::HeartbeatFailure {
165 peer_id,
166 consecutive_failures: self.consecutive_failures,
167 });
168 }
169
170 None
171 }
172
173 pub fn is_timeout(&self, timeout: Duration) -> bool {
175 SystemTime::now()
176 .duration_since(self.last_heartbeat)
177 .map(|elapsed| elapsed > timeout)
178 .unwrap_or(false)
179 }
180
181 pub fn partition_duration(&self) -> Option<Duration> {
183 self.partition_detected_at
184 .and_then(|detected_at| SystemTime::now().duration_since(detected_at).ok())
185 }
186}
187
188impl Default for PeerHeartbeat {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct PartitionConfig {
197 pub heartbeat_interval: Duration,
199 pub heartbeat_timeout: Duration,
201 pub failure_threshold: u64,
203}
204
205impl Default for PartitionConfig {
206 fn default() -> Self {
207 let heartbeat_interval = Duration::from_secs(5);
208 Self {
209 heartbeat_interval,
210 heartbeat_timeout: heartbeat_interval * 3,
211 failure_threshold: 3,
212 }
213 }
214}
215
216pub struct PartitionDetector {
220 heartbeats: Arc<RwLock<HashMap<EndpointId, PeerHeartbeat>>>,
222 config: PartitionConfig,
224}
225
226impl PartitionDetector {
227 pub fn new() -> Self {
229 Self::with_config(PartitionConfig::default())
230 }
231
232 pub fn with_config(config: PartitionConfig) -> Self {
234 Self {
235 heartbeats: Arc::new(RwLock::new(HashMap::new())),
236 config,
237 }
238 }
239
240 pub fn config(&self) -> &PartitionConfig {
242 &self.config
243 }
244
245 pub fn register_peer(&self, peer_id: EndpointId) {
247 self.heartbeats
248 .write()
249 .unwrap_or_else(|e| e.into_inner())
250 .entry(peer_id)
251 .or_default();
252 }
253
254 pub fn unregister_peer(&self, peer_id: &EndpointId) {
256 self.heartbeats
257 .write()
258 .unwrap_or_else(|e| e.into_inner())
259 .remove(peer_id);
260 }
261
262 pub fn record_heartbeat_success(&self, peer_id: &EndpointId) -> Option<PartitionEvent> {
266 self.heartbeats
267 .write()
268 .unwrap()
269 .get_mut(peer_id)
270 .and_then(|hb| hb.record_success(*peer_id))
271 }
272
273 pub fn record_heartbeat_failure(&self, peer_id: &EndpointId) -> Option<PartitionEvent> {
277 self.heartbeats
278 .write()
279 .unwrap()
280 .get_mut(peer_id)
281 .and_then(|hb| hb.record_failure(*peer_id, self.config.failure_threshold))
282 }
283
284 pub fn get_peer_state(&self, peer_id: &EndpointId) -> Option<PeerPartitionState> {
286 self.heartbeats
287 .read()
288 .unwrap()
289 .get(peer_id)
290 .map(|hb| hb.state)
291 }
292
293 pub fn get_peer_heartbeat(&self, peer_id: &EndpointId) -> Option<PeerHeartbeat> {
295 self.heartbeats
296 .read()
297 .unwrap_or_else(|e| e.into_inner())
298 .get(peer_id)
299 .cloned()
300 }
301
302 pub fn get_partitioned_peers(&self) -> Vec<EndpointId> {
304 self.heartbeats
305 .read()
306 .unwrap()
307 .iter()
308 .filter(|(_, hb)| hb.state == PeerPartitionState::Partitioned)
309 .map(|(peer_id, _)| *peer_id)
310 .collect()
311 }
312
313 pub fn check_timeouts(&self) -> Vec<PartitionEvent> {
317 let mut events = Vec::new();
318
319 let mut heartbeats = self.heartbeats.write().unwrap_or_else(|e| e.into_inner());
320 for (peer_id, hb) in heartbeats.iter_mut() {
321 if hb.state != PeerPartitionState::Partitioned
322 && hb.is_timeout(self.config.heartbeat_timeout)
323 {
324 hb.state = PeerPartitionState::Partitioned;
325 hb.partition_detected_at = Some(SystemTime::now());
326
327 warn!(
328 peer_id = ?peer_id,
329 timeout_secs = self.config.heartbeat_timeout.as_secs(),
330 "Partition detected via timeout"
331 );
332
333 events.push(PartitionEvent::PartitionDetected {
334 peer_id: *peer_id,
335 consecutive_failures: hb.consecutive_failures,
336 });
337 }
338 }
339
340 events
341 }
342
343 pub fn peer_count(&self) -> usize {
345 self.heartbeats
346 .read()
347 .unwrap_or_else(|e| e.into_inner())
348 .len()
349 }
350}
351
352impl Default for PartitionDetector {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_peer_heartbeat_success_resets_failures() {
364 let mut hb = PeerHeartbeat::new();
365 hb.consecutive_failures = 2;
366
367 let mut rng = rand::rng();
369 let secret_key = iroh::SecretKey::generate(&mut rng);
370 let peer_id = secret_key.public();
371
372 let event = hb.record_success(peer_id);
373
374 assert_eq!(hb.consecutive_failures, 0);
375 assert_eq!(hb.state, PeerPartitionState::Connected);
376 assert_eq!(event, Some(PartitionEvent::HeartbeatSuccess { peer_id }));
377 }
378
379 #[test]
380 fn test_peer_heartbeat_partition_detection() {
381 let mut hb = PeerHeartbeat::new();
382 let mut rng = rand::rng();
383 let secret_key = iroh::SecretKey::generate(&mut rng);
384 let peer_id = secret_key.public();
385
386 let event1 = hb.record_failure(peer_id, 3);
388 assert_eq!(hb.state, PeerPartitionState::Connected);
389 assert!(matches!(
390 event1,
391 Some(PartitionEvent::HeartbeatFailure { .. })
392 ));
393
394 let event2 = hb.record_failure(peer_id, 3);
395 assert_eq!(hb.state, PeerPartitionState::Connected);
396 assert!(matches!(
397 event2,
398 Some(PartitionEvent::HeartbeatFailure { .. })
399 ));
400
401 let event3 = hb.record_failure(peer_id, 3);
403 assert_eq!(hb.state, PeerPartitionState::Partitioned);
404 assert!(hb.partition_detected_at.is_some());
405 assert!(matches!(
406 event3,
407 Some(PartitionEvent::PartitionDetected { .. })
408 ));
409 }
410
411 #[test]
412 fn test_peer_heartbeat_recovery() {
413 let mut hb = PeerHeartbeat::new();
414 let mut rng = rand::rng();
415 let secret_key = iroh::SecretKey::generate(&mut rng);
416 let peer_id = secret_key.public();
417
418 hb.record_failure(peer_id, 3);
420 hb.record_failure(peer_id, 3);
421 hb.record_failure(peer_id, 3);
422 assert_eq!(hb.state, PeerPartitionState::Partitioned);
423
424 let event1 = hb.record_success(peer_id);
426 assert_eq!(hb.state, PeerPartitionState::Recovering);
427 assert!(hb.partition_detected_at.is_none());
428 assert!(matches!(
429 event1,
430 Some(PartitionEvent::PartitionHealed { .. })
431 ));
432
433 let event2 = hb.record_success(peer_id);
435 assert_eq!(hb.state, PeerPartitionState::Connected);
436 assert!(matches!(event2, Some(PartitionEvent::PeerRecovered { .. })));
437 }
438
439 #[test]
440 fn test_partition_config_defaults() {
441 let config = PartitionConfig::default();
442 assert_eq!(config.heartbeat_interval, Duration::from_secs(5));
443 assert_eq!(config.heartbeat_timeout, Duration::from_secs(15));
444 assert_eq!(config.failure_threshold, 3);
445 }
446
447 #[test]
448 fn test_partition_detector_creation() {
449 let detector = PartitionDetector::new();
450 assert_eq!(detector.peer_count(), 0);
451 assert_eq!(detector.config().heartbeat_interval, Duration::from_secs(5));
452 }
453
454 }