use dashmap::DashMap;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BackendHealth {
Healthy,
Unhealthy,
Unknown,
}
impl BackendHealth {
#[must_use]
pub fn is_usable(self) -> bool {
matches!(self, BackendHealth::Healthy | BackendHealth::Unknown)
}
}
#[derive(Clone, Debug)]
pub struct StreamService {
pub name: String,
pub backends: Vec<SocketAddr>,
health: Arc<RwLock<HashMap<SocketAddr, BackendHealth>>>,
rr_index: Arc<AtomicUsize>,
}
impl StreamService {
#[must_use]
pub fn new(name: String, backends: Vec<SocketAddr>) -> Self {
let health: HashMap<SocketAddr, BackendHealth> = backends
.iter()
.map(|addr| (*addr, BackendHealth::Unknown))
.collect();
Self {
name,
backends,
health: Arc::new(RwLock::new(health)),
rr_index: Arc::new(AtomicUsize::new(0)),
}
}
#[must_use]
pub fn select_backend(&self) -> Option<SocketAddr> {
if self.backends.is_empty() {
return None;
}
let len = self.backends.len();
let start = self.rr_index.fetch_add(1, Ordering::Relaxed);
let health_guard = self.health.try_read();
if let Ok(health) = health_guard {
for i in 0..len {
let idx = (start + i) % len;
let addr = self.backends[idx];
let status = health.get(&addr).copied().unwrap_or(BackendHealth::Unknown);
if status.is_usable() {
return Some(addr);
}
}
}
Some(self.backends[start % len])
}
pub fn update_backends(&mut self, backends: Vec<SocketAddr>) {
let mut health = self
.health
.try_write()
.unwrap_or_else(|_| {
tracing::warn!(service = %self.name, "Health map write contention during backend update");
unreachable!("update_backends requires exclusive access")
});
for addr in &backends {
health.entry(*addr).or_insert(BackendHealth::Unknown);
}
let backend_set: std::collections::HashSet<SocketAddr> = backends.iter().copied().collect();
health.retain(|addr, _| backend_set.contains(addr));
self.backends = backends;
}
pub async fn set_backend_health(&self, addr: SocketAddr, status: BackendHealth) {
let mut health = self.health.write().await;
if let Some(h) = health.get_mut(&addr) {
*h = status;
}
}
pub async fn get_backend_health(&self, addr: SocketAddr) -> BackendHealth {
let health = self.health.read().await;
health.get(&addr).copied().unwrap_or(BackendHealth::Unknown)
}
#[must_use]
pub fn backend_count(&self) -> usize {
self.backends.len()
}
pub async fn healthy_count(&self) -> usize {
let health = self.health.read().await;
self.backends
.iter()
.filter(|addr| {
health
.get(addr)
.copied()
.unwrap_or(BackendHealth::Unknown)
.is_usable()
})
.count()
}
}
#[derive(Default)]
pub struct StreamRegistry {
tcp_services: DashMap<u16, StreamService>,
udp_services: DashMap<u16, StreamService>,
}
impl StreamRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_tcp(&self, port: u16, service: StreamService) {
tracing::debug!(
port = port,
service = %service.name,
backends = service.backend_count(),
"Registered TCP stream service"
);
self.tcp_services.insert(port, service);
}
pub fn register_udp(&self, port: u16, service: StreamService) {
tracing::debug!(
port = port,
service = %service.name,
backends = service.backend_count(),
"Registered UDP stream service"
);
self.udp_services.insert(port, service);
}
#[must_use]
pub fn resolve_tcp(&self, port: u16) -> Option<StreamService> {
self.tcp_services.get(&port).map(|s| s.clone())
}
#[must_use]
pub fn resolve_udp(&self, port: u16) -> Option<StreamService> {
self.udp_services.get(&port).map(|s| s.clone())
}
pub fn update_tcp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
if let Some(mut service) = self.tcp_services.get_mut(&port) {
tracing::debug!(
port = port,
service = %service.name,
old_count = service.backend_count(),
new_count = backends.len(),
"Updating TCP backends"
);
service.update_backends(backends);
}
}
pub fn update_udp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
if let Some(mut service) = self.udp_services.get_mut(&port) {
tracing::debug!(
port = port,
service = %service.name,
old_count = service.backend_count(),
new_count = backends.len(),
"Updating UDP backends"
);
service.update_backends(backends);
}
}
#[must_use]
pub fn unregister_tcp(&self, port: u16) -> Option<StreamService> {
self.tcp_services.remove(&port).map(|(_, s)| s)
}
#[must_use]
pub fn unregister_udp(&self, port: u16) -> Option<StreamService> {
self.udp_services.remove(&port).map(|(_, s)| s)
}
#[must_use]
pub fn tcp_count(&self) -> usize {
self.tcp_services.len()
}
#[must_use]
pub fn udp_count(&self) -> usize {
self.udp_services.len()
}
#[must_use]
pub fn tcp_ports(&self) -> Vec<u16> {
self.tcp_services.iter().map(|e| *e.key()).collect()
}
#[must_use]
pub fn udp_ports(&self) -> Vec<u16> {
self.udp_services.iter().map(|e| *e.key()).collect()
}
#[must_use]
pub fn list_tcp_services(&self) -> Vec<(u16, StreamService)> {
self.tcp_services
.iter()
.map(|e| (*e.key(), e.value().clone()))
.collect()
}
#[must_use]
pub fn list_udp_services(&self) -> Vec<(u16, StreamService)> {
self.udp_services
.iter()
.map(|e| (*e.key(), e.value().clone()))
.collect()
}
#[must_use]
pub fn spawn_health_checker(
self: &Arc<Self>,
interval: Duration,
timeout: Duration,
) -> tokio::task::JoinHandle<()> {
let registry = Arc::clone(self);
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
for entry in ®istry.tcp_services {
let service = entry.value().clone();
let backends = service.backends.clone();
for addr in backends {
let svc = service.clone();
let probe_timeout = timeout;
tokio::spawn(async move {
let result = tokio::time::timeout(
probe_timeout,
tokio::net::TcpStream::connect(addr),
)
.await;
let health = match result {
Ok(Ok(_stream)) => BackendHealth::Healthy,
Ok(Err(e)) => {
tracing::debug!(
service = %svc.name,
backend = %addr,
error = %e,
"TCP health check failed (connect error)"
);
BackendHealth::Unhealthy
}
Err(_) => {
tracing::debug!(
service = %svc.name,
backend = %addr,
"TCP health check failed (timeout)"
);
BackendHealth::Unhealthy
}
};
svc.set_backend_health(addr, health).await;
});
}
}
}
})
}
}