1use crate::cache::{CacheStore, CachedResponse};
2use axum::{
3 body::Body,
4 extract::State,
5 http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
6};
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ProxyState {
11 cache: CacheStore,
12 proxy_url: String,
13}
14
15impl ProxyState {
16 pub fn new(cache: CacheStore, proxy_url: String) -> Self {
17 Self { cache, proxy_url }
18 }
19}
20
21pub async fn proxy_handler(
24 State(state): State<Arc<ProxyState>>,
25 req: Request<Body>,
26) -> Result<Response<Body>, StatusCode> {
27 let path = req.uri().path();
28 let query = req.uri().query().unwrap_or("");
29 let cache_key = format!("{}?{}", path, query);
30
31 if let Some(cached) = state.cache.get(&cache_key).await {
33 tracing::info!("Cache hit for: {}", cache_key);
34 return Ok(build_response_from_cache(cached));
35 }
36
37 tracing::info!("Cache miss for: {}, fetching from backend", cache_key);
38
39 let target_url = format!("{}{}", state.proxy_url, req.uri());
41 let client = reqwest::Client::new();
42
43 let method = req.method().clone();
44 let headers = req.headers().clone();
45
46 let response = match client
47 .request(method, &target_url)
48 .headers(convert_headers(&headers))
49 .send()
50 .await
51 {
52 Ok(resp) => resp,
53 Err(e) => {
54 tracing::error!("Failed to fetch from backend: {}", e);
55 return Err(StatusCode::BAD_GATEWAY);
56 }
57 };
58
59 let status = response.status().as_u16();
61 let response_headers = response.headers().clone();
62 let body_bytes = match response.bytes().await {
63 Ok(bytes) => bytes.to_vec(),
64 Err(e) => {
65 tracing::error!("Failed to read response body: {}", e);
66 return Err(StatusCode::BAD_GATEWAY);
67 }
68 };
69
70 let cached_response = CachedResponse {
71 body: body_bytes.clone(),
72 headers: convert_headers_to_map(&response_headers),
73 status,
74 };
75
76 state
77 .cache
78 .set(cache_key.clone(), cached_response.clone())
79 .await;
80 tracing::info!("Cached response for: {}", cache_key);
81
82 Ok(build_response_from_cache(cached_response))
83}
84
85fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
86 let mut response = Response::builder().status(cached.status);
87
88 let headers = response.headers_mut().unwrap();
90 for (key, value) in cached.headers {
91 if let Ok(header_name) = key.parse::<HeaderName>() {
92 if let Ok(header_value) = HeaderValue::from_str(&value) {
93 headers.insert(header_name, header_value);
94 }
95 }
96 }
97
98 response.body(Body::from(cached.body)).unwrap()
99}
100
101fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
102 let mut req_headers = reqwest::header::HeaderMap::new();
103 for (key, value) in headers {
104 if let Ok(val) = value.to_str() {
105 if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
106 req_headers.insert(key.clone(), header_value);
107 }
108 }
109 }
110 req_headers
111}
112
113fn convert_headers_to_map(
114 headers: &reqwest::header::HeaderMap,
115) -> std::collections::HashMap<String, String> {
116 let mut map = std::collections::HashMap::new();
117 for (key, value) in headers {
118 if let Ok(val) = value.to_str() {
119 map.insert(key.to_string(), val.to_string());
120 }
121 }
122 map
123}