1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5 body::Body,
6 extract::Extension,
7 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10use hyper_util::rt::TokioIo;
11
12#[derive(Clone)]
13pub struct ProxyState {
14 cache: CacheStore,
15 config: CreateProxyConfig,
16}
17
18impl ProxyState {
19 pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
20 Self { cache, config }
21 }
22}
23
24fn is_upgrade_request(headers: &HeaderMap) -> bool {
32 headers
33 .get(axum::http::header::CONNECTION)
34 .and_then(|v| v.to_str().ok())
35 .map(|v| v.to_lowercase().contains("upgrade"))
36 .unwrap_or(false)
37 || headers.contains_key(axum::http::header::UPGRADE)
38}
39
40pub async fn proxy_handler(
43 Extension(state): Extension<Arc<ProxyState>>,
44 req: Request<Body>,
45) -> Result<Response<Body>, StatusCode> {
46 let method_str = req.method().as_str();
47 let path = req.uri().path();
48 let query = req.uri().query().unwrap_or("");
49 let headers = req.headers();
50
51 if state.config.enable_websocket && is_upgrade_request(headers) {
54 tracing::info!("Upgrade request detected for {} {}, establishing direct proxy tunnel", method_str, path);
55 return handle_upgrade_request(state, req).await;
56 } else if !state.config.enable_websocket && is_upgrade_request(headers) {
57 tracing::warn!("Upgrade request detected for {} {} but WebSocket support is disabled", method_str, path);
58 return Err(StatusCode::NOT_IMPLEMENTED);
59 }
60
61 let should_cache = should_cache_path(
63 method_str,
64 path,
65 &state.config.include_paths,
66 &state.config.exclude_paths,
67 );
68
69 let req_info = crate::RequestInfo {
71 method: method_str,
72 path,
73 query,
74 headers,
75 };
76 let cache_key = (state.config.cache_key_fn)(&req_info);
77
78 if should_cache {
80 if let Some(cached) = state.cache.get(&cache_key).await {
81 tracing::info!("Cache hit for: {} {}", method_str, cache_key);
82 return Ok(build_response_from_cache(cached));
83 }
84 tracing::info!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
85 } else {
86 tracing::info!("{} {} not cacheable (filtered), proxying directly", method_str, path);
87 }
88
89 let target_url = format!("{}{}", state.config.proxy_url, req.uri());
91 let client = reqwest::Client::new();
92
93 let method = req.method().clone();
94 let headers = req.headers().clone();
95
96 let response = match client
97 .request(method, &target_url)
98 .headers(convert_headers(&headers))
99 .send()
100 .await
101 {
102 Ok(resp) => resp,
103 Err(e) => {
104 tracing::error!("Failed to fetch from backend: {}", e);
105 return Err(StatusCode::BAD_GATEWAY);
106 }
107 };
108
109 let status = response.status().as_u16();
111 let response_headers = response.headers().clone();
112 let body_bytes = match response.bytes().await {
113 Ok(bytes) => bytes.to_vec(),
114 Err(e) => {
115 tracing::error!("Failed to read response body: {}", e);
116 return Err(StatusCode::BAD_GATEWAY);
117 }
118 };
119
120 let cached_response = CachedResponse {
121 body: body_bytes.clone(),
122 headers: convert_headers_to_map(&response_headers),
123 status,
124 };
125
126 if should_cache {
127 state
128 .cache
129 .set(cache_key.clone(), cached_response.clone())
130 .await;
131 tracing::info!("Cached response for: {} {}", method_str, cache_key);
132 }
133
134 Ok(build_response_from_cache(cached_response))
135}
136
137async fn handle_upgrade_request(
149 state: Arc<ProxyState>,
150 mut req: Request<Body>,
151) -> Result<Response<Body>, StatusCode> {
152 let target_url = format!("{}{}", state.config.proxy_url, req.uri());
153
154 let backend_uri = target_url.parse::<hyper::Uri>().map_err(|e| {
156 tracing::error!("Failed to parse backend URL: {}", e);
157 StatusCode::BAD_GATEWAY
158 })?;
159
160 let host = backend_uri.host().ok_or_else(|| {
161 tracing::error!("No host in backend URL");
162 StatusCode::BAD_GATEWAY
163 })?;
164
165 let port = backend_uri.port_u16().unwrap_or_else(|| {
166 if backend_uri.scheme_str() == Some("https") {
167 443
168 } else {
169 80
170 }
171 });
172
173 let client_upgrade = hyper::upgrade::on(&mut req);
176
177 let backend_stream = tokio::net::TcpStream::connect((host, port))
179 .await
180 .map_err(|e| {
181 tracing::error!("Failed to connect to backend {}:{}: {}", host, port, e);
182 StatusCode::BAD_GATEWAY
183 })?;
184
185 let backend_stream = TokioIo::new(backend_stream);
186
187 let (mut sender, conn) = hyper::client::conn::http1::handshake(backend_stream)
189 .await
190 .map_err(|e| {
191 tracing::error!("Failed to handshake with backend: {}", e);
192 StatusCode::BAD_GATEWAY
193 })?;
194
195 tokio::spawn(async move {
197 if let Err(e) = conn.await {
198 tracing::error!("Connection to backend failed: {}", e);
199 }
200 });
201
202 let backend_response = sender.send_request(req).await.map_err(|e| {
204 tracing::error!("Failed to send request to backend: {}", e);
205 StatusCode::BAD_GATEWAY
206 })?;
207
208 let status = backend_response.status();
210 if status != StatusCode::SWITCHING_PROTOCOLS {
211 tracing::warn!("Backend did not accept upgrade request, status: {}", status);
212 let (parts, body) = backend_response.into_parts();
214 let body = Body::new(body);
215 return Ok(Response::from_parts(parts, body));
216 }
217
218 let backend_headers = backend_response.headers().clone();
220
221 tokio::spawn(async move {
223 tracing::info!("Starting upgrade tunnel establishment");
224
225 let (client_result, backend_result) = tokio::join!(
227 client_upgrade,
228 hyper::upgrade::on(backend_response)
229 );
230
231 match (client_result, backend_result) {
232 (Ok(client_upgraded), Ok(backend_upgraded)) => {
233 tracing::info!("Both upgrades successful, establishing bidirectional tunnel");
234
235 let mut client_stream = TokioIo::new(client_upgraded);
237 let mut backend_stream = TokioIo::new(backend_upgraded);
238
239 match tokio::io::copy_bidirectional(&mut client_stream, &mut backend_stream).await {
241 Ok((client_to_backend, backend_to_client)) => {
242 tracing::info!(
243 "Tunnel closed gracefully. Transferred {} bytes client->backend, {} bytes backend->client",
244 client_to_backend,
245 backend_to_client
246 );
247 }
248 Err(e) => {
249 tracing::error!("Tunnel error: {}", e);
250 }
251 }
252 }
253 (Err(e), _) => {
254 tracing::error!("Client upgrade failed: {}", e);
255 }
256 (_, Err(e)) => {
257 tracing::error!("Backend upgrade failed: {}", e);
258 }
259 }
260 });
261
262 let mut response = Response::builder()
264 .status(StatusCode::SWITCHING_PROTOCOLS)
265 .body(Body::empty())
266 .unwrap();
267
268 if let Some(upgrade_header) = backend_headers.get(axum::http::header::UPGRADE) {
271 response.headers_mut().insert(
272 axum::http::header::UPGRADE,
273 upgrade_header.clone(),
274 );
275 }
276 if let Some(connection_header) = backend_headers.get(axum::http::header::CONNECTION) {
277 response.headers_mut().insert(
278 axum::http::header::CONNECTION,
279 connection_header.clone(),
280 );
281 }
282 if let Some(sec_websocket_accept) = backend_headers.get("sec-websocket-accept") {
283 response.headers_mut().insert(
284 HeaderName::from_static("sec-websocket-accept"),
285 sec_websocket_accept.clone(),
286 );
287 }
288
289 tracing::info!("Upgrade response sent to client, tunnel task spawned");
290
291 Ok(response)
292}
293
294fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
295 let mut response = Response::builder().status(cached.status);
296
297 let headers = response.headers_mut().unwrap();
299 for (key, value) in cached.headers {
300 if let Ok(header_name) = key.parse::<HeaderName>() {
301 if let Ok(header_value) = HeaderValue::from_str(&value) {
302 headers.insert(header_name, header_value);
303 }
304 }
305 }
306
307 response.body(Body::from(cached.body)).unwrap()
308}
309
310fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
311 let mut req_headers = reqwest::header::HeaderMap::new();
312 for (key, value) in headers {
313 if let Ok(val) = value.to_str() {
314 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
315 req_headers.insert(key.clone(), header_value);
316 }
317 }
318 }
319 req_headers
320}
321
322fn convert_headers_to_map(
323 headers: &reqwest::header::HeaderMap,
324) -> std::collections::HashMap<String, String> {
325 let mut map = std::collections::HashMap::new();
326 for (key, value) in headers {
327 if let Ok(val) = value.to_str() {
328 map.insert(key.to_string(), val.to_string());
329 }
330 }
331 map
332}