Skip to main content

zlayer_agent/
proxy_manager.rs

1//! Proxy management for agent-controlled services
2//!
3//! This module provides the `ProxyManager` struct that integrates the proxy crate
4//! with the agent's service management. It handles:
5//! - Managing proxy routes based on `ServiceSpec` endpoints (HTTP/HTTPS/WebSocket)
6//! - Managing L4 stream proxy listeners (TCP/UDP)
7//! - Tracking and updating backend servers for load balancing
8//! - Coordinating proxy server lifecycle
9
10use crate::error::Result;
11use std::collections::{HashMap, HashSet};
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::RwLock;
17use tracing::{debug, info, warn};
18use zlayer_proxy::{
19    load_existing_certs_into_resolver, CertManager, LbStrategy, LoadBalancer, NetworkPolicyChecker,
20    ProxyConfig, ProxyServer, RouteEntry, ServiceRegistry, SniCertResolver, StreamRegistry,
21    TcpStreamService, UdpStreamService,
22};
23use zlayer_spec::{ExposeType, Protocol, ServiceSpec};
24
25/// Configuration for the `ProxyManager`
26#[derive(Debug, Clone)]
27pub struct ProxyManagerConfig {
28    /// HTTP bind address
29    pub http_addr: SocketAddr,
30    /// HTTPS bind address (optional)
31    pub https_addr: Option<SocketAddr>,
32    /// Whether to enable HTTP/2
33    pub http2_enabled: bool,
34}
35
36impl Default for ProxyManagerConfig {
37    fn default() -> Self {
38        Self {
39            http_addr: "0.0.0.0:80".parse().unwrap(),
40            https_addr: None,
41            http2_enabled: true,
42        }
43    }
44}
45
46impl ProxyManagerConfig {
47    /// Create a new configuration with the specified HTTP address
48    #[must_use]
49    pub fn new(http_addr: SocketAddr) -> Self {
50        Self {
51            http_addr,
52            https_addr: None,
53            http2_enabled: true,
54        }
55    }
56
57    /// Set the HTTPS address
58    #[must_use]
59    pub fn with_https(mut self, addr: SocketAddr) -> Self {
60        self.https_addr = Some(addr);
61        self
62    }
63
64    /// Set HTTP/2 support
65    #[must_use]
66    pub fn with_http2(mut self, enabled: bool) -> Self {
67        self.http2_enabled = enabled;
68        self
69    }
70}
71
72/// Per-service tracking information for cleanup purposes.
73#[derive(Debug, Clone)]
74struct ServiceTracking {
75    /// Endpoint names (retained for Debug output and future introspection)
76    #[allow(dead_code)]
77    endpoint_names: Vec<String>,
78    /// TCP ports owned by this service
79    tcp_ports: Vec<u16>,
80    /// UDP ports owned by this service
81    udp_ports: Vec<u16>,
82    /// HTTP/HTTPS/WebSocket ports owned by this service
83    http_ports: Vec<u16>,
84}
85
86/// Manages proxy routing for agent-controlled services
87///
88/// The `ProxyManager` coordinates between the agent's service lifecycle and
89/// the proxy crate's routing/load balancing infrastructure. It supports:
90///
91/// - **HTTP/HTTPS/WebSocket (L7)**: Multiple port listeners sharing the same
92///   `ServiceRegistry` for request matching and load balancing.
93/// - **TCP/UDP (L4)**: Standalone stream proxy listeners that forward raw
94///   connections/datagrams to backends via the `StreamRegistry`.
95pub struct ProxyManager {
96    /// Configuration
97    config: ProxyManagerConfig,
98    /// Shared service registry for HTTP request matching and backend management
99    registry: Arc<ServiceRegistry>,
100    /// Load balancer for health-aware backend selection
101    load_balancer: Arc<LoadBalancer>,
102    /// Per-port HTTP proxy server handles
103    servers: RwLock<HashMap<u16, Arc<ProxyServer>>>,
104    /// Tracked services and their endpoints (includes port ownership for cleanup)
105    services: RwLock<HashMap<String, ServiceTracking>>,
106    /// Stream registry for L4 TCP/UDP proxy routing
107    stream_registry: Option<Arc<StreamRegistry>>,
108    /// Certificate manager for TLS
109    cert_manager: Option<Arc<CertManager>>,
110    /// Ports with active TCP stream listeners (to avoid double-binding)
111    tcp_listeners: RwLock<HashSet<u16>>,
112    /// Ports with active UDP stream listeners (to avoid double-binding)
113    udp_listeners: RwLock<HashSet<u16>>,
114    /// Number of active proxy connections (for graceful drain on shutdown)
115    active_connections: Arc<AtomicU64>,
116    /// Optional network policy checker for access control enforcement
117    network_policy_checker: Option<NetworkPolicyChecker>,
118}
119
120impl ProxyManager {
121    /// Create a new `ProxyManager` with the given configuration, service registry,
122    /// and optional certificate manager.
123    pub fn new(
124        config: ProxyManagerConfig,
125        registry: Arc<ServiceRegistry>,
126        cert_manager: Option<Arc<CertManager>>,
127    ) -> Self {
128        let load_balancer = Arc::new(LoadBalancer::new());
129
130        Self {
131            config,
132            registry,
133            load_balancer,
134            servers: RwLock::new(HashMap::new()),
135            services: RwLock::new(HashMap::new()),
136            stream_registry: None,
137            cert_manager,
138            tcp_listeners: RwLock::new(HashSet::new()),
139            udp_listeners: RwLock::new(HashSet::new()),
140            active_connections: Arc::new(AtomicU64::new(0)),
141            network_policy_checker: None,
142        }
143    }
144
145    /// Get a reference to the service registry
146    pub fn registry(&self) -> Arc<ServiceRegistry> {
147        self.registry.clone()
148    }
149
150    /// Get a reference to the load balancer
151    pub fn load_balancer(&self) -> Arc<LoadBalancer> {
152        self.load_balancer.clone()
153    }
154
155    /// Get the number of currently active proxy connections.
156    pub fn active_connections(&self) -> u64 {
157        self.active_connections.load(Ordering::Relaxed)
158    }
159
160    /// Get a reference to the certificate manager (if configured)
161    pub fn cert_manager(&self) -> Option<&Arc<CertManager>> {
162        self.cert_manager.as_ref()
163    }
164
165    /// Set the stream registry for L4 proxy integration (TCP/UDP)
166    pub fn set_stream_registry(&mut self, registry: Arc<StreamRegistry>) {
167        self.stream_registry = Some(registry);
168    }
169
170    /// Builder pattern: add stream registry for L4 proxy integration
171    #[must_use]
172    pub fn with_stream_registry(mut self, registry: Arc<StreamRegistry>) -> Self {
173        self.stream_registry = Some(registry);
174        self
175    }
176
177    /// Get the stream registry (if configured)
178    pub fn stream_registry(&self) -> Option<&Arc<StreamRegistry>> {
179        self.stream_registry.as_ref()
180    }
181
182    /// Set the network policy checker for access control enforcement
183    pub fn set_network_policy_checker(&mut self, checker: NetworkPolicyChecker) {
184        self.network_policy_checker = Some(checker);
185    }
186
187    /// Builder pattern: add network policy checker for access control enforcement
188    #[must_use]
189    pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
190        self.network_policy_checker = Some(checker);
191        self
192    }
193
194    /// Start listening on a specific port bound to the given address.
195    ///
196    /// If already listening on this port, skip.
197    /// All port listeners share the same `ServiceRegistry` for request matching.
198    ///
199    /// # Errors
200    /// Returns an error if the proxy server cannot be started.
201    pub async fn listen_on(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
202        let mut servers = self.servers.write().await;
203
204        if servers.contains_key(&port) {
205            debug!(port = port, "Already listening on port");
206            return Ok(());
207        }
208
209        let addr = SocketAddr::new(bind_ip, port);
210        let mut proxy_config = ProxyConfig::default();
211        proxy_config.server.http_addr = addr;
212        proxy_config.server.http2_enabled = self.config.http2_enabled;
213
214        let mut server = ProxyServer::with_registry(
215            proxy_config,
216            self.registry.clone(),
217            self.load_balancer.clone(),
218        );
219        if let Some(ref checker) = self.network_policy_checker {
220            server = server.with_network_policy_checker(checker.clone());
221        }
222        let server = Arc::new(server);
223
224        info!(port = port, bind = %addr, "Proxy listening on port");
225
226        let server_clone = server.clone();
227        tokio::spawn(async move {
228            if let Err(e) = server_clone.run().await {
229                tracing::error!(port = port, error = %e, "Proxy server error on port");
230            }
231        });
232
233        servers.insert(port, server);
234        Ok(())
235    }
236
237    /// Start an HTTPS listener on the given port using `SniCertResolver` for dynamic cert selection.
238    ///
239    /// If already listening on this port, skip.
240    /// Requires a `CertManager` to be configured; logs a warning and returns `Ok(())` if not.
241    ///
242    /// # Errors
243    /// Returns an error if the HTTPS proxy server cannot be started.
244    pub async fn listen_on_tls(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
245        let mut servers = self.servers.write().await;
246
247        if servers.contains_key(&port) {
248            debug!(port = port, "Already listening on port (TLS)");
249            return Ok(());
250        }
251
252        let Some(cert_manager) = &self.cert_manager else {
253            warn!(
254                port = port,
255                "Cannot start TLS listener: no CertManager configured"
256            );
257            return Ok(());
258        };
259
260        // Create SniCertResolver and load existing certs
261        let sni_resolver = Arc::new(SniCertResolver::new());
262
263        // Load existing certificates (best-effort; log warnings on failure)
264        let _ = load_existing_certs_into_resolver(cert_manager, &sni_resolver).await;
265
266        let addr = SocketAddr::new(bind_ip, port);
267        let mut proxy_config = ProxyConfig::default();
268        proxy_config.server.https_addr = addr;
269
270        let mut server = ProxyServer::with_tls_resolver(
271            proxy_config,
272            self.registry.clone(),
273            self.load_balancer.clone(),
274            sni_resolver,
275        )
276        .with_cert_manager(Arc::clone(cert_manager));
277        if let Some(ref checker) = self.network_policy_checker {
278            server = server.with_network_policy_checker(checker.clone());
279        }
280        let server = Arc::new(server);
281
282        info!(port = port, bind = %addr, "HTTPS proxy listening on port");
283
284        let server_clone = server.clone();
285        tokio::spawn(async move {
286            if let Err(e) = server_clone.run_https().await {
287                tracing::error!(port = port, error = %e, "HTTPS proxy server error");
288            }
289        });
290
291        servers.insert(port, server);
292        Ok(())
293    }
294
295    /// Stop all proxy servers on all ports.
296    ///
297    /// After signalling each server to shut down, waits up to 30 seconds for
298    /// active connections to drain before returning.
299    pub async fn stop(&self) {
300        let mut servers = self.servers.write().await;
301        for (port, server) in servers.drain() {
302            info!(port = port, "Stopping proxy on port");
303            server.shutdown();
304        }
305
306        // Wait up to 30s for active connections to drain
307        let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
308        while self.active_connections.load(Ordering::Relaxed) > 0 {
309            if tokio::time::Instant::now() >= deadline {
310                let remaining = self.active_connections.load(Ordering::Relaxed);
311                warn!(
312                    remaining = remaining,
313                    "Drain timeout reached, forcing shutdown"
314                );
315                break;
316            }
317            tokio::time::sleep(Duration::from_millis(100)).await;
318        }
319
320        info!("All proxy servers stopped");
321    }
322
323    /// Remove and shut down the listener on a specific port.
324    pub async fn unbind(&self, port: u16) {
325        let mut servers = self.servers.write().await;
326        if let Some(server) = servers.remove(&port) {
327            info!(port = port, "Unbinding proxy from port");
328            server.shutdown();
329        }
330    }
331
332    /// Scan a service's endpoints and ensure the proxy is listening on all
333    /// required ports.
334    ///
335    /// - **HTTP/HTTPS/WebSocket** endpoints start an HTTP proxy listener.
336    /// - **TCP** endpoints bind a `TcpListener` and spawn a `TcpStreamService`.
337    /// - **UDP** endpoints bind a `UdpSocket` and spawn a `UdpStreamService`.
338    ///
339    /// Bind address is determined by the `expose` type:
340    /// - **Public** endpoints bind to `0.0.0.0` (all interfaces).
341    /// - **Internal** endpoints bind to the overlay IP so they are only
342    ///   reachable from within the overlay network.  If no overlay is
343    ///   available, internal endpoints bind to `127.0.0.1` (localhost only).
344    ///
345    /// # Errors
346    /// Returns an error if an HTTP/HTTPS listener cannot be started.
347    pub async fn ensure_ports_for_service(
348        &self,
349        spec: &ServiceSpec,
350        overlay_ip: Option<IpAddr>,
351    ) -> Result<()> {
352        for endpoint in &spec.endpoints {
353            let bind_ip = match endpoint.expose {
354                ExposeType::Public => IpAddr::V4(Ipv4Addr::UNSPECIFIED), // 0.0.0.0
355                ExposeType::Internal => {
356                    // Prefer overlay IP; fall back to loopback if overlay is unavailable.
357                    let ip = overlay_ip.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
358                    if overlay_ip.is_none() {
359                        warn!(
360                            endpoint = %endpoint.name,
361                            port = endpoint.port,
362                            "No overlay IP available for internal endpoint; binding to 127.0.0.1"
363                        );
364                    }
365                    ip
366                }
367            };
368
369            match endpoint.protocol {
370                Protocol::Https => {
371                    // L7 TLS: start HTTPS proxy listener with SNI cert resolution
372                    self.listen_on_tls(endpoint.port, bind_ip).await?;
373                }
374                Protocol::Http | Protocol::Websocket => {
375                    // L7: start HTTP proxy listener
376                    self.listen_on(endpoint.port, bind_ip).await?;
377                }
378                Protocol::Tcp => {
379                    // L4 TCP: bind listener and spawn TcpStreamService
380                    self.ensure_tcp_listener(endpoint.port, bind_ip).await;
381                }
382                Protocol::Udp => {
383                    // L4 UDP: bind socket and spawn UdpStreamService
384                    self.ensure_udp_listener(endpoint.port, bind_ip).await;
385                }
386            }
387        }
388        Ok(())
389    }
390
391    /// Ensure a TCP stream listener is running on the given port.
392    ///
393    /// If a listener is already active on this port, this is a no-op.
394    /// Requires `stream_registry` to be configured; logs a warning if not.
395    async fn ensure_tcp_listener(&self, port: u16, bind_ip: IpAddr) {
396        // Check if already listening
397        {
398            let listeners = self.tcp_listeners.read().await;
399            if listeners.contains(&port) {
400                debug!(port = port, "TCP stream listener already active");
401                return;
402            }
403        }
404
405        let registry = if let Some(r) = &self.stream_registry {
406            Arc::clone(r)
407        } else {
408            warn!(
409                port = port,
410                "Cannot start TCP listener: StreamRegistry not configured"
411            );
412            return;
413        };
414
415        let addr = SocketAddr::new(bind_ip, port);
416        let listener = match tokio::net::TcpListener::bind(addr).await {
417            Ok(l) => l,
418            Err(e) => {
419                warn!(
420                    port = port,
421                    bind = %addr,
422                    error = %e,
423                    "Failed to bind TCP stream listener, continuing"
424                );
425                return;
426            }
427        };
428
429        // Mark as active before spawning
430        {
431            let mut listeners = self.tcp_listeners.write().await;
432            listeners.insert(port);
433        }
434
435        let tcp_service = Arc::new(TcpStreamService::new(registry, port));
436        tokio::spawn(async move {
437            tcp_service.serve(listener).await;
438        });
439
440        info!(port = port, bind = %addr, "TCP stream proxy listening");
441    }
442
443    /// Ensure a UDP stream listener is running on the given port.
444    ///
445    /// If a listener is already active on this port, this is a no-op.
446    /// Requires `stream_registry` to be configured; logs a warning if not.
447    async fn ensure_udp_listener(&self, port: u16, bind_ip: IpAddr) {
448        // Check if already listening
449        {
450            let listeners = self.udp_listeners.read().await;
451            if listeners.contains(&port) {
452                debug!(port = port, "UDP stream listener already active");
453                return;
454            }
455        }
456
457        let registry = if let Some(r) = &self.stream_registry {
458            Arc::clone(r)
459        } else {
460            warn!(
461                port = port,
462                "Cannot start UDP listener: StreamRegistry not configured"
463            );
464            return;
465        };
466
467        let addr = SocketAddr::new(bind_ip, port);
468        let socket = match tokio::net::UdpSocket::bind(addr).await {
469            Ok(s) => s,
470            Err(e) => {
471                warn!(
472                    port = port,
473                    bind = %addr,
474                    error = %e,
475                    "Failed to bind UDP stream listener, continuing"
476                );
477                return;
478            }
479        };
480
481        // Mark as active before spawning
482        {
483            let mut listeners = self.udp_listeners.write().await;
484            listeners.insert(port);
485        }
486
487        let udp_service = Arc::new(UdpStreamService::new(registry, port, None));
488        tokio::spawn(async move {
489            if let Err(e) = udp_service.serve(socket).await {
490                tracing::error!(
491                    port = port,
492                    error = %e,
493                    "UDP stream proxy service failed"
494                );
495            }
496        });
497
498        info!(port = port, bind = %addr, "UDP stream proxy listening");
499    }
500
501    /// Add routes for a service based on its specification
502    ///
503    /// This creates proxy routes for each endpoint defined in the `ServiceSpec`.
504    /// HTTP/HTTPS/WebSocket endpoints get L7 routes via the `ServiceRegistry`.
505    /// TCP/UDP endpoints are tracked but their L4 registration is handled
506    /// by the `ServiceManager::register_service_routes()` method.
507    pub async fn add_service(&self, name: &str, spec: &ServiceSpec) {
508        let mut services = self.services.write().await;
509
510        // Track which endpoints and ports we're adding
511        let mut endpoint_names = Vec::new();
512        let mut tcp_ports = Vec::new();
513        let mut udp_ports = Vec::new();
514        let mut http_ports = Vec::new();
515
516        for endpoint in &spec.endpoints {
517            match endpoint.protocol {
518                Protocol::Http | Protocol::Https | Protocol::Websocket => {
519                    // L7: register route in the ServiceRegistry
520                    let entry = RouteEntry::from_endpoint(name, endpoint);
521                    self.registry.register(entry).await;
522                    http_ports.push(endpoint.port);
523
524                    info!(
525                        service = name,
526                        endpoint = %endpoint.name,
527                        protocol = ?endpoint.protocol,
528                        path = ?endpoint.path,
529                        expose = ?endpoint.expose,
530                        "Added HTTP proxy route for service"
531                    );
532                }
533                Protocol::Tcp => {
534                    tcp_ports.push(endpoint.port);
535                    info!(
536                        service = name,
537                        endpoint = %endpoint.name,
538                        protocol = ?endpoint.protocol,
539                        port = endpoint.port,
540                        expose = ?endpoint.expose,
541                        "Tracking TCP stream endpoint for service"
542                    );
543                }
544                Protocol::Udp => {
545                    udp_ports.push(endpoint.port);
546                    info!(
547                        service = name,
548                        endpoint = %endpoint.name,
549                        protocol = ?endpoint.protocol,
550                        port = endpoint.port,
551                        expose = ?endpoint.expose,
552                        "Tracking UDP stream endpoint for service"
553                    );
554                }
555            }
556
557            endpoint_names.push(endpoint.name.clone());
558        }
559
560        // Register the service in the load balancer (starts with no backends)
561        self.load_balancer
562            .register(name, vec![], LbStrategy::RoundRobin);
563
564        services.insert(
565            name.to_string(),
566            ServiceTracking {
567                endpoint_names,
568                tcp_ports,
569                udp_ports,
570                http_ports,
571            },
572        );
573    }
574
575    /// Remove all routes, L4 listeners, and HTTP server handles for a service.
576    ///
577    /// This performs a full cleanup of all proxy resources associated with the
578    /// service:
579    /// - Removes L7 (HTTP/HTTPS/WebSocket) routes from the `ServiceRegistry`
580    /// - Unregisters TCP/UDP stream services from the `StreamRegistry`
581    /// - Removes port tracking for TCP/UDP listeners
582    /// - Shuts down HTTP proxy server handles that were exclusively owned by
583    ///   this service (only if no other service uses the same port)
584    pub async fn remove_service(&self, name: &str) {
585        let mut services = self.services.write().await;
586
587        if let Some(tracking) = services.remove(name) {
588            // 1. Remove L7 routes from the ServiceRegistry
589            self.registry.unregister_service(name).await;
590
591            // 1b. Remove from the load balancer
592            self.load_balancer.unregister(name);
593
594            // 2. Unregister TCP stream services and clear port tracking
595            if !tracking.tcp_ports.is_empty() {
596                let mut tcp_set = self.tcp_listeners.write().await;
597                for port in &tracking.tcp_ports {
598                    if let Some(registry) = &self.stream_registry {
599                        let _ = registry.unregister_tcp(*port);
600                    }
601                    tcp_set.remove(port);
602                    debug!(service = name, port = port, "Removed TCP listener tracking");
603                }
604            }
605
606            // 3. Unregister UDP stream services and clear port tracking
607            if !tracking.udp_ports.is_empty() {
608                let mut udp_set = self.udp_listeners.write().await;
609                for port in &tracking.udp_ports {
610                    if let Some(registry) = &self.stream_registry {
611                        let _ = registry.unregister_udp(*port);
612                    }
613                    udp_set.remove(port);
614                    debug!(service = name, port = port, "Removed UDP listener tracking");
615                }
616            }
617
618            // 4. Shut down HTTP proxy servers on ports exclusively owned by
619            //    this service (skip ports still used by other services)
620            if !tracking.http_ports.is_empty() {
621                let ports_still_in_use: HashSet<u16> = services
622                    .values()
623                    .flat_map(|t| t.http_ports.iter().copied())
624                    .collect();
625
626                let mut servers = self.servers.write().await;
627                for port in &tracking.http_ports {
628                    if !ports_still_in_use.contains(port) {
629                        if let Some(server) = servers.remove(port) {
630                            server.shutdown();
631                            info!(
632                                service = name,
633                                port = port,
634                                "Shut down HTTP proxy server (no remaining services on port)"
635                            );
636                        }
637                    }
638                }
639            }
640
641            info!(service = name, "Removed all proxy resources for service");
642        }
643    }
644
645    /// Add a single backend to a service
646    pub async fn add_backend(&self, service: &str, addr: SocketAddr) {
647        self.registry.add_backend(service, addr).await;
648        self.load_balancer.add_backend(service, addr);
649        info!(service = service, backend = %addr, "Registered backend with proxy");
650    }
651
652    /// Remove a backend from a service
653    pub async fn remove_backend(&self, service: &str, addr: SocketAddr) {
654        self.registry.remove_backend(service, addr).await;
655        self.load_balancer.remove_backend(service, &addr);
656        debug!(service = service, backend = %addr, "Removed backend from service");
657    }
658
659    /// Update the health status of a backend in the load balancer.
660    ///
661    /// Delegates to [`LoadBalancer::mark_health`] so that unhealthy backends
662    /// are skipped during selection.
663    #[allow(clippy::unused_async)]
664    pub async fn update_backend_health(&self, service: &str, addr: SocketAddr, healthy: bool) {
665        self.load_balancer.mark_health(service, &addr, healthy);
666        debug!(
667            service = service,
668            backend = %addr,
669            healthy = healthy,
670            "Updated backend health in load balancer"
671        );
672    }
673
674    /// Update the backends for a service
675    ///
676    /// This replaces all backends for the given service with the provided list.
677    /// Each backend should be the address where the service replica is listening.
678    pub async fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
679        self.registry.update_backends(service, addrs.clone()).await;
680        self.load_balancer.update_backends(service, addrs);
681        debug!(service = service, "Updated backends for service");
682    }
683
684    /// Get the number of registered routes
685    pub async fn route_count(&self) -> usize {
686        self.registry.route_count().await
687    }
688
689    /// Get the list of registered service names
690    pub async fn list_services(&self) -> Vec<String> {
691        self.services.read().await.keys().cloned().collect()
692    }
693
694    /// Check if a service has any registered endpoints
695    pub async fn has_service(&self, name: &str) -> bool {
696        self.services.read().await.contains_key(name)
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703
704    fn mock_service_spec_with_endpoints() -> ServiceSpec {
705        use zlayer_spec::*;
706        serde_yaml::from_str::<DeploymentSpec>(
707            r"
708version: v1
709deployment: test
710services:
711  test:
712    rtype: service
713    image:
714      name: test:latest
715    endpoints:
716      - name: http
717        protocol: http
718        port: 8080
719        path: /api
720        expose: public
721      - name: websocket
722        protocol: websocket
723        port: 8081
724        path: /ws
725        expose: internal
726",
727        )
728        .unwrap()
729        .services
730        .remove("test")
731        .unwrap()
732    }
733
734    fn mock_service_spec_tcp_only() -> ServiceSpec {
735        use zlayer_spec::*;
736        serde_yaml::from_str::<DeploymentSpec>(
737            r"
738version: v1
739deployment: test
740services:
741  test:
742    rtype: service
743    image:
744      name: test:latest
745    endpoints:
746      - name: grpc
747        protocol: tcp
748        port: 9000
749",
750        )
751        .unwrap()
752        .services
753        .remove("test")
754        .unwrap()
755    }
756
757    #[tokio::test]
758    async fn test_proxy_manager_new() {
759        let config = ProxyManagerConfig::default();
760        let registry = Arc::new(ServiceRegistry::new());
761        let manager = ProxyManager::new(config, registry, None);
762
763        assert_eq!(manager.route_count().await, 0);
764        assert!(manager.list_services().await.is_empty());
765    }
766
767    #[tokio::test]
768    async fn test_add_service_with_http_endpoints() {
769        let config = ProxyManagerConfig::default();
770        let registry = Arc::new(ServiceRegistry::new());
771        let manager = ProxyManager::new(config, registry, None);
772
773        let spec = mock_service_spec_with_endpoints();
774        manager.add_service("api", &spec).await;
775
776        // Should have 2 routes (http and websocket)
777        assert_eq!(manager.route_count().await, 2);
778        assert!(manager.has_service("api").await);
779    }
780
781    #[tokio::test]
782    async fn test_tcp_endpoints_tracked_not_routed() {
783        let config = ProxyManagerConfig::default();
784        let registry = Arc::new(ServiceRegistry::new());
785        let manager = ProxyManager::new(config, registry, None);
786
787        let spec = mock_service_spec_tcp_only();
788        manager.add_service("grpc-service", &spec).await;
789
790        // TCP endpoints don't add HTTP routes
791        assert_eq!(manager.route_count().await, 0);
792        // But the service is still tracked with its endpoint name
793        assert!(manager.has_service("grpc-service").await);
794    }
795
796    #[tokio::test]
797    async fn test_remove_service() {
798        let config = ProxyManagerConfig::default();
799        let registry = Arc::new(ServiceRegistry::new());
800        let manager = ProxyManager::new(config, registry, None);
801
802        let spec = mock_service_spec_with_endpoints();
803        manager.add_service("api", &spec).await;
804        assert_eq!(manager.route_count().await, 2);
805
806        manager.remove_service("api").await;
807        assert_eq!(manager.route_count().await, 0);
808        assert!(!manager.has_service("api").await);
809    }
810
811    #[tokio::test]
812    async fn test_backend_management() {
813        let config = ProxyManagerConfig::default();
814        let registry = Arc::new(ServiceRegistry::new());
815        let manager = ProxyManager::new(config, registry.clone(), None);
816
817        let spec = mock_service_spec_with_endpoints();
818        manager.add_service("api", &spec).await;
819
820        // Add backends
821        let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
822        let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
823
824        manager.add_backend("api", addr1).await;
825        manager.add_backend("api", addr2).await;
826
827        // Verify backends via the registry's resolve
828        let resolved = registry.resolve(None, "/api").await.unwrap();
829        assert_eq!(resolved.backends.len(), 2);
830
831        // Remove a backend
832        manager.remove_backend("api", addr1).await;
833        let resolved = registry.resolve(None, "/api").await.unwrap();
834        assert_eq!(resolved.backends.len(), 1);
835    }
836
837    #[tokio::test]
838    async fn test_update_backends_replaces_all() {
839        let config = ProxyManagerConfig::default();
840        let registry = Arc::new(ServiceRegistry::new());
841        let manager = ProxyManager::new(config, registry.clone(), None);
842
843        let spec = mock_service_spec_with_endpoints();
844        manager.add_service("api", &spec).await;
845
846        // Add initial backend
847        let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
848        manager.add_backend("api", addr1).await;
849
850        // Update with new backends (replaces)
851        let new_backends: Vec<SocketAddr> = vec![
852            "127.0.0.1:9000".parse().unwrap(),
853            "127.0.0.1:9001".parse().unwrap(),
854            "127.0.0.1:9002".parse().unwrap(),
855        ];
856        manager.update_backends("api", new_backends).await;
857
858        let resolved = registry.resolve(None, "/api").await.unwrap();
859        assert_eq!(resolved.backends.len(), 3);
860    }
861
862    #[tokio::test]
863    async fn test_config_builder() {
864        let config = ProxyManagerConfig::new("0.0.0.0:8080".parse().unwrap())
865            .with_https("0.0.0.0:8443".parse().unwrap())
866            .with_http2(false);
867
868        assert_eq!(
869            config.http_addr,
870            "0.0.0.0:8080".parse::<SocketAddr>().unwrap()
871        );
872        assert_eq!(
873            config.https_addr,
874            Some("0.0.0.0:8443".parse::<SocketAddr>().unwrap())
875        );
876        assert!(!config.http2_enabled);
877    }
878
879    /// Test that `ensure_ports_for_service` correctly differentiates
880    /// Public (0.0.0.0) vs Internal (overlay or 127.0.0.1) bind addresses.
881    /// We can't actually bind in unit tests, but we verify the function
882    /// processes both endpoint types without error.
883    #[tokio::test]
884    async fn test_ensure_ports_differentiates_public_and_internal() {
885        let config = ProxyManagerConfig::default();
886        let registry = Arc::new(ServiceRegistry::new());
887        let manager = ProxyManager::new(config, registry, None);
888
889        let spec = mock_service_spec_with_endpoints();
890        // Passing None for overlay_ip: internal endpoints should fall back to 127.0.0.1
891        let result = manager.ensure_ports_for_service(&spec, None).await;
892        // listen_on may fail because we can't actually bind in tests, but
893        // the function itself should run without panicking.
894        let _ = result;
895    }
896
897    #[tokio::test]
898    async fn test_ensure_ports_with_overlay_ip() {
899        let config = ProxyManagerConfig::default();
900        let registry = Arc::new(ServiceRegistry::new());
901        let manager = ProxyManager::new(config, registry, None);
902
903        let spec = mock_service_spec_with_endpoints();
904        // Pass an overlay IP -- internal endpoints should bind there
905        let overlay_ip: IpAddr = "10.200.0.5".parse().unwrap();
906        let result = manager
907            .ensure_ports_for_service(&spec, Some(overlay_ip))
908            .await;
909        let _ = result;
910    }
911
912    fn mock_mixed_service_spec() -> ServiceSpec {
913        use zlayer_spec::*;
914        serde_yaml::from_str::<DeploymentSpec>(
915            r"
916version: v1
917deployment: test
918services:
919  mixed:
920    rtype: service
921    image:
922      name: test:latest
923    endpoints:
924      - name: http
925        protocol: http
926        port: 8080
927        path: /api
928        expose: public
929      - name: grpc
930        protocol: tcp
931        port: 9000
932        expose: public
933      - name: game
934        protocol: udp
935        port: 27015
936        expose: public
937",
938        )
939        .unwrap()
940        .services
941        .remove("mixed")
942        .unwrap()
943    }
944
945    #[tokio::test]
946    async fn test_add_mixed_service_tracks_all_endpoints() {
947        let config = ProxyManagerConfig::default();
948        let registry = Arc::new(ServiceRegistry::new());
949        let manager = ProxyManager::new(config, registry, None);
950
951        let spec = mock_mixed_service_spec();
952        manager.add_service("mixed", &spec).await;
953
954        // Only 1 HTTP route (tcp and udp don't add HTTP routes)
955        assert_eq!(manager.route_count().await, 1);
956        // Service is tracked
957        assert!(manager.has_service("mixed").await);
958    }
959
960    #[tokio::test]
961    async fn test_ensure_ports_tcp_with_stream_registry() {
962        use zlayer_proxy::StreamService;
963
964        let stream_registry = Arc::new(StreamRegistry::new());
965        let config = ProxyManagerConfig::default();
966        let registry = Arc::new(ServiceRegistry::new());
967        let mut manager = ProxyManager::new(config, registry, None);
968        manager.set_stream_registry(stream_registry.clone());
969
970        let spec = mock_service_spec_tcp_only();
971
972        // Register the TCP service in the stream registry first (as ServiceManager does)
973        stream_registry.register_tcp(9000, StreamService::new("grpc-service".to_string(), vec![]));
974
975        // Ensure ports -- should bind TCP listener
976        let result = manager.ensure_ports_for_service(&spec, None).await;
977        assert!(result.is_ok());
978
979        // Verify the TCP listener port is tracked
980        let tcp_ports = manager.tcp_listeners.read().await;
981        assert!(tcp_ports.contains(&9000));
982    }
983
984    #[tokio::test]
985    async fn test_ensure_ports_tcp_without_stream_registry() {
986        let config = ProxyManagerConfig::default();
987        let registry = Arc::new(ServiceRegistry::new());
988        let manager = ProxyManager::new(config, registry, None);
989
990        let spec = mock_service_spec_tcp_only();
991
992        // Without stream registry, ensure_ports should not fail, just warn
993        let result = manager.ensure_ports_for_service(&spec, None).await;
994        assert!(result.is_ok());
995
996        // No TCP listeners should be tracked
997        let tcp_ports = manager.tcp_listeners.read().await;
998        assert!(tcp_ports.is_empty());
999    }
1000
1001    #[tokio::test]
1002    async fn test_stream_registry_setter() {
1003        let stream_registry = Arc::new(StreamRegistry::new());
1004        let config = ProxyManagerConfig::default();
1005        let registry = Arc::new(ServiceRegistry::new());
1006        let mut manager = ProxyManager::new(config, registry, None);
1007
1008        assert!(manager.stream_registry().is_none());
1009        manager.set_stream_registry(stream_registry.clone());
1010        assert!(manager.stream_registry().is_some());
1011    }
1012
1013    #[tokio::test]
1014    async fn test_registry_accessor() {
1015        let config = ProxyManagerConfig::default();
1016        let registry = Arc::new(ServiceRegistry::new());
1017        let manager = ProxyManager::new(config, registry.clone(), None);
1018
1019        // registry() should return the same Arc
1020        assert_eq!(Arc::as_ptr(&manager.registry()), Arc::as_ptr(&registry));
1021    }
1022}