use std::collections::{HashMap, HashSet};
use std::time::Duration;
#[cfg(not(test))]
use std::time::Instant;
#[cfg(test)]
use mock_instant::Instant;
use serde::{Deserialize, Serialize};
use tracing::debug;
use crate::NodeId;
pub struct FailureDetector {
node_samples: HashMap<NodeId, SamplingWindow>,
config: FailureDetectorConfig,
live_nodes: HashSet<NodeId>,
dead_nodes: HashMap<NodeId, Instant>,
}
impl FailureDetector {
pub fn new(config: FailureDetectorConfig) -> Self {
Self {
node_samples: HashMap::new(),
config,
live_nodes: HashSet::new(),
dead_nodes: HashMap::new(),
}
}
pub fn report_heartbeat(&mut self, node_id: &NodeId) {
debug!(node_id = ?node_id, "reporting node heartbeat.");
let heartbeat_window = self.node_samples.entry(node_id.clone()).or_insert_with(|| {
SamplingWindow::new(
self.config.sampling_window_size,
self.config.max_interval,
self.config.initial_interval,
)
});
heartbeat_window.report_heartbeat();
}
pub fn update_node_liveliness(&mut self, node_id: &NodeId) {
if let Some(phi) = self.phi(node_id) {
debug!(node_id = ?node_id, phi = phi, "updating node liveliness");
if phi > self.config.phi_threshold {
self.live_nodes.remove(node_id);
self.dead_nodes.insert(node_id.clone(), Instant::now());
self.node_samples.remove(node_id);
} else {
self.live_nodes.insert(node_id.clone());
self.dead_nodes.remove(node_id);
}
}
}
pub fn garbage_collect(&mut self) -> Vec<NodeId> {
let mut garbage_collected_nodes = Vec::new();
for (node_id, instant) in self.dead_nodes.iter() {
if instant.elapsed() >= self.config.dead_node_grace_period {
garbage_collected_nodes.push(node_id.clone())
}
}
for node_id in garbage_collected_nodes.iter() {
self.dead_nodes.remove(node_id);
}
garbage_collected_nodes
}
pub fn live_nodes(&self) -> impl Iterator<Item = &NodeId> {
self.live_nodes.iter()
}
pub fn dead_nodes(&self) -> impl Iterator<Item = &NodeId> {
self.dead_nodes.iter().map(|(node_id, _)| node_id)
}
fn phi(&mut self, node_id: &NodeId) -> Option<f64> {
self.node_samples
.get(node_id)
.map(|sampling_window| sampling_window.phi())
}
}
#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct FailureDetectorConfig {
pub phi_threshold: f64,
pub sampling_window_size: usize,
pub max_interval: Duration,
pub initial_interval: Duration,
pub dead_node_grace_period: Duration,
}
impl FailureDetectorConfig {
pub fn new(
phi_threshold: f64,
sampling_window_size: usize,
max_interval: Duration,
initial_interval: Duration,
dead_node_grace_period: Duration,
) -> Self {
Self {
phi_threshold,
sampling_window_size,
max_interval,
initial_interval,
dead_node_grace_period,
}
}
}
impl Default for FailureDetectorConfig {
fn default() -> Self {
Self {
phi_threshold: 8.0,
sampling_window_size: 1000,
max_interval: Duration::from_secs(10),
initial_interval: Duration::from_secs(5),
dead_node_grace_period: Duration::from_secs(24 * 60 * 60), }
}
}
#[derive(Debug)]
struct SamplingWindow {
intervals: BoundedArrayStats,
last_heartbeat: Option<Instant>,
max_interval: Duration,
initial_interval: Duration,
}
impl SamplingWindow {
pub fn new(window_size: usize, max_interval: Duration, initial_interval: Duration) -> Self {
Self {
intervals: BoundedArrayStats::new(window_size),
last_heartbeat: None,
max_interval,
initial_interval,
}
}
pub fn report_heartbeat(&mut self) {
if let Some(last_value) = &self.last_heartbeat {
let interval = last_value.elapsed();
if interval <= self.max_interval {
self.intervals.append(interval.as_secs_f64());
}
} else {
self.intervals.append(self.initial_interval.as_secs_f64());
};
self.last_heartbeat = Some(Instant::now());
}
pub fn phi(&self) -> f64 {
assert!(self.intervals.mean() > 0.0 && self.last_heartbeat.is_some());
let elapsed_time = self.last_heartbeat.unwrap().elapsed().as_secs_f64();
elapsed_time / self.intervals.mean()
}
}
#[derive(Debug)]
struct BoundedArrayStats {
data: Vec<f64>,
size: usize,
is_filled: bool,
index: usize,
sum: f64,
mean: f64,
}
impl BoundedArrayStats {
pub fn new(size: usize) -> Self {
Self {
data: vec![0.0; size],
size,
is_filled: false,
index: 0,
sum: 0.0,
mean: 0.0,
}
}
pub fn mean(&self) -> f64 {
self.mean
}
pub fn append(&mut self, interval: f64) {
if self.index == self.size {
self.is_filled = true;
self.index = 0;
}
if self.is_filled {
self.sum -= self.data[self.index];
}
self.sum += interval;
self.data[self.index] = interval;
self.index += 1;
self.mean = self.sum / self.len() as f64;
}
fn len(&self) -> usize {
if self.is_filled {
return self.size;
}
self.index
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use mock_instant::MockClock;
use rand::prelude::*;
use super::{BoundedArrayStats, SamplingWindow};
use crate::failure_detector::{FailureDetector, FailureDetectorConfig};
use crate::NodeId;
#[test]
fn test_failure_detector() {
let mut rng = rand::thread_rng();
let mut failure_detector = FailureDetector::new(FailureDetectorConfig::default());
let intervals_choices = vec![1u64, 2];
let node_ids_choices = vec![
NodeId::for_test_localhost(10_001),
NodeId::for_test_localhost(10_002),
NodeId::for_test_localhost(10_003),
];
for _ in 0..=2000 {
let time_offset = intervals_choices.choose(&mut rng).unwrap();
let node_id = node_ids_choices.choose(&mut rng).unwrap();
MockClock::advance(Duration::from_secs(*time_offset));
failure_detector.report_heartbeat(node_id);
}
for node_id in &node_ids_choices {
failure_detector.update_node_liveliness(node_id);
}
let mut live_nodes = failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>();
live_nodes.sort_unstable();
assert_eq!(live_nodes, vec!["node-10001", "node-10002", "node-10003"]);
assert_eq!(failure_detector.garbage_collect(), vec![]);
MockClock::advance(Duration::from_secs(50));
for node_id in &node_ids_choices {
failure_detector.update_node_liveliness(node_id);
}
let mut dead_nodes = failure_detector
.dead_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>();
dead_nodes.sort_unstable();
assert_eq!(dead_nodes, vec!["node-10001", "node-10002", "node-10003"]);
assert_eq!(failure_detector.garbage_collect(), vec![]);
MockClock::advance(Duration::from_secs(25 * 60 * 60));
let garbage_collected_nodes = failure_detector.garbage_collect();
assert_eq!(
failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>(),
Vec::<&str>::new()
);
assert_eq!(
failure_detector
.dead_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>(),
Vec::<&str>::new()
);
let mut removed_nodes = garbage_collected_nodes
.iter()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>();
removed_nodes.sort_unstable();
assert_eq!(
removed_nodes,
vec!["node-10001", "node-10002", "node-10003"]
);
}
#[test]
fn test_failure_detector_node_state_from_live_to_down_to_live() {
let mut rng = rand::thread_rng();
let mut failure_detector = FailureDetector::new(FailureDetectorConfig::default());
let intervals_choices = vec![1u64, 2];
let node_1 = NodeId::for_test_localhost(10_001);
for _ in 0..=2000 {
let time_offset = intervals_choices.choose(&mut rng).unwrap();
MockClock::advance(Duration::from_secs(*time_offset));
failure_detector.report_heartbeat(&node_1);
}
failure_detector.update_node_liveliness(&node_1);
assert_eq!(
failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>(),
vec!["node-10001"]
);
MockClock::advance(Duration::from_secs(20));
failure_detector.update_node_liveliness(&node_1);
assert_eq!(
failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>(),
Vec::<&str>::new()
);
for _ in 0..=500 {
let time_offset = intervals_choices.choose(&mut rng).unwrap();
MockClock::advance(Duration::from_secs(*time_offset));
failure_detector.report_heartbeat(&node_1);
}
failure_detector.update_node_liveliness(&node_1);
assert_eq!(
failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>(),
vec!["node-10001"]
);
}
#[test]
fn test_failure_detector_node_state_after_initial_interval() {
let mut failure_detector = FailureDetector::new(FailureDetectorConfig::default());
let node_id = NodeId::for_test_localhost(10_001);
failure_detector.report_heartbeat(&node_id);
MockClock::advance(Duration::from_secs(1));
failure_detector.update_node_liveliness(&node_id);
let live_nodes = failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>();
assert_eq!(live_nodes, vec!["node-10001"]);
MockClock::advance(Duration::from_secs(40));
failure_detector.update_node_liveliness(&node_id);
let live_nodes = failure_detector
.live_nodes()
.map(|node_id| node_id.id.as_str())
.collect::<Vec<_>>();
assert_eq!(live_nodes, Vec::<&str>::new());
}
#[test]
fn test_sampling_window() {
let mut sampling_window =
SamplingWindow::new(10, Duration::from_secs(5), Duration::from_secs(2));
sampling_window.report_heartbeat();
MockClock::advance(Duration::from_secs(3));
sampling_window.report_heartbeat();
let mean = (2.0 + 3.0) / 2.0;
assert!((sampling_window.phi() - (0.0 / mean)).abs() < f64::EPSILON);
MockClock::advance(Duration::from_secs(1));
assert!((sampling_window.phi() - (1.0 / mean)).abs() < f64::EPSILON);
MockClock::advance(Duration::from_secs(5));
sampling_window.report_heartbeat();
MockClock::advance(Duration::from_secs(2));
assert!(
(sampling_window.phi() - (2.0 / mean)).abs() < f64::EPSILON,
"Mean value should not change."
);
}
#[test]
fn test_bounded_array_stats() {
let mut bounded_array = BoundedArrayStats::new(10);
for i in 1..10 {
bounded_array.append(i as f64);
}
assert_eq!(bounded_array.index, 9);
assert_eq!(bounded_array.len(), 9);
assert!(!bounded_array.is_filled);
assert!((bounded_array.mean() - 5.0f64).abs() < f64::EPSILON);
for i in 10..14 {
bounded_array.append(i as f64);
}
assert_eq!(bounded_array.index, 3);
assert_eq!(bounded_array.len(), 10);
assert!(bounded_array.is_filled);
assert!((bounded_array.mean() - 8.5f64).abs() < f64::EPSILON);
}
}