1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5 body::Body,
6 extract::Extension,
7 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10use hyper_util::rt::TokioIo;
11
12#[derive(Clone)]
13pub struct ProxyState {
14 cache: CacheStore,
15 config: CreateProxyConfig,
16}
17
18impl ProxyState {
19 pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
20 Self { cache, config }
21 }
22}
23
24fn is_upgrade_request(headers: &HeaderMap) -> bool {
32 headers
33 .get(axum::http::header::CONNECTION)
34 .and_then(|v| v.to_str().ok())
35 .map(|v| v.to_lowercase().contains("upgrade"))
36 .unwrap_or(false)
37 || headers.contains_key(axum::http::header::UPGRADE)
38}
39
40pub async fn proxy_handler(
43 Extension(state): Extension<Arc<ProxyState>>,
44 req: Request<Body>,
45) -> Result<Response<Body>, StatusCode> {
46 let is_upgrade = is_upgrade_request(req.headers());
49
50 if is_upgrade {
51 let method_str = req.method().as_str();
52 let path = req.uri().path();
53
54 if state.config.enable_websocket {
55 tracing::debug!("Upgrade request detected for {} {}, establishing direct proxy tunnel", method_str, path);
56 return handle_upgrade_request(state, req).await;
57 } else {
58 tracing::warn!("Upgrade request detected for {} {} but WebSocket support is disabled", method_str, path);
59 return Err(StatusCode::NOT_IMPLEMENTED);
60 }
61 }
62
63 let method = req.method().clone();
65 let method_str = method.as_str();
66 let uri = req.uri().clone();
67 let path = uri.path();
68 let query = uri.query().unwrap_or("");
69 let headers = req.headers().clone();
70
71 if state.config.forward_get_only && method != axum::http::Method::GET {
73 tracing::warn!("Non-GET request {} {} rejected (forward_get_only is enabled)", method_str, path);
74 return Err(StatusCode::METHOD_NOT_ALLOWED);
75 }
76
77 let should_cache = should_cache_path(
79 method_str,
80 path,
81 &state.config.include_paths,
82 &state.config.exclude_paths,
83 );
84
85 let req_info = crate::RequestInfo {
87 method: method_str,
88 path,
89 query,
90 headers: &headers,
91 };
92 let cache_key = (state.config.cache_key_fn)(&req_info);
93
94 if should_cache {
96 if let Some(cached) = state.cache.get(&cache_key).await {
97 tracing::debug!("Cache hit for: {} {}", method_str, cache_key);
98 return Ok(build_response_from_cache(cached));
99 }
100 tracing::debug!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
101 } else {
102 tracing::debug!("{} {} not cacheable (filtered), proxying directly", method_str, path);
103 }
104
105 let body_bytes = match axum::body::to_bytes(req.into_body(), usize::MAX).await {
107 Ok(bytes) => bytes,
108 Err(e) => {
109 tracing::error!("Failed to read request body: {}", e);
110 return Err(StatusCode::BAD_REQUEST);
111 }
112 };
113
114 let target_url = format!("{}{}", state.config.proxy_url, uri);
116 let client = reqwest::Client::new();
117
118 let response = match client
119 .request(method.clone(), &target_url)
120 .headers(convert_headers(&headers))
121 .body(body_bytes.to_vec())
122 .send()
123 .await
124 {
125 Ok(resp) => resp,
126 Err(e) => {
127 tracing::error!("Failed to fetch from backend: {}", e);
128 return Err(StatusCode::BAD_GATEWAY);
129 }
130 };
131
132 let status = response.status().as_u16();
134 let response_headers = response.headers().clone();
135 let body_bytes = match response.bytes().await {
136 Ok(bytes) => bytes.to_vec(),
137 Err(e) => {
138 tracing::error!("Failed to read response body: {}", e);
139 return Err(StatusCode::BAD_GATEWAY);
140 }
141 };
142
143 let cached_response = CachedResponse {
144 body: body_bytes.clone(),
145 headers: convert_headers_to_map(&response_headers),
146 status,
147 };
148
149 if should_cache {
150 state
151 .cache
152 .set(cache_key.clone(), cached_response.clone())
153 .await;
154 tracing::debug!("Cached response for: {} {}", method_str, cache_key);
155 }
156
157 Ok(build_response_from_cache(cached_response))
158}
159
160async fn handle_upgrade_request(
172 state: Arc<ProxyState>,
173 mut req: Request<Body>,
174) -> Result<Response<Body>, StatusCode> {
175 let target_url = format!("{}{}", state.config.proxy_url, req.uri());
176
177 let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
179 tracing::error!("Failed to parse backend URL: {}", e);
180 StatusCode::BAD_GATEWAY
181 })?;
182
183 let host = backend_uri.host().ok_or_else(|| {
184 tracing::error!("No host in backend URL");
185 StatusCode::BAD_GATEWAY
186 })?;
187
188 let port = backend_uri.port_u16().unwrap_or_else(|| {
189 if backend_uri.scheme_str() == Some("https") {
190 443
191 } else {
192 80
193 }
194 });
195
196 let client_upgrade = hyper::upgrade::on(&mut req);
199
200 let backend_stream = tokio::net::TcpStream::connect((host, port))
202 .await
203 .map_err(|e| {
204 tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
205 StatusCode::BAD_GATEWAY
206 })?;
207
208 let backend_io = TokioIo::new(backend_stream);
209
210 let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_io)
212 .await
213 .map_err(|e| {
214 tracing::error!("Failed to handshake with backend: {}", e);
215 StatusCode::BAD_GATEWAY
216 })?;
217
218 let conn_task = tokio::spawn(async move {
220 match conn.with_upgrades().await {
221 Ok(parts) => {
222 tracing::debug!("Backend connection upgraded successfully");
223 Ok(parts)
224 }
225 Err(e) => {
226 tracing::error!("Backend connection failed: {}", e);
227 Err(e)
228 }
229 }
230 });
231
232 let backend_response = sender.send_request(req).await.map_err(|e| {
234 tracing::error!("Failed to send request to backend: {}", e);
235 StatusCode::BAD_GATEWAY
236 })?;
237
238 let status = backend_response.status();
240 if status != StatusCode::SWITCHING_PROTOCOLS {
241 tracing::warn!("Backend did not accept upgrade request, status: {}", status);
242 let (parts, body) = backend_response.into_parts();
244 let body = Body::new(body);
245 return Ok(Response::from_parts(parts, body));
246 }
247
248 let backend_headers = backend_response.headers().clone();
250
251 let backend_upgrade = hyper::upgrade::on(backend_response);
253
254 tokio::spawn(async move {
256 tracing::debug!("Starting upgrade tunnel establishment");
257
258 let (client_result, backend_result) = tokio::join!(
260 client_upgrade,
261 backend_upgrade
262 );
263
264 drop(conn_task);
266
267 match (client_result, backend_result) {
268 (Ok(client_upgraded), Ok(backend_upgraded)) => {
269 tracing::debug!("Both upgrades successful, establishing bidirectional tunnel");
270
271 let mut client_stream = TokioIo::new(client_upgraded);
273 let mut backend_stream = TokioIo::new(backend_upgraded);
274
275 match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
277 Ok((client_to_backend, backend_to_client)) => {
278 tracing::debug!(
279 "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
280 client_to_backend,
281 backend_to_client
282 );
283 }
284 Err(e) => {
285 tracing::error!("Tunnel error: {}", e);
286 }
287 }
288 }
289 (Err(e), _) => {
290 tracing::error!("Client upgrade failed: {}", e);
291 }
292 (_, Err(e)) => {
293 tracing::error!("Backend upgrade failed: {}", e);
294 }
295 }
296 });
297
298 let mut response = Response::builder()
300 .status(StatusCode::SWITCHING_PROTOCOLS)
301 .body(Body::empty())
302 .unwrap();
303
304 if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
307 response.headers_mut().insert(
308 axum::http::header::UPGRADE,
309 upgrade_header.clone(),
310 );
311 }
312 if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
313 response.headers_mut().insert(
314 axum::http::header::CONNECTION,
315 connection_header.clone(),
316 );
317 }
318 if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
319 response.headers_mut().insert(
320 HeaderName::from_static("sec-websocket-accept"),
321 sec_websocket_accept.clone(),
322 );
323 }
324
325 tracing::debug!("Upgrade response sent to client, tunnel task spawned");
326
327 Ok(response)
328}
329
330fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
331 let mut response = Response::builder().status(cached.status);
332
333 let headers = response.headers_mut().unwrap();
335 for (key, value) in cached.headers {
336 if let Ok(header_name) = key.parse::<HeaderName>() {
337 if let Ok(header_value) = HeaderValue::from_str(&value) {
338 headers.insert(header_name, header_value);
339 } else {
340 tracing::warn!("Failed to parse header value for key '{}': {:?}", key, value);
341 }
342 } else {
343 tracing::warn!("Failed to parse header name: {}", key);
344 }
345 }
346
347 response.body(Body::from(cached.body)).unwrap()
348}
349
350fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
351 let mut req_headers = reqwest::header::HeaderMap::new();
352 for (key, value) in headers {
353 if key == axum::http::header::HOST {
355 continue;
356 }
357 if let Ok(val) = value.to_str() {
358 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
359 req_headers.insert(key.clone(), header_value);
360 }
361 }
362 }
363 req_headers
364}
365
366fn convert_headers_to_map(
367 headers: &reqwest::header::HeaderMap,
368) -> std::collections::HashMap<String, String> {
369 let mut map = std::collections::HashMap::new();
370 for (key, value) in headers {
371 if let Ok(val) = value.to_str() {
372 map.insert(key.to_string(), val.to_string());
373 } else {
374 tracing::debug!("Could not convert header '{}' to string", key);
376 }
377 }
378 map
379}