1use crate::acme::CertManager;
7use crate::config::ProxyConfig;
8use crate::error::{ProxyError, Result};
9use crate::lb::LoadBalancer;
10use crate::network_policy::NetworkPolicyChecker;
11use crate::routes::ServiceRegistry;
12use crate::service::ReverseProxyService;
13use crate::sni_resolver::SniCertResolver;
14use hyper::body::Incoming;
15use hyper::server::conn::http1;
16use hyper::service::service_fn;
17use hyper::Request;
18use hyper_util::rt::TokioIo;
19use std::net::SocketAddr;
20use std::sync::Arc;
21use tokio::net::TcpListener;
22use tokio::sync::watch;
23use tokio_rustls::TlsAcceptor;
24use tracing::{debug, error, info, warn};
25
26pub struct ProxyServer {
28 config: Arc<ProxyConfig>,
30 registry: Arc<ServiceRegistry>,
32 load_balancer: Arc<LoadBalancer>,
34 shutdown_tx: watch::Sender<bool>,
36 shutdown_rx: watch::Receiver<bool>,
38 tls_acceptor: Option<TlsAcceptor>,
40 cert_manager: Option<Arc<CertManager>>,
42 network_policy_checker: Option<NetworkPolicyChecker>,
44}
45
46impl ProxyServer {
47 pub fn new(
49 config: ProxyConfig,
50 registry: Arc<ServiceRegistry>,
51 load_balancer: Arc<LoadBalancer>,
52 ) -> Self {
53 let (shutdown_tx, shutdown_rx) = watch::channel(false);
54
55 Self {
56 config: Arc::new(config),
57 registry,
58 load_balancer,
59 shutdown_tx,
60 shutdown_rx,
61 tls_acceptor: None,
62 cert_manager: None,
63 network_policy_checker: None,
64 }
65 }
66
67 pub fn with_registry(
69 config: ProxyConfig,
70 registry: Arc<ServiceRegistry>,
71 load_balancer: Arc<LoadBalancer>,
72 ) -> Self {
73 Self::new(config, registry, load_balancer)
74 }
75
76 pub fn with_tls_resolver(
78 config: ProxyConfig,
79 registry: Arc<ServiceRegistry>,
80 load_balancer: Arc<LoadBalancer>,
81 resolver: Arc<SniCertResolver>,
82 ) -> Self {
83 let tls_config = rustls::ServerConfig::builder()
84 .with_no_client_auth()
85 .with_cert_resolver(resolver);
86 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
87 let (shutdown_tx, shutdown_rx) = watch::channel(false);
88
89 Self {
90 config: Arc::new(config),
91 registry,
92 load_balancer,
93 shutdown_tx,
94 shutdown_rx,
95 tls_acceptor: Some(acceptor),
96 cert_manager: None,
97 network_policy_checker: None,
98 }
99 }
100
101 #[must_use]
103 pub fn with_cert_manager(mut self, cm: Arc<CertManager>) -> Self {
104 self.cert_manager = Some(cm);
105 self
106 }
107
108 #[must_use]
110 pub fn with_network_policy_checker(mut self, checker: NetworkPolicyChecker) -> Self {
111 self.network_policy_checker = Some(checker);
112 self
113 }
114
115 #[must_use]
117 pub fn has_tls(&self) -> bool {
118 self.tls_acceptor.is_some()
119 }
120
121 #[must_use]
123 pub fn tls_acceptor(&self) -> Option<&TlsAcceptor> {
124 self.tls_acceptor.as_ref()
125 }
126
127 #[must_use]
129 pub fn registry(&self) -> Arc<ServiceRegistry> {
130 self.registry.clone()
131 }
132
133 #[must_use]
135 pub fn config(&self) -> Arc<ProxyConfig> {
136 self.config.clone()
137 }
138
139 pub fn shutdown(&self) {
141 let _ = self.shutdown_tx.send(true);
142 }
143
144 pub async fn run(&self) -> Result<()> {
151 let addr = self.config.server.http_addr;
152 let listener = TcpListener::bind(addr)
153 .await
154 .map_err(|e| ProxyError::BindFailed {
155 addr,
156 reason: e.to_string(),
157 })?;
158
159 info!(addr = %addr, "HTTP proxy server listening");
160
161 self.accept_loop(listener).await
162 }
163
164 pub async fn run_on(&self, addr: SocketAddr) -> Result<()> {
171 let listener = TcpListener::bind(addr)
172 .await
173 .map_err(|e| ProxyError::BindFailed {
174 addr,
175 reason: e.to_string(),
176 })?;
177
178 info!(addr = %addr, "HTTP proxy server listening");
179
180 self.accept_loop(listener).await
181 }
182
183 async fn accept_loop(&self, listener: TcpListener) -> Result<()> {
184 let mut shutdown_rx = self.shutdown_rx.clone();
185
186 loop {
187 tokio::select! {
188 _ = shutdown_rx.changed() => {
190 if *shutdown_rx.borrow() {
191 info!("Shutting down proxy server");
192 break;
193 }
194 }
195
196 result = listener.accept() => {
198 match result {
199 Ok((stream, remote_addr)) => {
200 let registry = self.registry.clone();
201 let load_balancer = self.load_balancer.clone();
202 let config = self.config.clone();
203 let cert_manager = self.cert_manager.clone();
204 let npc = self.network_policy_checker.clone();
205
206 tokio::spawn(async move {
207 if let Err(e) = Self::handle_connection(
208 stream,
209 remote_addr,
210 registry,
211 load_balancer,
212 config,
213 cert_manager,
214 npc,
215 ).await {
216 debug!(
217 error = %e,
218 remote_addr = %remote_addr,
219 "Connection error"
220 );
221 }
222 });
223 }
224 Err(e) => {
225 warn!(error = %e, "Failed to accept connection");
226 }
227 }
228 }
229 }
230 }
231
232 Ok(())
233 }
234
235 #[allow(clippy::too_many_arguments)]
236 async fn handle_connection(
237 stream: tokio::net::TcpStream,
238 remote_addr: SocketAddr,
239 registry: Arc<ServiceRegistry>,
240 load_balancer: Arc<LoadBalancer>,
241 config: Arc<ProxyConfig>,
242 cert_manager: Option<Arc<CertManager>>,
243 network_policy_checker: Option<NetworkPolicyChecker>,
244 ) -> Result<()> {
245 let io = TokioIo::new(stream);
246
247 let mut service =
248 ReverseProxyService::new(registry, load_balancer, config).with_remote_addr(remote_addr);
249 if let Some(cm) = cert_manager {
250 service = service.with_cert_manager(cm);
251 }
252 if let Some(checker) = network_policy_checker {
253 service = service.with_network_policy_checker(checker);
254 }
255
256 let service = service_fn(move |req: Request<Incoming>| {
257 let svc = service.clone();
258 async move {
259 match svc.proxy_request(req).await {
260 Ok(response) => Ok::<_, hyper::Error>(response),
261 Err(e) => {
262 error!(error = %e, "Proxy error");
263 Ok(ReverseProxyService::error_response(&e))
264 }
265 }
266 }
267 });
268
269 http1::Builder::new()
270 .preserve_header_case(true)
271 .title_case_headers(false)
272 .serve_connection(io, service)
273 .with_upgrades()
274 .await
275 .map_err(ProxyError::Hyper)?;
276
277 Ok(())
278 }
279
280 pub async fn run_https(&self) -> Result<()> {
289 let acceptor = self
290 .tls_acceptor
291 .as_ref()
292 .ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
293
294 let addr = self.config.server.https_addr;
295 let listener = TcpListener::bind(addr)
296 .await
297 .map_err(|e| ProxyError::BindFailed {
298 addr,
299 reason: e.to_string(),
300 })?;
301
302 info!(addr = %addr, "HTTPS proxy server listening");
303
304 self.accept_loop_tls(listener, acceptor.clone()).await
305 }
306
307 pub async fn run_https_on(&self, addr: SocketAddr) -> Result<()> {
314 let acceptor = self
315 .tls_acceptor
316 .as_ref()
317 .ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
318
319 let listener = TcpListener::bind(addr)
320 .await
321 .map_err(|e| ProxyError::BindFailed {
322 addr,
323 reason: e.to_string(),
324 })?;
325
326 info!(addr = %addr, "HTTPS proxy server listening");
327
328 self.accept_loop_tls(listener, acceptor.clone()).await
329 }
330
331 #[allow(clippy::similar_names)]
341 pub async fn run_both(&self) -> Result<()> {
342 let http_addr = self.config.server.http_addr;
343 let https_addr = self.config.server.https_addr;
344
345 let acceptor = self
346 .tls_acceptor
347 .as_ref()
348 .ok_or_else(|| ProxyError::Config("TLS not configured".to_string()))?;
349
350 let http_listener =
351 TcpListener::bind(http_addr)
352 .await
353 .map_err(|e| ProxyError::BindFailed {
354 addr: http_addr,
355 reason: e.to_string(),
356 })?;
357
358 let https_listener =
359 TcpListener::bind(https_addr)
360 .await
361 .map_err(|e| ProxyError::BindFailed {
362 addr: https_addr,
363 reason: e.to_string(),
364 })?;
365
366 info!(http = %http_addr, https = %https_addr, "Proxy server listening");
367
368 let http_future = self.accept_loop(http_listener);
370 let https_future = self.accept_loop_tls(https_listener, acceptor.clone());
371
372 tokio::select! {
373 result = http_future => result,
374 result = https_future => result,
375 }
376 }
377
378 async fn accept_loop_tls(&self, listener: TcpListener, acceptor: TlsAcceptor) -> Result<()> {
379 let mut shutdown_rx = self.shutdown_rx.clone();
380
381 loop {
382 tokio::select! {
383 _ = shutdown_rx.changed() => {
385 if *shutdown_rx.borrow() {
386 info!("Shutting down HTTPS proxy server");
387 break;
388 }
389 }
390
391 result = listener.accept() => {
393 match result {
394 Ok((stream, remote_addr)) => {
395 let registry = self.registry.clone();
396 let load_balancer = self.load_balancer.clone();
397 let config = self.config.clone();
398 let acceptor = acceptor.clone();
399 let cert_manager = self.cert_manager.clone();
400 let npc = self.network_policy_checker.clone();
401
402 tokio::spawn(async move {
403 if let Err(e) = Self::handle_tls_connection(
404 stream,
405 remote_addr,
406 registry,
407 load_balancer,
408 config,
409 acceptor,
410 cert_manager,
411 npc,
412 ).await {
413 debug!(
414 error = %e,
415 remote_addr = %remote_addr,
416 "TLS connection error"
417 );
418 }
419 });
420 }
421 Err(e) => {
422 warn!(error = %e, "Failed to accept TLS connection");
423 }
424 }
425 }
426 }
427 }
428
429 Ok(())
430 }
431
432 #[allow(clippy::too_many_arguments)]
433 async fn handle_tls_connection(
434 stream: tokio::net::TcpStream,
435 remote_addr: SocketAddr,
436 registry: Arc<ServiceRegistry>,
437 load_balancer: Arc<LoadBalancer>,
438 config: Arc<ProxyConfig>,
439 acceptor: TlsAcceptor,
440 cert_manager: Option<Arc<CertManager>>,
441 network_policy_checker: Option<NetworkPolicyChecker>,
442 ) -> Result<()> {
443 let tls_stream = acceptor
445 .accept(stream)
446 .await
447 .map_err(|e| ProxyError::Tls(format!("TLS handshake failed: {e}")))?;
448
449 let io = TokioIo::new(tls_stream);
450
451 let mut service = ReverseProxyService::new(registry, load_balancer, config)
452 .with_remote_addr(remote_addr)
453 .with_tls(true);
454 if let Some(cm) = cert_manager {
455 service = service.with_cert_manager(cm);
456 }
457 if let Some(checker) = network_policy_checker {
458 service = service.with_network_policy_checker(checker);
459 }
460
461 let service = service_fn(move |req: Request<Incoming>| {
462 let svc = service.clone();
463 async move {
464 match svc.proxy_request(req).await {
465 Ok(response) => Ok::<_, hyper::Error>(response),
466 Err(e) => {
467 error!(error = %e, "Proxy error");
468 Ok(ReverseProxyService::error_response(&e))
469 }
470 }
471 }
472 });
473
474 http1::Builder::new()
475 .preserve_header_case(true)
476 .title_case_headers(false)
477 .serve_connection(io, service)
478 .with_upgrades()
479 .await
480 .map_err(ProxyError::Hyper)?;
481
482 Ok(())
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::lb::LoadBalancer;
490 use crate::routes::{ResolvedService, RouteEntry};
491 use zlayer_spec::{ExposeType, Protocol};
492
493 fn make_entry(
495 service: &str,
496 host: Option<&str>,
497 path: &str,
498 backends: Vec<SocketAddr>,
499 ) -> RouteEntry {
500 RouteEntry {
501 service_name: service.to_string(),
502 endpoint_name: "http".to_string(),
503 host: host.map(std::string::ToString::to_string),
504 path_prefix: path.to_string(),
505 resolved: ResolvedService {
506 name: service.to_string(),
507 backends,
508 use_tls: false,
509 sni_hostname: String::new(),
510 expose: ExposeType::Public,
511 protocol: Protocol::Http,
512 strip_prefix: false,
513 path_prefix: path.to_string(),
514 target_port: 8080,
515 },
516 }
517 }
518
519 #[tokio::test]
520 async fn test_server_shutdown() {
521 let registry = Arc::new(ServiceRegistry::new());
522 let lb = Arc::new(LoadBalancer::new());
523 let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
524
525 let shutdown_tx = server.shutdown_tx.clone();
527
528 let _ = shutdown_tx.send(true);
530
531 }
534
535 #[tokio::test]
536 async fn test_registry_integration() {
537 let registry = Arc::new(ServiceRegistry::new());
538
539 registry
541 .register(make_entry(
542 "test-service",
543 None,
544 "/api",
545 vec!["127.0.0.1:8081".parse().unwrap()],
546 ))
547 .await;
548
549 let lb = Arc::new(LoadBalancer::new());
550 let server = ProxyServer::new(ProxyConfig::default(), registry, lb);
551
552 let reg = server.registry();
554 assert_eq!(reg.route_count().await, 1);
555 }
556}