1use anyhow::Error;
2use bytes::Bytes;
3use http_body_util::{BodyExt, Full};
4use hyper::body::Incoming;
5use hyper::header;
6use hyper::{Request, Response, StatusCode, Uri};
7use hyper_rustls::HttpsConnector;
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::Client;
10use std::sync::Arc;
11use tokio::time::{timeout, Duration};
12use tracing::{error, info};
13
14use crate::config::Config;
15use crate::proxy::ActionResult;
16
17use crate::proxy::directives::{
18 handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy,
19 handle_strip_prefix, handle_uri_replace,
20};
21
22type ResponseBody =
25 http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
26
27fn is_hop_header(name: &header::HeaderName) -> bool {
33 matches!(
34 name,
35 &header::CONNECTION
36 | &header::UPGRADE
37 | &header::TE
38 | &header::TRAILER
39 | &header::PROXY_AUTHENTICATE
40 | &header::PROXY_AUTHORIZATION
41 )
42}
43
44pub fn process_directives(
47 directives: &[crate::config::Directive],
48 req: &mut Request<Incoming>,
49 current_path: &str,
50) -> Result<ActionResult, String> {
51 let mut modified_path = current_path.to_string();
52
53 for directive in directives {
54 match directive {
55 crate::config::Directive::Header { name, value } => {
59 if let Err(e) = handle_header(name, value.as_deref(), req) {
60 info!(" Failed to apply header {}: {}", name, e);
61 }
62 }
63
64 crate::config::Directive::UriReplace { find, replace } => {
66 handle_uri_replace(find, replace, &mut modified_path);
67 }
68
69 crate::config::Directive::StripPrefix { prefix } => {
71 handle_strip_prefix(prefix, &mut modified_path);
72 }
73
74 crate::config::Directive::HandlePath {
76 pattern,
77 directives: nested_directives,
78 } => {
79 if let Some(remaining_path) = match_pattern(pattern, &modified_path) {
80 info!(" Matched handle_path: {}", pattern);
81 return process_directives(nested_directives, req, &remaining_path);
83 }
84 }
85
86 crate::config::Directive::Method {
88 methods,
89 directives: nested_directives,
90 } => {
91 if handle_method(methods, req) {
92 info!(" Matched method directive");
93 return process_directives(nested_directives, req, &modified_path);
95 }
96 }
97
98 crate::config::Directive::Redirect { status, url } => {
100 return Ok(handle_redirect(status, url));
101 }
102
103 crate::config::Directive::Respond { status, body } => {
105 return Ok(handle_respond(status, body));
106 }
107
108 crate::config::Directive::ReverseProxy {
110 to,
111 connect_timeout,
112 read_timeout,
113 } => {
114 return Ok(handle_reverse_proxy(
115 to,
116 &modified_path,
117 *connect_timeout,
118 *read_timeout,
119 ));
120 }
121 }
122 }
123
124 Err(format!(
125 "No action directive (respond or reverse_proxy) found in configuration for path: {}",
126 current_path
127 ))
128}
129
130pub async fn proxy(
141 mut req: Request<Incoming>,
142 client: Client<HttpsConnector<HttpConnector>, Incoming>,
143 config: Arc<Config>,
144 remote_addr: std::net::SocketAddr,
145) -> Result<Response<ResponseBody>, Error> {
146 let path = req.uri().path().to_string();
148
149 let host = req
151 .headers()
152 .get(hyper::header::HOST)
153 .and_then(|h| h.to_str().ok())
154 .unwrap_or("localhost");
155
156 if tracing::enabled!(tracing::Level::INFO) {
158 }
161
162 let site_config = match config.sites.get(host) {
164 Some(config) => config,
165 None => {
166 error!("No configuration found for host: {}", host);
167 return Ok(error_response(
168 StatusCode::NOT_FOUND,
169 &format!("No configuration found for host: {}", host),
170 ));
171 }
172 };
173
174 let action_result =
176 process_directives(&site_config.directives, &mut req, &path).map_err(anyhow::Error::msg)?;
177
178 match action_result {
180 ActionResult::Redirect { status, url } => {
181 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::FOUND);
182 let boxed: ResponseBody = Full::new(Bytes::from(url.clone()))
183 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
184 .boxed();
185 Ok(Response::builder()
186 .status(status_code)
187 .header("Location", &url)
188 .body(boxed)?)
189 }
190 ActionResult::Respond { status, body } => {
191 let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::OK);
192 let boxed: ResponseBody = Full::new(Bytes::from(body))
193 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
194 .boxed();
195 Ok(Response::builder().status(status_code).body(boxed)?)
196 }
197 ActionResult::ReverseProxy {
198 backend_url,
199 path_to_send,
200 connect_timeout: _,
201 read_timeout,
202 } => {
203 let backend_with_proto =
205 if backend_url.starts_with("http://") || backend_url.starts_with("https://") {
206 backend_url
207 } else {
208 format!("http://{}", backend_url)
209 };
210
211 let mut parts = backend_with_proto.parse::<Uri>()?.into_parts();
213 parts.path_and_query = Some(path_to_send.parse()?);
214 let new_uri = Uri::from_parts(parts)?;
215
216 if tracing::enabled!(tracing::Level::INFO) {
218 }
221
222 *req.uri_mut() = new_uri.clone();
223
224 let original_host_header = req.headers().get(hyper::header::HOST).cloned();
227
228 req.headers_mut().remove(hyper::header::HOST);
230 if let Some(authority) = new_uri.authority() {
231 if let Ok(host_value) = authority.as_str().parse::<hyper::header::HeaderValue>() {
232 req.headers_mut().insert(hyper::header::HOST, host_value);
233 }
234 }
235
236 if let Some(host_value) = original_host_header.clone() {
239 req.headers_mut().insert("X-Forwarded-Host", host_value);
240 }
241
242 let original_scheme = req.uri().scheme_str().unwrap_or("http");
244 match original_scheme {
246 "http" => {
247 req.headers_mut().insert(
248 "X-Forwarded-Proto",
249 hyper::header::HeaderValue::from_static("http"),
250 );
251 }
252 "https" => {
253 req.headers_mut().insert(
254 "X-Forwarded-Proto",
255 hyper::header::HeaderValue::from_static("https"),
256 );
257 }
258 _ => {} }
260
261 if let Ok(ip_value) =
263 hyper::header::HeaderValue::from_str(&remote_addr.ip().to_string())
264 {
265 req.headers_mut().insert("X-Forwarded-For", ip_value);
266 }
267
268 req.headers_mut().remove(header::CONNECTION);
271
272 req.headers_mut().remove("accept-encoding");
275
276 let backend_timeout = read_timeout.unwrap_or(30);
278 match timeout(Duration::from_secs(backend_timeout), client.request(req)).await {
279 Ok(Ok(response)) => {
280 let status = response.status();
282 let headers = response.headers().clone();
283
284 if tracing::enabled!(tracing::Level::INFO) {
286 }
289
290 let mut builder = Response::builder().status(status);
292
293 for (name, value) in headers.iter() {
297 if !is_hop_header(name) && name != header::CONTENT_LENGTH {
298 builder = builder.header(name, value);
299 }
300 }
301
302 let (_, incoming_body) = response.into_parts();
304 let boxed: ResponseBody = incoming_body
305 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
306 .boxed();
307
308 Ok(builder.body(boxed)?)
309 }
310 Ok(Err(e)) => {
311 error!("Backend connection failed: {:?}", e);
313
314 if e.is_connect() {
315 error!(" Reason: Connection refused - backend unavailable");
316 } else {
317 error!(" Reason: Other connection error");
318 }
319
320 Ok(error_response(
321 StatusCode::BAD_GATEWAY,
322 "Backend service unavailable",
323 ))
324 }
325 Err(_) => {
326 error!(
328 "Backend request timed out after {} seconds",
329 backend_timeout
330 );
331
332 Ok(error_response(
333 StatusCode::GATEWAY_TIMEOUT,
334 "Backend request timed out",
335 ))
336 }
337 }
338 }
339 }
340}
341
342fn error_response(status: StatusCode, message: &str) -> Response<ResponseBody> {
344 let body = format!(
345 r#"<!DOCTYPE html>
346 <html>
347 <head><title>{} {}</title></head>
348 <body>
349 <h1>{} {}</h1>
350 <p>{}</p>
351 <hr>
352 <p><em>Rust Proxy Server</em></p>
353 </body>
354 </html>"#,
355 status.as_u16(),
356 status.canonical_reason().unwrap_or("Error"),
357 status.as_u16(),
358 status.canonical_reason().unwrap_or("Error"),
359 message
360 );
361
362 let full = Full::new(Bytes::from(body));
363 let boxed: ResponseBody = full
364 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
365 .boxed();
366
367 Response::builder()
368 .status(status)
369 .header("Content-Type", "text/html; charset=utf-8")
370 .body(boxed)
371 .unwrap()
372}
373
374pub fn match_pattern(pattern: &str, path: &str) -> Option<String> {
377 if let Some(prefix) = pattern.strip_suffix("/*") {
378 if path.starts_with(prefix) {
379 let remaining = path.strip_prefix(prefix).unwrap_or(path);
381 Some(remaining.to_string())
382 } else {
383 None
384 }
385 } else if pattern == path {
386 Some("/".to_string()) } else {
388 None
389 }
390}