1use std::sync::Arc;
2
3use axum::{
4 body::{to_bytes, Body},
5 extract::{Path, Query, State},
6 http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
7 response::{IntoResponse, Response},
8 routing::{any, get},
9 Json, Router,
10};
11use reqwest::Client;
12
13use crate::Config;
14
15#[derive(Clone)]
16pub struct ProductRouters {
17 pub login: Option<Router>,
18 pub pay: Option<Router>,
19}
20
21impl ProductRouters {
22 pub fn empty() -> Self {
23 Self {
24 login: None,
25 pay: None,
26 }
27 }
28}
29
30impl Default for ProductRouters {
31 fn default() -> Self {
32 Self::empty()
33 }
34}
35
36#[derive(Clone)]
37struct ProxyState {
38 client: Client,
39 base_url: String,
40}
41
42pub fn router(config: &Config) -> Router {
43 router_with_products(config, ProductRouters::default())
44}
45
46pub fn router_with_products(config: &Config, products: ProductRouters) -> Router {
47 let mut app = Router::new().route("/health", get(health));
48
49 if let Some(login_router) = products.login {
50 app = app.nest("/login", login_router);
51 } else if let Some(upstream) = config.login_upstream_url.clone() {
52 app = app.nest("/login", proxy_router(upstream));
53 }
54
55 if let Some(pay_router) = products.pay {
56 app = app.nest("/pay", pay_router);
57 } else if let Some(upstream) = config.pay_upstream_url.clone() {
58 app = app.nest("/pay", proxy_router(upstream));
59 }
60
61 app
62}
63
64fn proxy_router(base_url: String) -> Router {
65 let state = ProxyState {
66 client: Client::new(),
67 base_url,
68 };
69
70 Router::new()
71 .route("/", any(proxy_root))
72 .route("/{*path}", any(proxy_path))
73 .with_state(Arc::new(state))
74}
75
76async fn health() -> Json<serde_json::Value> {
77 Json(serde_json::json!({
78 "status": "ok",
79 "service": "cedros-admin-server"
80 }))
81}
82
83async fn proxy_root(
84 State(state): State<Arc<ProxyState>>,
85 method: Method,
86 headers: HeaderMap,
87 Query(query): Query<std::collections::HashMap<String, String>>,
88 body: Body,
89) -> Response {
90 proxy_request(state, method, headers, query, None, body).await
91}
92
93async fn proxy_path(
94 State(state): State<Arc<ProxyState>>,
95 method: Method,
96 headers: HeaderMap,
97 Path(path): Path<String>,
98 Query(query): Query<std::collections::HashMap<String, String>>,
99 body: Body,
100) -> Response {
101 proxy_request(state, method, headers, query, Some(path), body).await
102}
103
104async fn proxy_request(
105 state: Arc<ProxyState>,
106 method: Method,
107 headers: HeaderMap,
108 query: std::collections::HashMap<String, String>,
109 path: Option<String>,
110 body: Body,
111) -> Response {
112 let target_url = build_target_url(&state.base_url, path.as_deref(), &query);
113 let body_bytes = match to_bytes(body, usize::MAX).await {
114 Ok(bytes) => bytes,
115 Err(error) => {
116 return (
117 StatusCode::BAD_REQUEST,
118 format!("failed to read request body: {error}"),
119 )
120 .into_response()
121 }
122 };
123
124 let reqwest_method = match reqwest::Method::from_bytes(method.as_str().as_bytes()) {
125 Ok(method) => method,
126 Err(error) => {
127 return (
128 StatusCode::BAD_REQUEST,
129 format!("unsupported method: {error}"),
130 )
131 .into_response()
132 }
133 };
134
135 let mut request = state.client.request(reqwest_method, target_url).body(body_bytes.to_vec());
136 for (name, value) in forwardable_headers(&headers) {
137 request = request.header(name, value);
138 }
139
140 match request.send().await {
141 Ok(response) => proxy_response(response).await,
142 Err(error) => (
143 StatusCode::BAD_GATEWAY,
144 format!("upstream request failed: {error}"),
145 )
146 .into_response(),
147 }
148}
149
150fn build_target_url(
151 base_url: &str,
152 path: Option<&str>,
153 query: &std::collections::HashMap<String, String>,
154) -> String {
155 let trimmed_base = base_url.trim_end_matches('/');
156 let path = path.unwrap_or("").trim_start_matches('/');
157 let mut url = if path.is_empty() {
158 trimmed_base.to_string()
159 } else {
160 format!("{trimmed_base}/{path}")
161 };
162
163 if !query.is_empty() {
164 let mut pairs: Vec<_> = query.iter().collect();
165 pairs.sort_by(|(left, _), (right, _)| left.cmp(right));
166 let query_string = pairs
167 .into_iter()
168 .map(|(key, value)| format!("{}={}", urlencoding::encode(key), urlencoding::encode(value)))
169 .collect::<Vec<_>>()
170 .join("&");
171 url.push('?');
172 url.push_str(&query_string);
173 }
174
175 url
176}
177
178fn forwardable_headers(headers: &HeaderMap) -> Vec<(HeaderName, HeaderValue)> {
179 const HOP_BY_HOP: &[&str] = &[
180 "connection",
181 "keep-alive",
182 "proxy-authenticate",
183 "proxy-authorization",
184 "te",
185 "trailer",
186 "transfer-encoding",
187 "upgrade",
188 "host",
189 "content-length",
190 ];
191
192 headers
193 .iter()
194 .filter_map(|(name, value)| {
195 if HOP_BY_HOP.iter().any(|blocked| name.as_str().eq_ignore_ascii_case(blocked)) {
196 return None;
197 }
198
199 Some((name.clone(), value.clone()))
200 })
201 .collect()
202}
203
204async fn proxy_response(response: reqwest::Response) -> Response {
205 let status = StatusCode::from_u16(response.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
206 let mut builder = Response::builder().status(status);
207
208 for (name, value) in response.headers().iter() {
209 if name.as_str().eq_ignore_ascii_case("content-length")
210 || name.as_str().eq_ignore_ascii_case("transfer-encoding")
211 {
212 continue;
213 }
214
215 builder = builder.header(name, value);
216 }
217
218 match response.bytes().await {
219 Ok(bytes) => builder
220 .body(Body::from(bytes))
221 .unwrap_or_else(|_| (StatusCode::BAD_GATEWAY, "failed to build proxy response").into_response()),
222 Err(error) => (
223 StatusCode::BAD_GATEWAY,
224 format!("failed to read upstream response: {error}"),
225 )
226 .into_response(),
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use axum::{routing::get, Router};
233 use reqwest::StatusCode;
234 use tokio::net::TcpListener;
235
236 use super::{router, router_with_products, Config, ProductRouters};
237
238 #[tokio::test]
239 async fn health_endpoint_is_available() {
240 let config = Config::default();
241 let app = router(&config);
242 let address = spawn(app).await;
243
244 let response = reqwest::get(format!("http://{address}/health")).await.unwrap();
245
246 assert_eq!(response.status(), StatusCode::OK);
247 assert_eq!(response.json::<serde_json::Value>().await.unwrap()["status"], "ok");
248 }
249
250 #[tokio::test]
251 async fn embedded_product_router_is_mounted_under_login_prefix() {
252 let config = Config::default();
253 let products = ProductRouters {
254 login: Some(Router::new().route("/admin/ping", get(|| async { "pong" }))),
255 pay: None,
256 };
257 let app = router_with_products(&config, products);
258 let address = spawn(app).await;
259
260 let response = reqwest::get(format!("http://{address}/login/admin/ping"))
261 .await
262 .unwrap();
263
264 assert_eq!(response.status(), StatusCode::OK);
265 assert_eq!(response.text().await.unwrap(), "pong");
266 }
267
268 #[tokio::test]
269 async fn standalone_proxy_forwards_requests_to_login_upstream() {
270 let upstream = Router::new().route("/admin/ping", get(|| async { "upstream-pong" }));
271 let upstream_address = spawn(upstream).await;
272
273 let config = Config {
274 login_upstream_url: Some(format!("http://{upstream_address}")),
275 ..Config::default()
276 };
277 let app = router(&config);
278 let address = spawn(app).await;
279
280 let response = reqwest::get(format!("http://{address}/login/admin/ping"))
281 .await
282 .unwrap();
283
284 assert_eq!(response.status(), StatusCode::OK);
285 assert_eq!(response.text().await.unwrap(), "upstream-pong");
286 }
287
288 async fn spawn(app: Router) -> std::net::SocketAddr {
289 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
290 let address = listener.local_addr().unwrap();
291 tokio::spawn(async move {
292 axum::serve(listener, app).await.unwrap();
293 });
294 address
295 }
296}