1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5 body::Body,
6 extract::Extension,
7 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10use hyper_util::rt::TokioIo;
11
12#[derive(Clone)]
13pub struct ProxyState {
14 cache: CacheStore,
15 config: CreateProxyConfig,
16}
17
18impl ProxyState {
19 pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
20 Self { cache, config }
21 }
22}
23
24fn is_upgrade_request(headers: &HeaderMap) -> bool {
32 headers
33 .get(axum::http::header::CONNECTION)
34 .and_then(|v| v.to_str().ok())
35 .map(|v| v.to_lowercase().contains("upgrade"))
36 .unwrap_or(false)
37 || headers.contains_key(axum::http::header::UPGRADE)
38}
39
40pub async fn proxy_handler(
43 Extension(state): Extension<Arc<ProxyState>>,
44 req: Request<Body>,
45) -> Result<Response<Body>, StatusCode> {
46 let is_upgrade = is_upgrade_request(req.headers());
49
50 if is_upgrade {
51 let method_str = req.method().as_str();
52 let path = req.uri().path();
53
54 if state.config.enable_websocket {
55 tracing::debug!("Upgrade request detected for {} {}, establishing direct proxy tunnel", method_str, path);
56 return handle_upgrade_request(state, req).await;
57 } else {
58 tracing::warn!("Upgrade request detected for {} {} but WebSocket support is disabled", method_str, path);
59 return Err(StatusCode::NOT_IMPLEMENTED);
60 }
61 }
62
63 let method = req.method().clone();
65 let method_str = method.as_str();
66 let uri = req.uri().clone();
67 let path = uri.path();
68 let query = uri.query().unwrap_or("");
69 let headers = req.headers().clone();
70
71 if state.config.forward_get_only && method != axum::http::Method::GET {
73 tracing::warn!("Non-GET request {} {} rejected (forward_get_only is enabled)", method_str, path);
74 return Err(StatusCode::METHOD_NOT_ALLOWED);
75 }
76
77 let should_cache = should_cache_path(
79 method_str,
80 path,
81 &state.config.include_paths,
82 &state.config.exclude_paths,
83 );
84
85 let req_info = crate::RequestInfo {
87 method: method_str,
88 path,
89 query,
90 headers: &headers,
91 };
92 let cache_key = (state.config.cache_key_fn)(&req_info);
93
94 if state.config.cache_404_capacity > 0 {
96 if let Some(cached) = state.cache.get_404(&cache_key).await {
97 tracing::debug!("404 cache hit for: {} {}", method_str, cache_key);
98 return Ok(build_response_from_cache(cached));
99 }
100 }
101
102 if should_cache {
104 if let Some(cached) = state.cache.get(&cache_key).await {
105 tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
106 return Ok(build_response_from_cache(cached));
107 }
108 tracing::debug!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
109 } else {
110 tracing::debug!("{} {} not cacheable (filtered), proxying directly", method_str, path);
111 }
112
113 let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
115 Ok(bytes) => bytes,
116 Err(e) => {
117 tracing::error!("Failed to read request body: {}", e);
118 return Err(StatusCode::BAD_REQUEST);
119 }
120 };
121
122 let target_url = format!("{}{}", state.config.proxy_url, uri);
124 let client = reqwest::Client::new();
125
126 let response = match client
127 .request(method.clone(), &target_url)
128 .headers(convert_headers(&headers))
129 .body(body_bytes.to_vec())
130 .send()
131 .await
132 {
133 Ok(resp) => resp,
134 Err(e) => {
135 tracing::error!("Failed to fetch from backend: {}", e);
136 return Err(StatusCode::BAD_GATEWAY);
137 }
138 };
139
140 let status = response.status().as_u16();
142 let response_headers = response.headers().clone();
143 let body_bytes = match response.bytes().await {
144 Ok(bytes) => bytes.to_vec(),
145 Err(e) => {
146 tracing::error!("Failed to read response body: {}", e);
147 return Err(StatusCode::BAD_GATEWAY);
148 }
149 };
150
151 let cached_response = CachedResponse {
152 body: body_bytes.clone(),
153 headers: convert_headers_to_map(&response_headers),
154 status,
155 };
156
157 let mut is_404 = status == 404;
159 if !is_404 && state.config.use_404_meta {
160 if let Ok(body_str) = std::str::from_utf8(&body_bytes) {
162 let name_dbl = "name=\"phantom-404\"";
164 let name_sgl = "name='phantom-404'";
165 let content_dbl = "content=\"true\"";
166 let content_sgl = "content='true'";
167
168 if (body_str.contains(name_dbl) || body_str.contains(name_sgl))
169 && (body_str.contains(content_dbl) || body_str.contains(content_sgl))
170 {
171 is_404 = true;
172 }
173 }
174 }
175
176 if is_404 && state.config.cache_404_capacity > 0 {
177 state
179 .cache
180 .set_404(cache_key.clone(), cached_response.clone())
181 .await;
182 tracing::debug!("Cached 404 response for: {} {}", method_str, cache_key);
183 } else if should_cache {
184 state
185 .cache
186 .set(cache_key.clone(), cached_response.clone())
187 .await;
188 tracing::debug!("Cached response for: {} {}", method_str, cache_key);
189 }
190
191 Ok(build_response_from_cache(cached_response))
192}
193
194async fn handle_upgrade_request(
206 state: Arc<ProxyState>,
207 mut req: Request<Body>,
208) -> Result<Response<Body>, StatusCode> {
209 let target_url = format!("{}{}", state.config.proxy_url, req.uri());
210
211 let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
213 tracing::error!("Failed to parse backend URL: {}", e);
214 StatusCode::BAD_GATEWAY
215 })?;
216
217 let host = backend_uri.host().ok_or_else(|| {
218 tracing::error!("No host in backend URL");
219 StatusCode::BAD_GATEWAY
220 })?;
221
222 let port = backend_uri.port_u16().unwrap_or_else(|| {
223 if backend_uri.scheme_str() == Some("https") {
224 443
225 } else {
226 80
227 }
228 });
229
230 let client_upgrade = hyper::upgrade::on(&mut req);
233
234 let backend_stream = tokio::net::TcpStream::connect((host, port))
236 .await
237 .map_err(|e| {
238 tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
239 StatusCode::BAD_GATEWAY
240 })?;
241
242 let backend_io = TokioIo::new(backend_stream);
243
244 let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
246 .await
247 .map_err(|e| {
248 tracing::error!("Failed to handshake with backend: {}", e);
249 StatusCode::BAD_GATEWAY
250 })?;
251
252 let conn_task = tokio::spawn(async move {
254 match conn.with_upgrades().await {
255 Ok(parts) => {
256 tracing::debug!("Backend connection upgraded successfully");
257 Ok(parts)
258 }
259 Err(e) => {
260 tracing::error!("Backend connection failed: {}", e);
261 Err(e)
262 }
263 }
264 });
265
266 let backend_response = sender.send_request(req).await.map_err(|e| {
268 tracing::error!("Failed to send request to backend: {}", e);
269 StatusCode::BAD_GATEWAY
270 })?;
271
272 let status = backend_response.status();
274 if status != StatusCode::SWITCHING_PROTOCOLS {
275 tracing::warn!("Backend did not accept upgrade request, status: {}", status);
276 let (parts, body) = backend_response.into_parts();
278 let body = Body::new(body);
279 return Ok(Response::from_parts(parts, body));
280 }
281
282 let backend_headers = backend_response.headers().clone();
284
285 let backend_upgrade = hyper::upgrade::on(backend_response);
287
288 tokio::spawn(async move {
290 tracing::debug!("Starting upgrade tunnel establishment");
291
292 let (client_result, backend_result) = tokio::join!(
294 client_upgrade,
295 backend_upgrade
296 );
297
298 drop(conn_task);
300
301 match (client_result, backend_result) {
302 (Ok(client_upgraded), Ok(backend_upgraded)) => {
303 tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
304
305 let mut client_stream = TokioIo::new(client_upgraded);
307 let mut backend_stream = TokioIo::new(backend_upgraded);
308
309 match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
311 Ok((client_to_backend, backend_to_client)) => {
312 tracing::debug!(
313 "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
314 client_to_backend,
315 backend_to_client
316 );
317 }
318 Err(e) => {
319 tracing::error!("Tunnel error: {}", e);
320 }
321 }
322 }
323 (Err(e), _) => {
324 tracing::error!("Client upgrade failed: {}", e);
325 }
326 (_, Err(e)) => {
327 tracing::error!("Backend upgrade failed: {}", e);
328 }
329 }
330 });
331
332 let mut response = Response::builder()
334 .status(StatusCode::SWITCHING_PROTOCOLS)
335 .body(Body::empty())
336 .unwrap();
337
338 if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
341 response.headers_mut().insert(
342 axum::http::header::UPGRADE,
343 upgrade_header.clone(),
344 );
345 }
346 if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
347 response.headers_mut().insert(
348 axum::http::header::CONNECTION,
349 connection_header.clone(),
350 );
351 }
352 if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
353 response.headers_mut().insert(
354 HeaderName::from_static("sec-websocket-accept"),
355 sec_websocket_accept.clone(),
356 );
357 }
358
359 tracing::debug!("Upgrade response sent to client, tunnel task spawned");
360
361 Ok(response)
362}
363
364fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
365 let mut response = Response::builder().status(cached.status);
366
367 let headers = response.headers_mut().unwrap();
369 for (key, value) in cached.headers {
370 if let Ok(header_name) = key.parse::<HeaderName>() {
371 if let Ok(header_value) = HeaderValue::from_str(&value) {
372 headers.insert(header_name, header_value);
373 } else {
374 tracing::warn!("Failed to parse header value for key '{}': {:?}", key, value);
375 }
376 } else {
377 tracing::warn!("Failed to parse header name: {}", key);
378 }
379 }
380
381 response.body(Body::from(cached.body)).unwrap()
382}
383
384fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
385 let mut req_headers = reqwest::header::HeaderMap::new();
386 for (key, value) in headers {
387 if key == axum::http::header::HOST {
389 continue;
390 }
391 if let Ok(val) = value.to_str() {
392 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
393 req_headers.insert(key.clone(), header_value);
394 }
395 }
396 }
397 req_headers
398}
399
400fn convert_headers_to_map(
401 headers: &reqwest::header::HeaderMap,
402) -> std::collections::HashMap<String, String> {
403 let mut map = std::collections::HashMap::new();
404 for (key, value) in headers {
405 if let Ok(val) = value.to_str() {
406 map.insert(key.to_string(), val.to_string());
407 } else {
408 tracing::debug!("Could not convert header '{}' to string", key);
410 }
411 }
412 map
413}