use crate::error::Result;
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use zlayer_proxy::{
load_existing_certs_into_resolver, CertManager, LbStrategy, LoadBalancer, NetworkPolicyChecker,
ProxyConfig, ProxyServer, RouteEntry, ServiceRegistry, SniCertResolver, StreamRegistry,
TcpStreamService, UdpStreamService,
};
use zlayer_spec::{ExposeType, Protocol, ServiceSpec};
#[derive(Debug, Clone)]
pub struct ProxyManagerConfig {
pub http_addr: SocketAddr,
pub https_addr: Option<SocketAddr>,
pub http2_enabled: bool,
}
impl Default for ProxyManagerConfig {
fn default() -> Self {
Self {
http_addr: "0.0.0.0:80".parse().unwrap(),
https_addr: None,
http2_enabled: true,
}
}
}
impl ProxyManagerConfig {
#[must_use]
pub fn new(http_addr: SocketAddr) -> Self {
Self {
http_addr,
https_addr: None,
http2_enabled: true,
}
}
#[must_use]
pub fn with_https(mut self, addr: SocketAddr) -> Self {
self.https_addr = Some(addr);
self
}
#[must_use]
pub fn with_http2(mut self, enabled: bool) -> Self {
self.http2_enabled = enabled;
self
}
}
#[derive(Debug, Clone)]
struct ServiceTracking {
#[allow(dead_code)]
endpoint_names: Vec<String>,
tcp_ports: Vec<u16>,
udp_ports: Vec<u16>,
http_ports: Vec<u16>,
}
pub struct ProxyManager {
config: ProxyManagerConfig,
registry: Arc<ServiceRegistry>,
load_balancer: Arc<LoadBalancer>,
servers: RwLock<HashMap<u16, Arc<ProxyServer>>>,
services: RwLock<HashMap<String, ServiceTracking>>,
stream_registry: Option<Arc<StreamRegistry>>,
cert_manager: Option<Arc<CertManager>>,
tcp_listeners: RwLock<HashSet<u16>>,
udp_listeners: RwLock<HashSet<u16>>,
active_connections: Arc<AtomicU64>,
network_policy_checker: Option<NetworkPolicyChecker>,
}
impl ProxyManager {
pub fn new(
config: ProxyManagerConfig,
registry: Arc<ServiceRegistry>,
cert_manager: Option<Arc<CertManager>>,
) -> Self {
let load_balancer = Arc::new(LoadBalancer::new());
Self {
config,
registry,
load_balancer,
servers: RwLock::new(HashMap::new()),
services: RwLock::new(HashMap::new()),
stream_registry: None,
cert_manager,
tcp_listeners: RwLock::new(HashSet::new()),
udp_listeners: RwLock::new(HashSet::new()),
active_connections: Arc::new(AtomicU64::new(0)),
network_policy_checker: None,
}
}
pub fn registry(&self) -> Arc<ServiceRegistry> {
self.registry.clone()
}
pub fn load_balancer(&self) -> Arc<LoadBalancer> {
self.load_balancer.clone()
}
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
}
pub fn cert_manager(&self) -> Option<&Arc<CertManager>> {
self.cert_manager.as_ref()
}
pub fn set_stream_registry(&mut self, registry: Arc<StreamRegistry>) {
self.stream_registry = Some(registry);
}
#[must_use]
pub fn with_stream_registry(mut self, registry: Arc<StreamRegistry>) -> Self {
self.stream_registry = Some(registry);
self
}
pub fn stream_registry(&self) -> Option<&Arc<StreamRegistry>> {
self.stream_registry.as_ref()
}
pub fn set_network_policy_checker(&mut self, checker: NetworkPolicyChecker) {
self.network_policy_checker = Some(checker);
}
#[must_use]
pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
self.network_policy_checker = Some(checker);
self
}
pub async fn listen_on(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
let mut servers = self.servers.write().await;
if servers.contains_key(&port) {
debug!(port = port, "Already listening on port");
return Ok(());
}
let addr = SocketAddr::new(bind_ip, port);
let mut proxy_config = ProxyConfig::default();
proxy_config.server.http_addr = addr;
proxy_config.server.http2_enabled = self.config.http2_enabled;
let mut server = ProxyServer::with_registry(
proxy_config,
self.registry.clone(),
self.load_balancer.clone(),
);
if let Some(ref checker) = self.network_policy_checker {
server = server.with_network_policy_checker(checker.clone());
}
let server = Arc::new(server);
info!(port = port, bind = %addr, "Proxy listening on port");
let server_clone = server.clone();
tokio::spawn(async move {
if let Err(e) = server_clone.run().await {
tracing::error!(port = port, error = %e, "Proxy server error on port");
}
});
servers.insert(port, server);
Ok(())
}
pub async fn listen_on_tls(&self, port: u16, bind_ip: IpAddr) -> Result<()> {
let mut servers = self.servers.write().await;
if servers.contains_key(&port) {
debug!(port = port, "Already listening on port (TLS)");
return Ok(());
}
let Some(cert_manager) = &self.cert_manager else {
warn!(
port = port,
"Cannot start TLS listener: no CertManager configured"
);
return Ok(());
};
let sni_resolver = Arc::new(SniCertResolver::new());
let _ = load_existing_certs_into_resolver(cert_manager, &sni_resolver).await;
let addr = SocketAddr::new(bind_ip, port);
let mut proxy_config = ProxyConfig::default();
proxy_config.server.https_addr = addr;
let mut server = ProxyServer::with_tls_resolver(
proxy_config,
self.registry.clone(),
self.load_balancer.clone(),
sni_resolver,
)
.with_cert_manager(Arc::clone(cert_manager));
if let Some(ref checker) = self.network_policy_checker {
server = server.with_network_policy_checker(checker.clone());
}
let server = Arc::new(server);
info!(port = port, bind = %addr, "HTTPS proxy listening on port");
let server_clone = server.clone();
tokio::spawn(async move {
if let Err(e) = server_clone.run_https().await {
tracing::error!(port = port, error = %e, "HTTPS proxy server error");
}
});
servers.insert(port, server);
Ok(())
}
pub async fn stop(&self) {
let mut servers = self.servers.write().await;
for (port, server) in servers.drain() {
info!(port = port, "Stopping proxy on port");
server.shutdown();
}
let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
while self.active_connections.load(Ordering::Relaxed) > 0 {
if tokio::time::Instant::now() >= deadline {
let remaining = self.active_connections.load(Ordering::Relaxed);
warn!(
remaining = remaining,
"Drain timeout reached, forcing shutdown"
);
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
info!("All proxy servers stopped");
}
pub async fn unbind(&self, port: u16) {
let mut servers = self.servers.write().await;
if let Some(server) = servers.remove(&port) {
info!(port = port, "Unbinding proxy from port");
server.shutdown();
}
}
pub async fn ensure_ports_for_service(
&self,
spec: &ServiceSpec,
overlay_ip: Option<IpAddr>,
) -> Result<()> {
for endpoint in &spec.endpoints {
let bind_ip = match endpoint.expose {
ExposeType::Public => IpAddr::V4(Ipv4Addr::UNSPECIFIED), ExposeType::Internal => {
let ip = overlay_ip.unwrap_or(IpAddr::V4(Ipv4Addr::LOCALHOST));
if overlay_ip.is_none() {
warn!(
endpoint = %endpoint.name,
port = endpoint.port,
"No overlay IP available for internal endpoint; binding to 127.0.0.1"
);
}
ip
}
};
match endpoint.protocol {
Protocol::Https => {
self.listen_on_tls(endpoint.port, bind_ip).await?;
}
Protocol::Http | Protocol::Websocket => {
self.listen_on(endpoint.port, bind_ip).await?;
}
Protocol::Tcp => {
self.ensure_tcp_listener(endpoint.port, bind_ip).await;
}
Protocol::Udp => {
self.ensure_udp_listener(endpoint.port, bind_ip).await;
}
}
}
Ok(())
}
async fn ensure_tcp_listener(&self, port: u16, bind_ip: IpAddr) {
{
let listeners = self.tcp_listeners.read().await;
if listeners.contains(&port) {
debug!(port = port, "TCP stream listener already active");
return;
}
}
let registry = if let Some(r) = &self.stream_registry {
Arc::clone(r)
} else {
warn!(
port = port,
"Cannot start TCP listener: StreamRegistry not configured"
);
return;
};
let addr = SocketAddr::new(bind_ip, port);
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => l,
Err(e) => {
warn!(
port = port,
bind = %addr,
error = %e,
"Failed to bind TCP stream listener, continuing"
);
return;
}
};
{
let mut listeners = self.tcp_listeners.write().await;
listeners.insert(port);
}
let tcp_service = Arc::new(TcpStreamService::new(registry, port));
tokio::spawn(async move {
tcp_service.serve(listener).await;
});
info!(port = port, bind = %addr, "TCP stream proxy listening");
}
async fn ensure_udp_listener(&self, port: u16, bind_ip: IpAddr) {
{
let listeners = self.udp_listeners.read().await;
if listeners.contains(&port) {
debug!(port = port, "UDP stream listener already active");
return;
}
}
let registry = if let Some(r) = &self.stream_registry {
Arc::clone(r)
} else {
warn!(
port = port,
"Cannot start UDP listener: StreamRegistry not configured"
);
return;
};
let addr = SocketAddr::new(bind_ip, port);
let socket = match tokio::net::UdpSocket::bind(addr).await {
Ok(s) => s,
Err(e) => {
warn!(
port = port,
bind = %addr,
error = %e,
"Failed to bind UDP stream listener, continuing"
);
return;
}
};
{
let mut listeners = self.udp_listeners.write().await;
listeners.insert(port);
}
let udp_service = Arc::new(UdpStreamService::new(registry, port, None));
tokio::spawn(async move {
if let Err(e) = udp_service.serve(socket).await {
tracing::error!(
port = port,
error = %e,
"UDP stream proxy service failed"
);
}
});
info!(port = port, bind = %addr, "UDP stream proxy listening");
}
pub async fn add_service(&self, name: &str, spec: &ServiceSpec) {
let mut services = self.services.write().await;
let mut endpoint_names = Vec::new();
let mut tcp_ports = Vec::new();
let mut udp_ports = Vec::new();
let mut http_ports = Vec::new();
for endpoint in &spec.endpoints {
match endpoint.protocol {
Protocol::Http | Protocol::Https | Protocol::Websocket => {
let entry = RouteEntry::from_endpoint(name, endpoint);
self.registry.register(entry).await;
http_ports.push(endpoint.port);
info!(
service = name,
endpoint = %endpoint.name,
protocol = ?endpoint.protocol,
path = ?endpoint.path,
expose = ?endpoint.expose,
"Added HTTP proxy route for service"
);
}
Protocol::Tcp => {
tcp_ports.push(endpoint.port);
info!(
service = name,
endpoint = %endpoint.name,
protocol = ?endpoint.protocol,
port = endpoint.port,
expose = ?endpoint.expose,
"Tracking TCP stream endpoint for service"
);
}
Protocol::Udp => {
udp_ports.push(endpoint.port);
info!(
service = name,
endpoint = %endpoint.name,
protocol = ?endpoint.protocol,
port = endpoint.port,
expose = ?endpoint.expose,
"Tracking UDP stream endpoint for service"
);
}
}
endpoint_names.push(endpoint.name.clone());
}
self.load_balancer
.register(name, vec![], LbStrategy::RoundRobin);
services.insert(
name.to_string(),
ServiceTracking {
endpoint_names,
tcp_ports,
udp_ports,
http_ports,
},
);
}
pub async fn remove_service(&self, name: &str) {
let mut services = self.services.write().await;
if let Some(tracking) = services.remove(name) {
self.registry.unregister_service(name).await;
self.load_balancer.unregister(name);
if !tracking.tcp_ports.is_empty() {
let mut tcp_set = self.tcp_listeners.write().await;
for port in &tracking.tcp_ports {
if let Some(registry) = &self.stream_registry {
let _ = registry.unregister_tcp(*port);
}
tcp_set.remove(port);
debug!(service = name, port = port, "Removed TCP listener tracking");
}
}
if !tracking.udp_ports.is_empty() {
let mut udp_set = self.udp_listeners.write().await;
for port in &tracking.udp_ports {
if let Some(registry) = &self.stream_registry {
let _ = registry.unregister_udp(*port);
}
udp_set.remove(port);
debug!(service = name, port = port, "Removed UDP listener tracking");
}
}
if !tracking.http_ports.is_empty() {
let ports_still_in_use: HashSet<u16> = services
.values()
.flat_map(|t| t.http_ports.iter().copied())
.collect();
let mut servers = self.servers.write().await;
for port in &tracking.http_ports {
if !ports_still_in_use.contains(port) {
if let Some(server) = servers.remove(port) {
server.shutdown();
info!(
service = name,
port = port,
"Shut down HTTP proxy server (no remaining services on port)"
);
}
}
}
}
info!(service = name, "Removed all proxy resources for service");
}
}
pub async fn add_backend(&self, service: &str, addr: SocketAddr) {
self.registry.add_backend(service, addr).await;
self.load_balancer.add_backend(service, addr);
info!(service = service, backend = %addr, "Registered backend with proxy");
}
pub async fn remove_backend(&self, service: &str, addr: SocketAddr) {
self.registry.remove_backend(service, addr).await;
self.load_balancer.remove_backend(service, &addr);
debug!(service = service, backend = %addr, "Removed backend from service");
}
#[allow(clippy::unused_async)]
pub async fn update_backend_health(&self, service: &str, addr: SocketAddr, healthy: bool) {
self.load_balancer.mark_health(service, &addr, healthy);
debug!(
service = service,
backend = %addr,
healthy = healthy,
"Updated backend health in load balancer"
);
}
pub async fn update_backends(&self, service: &str, addrs: Vec<SocketAddr>) {
self.registry.update_backends(service, addrs.clone()).await;
self.load_balancer.update_backends(service, addrs);
debug!(service = service, "Updated backends for service");
}
pub async fn route_count(&self) -> usize {
self.registry.route_count().await
}
pub async fn list_services(&self) -> Vec<String> {
self.services.read().await.keys().cloned().collect()
}
pub async fn has_service(&self, name: &str) -> bool {
self.services.read().await.contains_key(name)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mock_service_spec_with_endpoints() -> ServiceSpec {
use zlayer_spec::*;
serde_yaml::from_str::<DeploymentSpec>(
r"
version: v1
deployment: test
services:
test:
rtype: service
image:
name: test:latest
endpoints:
- name: http
protocol: http
port: 8080
path: /api
expose: public
- name: websocket
protocol: websocket
port: 8081
path: /ws
expose: internal
",
)
.unwrap()
.services
.remove("test")
.unwrap()
}
fn mock_service_spec_tcp_only() -> ServiceSpec {
mock_service_spec_tcp_only_port(9000)
}
fn mock_service_spec_tcp_only_port(port: u16) -> ServiceSpec {
use zlayer_spec::*;
let yaml = format!(
"
version: v1
deployment: test
services:
test:
rtype: service
image:
name: test:latest
endpoints:
- name: grpc
protocol: tcp
port: {port}
"
);
serde_yaml::from_str::<DeploymentSpec>(&yaml)
.unwrap()
.services
.remove("test")
.unwrap()
}
fn reserve_free_tcp_port() -> u16 {
let listener =
std::net::TcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral test port");
listener.local_addr().unwrap().port()
}
#[tokio::test]
async fn test_proxy_manager_new() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
assert_eq!(manager.route_count().await, 0);
assert!(manager.list_services().await.is_empty());
}
#[tokio::test]
async fn test_add_service_with_http_endpoints() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_service_spec_with_endpoints();
manager.add_service("api", &spec).await;
assert_eq!(manager.route_count().await, 2);
assert!(manager.has_service("api").await);
}
#[tokio::test]
async fn test_tcp_endpoints_tracked_not_routed() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_service_spec_tcp_only();
manager.add_service("grpc-service", &spec).await;
assert_eq!(manager.route_count().await, 0);
assert!(manager.has_service("grpc-service").await);
}
#[tokio::test]
async fn test_remove_service() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_service_spec_with_endpoints();
manager.add_service("api", &spec).await;
assert_eq!(manager.route_count().await, 2);
manager.remove_service("api").await;
assert_eq!(manager.route_count().await, 0);
assert!(!manager.has_service("api").await);
}
#[tokio::test]
async fn test_backend_management() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry.clone(), None);
let spec = mock_service_spec_with_endpoints();
manager.add_service("api", &spec).await;
let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:8081".parse().unwrap();
manager.add_backend("api", addr1).await;
manager.add_backend("api", addr2).await;
let resolved = registry.resolve(None, "/api").await.unwrap();
assert_eq!(resolved.backends.len(), 2);
manager.remove_backend("api", addr1).await;
let resolved = registry.resolve(None, "/api").await.unwrap();
assert_eq!(resolved.backends.len(), 1);
}
#[tokio::test]
async fn test_update_backends_replaces_all() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry.clone(), None);
let spec = mock_service_spec_with_endpoints();
manager.add_service("api", &spec).await;
let addr1: SocketAddr = "127.0.0.1:8080".parse().unwrap();
manager.add_backend("api", addr1).await;
let new_backends: Vec<SocketAddr> = vec![
"127.0.0.1:9000".parse().unwrap(),
"127.0.0.1:9001".parse().unwrap(),
"127.0.0.1:9002".parse().unwrap(),
];
manager.update_backends("api", new_backends).await;
let resolved = registry.resolve(None, "/api").await.unwrap();
assert_eq!(resolved.backends.len(), 3);
}
#[tokio::test]
async fn test_config_builder() {
let config = ProxyManagerConfig::new("0.0.0.0:8080".parse().unwrap())
.with_https("0.0.0.0:8443".parse().unwrap())
.with_http2(false);
assert_eq!(
config.http_addr,
"0.0.0.0:8080".parse::<SocketAddr>().unwrap()
);
assert_eq!(
config.https_addr,
Some("0.0.0.0:8443".parse::<SocketAddr>().unwrap())
);
assert!(!config.http2_enabled);
}
#[tokio::test]
async fn test_ensure_ports_differentiates_public_and_internal() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_service_spec_with_endpoints();
let result = manager.ensure_ports_for_service(&spec, None).await;
let _ = result;
}
#[tokio::test]
async fn test_ensure_ports_with_overlay_ip() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_service_spec_with_endpoints();
let overlay_ip: IpAddr = "10.200.0.5".parse().unwrap();
let result = manager
.ensure_ports_for_service(&spec, Some(overlay_ip))
.await;
let _ = result;
}
fn mock_mixed_service_spec() -> ServiceSpec {
use zlayer_spec::*;
serde_yaml::from_str::<DeploymentSpec>(
r"
version: v1
deployment: test
services:
mixed:
rtype: service
image:
name: test:latest
endpoints:
- name: http
protocol: http
port: 8080
path: /api
expose: public
- name: grpc
protocol: tcp
port: 9000
expose: public
- name: game
protocol: udp
port: 27015
expose: public
",
)
.unwrap()
.services
.remove("mixed")
.unwrap()
}
#[tokio::test]
async fn test_add_mixed_service_tracks_all_endpoints() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_mixed_service_spec();
manager.add_service("mixed", &spec).await;
assert_eq!(manager.route_count().await, 1);
assert!(manager.has_service("mixed").await);
}
#[tokio::test]
async fn test_ensure_ports_tcp_with_stream_registry() {
use zlayer_proxy::StreamService;
let stream_registry = Arc::new(StreamRegistry::new());
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let mut manager = ProxyManager::new(config, registry, None);
manager.set_stream_registry(stream_registry.clone());
let port = reserve_free_tcp_port();
let spec = mock_service_spec_tcp_only_port(port);
stream_registry.register_tcp(port, StreamService::new("grpc-service".to_string(), vec![]));
let result = manager.ensure_ports_for_service(&spec, None).await;
assert!(result.is_ok());
let tcp_ports = manager.tcp_listeners.read().await;
assert!(tcp_ports.contains(&port));
}
#[tokio::test]
async fn test_ensure_ports_tcp_without_stream_registry() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry, None);
let spec = mock_service_spec_tcp_only();
let result = manager.ensure_ports_for_service(&spec, None).await;
assert!(result.is_ok());
let tcp_ports = manager.tcp_listeners.read().await;
assert!(tcp_ports.is_empty());
}
#[tokio::test]
async fn test_stream_registry_setter() {
let stream_registry = Arc::new(StreamRegistry::new());
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let mut manager = ProxyManager::new(config, registry, None);
assert!(manager.stream_registry().is_none());
manager.set_stream_registry(stream_registry.clone());
assert!(manager.stream_registry().is_some());
}
#[tokio::test]
async fn test_registry_accessor() {
let config = ProxyManagerConfig::default();
let registry = Arc::new(ServiceRegistry::new());
let manager = ProxyManager::new(config, registry.clone(), None);
assert_eq!(Arc::as_ptr(&manager.registry()), Arc::as_ptr(®istry));
}
}