phantom_frame/
proxy.rs

1use crate::cache::{CacheStore, CachedResponse};
2use crate::path_matcher::should_cache_path;
3use crate::CreateProxyConfig;
4use axum::{
5    body::Body,
6    extract::State,
7    http::{HeaderMap, HeaderName, HeaderValue, Request, Response, StatusCode},
8};
9use std::sync::Arc;
10
11#[derive(Clone)]
12pub struct ProxyState {
13    cache: CacheStore,
14    config: CreateProxyConfig,
15}
16
17impl ProxyState {
18    pub fn new(cache: CacheStore, config: CreateProxyConfig) -> Self {
19        Self { cache, config }
20    }
21}
22
23/// Main proxy handler that serves prerendered content from cache
24/// or fetches from backend if not cached
25pub async fn proxy_handler(
26    State(state): State<Arc<ProxyState>>,
27    req: Request<Body>,
28) -> Result<Response<Body>, StatusCode> {
29    let method_str = req.method().as_str();
30    let path = req.uri().path();
31    let query = req.uri().query().unwrap_or("");
32    let headers = req.headers();
33    
34    // Check if this path should be cached based on include/exclude patterns
35    let should_cache = should_cache_path(
36        method_str,
37        path,
38        &state.config.include_paths,
39        &state.config.exclude_paths,
40    );
41    
42    // Generate cache key using the configured function
43    let req_info = crate::RequestInfo {
44        method: method_str,
45        path,
46        query,
47        headers,
48    };
49    let cache_key = (state.config.cache_key_fn)(&req_info);
50
51    // Try to get from cache first (only if caching is enabled for this path)
52    if should_cache {
53        if let Some(cached) = state.cache.get(&cache_key).await {
54            tracing::info!("Cache hit for: {} {}", method_str, cache_key);
55            return Ok(build_response_from_cache(cached));
56        }
57        tracing::info!("Cache miss for: {} {}, fetching from backend", method_str, cache_key);
58    } else {
59        tracing::info!("{} {} not cacheable (filtered), proxying directly", method_str, path);
60    }
61
62    // Fetch from backend (proxy_url)
63    let target_url = format!("{}{}", state.config.proxy_url, req.uri());
64    let client = reqwest::Client::new();
65
66    let method = req.method().clone();
67    let headers = req.headers().clone();
68
69    let response = match client
70        .request(method, &target_url)
71        .headers(convert_headers(&headers))
72        .send()
73        .await
74    {
75        Ok(resp) => resp,
76        Err(e) => {
77            tracing::error!("Failed to fetch from backend: {}", e);
78            return Err(StatusCode::BAD_GATEWAY);
79        }
80    };
81
82    // Cache the response (only if caching is enabled for this path)
83    let status = response.status().as_u16();
84    let response_headers = response.headers().clone();
85    let body_bytes = match response.bytes().await {
86        Ok(bytes) => bytes.to_vec(),
87        Err(e) => {
88            tracing::error!("Failed to read response body: {}", e);
89            return Err(StatusCode::BAD_GATEWAY);
90        }
91    };
92
93    let cached_response = CachedResponse {
94        body: body_bytes.clone(),
95        headers: convert_headers_to_map(&response_headers),
96        status,
97    };
98
99    if should_cache {
100        state
101            .cache
102            .set(cache_key.clone(), cached_response.clone())
103            .await;
104        tracing::info!("Cached response for: {} {}", method_str, cache_key);
105    }
106
107    Ok(build_response_from_cache(cached_response))
108}
109
110fn build_response_from_cache(cached: CachedResponse) -> Response<Body> {
111    let mut response = Response::builder().status(cached.status);
112
113    // Add headers
114    let headers = response.headers_mut().unwrap();
115    for (key, value) in cached.headers {
116        if let Ok(header_name) = key.parse::<HeaderName>() {
117            if let Ok(header_value) = HeaderValue::from_str(&value) {
118                headers.insert(header_name, header_value);
119            }
120        }
121    }
122
123    response.body(Body::from(cached.body)).unwrap()
124}
125
126fn convert_headers(headers: &HeaderMap) -> reqwest::header::HeaderMap {
127    let mut req_headers = reqwest::header::HeaderMap::new();
128    for (key, value) in headers {
129        if let Ok(val) = value.to_str() {
130            if let Ok(header_value) = reqwest::header::HeaderValue::from_str(val) {
131                req_headers.insert(key.clone(), header_value);
132            }
133        }
134    }
135    req_headers
136}
137
138fn convert_headers_to_map(
139    headers: &reqwest::header::HeaderMap,
140) -> std::collections::HashMap<String, String> {
141    let mut map = std::collections::HashMap::new();
142    for (key, value) in headers {
143        if let Ok(val) = value.to_str() {
144            map.insert(key.to_string(), val.to_string());
145        }
146    }
147    map
148}