use iroh::EndpointId;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PartitionEvent {
PartitionDetected {
peer_id: EndpointId,
consecutive_failures: u64,
},
PartitionHealed {
peer_id: EndpointId,
partition_duration: Duration,
},
PeerRecovered { peer_id: EndpointId },
HeartbeatSuccess { peer_id: EndpointId },
HeartbeatFailure {
peer_id: EndpointId,
consecutive_failures: u64,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeerPartitionState {
Connected,
Partitioned,
Recovering,
}
#[derive(Debug, Clone)]
pub struct PeerHeartbeat {
pub state: PeerPartitionState,
pub last_heartbeat: SystemTime,
pub partition_detected_at: Option<SystemTime>,
pub consecutive_failures: u64,
}
impl PeerHeartbeat {
pub fn new() -> Self {
Self {
state: PeerPartitionState::Connected,
last_heartbeat: SystemTime::now(),
partition_detected_at: None,
consecutive_failures: 0,
}
}
pub fn record_success(&mut self, peer_id: EndpointId) -> Option<PartitionEvent> {
let now = SystemTime::now();
let was_partitioned = self.state == PeerPartitionState::Partitioned;
self.last_heartbeat = now;
self.consecutive_failures = 0;
if was_partitioned {
let partition_duration = self
.partition_detected_at
.and_then(|detected_at| now.duration_since(detected_at).ok())
.unwrap_or(Duration::from_secs(0));
self.state = PeerPartitionState::Recovering;
self.partition_detected_at = None;
info!(
peer_id = ?peer_id,
partition_duration_secs = partition_duration.as_secs(),
"Partition healed - peer recovering"
);
return Some(PartitionEvent::PartitionHealed {
peer_id,
partition_duration,
});
} else if self.state == PeerPartitionState::Recovering {
self.state = PeerPartitionState::Connected;
info!(peer_id = ?peer_id, "Peer fully recovered");
return Some(PartitionEvent::PeerRecovered { peer_id });
}
debug!(peer_id = ?peer_id, "Heartbeat success");
Some(PartitionEvent::HeartbeatSuccess { peer_id })
}
pub fn record_failure(
&mut self,
peer_id: EndpointId,
timeout_threshold: u64,
) -> Option<PartitionEvent> {
self.consecutive_failures += 1;
if self.consecutive_failures >= timeout_threshold
&& self.state != PeerPartitionState::Partitioned
{
self.state = PeerPartitionState::Partitioned;
self.partition_detected_at = Some(SystemTime::now());
warn!(
peer_id = ?peer_id,
consecutive_failures = self.consecutive_failures,
"Partition detected"
);
return Some(PartitionEvent::PartitionDetected {
peer_id,
consecutive_failures: self.consecutive_failures,
});
}
if self.state == PeerPartitionState::Connected {
debug!(
peer_id = ?peer_id,
consecutive_failures = self.consecutive_failures,
threshold = timeout_threshold,
"Heartbeat failure"
);
return Some(PartitionEvent::HeartbeatFailure {
peer_id,
consecutive_failures: self.consecutive_failures,
});
}
None
}
pub fn is_timeout(&self, timeout: Duration) -> bool {
SystemTime::now()
.duration_since(self.last_heartbeat)
.map(|elapsed| elapsed > timeout)
.unwrap_or(false)
}
pub fn partition_duration(&self) -> Option<Duration> {
self.partition_detected_at
.and_then(|detected_at| SystemTime::now().duration_since(detected_at).ok())
}
}
impl Default for PeerHeartbeat {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct PartitionConfig {
pub heartbeat_interval: Duration,
pub heartbeat_timeout: Duration,
pub failure_threshold: u64,
}
impl Default for PartitionConfig {
fn default() -> Self {
let heartbeat_interval = Duration::from_secs(5);
Self {
heartbeat_interval,
heartbeat_timeout: heartbeat_interval * 3,
failure_threshold: 3,
}
}
}
pub struct PartitionDetector {
heartbeats: Arc<RwLock<HashMap<EndpointId, PeerHeartbeat>>>,
config: PartitionConfig,
}
impl PartitionDetector {
pub fn new() -> Self {
Self::with_config(PartitionConfig::default())
}
pub fn with_config(config: PartitionConfig) -> Self {
Self {
heartbeats: Arc::new(RwLock::new(HashMap::new())),
config,
}
}
pub fn config(&self) -> &PartitionConfig {
&self.config
}
pub fn register_peer(&self, peer_id: EndpointId) {
self.heartbeats
.write()
.unwrap_or_else(|e| e.into_inner())
.entry(peer_id)
.or_default();
}
pub fn unregister_peer(&self, peer_id: &EndpointId) {
self.heartbeats
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(peer_id);
}
pub fn record_heartbeat_success(&self, peer_id: &EndpointId) -> Option<PartitionEvent> {
self.heartbeats
.write()
.unwrap()
.get_mut(peer_id)
.and_then(|hb| hb.record_success(*peer_id))
}
pub fn record_heartbeat_failure(&self, peer_id: &EndpointId) -> Option<PartitionEvent> {
self.heartbeats
.write()
.unwrap()
.get_mut(peer_id)
.and_then(|hb| hb.record_failure(*peer_id, self.config.failure_threshold))
}
pub fn get_peer_state(&self, peer_id: &EndpointId) -> Option<PeerPartitionState> {
self.heartbeats
.read()
.unwrap()
.get(peer_id)
.map(|hb| hb.state)
}
pub fn get_peer_heartbeat(&self, peer_id: &EndpointId) -> Option<PeerHeartbeat> {
self.heartbeats
.read()
.unwrap_or_else(|e| e.into_inner())
.get(peer_id)
.cloned()
}
pub fn get_partitioned_peers(&self) -> Vec<EndpointId> {
self.heartbeats
.read()
.unwrap()
.iter()
.filter(|(_, hb)| hb.state == PeerPartitionState::Partitioned)
.map(|(peer_id, _)| *peer_id)
.collect()
}
pub fn check_timeouts(&self) -> Vec<PartitionEvent> {
let mut events = Vec::new();
let mut heartbeats = self.heartbeats.write().unwrap_or_else(|e| e.into_inner());
for (peer_id, hb) in heartbeats.iter_mut() {
if hb.state != PeerPartitionState::Partitioned
&& hb.is_timeout(self.config.heartbeat_timeout)
{
hb.state = PeerPartitionState::Partitioned;
hb.partition_detected_at = Some(SystemTime::now());
warn!(
peer_id = ?peer_id,
timeout_secs = self.config.heartbeat_timeout.as_secs(),
"Partition detected via timeout"
);
events.push(PartitionEvent::PartitionDetected {
peer_id: *peer_id,
consecutive_failures: hb.consecutive_failures,
});
}
}
events
}
pub fn peer_count(&self) -> usize {
self.heartbeats
.read()
.unwrap_or_else(|e| e.into_inner())
.len()
}
}
impl Default for PartitionDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_heartbeat_success_resets_failures() {
let mut hb = PeerHeartbeat::new();
hb.consecutive_failures = 2;
let mut rng = rand::rng();
let secret_key = iroh::SecretKey::generate(&mut rng);
let peer_id = secret_key.public();
let event = hb.record_success(peer_id);
assert_eq!(hb.consecutive_failures, 0);
assert_eq!(hb.state, PeerPartitionState::Connected);
assert_eq!(event, Some(PartitionEvent::HeartbeatSuccess { peer_id }));
}
#[test]
fn test_peer_heartbeat_partition_detection() {
let mut hb = PeerHeartbeat::new();
let mut rng = rand::rng();
let secret_key = iroh::SecretKey::generate(&mut rng);
let peer_id = secret_key.public();
let event1 = hb.record_failure(peer_id, 3);
assert_eq!(hb.state, PeerPartitionState::Connected);
assert!(matches!(
event1,
Some(PartitionEvent::HeartbeatFailure { .. })
));
let event2 = hb.record_failure(peer_id, 3);
assert_eq!(hb.state, PeerPartitionState::Connected);
assert!(matches!(
event2,
Some(PartitionEvent::HeartbeatFailure { .. })
));
let event3 = hb.record_failure(peer_id, 3);
assert_eq!(hb.state, PeerPartitionState::Partitioned);
assert!(hb.partition_detected_at.is_some());
assert!(matches!(
event3,
Some(PartitionEvent::PartitionDetected { .. })
));
}
#[test]
fn test_peer_heartbeat_recovery() {
let mut hb = PeerHeartbeat::new();
let mut rng = rand::rng();
let secret_key = iroh::SecretKey::generate(&mut rng);
let peer_id = secret_key.public();
hb.record_failure(peer_id, 3);
hb.record_failure(peer_id, 3);
hb.record_failure(peer_id, 3);
assert_eq!(hb.state, PeerPartitionState::Partitioned);
let event1 = hb.record_success(peer_id);
assert_eq!(hb.state, PeerPartitionState::Recovering);
assert!(hb.partition_detected_at.is_none());
assert!(matches!(
event1,
Some(PartitionEvent::PartitionHealed { .. })
));
let event2 = hb.record_success(peer_id);
assert_eq!(hb.state, PeerPartitionState::Connected);
assert!(matches!(event2, Some(PartitionEvent::PeerRecovered { .. })));
}
#[test]
fn test_partition_config_defaults() {
let config = PartitionConfig::default();
assert_eq!(config.heartbeat_interval, Duration::from_secs(5));
assert_eq!(config.heartbeat_timeout, Duration::from_secs(15));
assert_eq!(config.failure_threshold, 3);
}
#[test]
fn test_partition_detector_creation() {
let detector = PartitionDetector::new();
assert_eq!(detector.peer_count(), 0);
assert_eq!(detector.config().heartbeat_interval, Duration::from_secs(5));
}
}