Skip to main content

zlayer_proxy/stream/
registry.rs

1//! Stream service registry for L4 routing
2//!
3//! Maps listen ports to backend services for TCP and UDP proxying.
4//! Includes health-aware backend selection: unhealthy backends are
5//! skipped during round-robin selection, with a fallback to any
6//! backend if all are marked unhealthy.
7
8use dashmap::DashMap;
9use std::collections::HashMap;
10use std::net::SocketAddr;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15
16use super::config::{StreamHealthProbe, StreamProxyConfig};
17
18/// Health state of a stream backend
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum BackendHealth {
21    /// Backend is reachable and accepting connections
22    Healthy,
23    /// Backend failed the last health probe
24    Unhealthy,
25    /// Health has not yet been determined (treated as healthy)
26    Unknown,
27}
28
29impl BackendHealth {
30    /// Returns `true` if the backend should be considered usable.
31    #[must_use]
32    pub fn is_usable(self) -> bool {
33        matches!(self, BackendHealth::Healthy | BackendHealth::Unknown)
34    }
35}
36
37/// A resolved stream service with backend addresses and health state
38#[derive(Clone, Debug)]
39pub struct StreamService {
40    /// Service name (for logging/metrics)
41    pub name: String,
42    /// Backend addresses for load balancing
43    pub backends: Vec<SocketAddr>,
44    /// Per-backend health state
45    health: Arc<RwLock<HashMap<SocketAddr, BackendHealth>>>,
46    /// Round-robin index for backend selection
47    rr_index: Arc<AtomicUsize>,
48    /// Runtime L4 config (TLS / proxy-protocol / session-timeout / health probe)
49    /// translated from the endpoint's `stream:` block. Drives the health
50    /// checker's probe selection (see [`StreamRegistry::spawn_health_checker`]).
51    pub config: StreamProxyConfig,
52}
53
54impl StreamService {
55    /// Create a new stream service
56    #[must_use]
57    pub fn new(name: String, backends: Vec<SocketAddr>) -> Self {
58        let health: HashMap<SocketAddr, BackendHealth> = backends
59            .iter()
60            .map(|addr| (*addr, BackendHealth::Unknown))
61            .collect();
62        Self {
63            name,
64            backends,
65            health: Arc::new(RwLock::new(health)),
66            rr_index: Arc::new(AtomicUsize::new(0)),
67            config: StreamProxyConfig::default(),
68        }
69    }
70
71    /// Attach a runtime [`StreamProxyConfig`] to this service.
72    ///
73    /// Builder-style; preserves [`StreamService::new`]'s 2-arg arity so the
74    /// existing call sites keep compiling. The config's `health_check` drives
75    /// what the background health checker probes.
76    #[must_use]
77    pub fn with_config(mut self, config: StreamProxyConfig) -> Self {
78        self.config = config;
79        self
80    }
81
82    /// Select next backend using round-robin, skipping unhealthy backends.
83    ///
84    /// Tries up to `backends.len()` candidates. If all backends are unhealthy,
85    /// falls back to returning *any* backend (better than nothing).
86    #[must_use]
87    pub fn select_backend(&self) -> Option<SocketAddr> {
88        if self.backends.is_empty() {
89            return None;
90        }
91
92        let len = self.backends.len();
93        let start = self.rr_index.fetch_add(1, Ordering::Relaxed);
94
95        // Try to read health state without blocking; if the lock is held,
96        // just fall through to simple round-robin.
97        let health_guard = self.health.try_read();
98
99        if let Ok(health) = health_guard {
100            // First pass: find a healthy backend
101            for i in 0..len {
102                let idx = (start + i) % len;
103                let addr = self.backends[idx];
104                let status = health.get(&addr).copied().unwrap_or(BackendHealth::Unknown);
105                if status.is_usable() {
106                    return Some(addr);
107                }
108            }
109        }
110
111        // Fallback: all unhealthy or lock contention — use simple round-robin
112        Some(self.backends[start % len])
113    }
114
115    /// Update backend addresses (for scaling events).
116    ///
117    /// New backends start with `Unknown` health; removed backends are pruned
118    /// from the health map.
119    pub fn update_backends(&mut self, backends: Vec<SocketAddr>) {
120        // We need to block here since this is called from a &mut self context
121        // (inside DashMap::get_mut), so we can use blocking write.
122        let mut health = self
123            .health
124            .try_write()
125            .unwrap_or_else(|_| {
126                // In the extremely unlikely case of write contention, just proceed
127                // with a fresh health map.
128                tracing::warn!(service = %self.name, "Health map write contention during backend update");
129                // This should never actually happen since update_backends holds &mut self
130                unreachable!("update_backends requires exclusive access")
131            });
132
133        // Add new backends with Unknown health
134        for addr in &backends {
135            health.entry(*addr).or_insert(BackendHealth::Unknown);
136        }
137
138        // Remove backends that are no longer present
139        let backend_set: std::collections::HashSet<SocketAddr> = backends.iter().copied().collect();
140        health.retain(|addr, _| backend_set.contains(addr));
141
142        self.backends = backends;
143    }
144
145    /// Set the health status of a specific backend
146    pub async fn set_backend_health(&self, addr: SocketAddr, status: BackendHealth) {
147        let mut health = self.health.write().await;
148        if let Some(h) = health.get_mut(&addr) {
149            *h = status;
150        }
151    }
152
153    /// Get the health status of a specific backend
154    pub async fn get_backend_health(&self, addr: SocketAddr) -> BackendHealth {
155        let health = self.health.read().await;
156        health.get(&addr).copied().unwrap_or(BackendHealth::Unknown)
157    }
158
159    /// Get current backend count
160    #[must_use]
161    pub fn backend_count(&self) -> usize {
162        self.backends.len()
163    }
164
165    /// Get count of healthy (usable) backends
166    pub async fn healthy_count(&self) -> usize {
167        let health = self.health.read().await;
168        self.backends
169            .iter()
170            .filter(|addr| {
171                health
172                    .get(addr)
173                    .copied()
174                    .unwrap_or(BackendHealth::Unknown)
175                    .is_usable()
176            })
177            .count()
178    }
179}
180
181/// Registry for L4 stream services
182///
183/// Maps listen ports to services for both TCP and UDP protocols.
184#[derive(Default)]
185pub struct StreamRegistry {
186    /// TCP services by listen port
187    tcp_services: DashMap<u16, StreamService>,
188    /// UDP services by listen port
189    udp_services: DashMap<u16, StreamService>,
190}
191
192impl StreamRegistry {
193    /// Create a new empty registry
194    #[must_use]
195    pub fn new() -> Self {
196        Self::default()
197    }
198
199    /// Register a TCP service for a port
200    pub fn register_tcp(&self, port: u16, service: StreamService) {
201        tracing::debug!(
202            port = port,
203            service = %service.name,
204            backends = service.backend_count(),
205            "Registered TCP stream service"
206        );
207        self.tcp_services.insert(port, service);
208    }
209
210    /// Register a UDP service for a port
211    pub fn register_udp(&self, port: u16, service: StreamService) {
212        tracing::debug!(
213            port = port,
214            service = %service.name,
215            backends = service.backend_count(),
216            "Registered UDP stream service"
217        );
218        self.udp_services.insert(port, service);
219    }
220
221    /// Resolve TCP service for a port
222    #[must_use]
223    pub fn resolve_tcp(&self, port: u16) -> Option<StreamService> {
224        self.tcp_services.get(&port).map(|s| s.clone())
225    }
226
227    /// Resolve UDP service for a port
228    #[must_use]
229    pub fn resolve_udp(&self, port: u16) -> Option<StreamService> {
230        self.udp_services.get(&port).map(|s| s.clone())
231    }
232
233    /// Apply a runtime [`StreamProxyConfig`] to the TCP service on `port`.
234    ///
235    /// No-op when no TCP service is registered for that port. Used by the agent
236    /// to attach the endpoint's translated `stream:` settings (notably the
237    /// health probe) to a service whose backends were registered out-of-band.
238    pub fn set_tcp_config(&self, port: u16, config: StreamProxyConfig) {
239        if let Some(mut service) = self.tcp_services.get_mut(&port) {
240            service.config = config;
241        }
242    }
243
244    /// Apply a runtime [`StreamProxyConfig`] to the UDP service on `port`.
245    ///
246    /// No-op when no UDP service is registered for that port.
247    pub fn set_udp_config(&self, port: u16, config: StreamProxyConfig) {
248        if let Some(mut service) = self.udp_services.get_mut(&port) {
249            service.config = config;
250        }
251    }
252
253    /// Update backends for a TCP service
254    pub fn update_tcp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
255        if let Some(mut service) = self.tcp_services.get_mut(&port) {
256            tracing::debug!(
257                port = port,
258                service = %service.name,
259                old_count = service.backend_count(),
260                new_count = backends.len(),
261                "Updating TCP backends"
262            );
263            service.update_backends(backends);
264        }
265    }
266
267    /// Update backends for a UDP service
268    pub fn update_udp_backends(&self, port: u16, backends: Vec<SocketAddr>) {
269        if let Some(mut service) = self.udp_services.get_mut(&port) {
270            tracing::debug!(
271                port = port,
272                service = %service.name,
273                old_count = service.backend_count(),
274                new_count = backends.len(),
275                "Updating UDP backends"
276            );
277            service.update_backends(backends);
278        }
279    }
280
281    /// Remove a TCP service
282    #[must_use]
283    pub fn unregister_tcp(&self, port: u16) -> Option<StreamService> {
284        self.tcp_services.remove(&port).map(|(_, s)| s)
285    }
286
287    /// Remove a UDP service
288    #[must_use]
289    pub fn unregister_udp(&self, port: u16) -> Option<StreamService> {
290        self.udp_services.remove(&port).map(|(_, s)| s)
291    }
292
293    /// Get count of registered TCP services
294    #[must_use]
295    pub fn tcp_count(&self) -> usize {
296        self.tcp_services.len()
297    }
298
299    /// Get count of registered UDP services
300    #[must_use]
301    pub fn udp_count(&self) -> usize {
302        self.udp_services.len()
303    }
304
305    /// List all registered TCP ports
306    #[must_use]
307    pub fn tcp_ports(&self) -> Vec<u16> {
308        self.tcp_services.iter().map(|e| *e.key()).collect()
309    }
310
311    /// List all registered UDP ports
312    #[must_use]
313    pub fn udp_ports(&self) -> Vec<u16> {
314        self.udp_services.iter().map(|e| *e.key()).collect()
315    }
316
317    /// List all registered TCP services with their listen ports.
318    #[must_use]
319    pub fn list_tcp_services(&self) -> Vec<(u16, StreamService)> {
320        self.tcp_services
321            .iter()
322            .map(|e| (*e.key(), e.value().clone()))
323            .collect()
324    }
325
326    /// List all registered UDP services with their listen ports.
327    #[must_use]
328    pub fn list_udp_services(&self) -> Vec<(u16, StreamService)> {
329        self.udp_services
330            .iter()
331            .map(|e| (*e.key(), e.value().clone()))
332            .collect()
333    }
334
335    /// Spawn a background health checker that periodically probes registered
336    /// stream backends.
337    ///
338    /// TCP services are probed with a connect-only check (matching the
339    /// `TcpConnect` health-check type — the only supported TCP probe).
340    ///
341    /// UDP services are probed only when their [`StreamProxyConfig::health_check`]
342    /// is `Some(StreamHealthProbe::UdpProbe { .. })`: the configured request is
343    /// sent to each backend and the backend is marked `Healthy` iff a reply
344    /// arrives (and matches `expect` by byte-substring when set). UDP services
345    /// without a configured probe remain `Unknown` (always usable) — preserving
346    /// the previous "never probe UDP" behavior.
347    ///
348    /// The task runs every `interval` and uses `timeout` for each probe.
349    /// Returns a `JoinHandle` that can be used to cancel the checker.
350    #[must_use]
351    pub fn spawn_health_checker(
352        self: &Arc<Self>,
353        interval: Duration,
354        timeout: Duration,
355    ) -> tokio::task::JoinHandle<()> {
356        let registry = Arc::clone(self);
357
358        tokio::spawn(async move {
359            let mut ticker = tokio::time::interval(interval);
360            // Skip the first immediate tick
361            ticker.tick().await;
362
363            loop {
364                ticker.tick().await;
365
366                // Iterate all TCP services and probe each backend
367                for entry in &registry.tcp_services {
368                    let service = entry.value().clone();
369                    let backends = service.backends.clone();
370
371                    for addr in backends {
372                        let svc = service.clone();
373                        let probe_timeout = timeout;
374
375                        // Probe each backend concurrently
376                        tokio::spawn(async move {
377                            let result = tokio::time::timeout(
378                                probe_timeout,
379                                tokio::net::TcpStream::connect(addr),
380                            )
381                            .await;
382
383                            let health = match result {
384                                Ok(Ok(_stream)) => BackendHealth::Healthy,
385                                Ok(Err(e)) => {
386                                    tracing::debug!(
387                                        service = %svc.name,
388                                        backend = %addr,
389                                        error = %e,
390                                        "TCP health check failed (connect error)"
391                                    );
392                                    BackendHealth::Unhealthy
393                                }
394                                Err(_) => {
395                                    tracing::debug!(
396                                        service = %svc.name,
397                                        backend = %addr,
398                                        "TCP health check failed (timeout)"
399                                    );
400                                    BackendHealth::Unhealthy
401                                }
402                            };
403
404                            svc.set_backend_health(addr, health).await;
405                        });
406                    }
407                }
408
409                // Iterate all UDP services and probe each backend when a UDP
410                // probe is configured. Services without one are left untouched.
411                for entry in &registry.udp_services {
412                    let service = entry.value().clone();
413                    let Some(StreamHealthProbe::UdpProbe { request, expect }) =
414                        service.config.health_check.clone()
415                    else {
416                        continue;
417                    };
418                    let backends = service.backends.clone();
419
420                    for addr in backends {
421                        let svc = service.clone();
422                        let probe_timeout = timeout;
423                        let request = request.clone();
424                        let expect = expect.clone();
425
426                        tokio::spawn(async move {
427                            let health = match probe_udp_backend(
428                                addr,
429                                &request,
430                                expect.as_deref(),
431                                probe_timeout,
432                            )
433                            .await
434                            {
435                                Ok(true) => BackendHealth::Healthy,
436                                Ok(false) => {
437                                    tracing::debug!(
438                                        service = %svc.name,
439                                        backend = %addr,
440                                        "UDP health check failed (reply did not match expect)"
441                                    );
442                                    BackendHealth::Unhealthy
443                                }
444                                Err(e) => {
445                                    tracing::debug!(
446                                        service = %svc.name,
447                                        backend = %addr,
448                                        error = %e,
449                                        "UDP health check failed"
450                                    );
451                                    BackendHealth::Unhealthy
452                                }
453                            };
454
455                            svc.set_backend_health(addr, health).await;
456                        });
457                    }
458                }
459            }
460        })
461    }
462}
463
464/// Probe a single UDP backend by sending `request` and waiting (up to
465/// `timeout`) for any reply.
466///
467/// Returns `Ok(true)` when a reply arrives that satisfies `expect` (any reply
468/// when `expect` is `None`; a reply containing `expect` as a byte-substring
469/// otherwise). Returns `Ok(false)` when a reply arrives but does not contain
470/// `expect`. Returns `Err(..)` on socket errors or recv timeout.
471///
472/// # Errors
473///
474/// Returns an error if binding/connecting the probe socket fails, if the send
475/// fails, or if no reply arrives before `timeout` elapses.
476pub async fn probe_udp_backend(
477    addr: SocketAddr,
478    request: &[u8],
479    expect: Option<&[u8]>,
480    timeout: Duration,
481) -> std::result::Result<bool, std::io::Error> {
482    let socket = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
483    socket.connect(addr).await?;
484    socket.send(request).await?;
485
486    let mut buf = vec![0u8; 65535];
487    let len = tokio::time::timeout(timeout, socket.recv(&mut buf))
488        .await
489        .map_err(|_| {
490            std::io::Error::new(std::io::ErrorKind::TimedOut, "UDP health probe timed out")
491        })??;
492
493    let reply = &buf[..len];
494    match expect {
495        Some(pat) => Ok(byte_contains(reply, pat)),
496        None => Ok(true),
497    }
498}
499
500/// `true` iff `haystack` contains `needle` as a contiguous byte substring.
501/// An empty needle always matches.
502#[must_use]
503fn byte_contains(haystack: &[u8], needle: &[u8]) -> bool {
504    if needle.is_empty() {
505        return true;
506    }
507    if needle.len() > haystack.len() {
508        return false;
509    }
510    haystack.windows(needle.len()).any(|w| w == needle)
511}
512
513#[cfg(test)]
514mod health_probe_tests {
515    use super::*;
516    use std::time::Duration;
517    use tokio::net::UdpSocket;
518
519    #[test]
520    fn byte_contains_matches() {
521        assert!(byte_contains(b"hello world", b"world"));
522        assert!(byte_contains(b"hello world", b"hello"));
523        assert!(byte_contains(b"\xFF\x00\xAB", b"\x00\xAB"));
524        assert!(byte_contains(b"anything", b"")); // empty needle always matches
525    }
526
527    #[test]
528    fn byte_contains_rejects() {
529        assert!(!byte_contains(b"hello", b"world"));
530        assert!(!byte_contains(b"abc", b"abcd")); // needle longer than haystack
531        assert!(!byte_contains(b"", b"x"));
532    }
533
534    #[tokio::test]
535    async fn udp_probe_healthy_against_echo() {
536        // Spawn a tiny echo server.
537        let echo = UdpSocket::bind("127.0.0.1:0").await.unwrap();
538        let echo_addr = echo.local_addr().unwrap();
539        tokio::spawn(async move {
540            let mut buf = vec![0u8; 1500];
541            if let Ok((n, peer)) = echo.recv_from(&mut buf).await {
542                let _ = echo.send_to(&buf[..n], peer).await;
543            }
544        });
545
546        // Any reply (no expect) -> healthy.
547        let ok = probe_udp_backend(echo_addr, b"ping", None, Duration::from_secs(2))
548            .await
549            .unwrap();
550        assert!(ok, "echo reply with no expect must be healthy");
551    }
552
553    #[tokio::test]
554    async fn udp_probe_expect_substring() {
555        let echo = UdpSocket::bind("127.0.0.1:0").await.unwrap();
556        let echo_addr = echo.local_addr().unwrap();
557        tokio::spawn(async move {
558            let mut buf = vec![0u8; 1500];
559            for _ in 0..2 {
560                if let Ok((n, peer)) = echo.recv_from(&mut buf).await {
561                    let _ = echo.send_to(&buf[..n], peer).await;
562                }
563            }
564        });
565
566        // Reply contains expect substring -> healthy.
567        let ok = probe_udp_backend(
568            echo_addr,
569            b"PONG-token",
570            Some(b"token"),
571            Duration::from_secs(2),
572        )
573        .await
574        .unwrap();
575        assert!(ok, "reply containing expect substring must be healthy");
576
577        // Reply does NOT contain expect -> unhealthy (Ok(false)).
578        let not_matched =
579            probe_udp_backend(echo_addr, b"abc", Some(b"zzz"), Duration::from_secs(2))
580                .await
581                .unwrap();
582        assert!(
583            !not_matched,
584            "reply missing expect substring must be unhealthy"
585        );
586    }
587
588    #[tokio::test]
589    async fn udp_probe_dead_port_times_out() {
590        // Bind a socket to grab a free port, then drop it so nothing listens.
591        let dead = UdpSocket::bind("127.0.0.1:0").await.unwrap();
592        let dead_addr = dead.local_addr().unwrap();
593        drop(dead);
594
595        // No listener -> no reply -> timeout error (treated as Unhealthy by caller).
596        let res = probe_udp_backend(dead_addr, b"ping", None, Duration::from_millis(300)).await;
597        assert!(res.is_err(), "probe to dead UDP port must error (timeout)");
598    }
599}