phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use axum::{
3    body::Body,
4    extract::State,
5    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
6};
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ProxyState {
11    cache: CacheStore,
12    proxy_url: String,
13}
14
15impl ProxyState {
16    pub fn new(cache: CacheStore, proxy_url: String) -> Self {
17        Self { cache, proxy_url }
18    }
19}
20
21/// Main proxy handler that serves prerendered content from cache
22/// or fetches from backend if not cached
23pub async fn proxy_handler(
24    State(state): State<Arc<ProxyState>>,
25    req: Request<Body>,
26) -> Result<Response<Body>, StatusCode> {
27    let path = req.uri().path();
28    let query = req.uri().query().unwrap_or("");
29    let cache_key = format!("{}?{}", path, query);
30
31    // Try to get from cache first
32    if let Some(cached) = state.cache.get(&cache_key).await {
33        tracing::info!("Cache hit for: {}", cache_key);
34        return Ok(build_response_from_cache(cached));
35    }
36
37    tracing::info!("Cache miss for: {}, fetching from backend", cache_key);
38
39    // Fetch from backend (proxy_url)
40    let target_url = format!("{}{}", state.proxy_url, req.uri());
41    let client = reqwest::Client::new();
42
43    let method = req.method().clone();
44    let headers = req.headers().clone();
45
46    let response = match client
47        .request(method, &target_url)
48        .headers(convert_headers(&headers))
49        .send()
50        .await
51    {
52        Ok(resp) => resp,
53        Err(e) => {
54            tracing::error!("Failed to fetch from backend: {}", e);
55            return Err(StatusCode::BAD_GATEWAY);
56        }
57    };
58
59    // Cache the response
60    let status = response.status().as_u16();
61    let response_headers = response.headers().clone();
62    let body_bytes = match response.bytes().await {
63        Ok(bytes) => bytes.to_vec(),
64        Err(e) => {
65            tracing::error!("Failed to read response body: {}", e);
66            return Err(StatusCode::BAD_GATEWAY);
67        }
68    };
69
70    let cached_response = CachedResponse {
71        body: body_bytes.clone(),
72        headers: convert_headers_to_map(&response_headers),
73        status,
74    };
75
76    state
77        .cache
78        .set(cache_key.clone(), cached_response.clone())
79        .await;
80    tracing::info!("Cached response for: {}", cache_key);
81
82    Ok(build_response_from_cache(cached_response))
83}
84
85fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
86    let mut response = Response::builder().status(cached.status);
87
88    // Add headers
89    let headers = response.headers_mut().unwrap();
90    for (key, value) in cached.headers {
91        if let Ok(header_name) = key.parse::<HeaderName>() {
92            if let Ok(header_value) = HeaderValue::from_str(&value) {
93                headers.insert(header_name, header_value);
94            }
95        }
96    }
97
98    response.body(Body::from(cached.body)).unwrap()
99}
100
101fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
102    let mut req_headers = reqwest::header::HeaderMap::new();
103    for (key, value) in headers {
104        if let Ok(val) = value.to_str() {
105            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
106                req_headers.insert(key.clone(), header_value);
107            }
108        }
109    }
110    req_headers
111}
112
113fn convert_headers_to_map(
114    headers: &reqwest::header::HeaderMap,
115) -> std::collections::HashMap<String, String> {
116    let mut map = std::collections::HashMap::new();
117    for (key, value) in headers {
118        if let Ok(val) = value.to_str() {
119            map.insert(key.to_string(), val.to_string());
120        }
121    }
122    map
123}