1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5 body::Body,
6 extract::State,
7 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10
11#[derive(Clone)]
12pub struct ProxyState {
13 cache: CacheStore,
14 config: CreateProxyConfig,
15}
16
17impl ProxyState {
18 pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
19 Self { cache, config }
20 }
21}
22
23pub async fn proxy_handler(
26 State(state): State<Arc<ProxyState>>,
27 req: Request<Body>,
28) -> Result<Response<Body>, StatusCode> {
29 let method_str = req.method().as_str();
30 let path = req.uri().path();
31 let query = req.uri().query().unwrap_or("");
32 let headers = req.headers();
33
34 let should_cache = should_cache_path(
36 method_str,
37 path,
38 &state.config.include_paths,
39 &state.config.exclude_paths,
40 );
41
42 let req_info = crate::RequestInfo {
44 method: method_str,
45 path,
46 query,
47 headers,
48 };
49 let cache_key = (state.config.cache_key_fn)(&req_info);
50
51 if should_cache {
53 if let Some(cached) = state.cache.get(&cache_key).await {
54 tracing::info!("Cache hit for: {} {}", method_str, cache_key);
55 return Ok(build_response_from_cache(cached));
56 }
57 tracing::info!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
58 } else {
59 tracing::info!("{} {} not cacheable (filtered), proxying directly", method_str, path);
60 }
61
62 let target_url = format!("{}{}", state.config.proxy_url, req.uri());
64 let client = reqwest::Client::new();
65
66 let method = req.method().clone();
67 let headers = req.headers().clone();
68
69 let response = match client
70 .request(method, &target_url)
71 .headers(convert_headers(&headers))
72 .send()
73 .await
74 {
75 Ok(resp) => resp,
76 Err(e) => {
77 tracing::error!("Failed to fetch from backend: {}", e);
78 return Err(StatusCode::BAD_GATEWAY);
79 }
80 };
81
82 let status = response.status().as_u16();
84 let response_headers = response.headers().clone();
85 let body_bytes = match response.bytes().await {
86 Ok(bytes) => bytes.to_vec(),
87 Err(e) => {
88 tracing::error!("Failed to read response body: {}", e);
89 return Err(StatusCode::BAD_GATEWAY);
90 }
91 };
92
93 let cached_response = CachedResponse {
94 body: body_bytes.clone(),
95 headers: convert_headers_to_map(&response_headers),
96 status,
97 };
98
99 if should_cache {
100 state
101 .cache
102 .set(cache_key.clone(), cached_response.clone())
103 .await;
104 tracing::info!("Cached response for: {} {}", method_str, cache_key);
105 }
106
107 Ok(build_response_from_cache(cached_response))
108}
109
110fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
111 let mut response = Response::builder().status(cached.status);
112
113 let headers = response.headers_mut().unwrap();
115 for (key, value) in cached.headers {
116 if let Ok(header_name) = key.parse::<HeaderName>() {
117 if let Ok(header_value) = HeaderValue::from_str(&value) {
118 headers.insert(header_name, header_value);
119 }
120 }
121 }
122
123 response.body(Body::from(cached.body)).unwrap()
124}
125
126fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
127 let mut req_headers = reqwest::header::HeaderMap::new();
128 for (key, value) in headers {
129 if let Ok(val) = value.to_str() {
130 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
131 req_headers.insert(key.clone(), header_value);
132 }
133 }
134 }
135 req_headers
136}
137
138fn convert_headers_to_map(
139 headers: &reqwest::header::HeaderMap,
140) -> std::collections::HashMap<String, String> {
141 let mut map = std::collections::HashMap::new();
142 for (key, value) in headers {
143 if let Ok(val) = value.to_str() {
144 map.insert(key.to_string(), val.to_string());
145 }
146 }
147 map
148}