use std::io::{Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use crate::proxy_config::HealthCheckConfig;
pub(crate) fn start_health_checker(
upstream_name: String,
backends: Vec<String>,
live: Arc<RwLock<Vec<String>>>,
config: HealthCheckConfig,
) {
std::thread::Builder::new()
.name(format!("health-{}", upstream_name))
.spawn(move || {
let interval = Duration::from_secs(config.interval_secs);
let timeout = Duration::from_millis(config.timeout_ms);
let mut successes: Vec<u32> = vec![0; backends.len()];
let mut failures: Vec<u32> = vec![0; backends.len()];
let mut is_live: Vec<bool> = vec![true; backends.len()];
loop {
std::thread::sleep(interval);
for (i, backend) in backends.iter().enumerate() {
let ok = check_backend(backend, &config.path, timeout);
if ok {
successes[i] += 1;
failures[i] = 0;
if !is_live[i] && successes[i] >= config.healthy_threshold {
is_live[i] = true;
eprintln!(
"[health] upstream={} backend={} restored ({}x ok)",
upstream_name, backend, successes[i]
);
}
} else {
failures[i] += 1;
successes[i] = 0;
if is_live[i] && failures[i] >= config.unhealthy_threshold {
is_live[i] = false;
eprintln!(
"[health] upstream={} backend={} removed ({}x fail)",
upstream_name, backend, failures[i]
);
}
}
}
let live_list: Vec<String> = backends
.iter()
.enumerate()
.filter(|(i, _)| is_live[*i])
.map(|(_, b)| b.clone())
.collect();
if let Ok(mut guard) = live.write() {
*guard = live_list;
}
}
})
.ok();
}
fn check_backend(backend: &str, path: &str, timeout: Duration) -> bool {
let (host, port, tls) = match parse_backend_url(backend) {
Some(t) => t,
None => return false,
};
let addr_str = format!("{}:{}", host, port);
let sock_addr = match addr_str.to_socket_addrs().ok().and_then(|mut a| a.next()) {
Some(a) => a,
None => return false,
};
let stream = match TcpStream::connect_timeout(&sock_addr, timeout) {
Ok(s) => s,
Err(_) => return false,
};
let _ = stream.set_read_timeout(Some(timeout));
let _ = stream.set_write_timeout(Some(timeout));
let req = format!(
"GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
path, host
);
if tls {
check_via_tls(stream, &host, req.as_bytes())
} else {
let mut stream = stream;
if stream.write_all(req.as_bytes()).is_err() {
return false;
}
let mut buf = [0u8; 16];
if stream.read(&mut buf).is_err() {
return false;
}
buf.starts_with(b"HTTP/1.1 2") || buf.starts_with(b"HTTP/1.0 2")
}
}
#[cfg(any(feature = "http-client", feature = "http2"))]
fn check_via_tls(stream: TcpStream, host: &str, req: &[u8]) -> bool {
use rustls::pki_types::ServerName;
use rustls::ClientConfig;
use std::sync::Arc;
let root_store =
rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = Arc::new(
ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
let server_name = match ServerName::try_from(host.to_string()) {
Ok(n) => n,
Err(_) => return false,
};
let conn = match rustls::ClientConnection::new(config, server_name) {
Ok(c) => c,
Err(_) => return false,
};
let mut tls = rustls::StreamOwned::new(conn, stream);
if tls.write_all(req).is_err() {
return false;
}
let mut buf = [0u8; 16];
if tls.read(&mut buf).is_err() {
return false;
}
buf.starts_with(b"HTTP/1.1 2") || buf.starts_with(b"HTTP/1.0 2")
}
#[cfg(not(any(feature = "http-client", feature = "http2")))]
fn check_via_tls(_stream: TcpStream, _host: &str, _req: &[u8]) -> bool {
false
}
pub(crate) fn parse_backend_url(backend: &str) -> Option<(String, u16, bool)> {
let (rest, tls, default_port) = if let Some(r) = backend.strip_prefix("https://") {
(r, true, 443u16)
} else if let Some(r) = backend.strip_prefix("http://") {
(r, false, 80u16)
} else if let Some(r) = backend.strip_prefix("h2://") {
(r, false, 80u16)
} else {
(backend, false, 80u16)
};
let host_port = rest.split('/').next().unwrap_or(rest);
if host_port.is_empty() {
return None;
}
let (host, port) = if host_port.starts_with('[') {
let close = host_port.find(']')?;
let host = host_port[1..close].to_string();
let port = if host_port.len() > close + 1 && host_port.as_bytes()[close + 1] == b':' {
host_port[close + 2..].parse::<u16>().unwrap_or(default_port)
} else {
default_port
};
(host, port)
} else if let Some(colon) = host_port.rfind(':') {
let port_str = &host_port[colon + 1..];
if let Ok(p) = port_str.parse::<u16>() {
(host_port[..colon].to_string(), p)
} else {
(host_port.to_string(), default_port)
}
} else {
(host_port.to_string(), default_port)
};
if host.is_empty() {
return None;
}
Some((host, port, tls))
}