use std::{
collections::HashMap,
io,
net::{SocketAddr, ToSocketAddrs},
sync::RwLock,
time::{Duration, Instant},
};
use lazy_static::lazy_static;
use tokio::net::TcpStream;
lazy_static! {
static ref HEALTH_CHECK: RwLock<HealthCheck> = RwLock::new(HealthCheck::new(60, 3, 2));
}
struct HealthRecord {
last_request: Option<Instant>,
last_record: Instant,
fail_timeout: Duration,
fall_times: usize,
rise_times: usize,
failed: bool,
}
impl HealthRecord {
pub fn new(fail_timeout: usize) -> Self {
Self {
last_request: None,
last_record: Instant::now(),
fail_timeout: Duration::new(fail_timeout as u64, 0),
fall_times: 0,
rise_times: 0,
failed: false,
}
}
pub fn clear_status(&mut self) {
self.fall_times = 0;
self.rise_times = 0;
self.failed = false;
}
}
pub struct HealthCheck {
fail_timeout: usize,
max_fails: usize,
min_rises: usize,
health_map: HashMap<SocketAddr, HealthRecord>,
}
impl HealthCheck {
pub fn new(fail_timeout: usize, max_fails: usize, min_rises: usize) -> Self {
Self {
fail_timeout,
max_fails,
min_rises,
health_map: HashMap::new(),
}
}
pub fn instance() -> &'static RwLock<HealthCheck> {
&HEALTH_CHECK
}
pub fn check_can_request(addr: &SocketAddr, duration: Duration) -> bool {
if let Ok(mut h) = HEALTH_CHECK.write() {
if !h.health_map.contains_key(addr) {
let mut health = HealthRecord::new(30);
health.fall_times = 0;
health.last_request = Some(Instant::now());
h.health_map.insert(addr.clone(), health);
return true;
}
let value = h.health_map.get(&addr).unwrap();
let can = if let Some(ins) = value.last_request {
Instant::now().duration_since(ins) > duration
} else {
true
};
if can {
h.health_map.get_mut(&addr).unwrap().last_request = Some(Instant::now());
}
can
} else {
true
}
}
pub fn is_fall_down(addr: &SocketAddr) -> bool {
if let Ok(h) = HEALTH_CHECK.read() {
if !h.health_map.contains_key(addr) {
return false;
}
let value = h.health_map.get(&addr).unwrap();
if Instant::now().duration_since(value.last_record) > value.fail_timeout {
return false;
}
value.failed
} else {
false
}
}
pub fn check_fall_down(addr: &SocketAddr, fail_timeout: &Duration, fall_times: &usize, rise_times: &usize) -> bool {
if let Ok(h) = HEALTH_CHECK.read() {
if !h.health_map.contains_key(addr) {
return false;
}
let value = h.health_map.get(&addr).unwrap();
if Instant::now().duration_since(value.last_record) > *fail_timeout {
return false;
}
if &value.fall_times >= fall_times {
return true;
}
if &value.rise_times >= rise_times {
return false;
}
value.failed
} else {
false
}
}
pub fn add_fall_down(addr: SocketAddr) {
if let Ok(mut h) = HEALTH_CHECK.write() {
if !h.health_map.contains_key(&addr) {
let mut health = HealthRecord::new(h.fail_timeout);
health.fall_times = 1;
h.health_map.insert(addr, health);
} else {
let max_fails = h.max_fails;
let value = h.health_map.get_mut(&addr).unwrap();
if Instant::now().duration_since(value.last_record) > value.fail_timeout {
value.clear_status();
}
value.last_record = Instant::now();
value.fall_times += 1;
value.rise_times = 0;
if value.fall_times >= max_fails {
value.failed = true;
}
}
}
}
pub fn add_rise_up(addr: SocketAddr) {
if let Ok(mut h) = HEALTH_CHECK.write() {
if !h.health_map.contains_key(&addr) {
let mut health = HealthRecord::new(h.fail_timeout);
health.rise_times = 1;
h.health_map.insert(addr, health);
} else {
let min_rises = h.min_rises;
let value = h.health_map.get_mut(&addr).unwrap();
if Instant::now().duration_since(value.last_record) > value.fail_timeout {
value.clear_status();
}
value.last_record = Instant::now();
value.rise_times += 1;
value.fall_times = 0;
if value.rise_times >= min_rises {
value.failed = false;
}
}
}
}
pub async fn connect<A>(addr: &A) -> io::Result<TcpStream>
where
A: ToSocketAddrs,
{
let addrs = addr.to_socket_addrs()?;
let mut last_err = None;
for addr in addrs {
if Self::is_fall_down(&addr) {
last_err = Some(io::Error::new(io::ErrorKind::Other, "health check falldown"));
} else {
log::trace!("尝试与远端{addr}建立连接");
match TcpStream::connect(&addr).await {
Ok(stream) =>
{
if let Ok(local) = stream.local_addr() {
log::trace!("成功与远端{addr}建立连接:{local}->{addr}");
}
Self::add_rise_up(addr);
return Ok(stream)
},
Err(e) => {
log::trace!("与远端{addr}建立连接失败, 原因: {:?}", e);
Self::add_fall_down(addr);
last_err = Some(e)
},
}
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
}
pub async fn connect_timeout<A>(addr: &A, connect: Option<Duration>) -> io::Result<TcpStream>
where
A: ToSocketAddrs,
{
if connect.is_none() {
HealthCheck::connect(addr).await
} else {
match tokio::time::timeout(connect.unwrap(), HealthCheck::connect(addr)).await {
Ok(s) => s,
Err(_) => return Err(io::Error::new(io::ErrorKind::NotConnected, "timeout")),
}
}
}
}