1use dashmap::DashMap;
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16use super::config::{StreamHealthProbe, StreamProxyConfig};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum BackendHealth {
21 Healthy,
23 Unhealthy,
25 Unknown,
27}
28
29impl BackendHealth {
30 #[must_use]
32 pub fn is_usable(self) -> bool {
33 matches!(self, BackendHealth::Healthy | BackendHealth::Unknown)
34 }
35}
36
37#[derive(Clone, Debug)]
39pub struct StreamService {
40 pub name: String,
42 pub backends: Vec<SocketAddr>,
44 health: Arc<RwLock<HashMap<SocketAddr, BackendHealth>>>,
46 rr_index: Arc<AtomicUsize>,
48 pub config: StreamProxyConfig,
52}
53
54impl StreamService {
55 #[must_use]
57 pub fn new(name: String, backends: Vec<SocketAddr>) -> Self {
58 let health: HashMap<SocketAddr, BackendHealth> = backends
59 .iter()
60 .map(|addr| (*addr, BackendHealth::Unknown))
61 .collect();
62 Self {
63 name,
64 backends,
65 health: Arc::new(RwLock::new(health)),
66 rr_index: Arc::new(AtomicUsize::new(0)),
67 config: StreamProxyConfig::default(),
68 }
69 }
70
71 #[must_use]
77 pub fn with_config(mut self, config: StreamProxyConfig) -> Self {
78 self.config = config;
79 self
80 }
81
82 #[must_use]
87 pub fn select_backend(&self) -> Option<SocketAddr> {
88 if self.backends.is_empty() {
89 return None;
90 }
91
92 let len = self.backends.len();
93 let start = self.rr_index.fetch_add(1, Ordering::Relaxed);
94
95 let health_guard = self.health.try_read();
98
99 if let Ok(health) = health_guard {
100 for i in 0..len {
102 let idx = (start + i) % len;
103 let addr = self.backends[idx];
104 let status = health.get(&addr).copied().unwrap_or(BackendHealth::Unknown);
105 if status.is_usable() {
106 return Some(addr);
107 }
108 }
109 }
110
111 Some(self.backends[start % len])
113 }
114
115 pub fn update_backends(&mut self, backends: Vec<SocketAddr>) {
120 let mut health = self
123 .health
124 .try_write()
125 .unwrap_or_else(|_| {
126 tracing::warn!(service = %self.name, "Health map write contention during backend update");
129 unreachable!("update_backends requires exclusive access")
131 });
132
133 for addr in &backends {
135 health.entry(*addr).or_insert(BackendHealth::Unknown);
136 }
137
138 let backend_set: std::collections::HashSet<SocketAddr> = backends.iter().copied().collect();
140 health.retain(|addr, _| backend_set.contains(addr));
141
142 self.backends = backends;
143 }
144
145 pub async fn set_backend_health(&self, addr: SocketAddr, status: BackendHealth) {
147 let mut health = self.health.write().await;
148 if let Some(h) = health.get_mut(&addr) {
149 *h = status;
150 }
151 }
152
153 pub async fn get_backend_health(&self, addr: SocketAddr) -> BackendHealth {
155 let health = self.health.read().await;
156 health.get(&addr).copied().unwrap_or(BackendHealth::Unknown)
157 }
158
159 #[must_use]
161 pub fn backend_count(&self) -> usize {
162 self.backends.len()
163 }
164
165 pub async fn healthy_count(&self) -> usize {
167 let health = self.health.read().await;
168 self.backends
169 .iter()
170 .filter(|addr| {
171 health
172 .get(addr)
173 .copied()
174 .unwrap_or(BackendHealth::Unknown)
175 .is_usable()
176 })
177 .count()
178 }
179}
180
181#[derive(Default)]
185pub struct StreamRegistry {
186 tcp_services: DashMap<u16, StreamService>,
188 udp_services: DashMap<u16, StreamService>,
190}
191
192impl StreamRegistry {
193 #[must_use]
195 pub fn new() -> Self {
196 Self::default()
197 }
198
199 pub fn register_tcp(&self, port: u16, service: StreamService) {
201 tracing::debug!(
202 port = port,
203 service = %service.name,
204 backends = service.backend_count(),
205 "Registered TCP stream service"
206 );
207 self.tcp_services.insert(port, service);
208 }
209
210 pub fn register_udp(&self, port: u16, service: StreamService) {
212 tracing::debug!(
213 port = port,
214 service = %service.name,
215 backends = service.backend_count(),
216 "Registered UDP stream service"
217 );
218 self.udp_services.insert(port, service);
219 }
220
221 #[must_use]
223 pub fn resolve_tcp(&self, port: u16) -> Option<StreamService> {
224 self.tcp_services.get(&port).map(|s| s.clone())
225 }
226
227 #[must_use]
229 pub fn resolve_udp(&self, port: u16) -> Option<StreamService> {
230 self.udp_services.get(&port).map(|s| s.clone())
231 }
232
233 pub fn set_tcp_config(&self, port: u16, config: StreamProxyConfig) {
239 if let Some(mut service) = self.tcp_services.get_mut(&port) {
240 service.config = config;
241 }
242 }
243
244 pub fn set_udp_config(&self, port: u16, config: StreamProxyConfig) {
248 if let Some(mut service) = self.udp_services.get_mut(&port) {
249 service.config = config;
250 }
251 }
252
253 pub fn update_tcp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
255 if let Some(mut service) = self.tcp_services.get_mut(&port) {
256 tracing::debug!(
257 port = port,
258 service = %service.name,
259 old_count = service.backend_count(),
260 new_count = backends.len(),
261 "Updating TCP backends"
262 );
263 service.update_backends(backends);
264 }
265 }
266
267 pub fn update_udp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
269 if let Some(mut service) = self.udp_services.get_mut(&port) {
270 tracing::debug!(
271 port = port,
272 service = %service.name,
273 old_count = service.backend_count(),
274 new_count = backends.len(),
275 "Updating UDP backends"
276 );
277 service.update_backends(backends);
278 }
279 }
280
281 #[must_use]
283 pub fn unregister_tcp(&self, port: u16) -> Option<StreamService> {
284 self.tcp_services.remove(&port).map(|(_, s)| s)
285 }
286
287 #[must_use]
289 pub fn unregister_udp(&self, port: u16) -> Option<StreamService> {
290 self.udp_services.remove(&port).map(|(_, s)| s)
291 }
292
293 #[must_use]
295 pub fn tcp_count(&self) -> usize {
296 self.tcp_services.len()
297 }
298
299 #[must_use]
301 pub fn udp_count(&self) -> usize {
302 self.udp_services.len()
303 }
304
305 #[must_use]
307 pub fn tcp_ports(&self) -> Vec<u16> {
308 self.tcp_services.iter().map(|e| *e.key()).collect()
309 }
310
311 #[must_use]
313 pub fn udp_ports(&self) -> Vec<u16> {
314 self.udp_services.iter().map(|e| *e.key()).collect()
315 }
316
317 #[must_use]
319 pub fn list_tcp_services(&self) -> Vec<(u16, StreamService)> {
320 self.tcp_services
321 .iter()
322 .map(|e| (*e.key(), e.value().clone()))
323 .collect()
324 }
325
326 #[must_use]
328 pub fn list_udp_services(&self) -> Vec<(u16, StreamService)> {
329 self.udp_services
330 .iter()
331 .map(|e| (*e.key(), e.value().clone()))
332 .collect()
333 }
334
335 #[must_use]
351 pub fn spawn_health_checker(
352 self: &Arc<Self>,
353 interval: Duration,
354 timeout: Duration,
355 ) -> tokio::task::JoinHandle<()> {
356 let registry = Arc::clone(self);
357
358 tokio::spawn(async move {
359 let mut ticker = tokio::time::interval(interval);
360 ticker.tick().await;
362
363 loop {
364 ticker.tick().await;
365
366 for entry in ®istry.tcp_services {
368 let service = entry.value().clone();
369 let backends = service.backends.clone();
370
371 for addr in backends {
372 let svc = service.clone();
373 let probe_timeout = timeout;
374
375 tokio::spawn(async move {
377 let result = tokio::time::timeout(
378 probe_timeout,
379 tokio::net::TcpStream::connect(addr),
380 )
381 .await;
382
383 let health = match result {
384 Ok(Ok(_stream)) => BackendHealth::Healthy,
385 Ok(Err(e)) => {
386 tracing::debug!(
387 service = %svc.name,
388 backend = %addr,
389 error = %e,
390 "TCP health check failed (connect error)"
391 );
392 BackendHealth::Unhealthy
393 }
394 Err(_) => {
395 tracing::debug!(
396 service = %svc.name,
397 backend = %addr,
398 "TCP health check failed (timeout)"
399 );
400 BackendHealth::Unhealthy
401 }
402 };
403
404 svc.set_backend_health(addr, health).await;
405 });
406 }
407 }
408
409 for entry in ®istry.udp_services {
412 let service = entry.value().clone();
413 let Some(StreamHealthProbe::UdpProbe { request, expect }) =
414 service.config.health_check.clone()
415 else {
416 continue;
417 };
418 let backends = service.backends.clone();
419
420 for addr in backends {
421 let svc = service.clone();
422 let probe_timeout = timeout;
423 let request = request.clone();
424 let expect = expect.clone();
425
426 tokio::spawn(async move {
427 let health = match probe_udp_backend(
428 addr,
429 &request,
430 expect.as_deref(),
431 probe_timeout,
432 )
433 .await
434 {
435 Ok(true) => BackendHealth::Healthy,
436 Ok(false) => {
437 tracing::debug!(
438 service = %svc.name,
439 backend = %addr,
440 "UDP health check failed (reply did not match expect)"
441 );
442 BackendHealth::Unhealthy
443 }
444 Err(e) => {
445 tracing::debug!(
446 service = %svc.name,
447 backend = %addr,
448 error = %e,
449 "UDP health check failed"
450 );
451 BackendHealth::Unhealthy
452 }
453 };
454
455 svc.set_backend_health(addr, health).await;
456 });
457 }
458 }
459 }
460 })
461 }
462}
463
464pub async fn probe_udp_backend(
477 addr: SocketAddr,
478 request: &[u8],
479 expect: Option<&[u8]>,
480 timeout: Duration,
481) -> std::result::Result<bool, std::io::Error> {
482 let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
483 socket.connect(addr).await?;
484 socket.send(request).await?;
485
486 let mut buf = vec![0u8; 65535];
487 let len = tokio::time::timeout(timeout, socket.recv(&mut buf))
488 .await
489 .map_err(|_| {
490 std::io::Error::new(std::io::ErrorKind::TimedOut, "UDP health probe timed out")
491 })??;
492
493 let reply = &buf[..len];
494 match expect {
495 Some(pat) => Ok(byte_contains(reply, pat)),
496 None => Ok(true),
497 }
498}
499
500#[must_use]
503fn byte_contains(haystack: &[u8], needle: &[u8]) -> bool {
504 if needle.is_empty() {
505 return true;
506 }
507 if needle.len() > haystack.len() {
508 return false;
509 }
510 haystack.windows(needle.len()).any(|w| w == needle)
511}
512
513#[cfg(test)]
514mod health_probe_tests {
515 use super::*;
516 use std::time::Duration;
517 use tokio::net::UdpSocket;
518
519 #[test]
520 fn byte_contains_matches() {
521 assert!(byte_contains(b"hello world", b"world"));
522 assert!(byte_contains(b"hello world", b"hello"));
523 assert!(byte_contains(b"\xFF\x00\xAB", b"\x00\xAB"));
524 assert!(byte_contains(b"anything", b"")); }
526
527 #[test]
528 fn byte_contains_rejects() {
529 assert!(!byte_contains(b"hello", b"world"));
530 assert!(!byte_contains(b"abc", b"abcd")); assert!(!byte_contains(b"", b"x"));
532 }
533
534 #[tokio::test]
535 async fn udp_probe_healthy_against_echo() {
536 let echo = UdpSocket::bind("127.0.0.1:0").await.unwrap();
538 let echo_addr = echo.local_addr().unwrap();
539 tokio::spawn(async move {
540 let mut buf = vec![0u8; 1500];
541 if let Ok((n, peer)) = echo.recv_from(&mut buf).await {
542 let _ = echo.send_to(&buf[..n], peer).await;
543 }
544 });
545
546 let ok = probe_udp_backend(echo_addr, b"ping", None, Duration::from_secs(2))
548 .await
549 .unwrap();
550 assert!(ok, "echo reply with no expect must be healthy");
551 }
552
553 #[tokio::test]
554 async fn udp_probe_expect_substring() {
555 let echo = UdpSocket::bind("127.0.0.1:0").await.unwrap();
556 let echo_addr = echo.local_addr().unwrap();
557 tokio::spawn(async move {
558 let mut buf = vec![0u8; 1500];
559 for _ in 0..2 {
560 if let Ok((n, peer)) = echo.recv_from(&mut buf).await {
561 let _ = echo.send_to(&buf[..n], peer).await;
562 }
563 }
564 });
565
566 let ok = probe_udp_backend(
568 echo_addr,
569 b"PONG-token",
570 Some(b"token"),
571 Duration::from_secs(2),
572 )
573 .await
574 .unwrap();
575 assert!(ok, "reply containing expect substring must be healthy");
576
577 let not_matched =
579 probe_udp_backend(echo_addr, b"abc", Some(b"zzz"), Duration::from_secs(2))
580 .await
581 .unwrap();
582 assert!(
583 !not_matched,
584 "reply missing expect substring must be unhealthy"
585 );
586 }
587
588 #[tokio::test]
589 async fn udp_probe_dead_port_times_out() {
590 let dead = UdpSocket::bind("127.0.0.1:0").await.unwrap();
592 let dead_addr = dead.local_addr().unwrap();
593 drop(dead);
594
595 let res = probe_udp_backend(dead_addr, b"ping", None, Duration::from_millis(300)).await;
597 assert!(res.is_err(), "probe to dead UDP port must error (timeout)");
598 }
599}