Skip to main content

cedros_admin/
router.rs

1use std::sync::Arc;
2
3use axum::{
4    body::{to_bytes, Body},
5    extract::{Path, Query, State},
6    http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
7    response::{IntoResponse, Response},
8    routing::{any, get},
9    Json, Router,
10};
11use reqwest::Client;
12
13use crate::Config;
14
15#[derive(Clone)]
16pub struct ProductRouters {
17    pub login: Option<Router>,
18    pub pay: Option<Router>,
19}
20
21impl ProductRouters {
22    pub fn empty() -> Self {
23        Self {
24            login: None,
25            pay: None,
26        }
27    }
28}
29
30impl Default for ProductRouters {
31    fn default() -> Self {
32        Self::empty()
33    }
34}
35
36#[derive(Clone)]
37struct ProxyState {
38    client: Client,
39    base_url: String,
40}
41
42pub fn router(config: &Config) -> Router {
43    router_with_products(config, ProductRouters::default())
44}
45
46pub fn router_with_products(config: &Config, products: ProductRouters) -> Router {
47    let mut app = Router::new().route("/health", get(health));
48
49    if let Some(login_router) = products.login {
50        app = app.nest("/login", login_router);
51    } else if let Some(upstream) = config.login_upstream_url.clone() {
52        app = app.nest("/login", proxy_router(upstream));
53    }
54
55    if let Some(pay_router) = products.pay {
56        app = app.nest("/pay", pay_router);
57    } else if let Some(upstream) = config.pay_upstream_url.clone() {
58        app = app.nest("/pay", proxy_router(upstream));
59    }
60
61    app
62}
63
64fn proxy_router(base_url: String) -> Router {
65    let state = ProxyState {
66        client: Client::new(),
67        base_url,
68    };
69
70    Router::new()
71        .route("/", any(proxy_root))
72        .route("/{*path}", any(proxy_path))
73        .with_state(Arc::new(state))
74}
75
76async fn health() -> Json<serde_json::Value> {
77    Json(serde_json::json!({
78        "status": "ok",
79        "service": "cedros-admin-server"
80    }))
81}
82
83async fn proxy_root(
84    State(state): State<Arc<ProxyState>>,
85    method: Method,
86    headers: HeaderMap,
87    Query(query): Query<std::collections::HashMap<String, String>>,
88    body: Body,
89) -> Response {
90    proxy_request(state, method, headers, query, None, body).await
91}
92
93async fn proxy_path(
94    State(state): State<Arc<ProxyState>>,
95    method: Method,
96    headers: HeaderMap,
97    Path(path): Path<String>,
98    Query(query): Query<std::collections::HashMap<String, String>>,
99    body: Body,
100) -> Response {
101    proxy_request(state, method, headers, query, Some(path), body).await
102}
103
104async fn proxy_request(
105    state: Arc<ProxyState>,
106    method: Method,
107    headers: HeaderMap,
108    query: std::collections::HashMap<String, String>,
109    path: Option<String>,
110    body: Body,
111) -> Response {
112    let target_url = build_target_url(&state.base_url, path.as_deref(), &query);
113    let body_bytes = match to_bytes(body, usize::MAX).await {
114        Ok(bytes) => bytes,
115        Err(error) => {
116            return (
117                StatusCode::BAD_REQUEST,
118                format!("failed to read request body: {error}"),
119            )
120                .into_response()
121        }
122    };
123
124    let reqwest_method = match reqwest::Method::from_bytes(method.as_str().as_bytes()) {
125        Ok(method) => method,
126        Err(error) => {
127            return (
128                StatusCode::BAD_REQUEST,
129                format!("unsupported method: {error}"),
130            )
131                .into_response()
132        }
133    };
134
135    let mut request = state.client.request(reqwest_method, target_url).body(body_bytes.to_vec());
136    for (name, value) in forwardable_headers(&headers) {
137        request = request.header(name, value);
138    }
139
140    match request.send().await {
141        Ok(response) => proxy_response(response).await,
142        Err(error) => (
143            StatusCode::BAD_GATEWAY,
144            format!("upstream request failed: {error}"),
145        )
146            .into_response(),
147    }
148}
149
150fn build_target_url(
151    base_url: &str,
152    path: Option<&str>,
153    query: &std::collections::HashMap<String, String>,
154) -> String {
155    let trimmed_base = base_url.trim_end_matches('/');
156    let path = path.unwrap_or("").trim_start_matches('/');
157    let mut url = if path.is_empty() {
158        trimmed_base.to_string()
159    } else {
160        format!("{trimmed_base}/{path}")
161    };
162
163    if !query.is_empty() {
164        let mut pairs: Vec<_> = query.iter().collect();
165        pairs.sort_by(|(left, _), (right, _)| left.cmp(right));
166        let query_string = pairs
167            .into_iter()
168            .map(|(key, value)| format!("{}={}", urlencoding::encode(key), urlencoding::encode(value)))
169            .collect::<Vec<_>>()
170            .join("&");
171        url.push('?');
172        url.push_str(&query_string);
173    }
174
175    url
176}
177
178fn forwardable_headers(headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
179    const HOP_BY_HOP: &[&str] = &[
180        "connection",
181        "keep-alive",
182        "proxy-authenticate",
183        "proxy-authorization",
184        "te",
185        "trailer",
186        "transfer-encoding",
187        "upgrade",
188        "host",
189        "content-length",
190    ];
191
192    headers
193        .iter()
194        .filter_map(|(name, value)| {
195            if HOP_BY_HOP.iter().any(|blocked| name.as_str().eq_ignore_ascii_case(blocked)) {
196                return None;
197            }
198
199            Some((name.clone(), value.clone()))
200        })
201        .collect()
202}
203
204async fn proxy_response(response: reqwest::Response) -> Response {
205    let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
206    let mut builder = Response::builder().status(status);
207
208    for (name, value) in response.headers().iter() {
209        if name.as_str().eq_ignore_ascii_case("content-length")
210            || name.as_str().eq_ignore_ascii_case("transfer-encoding")
211        {
212            continue;
213        }
214
215        builder = builder.header(name, value);
216    }
217
218    match response.bytes().await {
219        Ok(bytes) => builder
220            .body(Body::from(bytes))
221            .unwrap_or_else(|_| (StatusCode::BAD_GATEWAY, "failed to build proxy response").into_response()),
222        Err(error) => (
223            StatusCode::BAD_GATEWAY,
224            format!("failed to read upstream response: {error}"),
225        )
226            .into_response(),
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use axum::{routing::get, Router};
233    use reqwest::StatusCode;
234    use tokio::net::TcpListener;
235
236    use super::{router, router_with_products, Config, ProductRouters};
237
238    #[tokio::test]
239    async fn health_endpoint_is_available() {
240        let config = Config::default();
241        let app = router(&config);
242        let address = spawn(app).await;
243
244        let response = reqwest::get(format!("http://{address}/health")).await.unwrap();
245
246        assert_eq!(response.status(), StatusCode::OK);
247        assert_eq!(response.json::<serde_json::Value>().await.unwrap()["status"], "ok");
248    }
249
250    #[tokio::test]
251    async fn embedded_product_router_is_mounted_under_login_prefix() {
252        let config = Config::default();
253        let products = ProductRouters {
254            login: Some(Router::new().route("/admin/ping", get(|| async { "pong" }))),
255            pay: None,
256        };
257        let app = router_with_products(&config, products);
258        let address = spawn(app).await;
259
260        let response = reqwest::get(format!("http://{address}/login/admin/ping"))
261            .await
262            .unwrap();
263
264        assert_eq!(response.status(), StatusCode::OK);
265        assert_eq!(response.text().await.unwrap(), "pong");
266    }
267
268    #[tokio::test]
269    async fn standalone_proxy_forwards_requests_to_login_upstream() {
270        let upstream = Router::new().route("/admin/ping", get(|| async { "upstream-pong" }));
271        let upstream_address = spawn(upstream).await;
272
273        let config = Config {
274            login_upstream_url: Some(format!("http://{upstream_address}")),
275            ..Config::default()
276        };
277        let app = router(&config);
278        let address = spawn(app).await;
279
280        let response = reqwest::get(format!("http://{address}/login/admin/ping"))
281            .await
282            .unwrap();
283
284        assert_eq!(response.status(), StatusCode::OK);
285        assert_eq!(response.text().await.unwrap(), "upstream-pong");
286    }
287
288    async fn spawn(app: Router) -> std::net::SocketAddr {
289        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
290        let address = listener.local_addr().unwrap();
291        tokio::spawn(async move {
292            axum::serve(listener, app).await.unwrap();
293        });
294        address
295    }
296}