use dashmap::DashMap;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use super::config::{StreamHealthProbe, StreamProxyConfig};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendHealth {
Healthy,
Unhealthy,
Unknown,
}
impl BackendHealth {
#[must_use]
pub fn is_usable(self) -> bool {
matches!(self, BackendHealth::Healthy | BackendHealth::Unknown)
}
}
#[derive(Clone, Debug)]
pub struct StreamService {
pub name: String,
pub backends: Vec<SocketAddr>,
health: Arc<RwLock<HashMap<SocketAddr, BackendHealth>>>,
rr_index: Arc<AtomicUsize>,
pub config: StreamProxyConfig,
}
impl StreamService {
#[must_use]
pub fn new(name: String, backends: Vec<SocketAddr>) -> Self {
let health: HashMap<SocketAddr, BackendHealth> = backends
.iter()
.map(|addr| (*addr, BackendHealth::Unknown))
.collect();
Self {
name,
backends,
health: Arc::new(RwLock::new(health)),
rr_index: Arc::new(AtomicUsize::new(0)),
config: StreamProxyConfig::default(),
}
}
#[must_use]
pub fn with_config(mut self, config: StreamProxyConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn select_backend(&self) -> Option<SocketAddr> {
if self.backends.is_empty() {
return None;
}
let len = self.backends.len();
let start = self.rr_index.fetch_add(1, Ordering::Relaxed);
let health_guard = self.health.try_read();
if let Ok(health) = health_guard {
for i in 0..len {
let idx = (start + i) % len;
let addr = self.backends[idx];
let status = health.get(&addr).copied().unwrap_or(BackendHealth::Unknown);
if status.is_usable() {
return Some(addr);
}
}
}
Some(self.backends[start % len])
}
pub fn update_backends(&mut self, backends: Vec<SocketAddr>) {
let mut health = self
.health
.try_write()
.unwrap_or_else(|_| {
tracing::warn!(service = %self.name, "Health map write contention during backend update");
unreachable!("update_backends requires exclusive access")
});
for addr in &backends {
health.entry(*addr).or_insert(BackendHealth::Unknown);
}
let backend_set: std::collections::HashSet<SocketAddr> = backends.iter().copied().collect();
health.retain(|addr, _| backend_set.contains(addr));
self.backends = backends;
}
pub async fn set_backend_health(&self, addr: SocketAddr, status: BackendHealth) {
let mut health = self.health.write().await;
if let Some(h) = health.get_mut(&addr) {
*h = status;
}
}
pub async fn get_backend_health(&self, addr: SocketAddr) -> BackendHealth {
let health = self.health.read().await;
health.get(&addr).copied().unwrap_or(BackendHealth::Unknown)
}
#[must_use]
pub fn backend_count(&self) -> usize {
self.backends.len()
}
pub async fn healthy_count(&self) -> usize {
let health = self.health.read().await;
self.backends
.iter()
.filter(|addr| {
health
.get(addr)
.copied()
.unwrap_or(BackendHealth::Unknown)
.is_usable()
})
.count()
}
}
#[derive(Default)]
pub struct StreamRegistry {
tcp_services: DashMap<u16, StreamService>,
udp_services: DashMap<u16, StreamService>,
}
impl StreamRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_tcp(&self, port: u16, service: StreamService) {
tracing::debug!(
port = port,
service = %service.name,
backends = service.backend_count(),
"Registered TCP stream service"
);
self.tcp_services.insert(port, service);
}
pub fn register_udp(&self, port: u16, service: StreamService) {
tracing::debug!(
port = port,
service = %service.name,
backends = service.backend_count(),
"Registered UDP stream service"
);
self.udp_services.insert(port, service);
}
#[must_use]
pub fn resolve_tcp(&self, port: u16) -> Option<StreamService> {
self.tcp_services.get(&port).map(|s| s.clone())
}
#[must_use]
pub fn resolve_udp(&self, port: u16) -> Option<StreamService> {
self.udp_services.get(&port).map(|s| s.clone())
}
pub fn set_tcp_config(&self, port: u16, config: StreamProxyConfig) {
if let Some(mut service) = self.tcp_services.get_mut(&port) {
service.config = config;
}
}
pub fn set_udp_config(&self, port: u16, config: StreamProxyConfig) {
if let Some(mut service) = self.udp_services.get_mut(&port) {
service.config = config;
}
}
pub fn update_tcp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
if let Some(mut service) = self.tcp_services.get_mut(&port) {
tracing::debug!(
port = port,
service = %service.name,
old_count = service.backend_count(),
new_count = backends.len(),
"Updating TCP backends"
);
service.update_backends(backends);
}
}
pub fn update_udp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
if let Some(mut service) = self.udp_services.get_mut(&port) {
tracing::debug!(
port = port,
service = %service.name,
old_count = service.backend_count(),
new_count = backends.len(),
"Updating UDP backends"
);
service.update_backends(backends);
}
}
#[must_use]
pub fn unregister_tcp(&self, port: u16) -> Option<StreamService> {
self.tcp_services.remove(&port).map(|(_, s)| s)
}
#[must_use]
pub fn unregister_udp(&self, port: u16) -> Option<StreamService> {
self.udp_services.remove(&port).map(|(_, s)| s)
}
#[must_use]
pub fn tcp_count(&self) -> usize {
self.tcp_services.len()
}
#[must_use]
pub fn udp_count(&self) -> usize {
self.udp_services.len()
}
#[must_use]
pub fn tcp_ports(&self) -> Vec<u16> {
self.tcp_services.iter().map(|e| *e.key()).collect()
}
#[must_use]
pub fn udp_ports(&self) -> Vec<u16> {
self.udp_services.iter().map(|e| *e.key()).collect()
}
#[must_use]
pub fn list_tcp_services(&self) -> Vec<(u16, StreamService)> {
self.tcp_services
.iter()
.map(|e| (*e.key(), e.value().clone()))
.collect()
}
#[must_use]
pub fn list_udp_services(&self) -> Vec<(u16, StreamService)> {
self.udp_services
.iter()
.map(|e| (*e.key(), e.value().clone()))
.collect()
}
#[must_use]
pub fn spawn_health_checker(
self: &Arc<Self>,
interval: Duration,
timeout: Duration,
) -> tokio::task::JoinHandle<()> {
let registry = Arc::clone(self);
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
for entry in ®istry.tcp_services {
let service = entry.value().clone();
let backends = service.backends.clone();
for addr in backends {
let svc = service.clone();
let probe_timeout = timeout;
tokio::spawn(async move {
let result = tokio::time::timeout(
probe_timeout,
tokio::net::TcpStream::connect(addr),
)
.await;
let health = match result {
Ok(Ok(_stream)) => BackendHealth::Healthy,
Ok(Err(e)) => {
tracing::debug!(
service = %svc.name,
backend = %addr,
error = %e,
"TCP health check failed (connect error)"
);
BackendHealth::Unhealthy
}
Err(_) => {
tracing::debug!(
service = %svc.name,
backend = %addr,
"TCP health check failed (timeout)"
);
BackendHealth::Unhealthy
}
};
svc.set_backend_health(addr, health).await;
});
}
}
for entry in ®istry.udp_services {
let service = entry.value().clone();
let Some(StreamHealthProbe::UdpProbe { request, expect }) =
service.config.health_check.clone()
else {
continue;
};
let backends = service.backends.clone();
for addr in backends {
let svc = service.clone();
let probe_timeout = timeout;
let request = request.clone();
let expect = expect.clone();
tokio::spawn(async move {
let health = match probe_udp_backend(
addr,
&request,
expect.as_deref(),
probe_timeout,
)
.await
{
Ok(true) => BackendHealth::Healthy,
Ok(false) => {
tracing::debug!(
service = %svc.name,
backend = %addr,
"UDP health check failed (reply did not match expect)"
);
BackendHealth::Unhealthy
}
Err(e) => {
tracing::debug!(
service = %svc.name,
backend = %addr,
error = %e,
"UDP health check failed"
);
BackendHealth::Unhealthy
}
};
svc.set_backend_health(addr, health).await;
});
}
}
}
})
}
}
pub async fn probe_udp_backend(
addr: SocketAddr,
request: &[u8],
expect: Option<&[u8]>,
timeout: Duration,
) -> std::result::Result<bool, std::io::Error> {
let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
socket.connect(addr).await?;
socket.send(request).await?;
let mut buf = vec![0u8; 65535];
let len = tokio::time::timeout(timeout, socket.recv(&mut buf))
.await
.map_err(|_| {
std::io::Error::new(std::io::ErrorKind::TimedOut, "UDP health probe timed out")
})??;
let reply = &buf[..len];
match expect {
Some(pat) => Ok(byte_contains(reply, pat)),
None => Ok(true),
}
}
#[must_use]
fn byte_contains(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() {
return true;
}
if needle.len() > haystack.len() {
return false;
}
haystack.windows(needle.len()).any(|w| w == needle)
}
#[cfg(test)]
mod health_probe_tests {
use super::*;
use std::time::Duration;
use tokio::net::UdpSocket;
#[test]
fn byte_contains_matches() {
assert!(byte_contains(b"hello world", b"world"));
assert!(byte_contains(b"hello world", b"hello"));
assert!(byte_contains(b"\xFF\x00\xAB", b"\x00\xAB"));
assert!(byte_contains(b"anything", b"")); }
#[test]
fn byte_contains_rejects() {
assert!(!byte_contains(b"hello", b"world"));
assert!(!byte_contains(b"abc", b"abcd")); assert!(!byte_contains(b"", b"x"));
}
#[tokio::test]
async fn udp_probe_healthy_against_echo() {
let echo = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let echo_addr = echo.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0u8; 1500];
if let Ok((n, peer)) = echo.recv_from(&mut buf).await {
let _ = echo.send_to(&buf[..n], peer).await;
}
});
let ok = probe_udp_backend(echo_addr, b"ping", None, Duration::from_secs(2))
.await
.unwrap();
assert!(ok, "echo reply with no expect must be healthy");
}
#[tokio::test]
async fn udp_probe_expect_substring() {
let echo = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let echo_addr = echo.local_addr().unwrap();
tokio::spawn(async move {
let mut buf = vec![0u8; 1500];
for _ in 0..2 {
if let Ok((n, peer)) = echo.recv_from(&mut buf).await {
let _ = echo.send_to(&buf[..n], peer).await;
}
}
});
let ok = probe_udp_backend(
echo_addr,
b"PONG-token",
Some(b"token"),
Duration::from_secs(2),
)
.await
.unwrap();
assert!(ok, "reply containing expect substring must be healthy");
let not_matched =
probe_udp_backend(echo_addr, b"abc", Some(b"zzz"), Duration::from_secs(2))
.await
.unwrap();
assert!(
!not_matched,
"reply missing expect substring must be unhealthy"
);
}
#[tokio::test]
async fn udp_probe_dead_port_times_out() {
let dead = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let dead_addr = dead.local_addr().unwrap();
drop(dead);
let res = probe_udp_backend(dead_addr, b"ping", None, Duration::from_millis(300)).await;
assert!(res.is_err(), "probe to dead UDP port must error (timeout)");
}
}