1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14#[cfg(feature = "logging")]
15use tracing::info_span;
16#[cfg(feature = "logging")]
17use tracing::Instrument;
18
19use crate::config::{extract_hostname, Config, SiteConfig};
20#[cfg(feature = "logging")]
21use crate::proxy::access_log::AccessLogGuard;
22use crate::proxy::access_log::{ensure_request_id, final_request_id};
23use crate::proxy::ActionResult;
24
25use crate::proxy::directives::{
26 apply_header_up, handle_header, handle_method, handle_redirect, handle_respond,
27 handle_reverse_proxy, handle_strip_prefix, handle_uri_replace,
28};
29
30type ResponseBody =
33 http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
34
35fn is_hop_header(name: &header::HeaderName) -> bool {
41 matches!(
42 name,
43 &header::CONNECTION
44 | &header::UPGRADE
45 | &header::TE
46 | &header::TRAILER
47 | &header::PROXY_AUTHENTICATE
48 | &header::PROXY_AUTHORIZATION
49 )
50}
51
52pub fn process_directives(
59 directives: &[crate::config::Directive],
60 req: &mut Request<Incoming>,
61 current_path: &str,
62) -> Result<ActionResult, String> {
63 let mut modified_path = current_path.to_string();
64
65 for directive in directives {
66 match directive {
67 crate::config::Directive::Header { name, value } => {
68 if let Err(e) = handle_header(name, value.as_deref(), req) {
69 info!(" Failed to apply header {}: {}", name, e);
70 }
71 }
72
73 crate::config::Directive::UriReplace { find, replace } => {
74 handle_uri_replace(find, replace, &mut modified_path);
75 }
76
77 crate::config::Directive::StripPrefix { prefix } => {
78 handle_strip_prefix(prefix, &mut modified_path);
79 }
80
81 crate::config::Directive::HandlePath {
82 pattern,
83 directives: nested_directives,
84 } => {
85 if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
86 info!(" Matched handle_path: {}", pattern);
87 return process_directives(nested_directives, req, &remaining_path);
88 }
89 }
90
91 crate::config::Directive::Method {
92 methods,
93 directives: nested_directives,
94 } => {
95 if handle_method(methods, req) {
96 info!(" Matched method directive");
97 return process_directives(nested_directives, req, &modified_path);
98 }
99 }
100
101 crate::config::Directive::Redirect { status, url } => {
102 return Ok(handle_redirect(status, url));
103 }
104
105 crate::config::Directive::Respond { status, body } => {
106 return Ok(handle_respond(status, body));
107 }
108
109 crate::config::Directive::ReverseProxy {
110 to,
111 connect_timeout,
112 read_timeout,
113 header_up,
114 } => {
115 return Ok(handle_reverse_proxy(
116 to,
117 &modified_path,
118 *connect_timeout,
119 *read_timeout,
120 header_up.clone(),
121 ));
122 }
123 }
124 }
125
126 Err(format!(
127 "No action directive (respond or reverse_proxy) found in configuration for path: {}",
128 current_path
129 ))
130}
131
132pub async fn proxy(
143 mut req: Request<Incoming>,
144 client: Client<HttpsConnector<HttpConnector>, Incoming>,
145 config: Arc<Config>,
146 remote_addr: std::net::SocketAddr,
147 is_tls: bool,
148) -> Result<Response<ResponseBody>, Error> {
149 let initial_request_id = ensure_request_id(&mut req);
151
152 #[cfg(feature = "logging")]
153 let span = info_span!("request", req_id = %initial_request_id);
154
155 let future = async move {
156 let path = req.uri().path().to_string();
157 let host = req
158 .headers()
159 .get(hyper::header::HOST)
160 .and_then(|h| h.to_str().ok())
161 .unwrap_or("localhost");
162
163 #[cfg(feature = "metrics")]
164 let mut metrics_guard =
165 crate::metrics::MetricsGuard::new(req.method().to_string(), host.to_string());
166
167 #[cfg(feature = "logging")]
168 let mut log_guard = AccessLogGuard::new(
169 initial_request_id.clone(),
170 remote_addr,
171 req.method().to_string(),
172 path.clone(),
173 host.to_string(),
174 );
175
176 let site_config = match find_site(&config, host, is_tls) {
178 Some(config) => config,
179 None => {
180 error!("No configuration found for host: {}", host);
181 let (response, _body_len) = error_response_with_id(
182 StatusCode::NOT_FOUND,
183 &format!("No configuration found for host: {}", host),
184 &initial_request_id,
185 );
186 #[cfg(feature = "logging")]
187 {
188 log_guard.set_bytes_sent(_body_len);
189 log_guard.finish(404);
190 }
191 #[cfg(feature = "metrics")]
192 metrics_guard.record(404);
193 return Ok(response);
194 }
195 };
196
197 let action_result = match process_directives(&site_config.directives, &mut req, &path) {
199 Ok(result) => result,
200 Err(e) => {
201 error!("Directive processing error: {}", e);
202 let final_id = final_request_id(&req, &initial_request_id);
203 #[cfg(feature = "logging")]
204 {
205 log_guard.set_request_id(final_id.clone());
206 tracing::Span::current().record("req_id", final_id.as_str());
207 }
208 let (response, _body_len) =
209 error_response_with_id(StatusCode::INTERNAL_SERVER_ERROR, &e, &final_id);
210 #[cfg(feature = "logging")]
211 {
212 log_guard.set_bytes_sent(_body_len);
213 log_guard.finish(500);
214 }
215 #[cfg(feature = "metrics")]
216 metrics_guard.record(500);
217 return Ok(response);
218 }
219 };
220
221 let request_id = final_request_id(&req, &initial_request_id);
223 #[cfg(feature = "logging")]
224 {
225 log_guard.set_request_id(request_id.clone());
226 tracing::Span::current().record("req_id", request_id.as_str());
228 }
229
230 match action_result {
232 ActionResult::Redirect { status, url } => {
233 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
234
235 let boxed: ResponseBody = Full::new(Bytes::new())
236 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
237 .boxed();
238 let response = Response::builder()
239 .status(status_code)
240 .header("Location", &url)
241 .header("X-Request-ID", &request_id)
242 .body(boxed)?;
243 #[cfg(feature = "logging")]
244 {
245 log_guard.set_bytes_sent(0);
246 log_guard.finish(status_code.as_u16());
247 }
248 #[cfg(feature = "metrics")]
249 metrics_guard.record(status_code.as_u16());
250 Ok(response)
251 }
252 ActionResult::Respond { status, body } => {
253 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
254 let _body_len = body.len();
255
256 let boxed: ResponseBody = Full::new(Bytes::from(body))
257 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
258 .boxed();
259 let response = Response::builder()
260 .status(status_code)
261 .header("X-Request-ID", &request_id)
262 .body(boxed)?;
263 #[cfg(feature = "logging")]
264 {
265 log_guard.set_bytes_sent(_body_len);
266 log_guard.finish(status_code.as_u16());
267 }
268 #[cfg(feature = "metrics")]
269 metrics_guard.record(status_code.as_u16());
270 Ok(response)
271 }
272 ActionResult::ReverseProxy {
273 backend_url,
274 path_to_send,
275 connect_timeout: _,
276 read_timeout,
277 header_up,
278 } => {
279 let backend_with_proto =
281 if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
282 backend_url
283 } else {
284 format!("http://{}", backend_url)
285 };
286
287 let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
289 parts.path_and_query = Some(path_to_send.parse()?);
290 let new_uri = Uri::from_parts(parts)?;
291
292 let original_request_uri = req
295 .uri()
296 .path_and_query()
297 .map(|pq| pq.as_str().to_string())
298 .unwrap_or_default();
299
300 *req.uri_mut() = new_uri.clone();
301
302 let upstream_host = new_uri
304 .authority()
305 .map(|a| a.as_str().to_string())
306 .unwrap_or_default();
307
308 let remote_ip = crate::auth::headers::extract_remote_ip(&req)
310 .unwrap_or_else(|| remote_addr.ip().to_string());
311
312 let original_host_header = req.headers().get(hyper::header::HOST).cloned();
314
315 req.headers_mut().remove(hyper::header::HOST);
317 if let Some(authority) = new_uri.authority() {
318 if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>()
319 {
320 req.headers_mut().insert(hyper::header::HOST, host_value);
321 }
322 }
323
324 if let Some(host_value) = original_host_header.clone() {
326 req.headers_mut().insert("X-Forwarded-Host", host_value);
327 }
328
329 req.headers_mut().insert(
331 "X-Forwarded-Proto",
332 hyper::header::HeaderValue::from_static(if is_tls { "https" } else { "http" }),
333 );
334
335 if let Ok(ip_value) =
337 hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
338 {
339 req.headers_mut().insert("X-Forwarded-For", ip_value);
340 }
341
342 req.headers_mut().remove(header::CONNECTION);
344 req.headers_mut().remove("accept-encoding");
345
346 apply_header_up(
347 &header_up,
348 &mut req,
349 &upstream_host,
350 &original_request_uri,
351 &remote_ip,
352 );
353
354 let backend_timeout = read_timeout.unwrap_or(30);
356 match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
357 Ok(Ok(response)) => {
358 let status = response.status();
359 let headers = response.headers().clone();
360
361 let mut builder = Response::builder().status(status);
363
364 for (name, value) in headers.iter() {
366 if !is_hop_header(name) && name != header::CONTENT_LENGTH {
367 builder = builder.header(name, value);
368 }
369 }
370
371 let (_, incoming_body) = response.into_parts();
372 let boxed: ResponseBody = incoming_body
373 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
374 .boxed();
375
376 let response = builder.header("X-Request-ID", &request_id).body(boxed)?;
377 #[cfg(feature = "logging")]
378 log_guard.finish(status.as_u16());
379 #[cfg(feature = "metrics")]
380 metrics_guard.record(status.as_u16());
381 Ok(response)
382 }
383 Ok(Err(e)) => {
384 error!("Backend connection failed: {:?}", e);
385 if e.is_connect() {
386 error!(" Reason: Connection refused - backend unavailable");
387 } else {
388 error!(" Reason: Other connection error");
389 }
390
391 let (response, _body_len) = error_response_with_id(
392 StatusCode::BAD_GATEWAY,
393 "Backend service unavailable",
394 &request_id,
395 );
396 #[cfg(feature = "logging")]
397 {
398 log_guard.set_bytes_sent(_body_len);
399 log_guard.finish(502);
400 }
401 #[cfg(feature = "metrics")]
402 metrics_guard.record(502);
403 Ok(response)
404 }
405 Err(_) => {
406 error!(
407 "Backend request timed out after {} seconds",
408 backend_timeout
409 );
410
411 let (response, _body_len) = error_response_with_id(
412 StatusCode::GATEWAY_TIMEOUT,
413 "Backend request timed out",
414 &request_id,
415 );
416 #[cfg(feature = "logging")]
417 {
418 log_guard.set_bytes_sent(_body_len);
419 log_guard.finish(504);
420 }
421 #[cfg(feature = "metrics")]
422 metrics_guard.record(504);
423 Ok(response)
424 }
425 }
426 }
427 }
428 };
429
430 #[cfg(feature = "logging")]
431 let future = future.instrument(span);
432
433 future.await
434}
435
436fn error_response_with_id(
440 status: StatusCode,
441 message: &str,
442 request_id: &str,
443) -> (Response<ResponseBody>, usize) {
444 let body = format!(
445 r#"<!DOCTYPE html>
446 <html>
447 <head><title>{} {}</title></head>
448 <body>
449 <h1>{} {}</h1>
450 <p>{}</p>
451 <hr>
452 <p><em>Rust Proxy Server</em></p>
453 </body>
454 </html>"#,
455 status.as_u16(),
456 status.canonical_reason().unwrap_or("Error"),
457 status.as_u16(),
458 status.canonical_reason().unwrap_or("Error"),
459 message
460 );
461
462 let body_len = body.len();
463 let full = Full::new(Bytes::from(body));
464 let boxed: ResponseBody = full
465 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
466 .boxed();
467
468 let mut builder = Response::builder()
469 .status(status)
470 .header("Content-Type", "text/html; charset=utf-8");
471
472 if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
473 builder = builder.header("X-Request-ID", val);
474 }
475
476 let response = builder.body(boxed).unwrap_or_else(|_| {
477 Response::new(
478 Full::new(Bytes::from("Internal Server Error"))
479 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
480 .boxed(),
481 )
482 });
483
484 (response, body_len)
485}
486
487pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
490 if let Some(prefix) = pattern.strip_suffix("/*") {
491 if path.starts_with(prefix) {
492 let remaining = path.strip_prefix(prefix).unwrap_or(path);
493 Some(remaining.to_string())
494 } else {
495 None
496 }
497 } else if pattern == path {
498 Some("/".to_string())
499 } else {
500 None
501 }
502}
503
504pub fn find_site<'a>(config: &'a Config, host: &str, is_tls: bool) -> Option<&'a SiteConfig> {
526 if let Some(site) = config.sites.get(host) {
528 return Some(site);
529 }
530
531 let has_port = if host.starts_with('[') {
535 if let Some(bracket_end) = host.find(']') {
537 host[bracket_end..].contains(':')
538 } else {
539 false
540 }
541 } else {
542 host.contains(':')
543 };
544
545 if !has_port {
546 let default_port = if is_tls { 443 } else { 80 };
548 let candidate = format!("{}:{}", host, default_port);
549 if let Some(site) = config.sites.get(&candidate) {
550 return Some(site);
551 }
552
553 if is_tls {
555 let mut matches = config.sites.values().filter(|s| {
556 s.tls.is_some() && extract_hostname(&s.address).eq_ignore_ascii_case(host)
557 });
558 if let Some(site) = matches.next() {
559 if matches.next().is_none() {
560 return Some(site);
561 }
562 }
563 }
564 } else {
565 let hostname = if host.starts_with('[') {
567 let end = host.find(']').unwrap_or(host.len());
569 &host[1..end]
570 } else {
571 host.rsplit_once(':').map(|(name, _)| name).unwrap_or(host)
573 };
574 if let Some(site) = config.sites.get(hostname) {
575 return Some(site);
576 }
577 }
578
579 None
580}
581
582#[cfg(test)]
583mod find_site_tests {
584 use super::*;
585 use std::collections::HashMap;
586
587 fn make_config(sites: Vec<(&str, bool)>) -> Config {
588 let mut map = HashMap::new();
589 for (addr, has_tls) in sites {
590 map.insert(
591 addr.to_string(),
592 crate::config::SiteConfig {
593 address: addr.to_string(),
594 directives: vec![],
595 tls: if has_tls {
596 Some(crate::config::TlsConfig {
597 cert_path: "/fake/cert.pem".to_string(),
598 key_path: "/fake/key.pem".to_string(),
599 })
600 } else {
601 None
602 },
603 },
604 );
605 }
606 Config { sites: map }
607 }
608
609 #[test]
610 fn test_exact_match() {
611 let config = make_config(vec![("example.com:443", true)]);
612 assert!(find_site(&config, "example.com:443", true).is_some());
613 }
614
615 #[test]
616 fn test_tls_host_without_port_finds_443() {
617 let config = make_config(vec![("example.com:443", true)]);
618 assert!(
620 find_site(&config, "example.com", true).is_some(),
621 "Should find example.com:443 when Host has no port and is_tls=true"
622 );
623 }
624
625 #[test]
626 fn test_http_host_without_port_finds_80() {
627 let config = make_config(vec![("example.com:80", false)]);
628 assert!(
630 find_site(&config, "example.com", false).is_some(),
631 "Should find example.com:80 when Host has no port and is_tls=false"
632 );
633 }
634
635 #[test]
636 fn test_tls_host_without_port_no_match_on_80() {
637 let config = make_config(vec![("example.com:80", false)]);
638 assert!(
640 find_site(&config, "example.com", true).is_none(),
641 "TLS on port 443 should not find :80 site"
642 );
643 }
644
645 #[test]
646 fn test_host_with_port_strips_port_fallback() {
647 let config = make_config(vec![("example.com", false)]);
648 assert!(
650 find_site(&config, "example.com:8080", false).is_some(),
651 "Should strip port from Host and find config without port"
652 );
653 }
654
655 #[test]
656 fn test_tls_host_without_port_finds_non_standard_port() {
657 let config = make_config(vec![("alpha.local:8443", true)]);
658 assert!(
659 find_site(&config, "alpha.local", true).is_some(),
660 "Should find alpha.local:8443 when Host has no port on TLS"
661 );
662 }
663
664 #[test]
665 fn test_no_match() {
666 let config = make_config(vec![("other.com:443", true)]);
667 assert!(find_site(&config, "example.com", true).is_none());
668 }
669}