1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5 Router,
6 extract::State,
7 http::{
8 HeaderValue, Method, Request, StatusCode,
9 header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10 },
11 middleware::{self, Next},
12 response::{IntoResponse, Response},
13};
14use tower_http::cors::CorsLayer;
15use tower_http::trace::TraceLayer;
16use relay_core_runtime::CoreState;
17use tracing::info;
18
19use crate::routes;
20
21#[derive(Debug, Clone)]
23pub struct HttpApiConfig {
24 pub addr: SocketAddr,
26 pub bearer_token: Option<String>,
27 pub allowed_origins: Vec<HeaderValue>,
28}
29
30impl HttpApiConfig {
31 pub fn new(port: u16) -> Self {
32 Self {
33 addr: SocketAddr::from(([127, 0, 0, 1], port)),
34 bearer_token: None,
35 allowed_origins: Vec::new(),
36 }
37 }
38
39 pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
40 self.bearer_token = Some(token.into());
41 self
42 }
43
44 pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
45 self.allowed_origins = origins.into_iter().collect();
46 self
47 }
48}
49
50pub struct HttpApiServer {
52 config: HttpApiConfig,
53 state: Arc<CoreState>,
54}
55
56impl HttpApiServer {
57 pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
58 Self { config, state }
59 }
60
61 pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
63 let app = build_router(self.state, Arc::new(self.config.clone()));
64 let listener = tokio::net::TcpListener::bind(self.config.addr).await?;
65 info!("relay-core HTTP API listening on {}", self.config.addr);
66 axum::serve(listener, app).await?;
67 Ok(())
68 }
69}
70
71fn build_router(state: Arc<CoreState>, config: Arc<HttpApiConfig>) -> Router {
72 let router = Router::new()
73 .merge(routes::version::router())
74 .merge(routes::metrics::router(state.clone()))
75 .merge(routes::flows::router(state.clone()))
76 .merge(routes::rules::router(state.clone()))
77 .merge(routes::intercepts::router(state.clone()))
78 .merge(routes::events::router(state))
79 .route_layer(middleware::from_fn_with_state(config.clone(), require_bearer_token))
80 .layer(TraceLayer::new_for_http());
81
82 if config.allowed_origins.is_empty() {
83 router
84 } else {
85 router.layer(
86 CorsLayer::new()
87 .allow_origin(config.allowed_origins.clone())
88 .allow_methods([
89 Method::GET,
90 Method::POST,
91 Method::PUT,
92 Method::DELETE,
93 Method::OPTIONS,
94 ])
95 .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
96 )
97 }
98}
99
100async fn require_bearer_token(
101 State(config): State<Arc<HttpApiConfig>>,
102 request: Request<axum::body::Body>,
103 next: Next,
104) -> Response {
105 if request.method() == Method::OPTIONS {
106 return next.run(request).await;
107 }
108
109 let Some(expected_token) = config.bearer_token.as_deref() else {
110 return next.run(request).await;
111 };
112
113 let expected_value = format!("Bearer {}", expected_token);
114 let is_authorized = request
115 .headers()
116 .get(AUTHORIZATION)
117 .and_then(|value| value.to_str().ok())
118 .map(|value| value == expected_value)
119 .unwrap_or(false);
120
121 if is_authorized {
122 return next.run(request).await;
123 }
124
125 (
126 StatusCode::UNAUTHORIZED,
127 [
128 (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
129 (CONTENT_TYPE, HeaderValue::from_static("application/json")),
130 ],
131 serde_json::json!({
132 "error": "missing_or_invalid_bearer_token"
133 })
134 .to_string(),
135 )
136 .into_response()
137}
138
139#[cfg(test)]
140mod tests {
141 use super::{HttpApiConfig, build_router};
142 use axum::{
143 body::{Body, to_bytes},
144 http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
145 };
146 use relay_core_api::flow::Flow;
147 use relay_core_api::policy::ProxyPolicy;
148 use relay_core_runtime::{CoreState, audit::AuditActor};
149 use std::sync::Arc;
150 use tokio::time::{Duration, sleep};
151 use tower::ServiceExt;
152 use serde_json::json;
153
154 fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
155 let flow_id = format!(
156 "00000000-0000-0000-0000-{:012}",
157 (ts as u64) % 1_000_000_000_000
158 );
159 let minute = ((ts / 60_000) % 60) as i64;
160 let second = ((ts / 1_000) % 60) as i64;
161 let millis = (ts % 1_000).abs();
162 let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
163 serde_json::from_value(json!({
164 "id": flow_id,
165 "start_time": start_rfc3339,
166 "end_time": start_rfc3339,
167 "network": {
168 "client_ip": "127.0.0.1",
169 "client_port": 12000,
170 "server_ip": "127.0.0.1",
171 "server_port": 8080,
172 "protocol": "TCP",
173 "tls": false,
174 "tls_version": null,
175 "sni": null
176 },
177 "layer": {
178 "type": "Http",
179 "data": {
180 "request": {
181 "method": method,
182 "url": format!("http://{}{}", host, path),
183 "version": "HTTP/1.1",
184 "headers": [],
185 "cookies": [],
186 "query": [],
187 "body": null
188 },
189 "response": {
190 "status": status,
191 "status_text": "OK",
192 "version": "HTTP/1.1",
193 "headers": [],
194 "cookies": [],
195 "body": null,
196 "timing": {
197 "time_to_first_byte": null,
198 "time_to_last_byte": null
199 }
200 },
201 "error": null
202 }
203 },
204 "tags": []
205 }))
206 .expect("flow json should deserialize")
207 }
208
209 #[tokio::test]
210 async fn status_endpoint_is_available_without_auth_by_default() {
211 let state = Arc::new(CoreState::new(None).await);
212 let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
213
214 let response = app
215 .oneshot(
216 Request::builder()
217 .uri("/api/v1/status")
218 .method(Method::GET)
219 .body(Body::empty())
220 .expect("request should build"),
221 )
222 .await
223 .expect("request should succeed");
224
225 assert_eq!(response.status(), StatusCode::OK);
226 let body = to_bytes(response.into_body(), usize::MAX)
227 .await
228 .expect("body should be readable");
229 let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
230 assert_eq!(json["phase"], "created");
231 assert_eq!(json["running"], false);
232 assert!(json.get("started_at_ms").is_none());
233 }
234
235 #[tokio::test]
236 async fn intercepts_endpoint_uses_shared_snapshot_shape() {
237 let state = Arc::new(CoreState::new(None).await);
238 let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
239
240 let response = app
241 .oneshot(
242 Request::builder()
243 .uri("/api/v1/intercepts")
244 .method(Method::GET)
245 .body(Body::empty())
246 .expect("request should build"),
247 )
248 .await
249 .expect("request should succeed");
250
251 assert_eq!(response.status(), StatusCode::OK);
252 let body = to_bytes(response.into_body(), usize::MAX)
253 .await
254 .expect("body should be readable");
255 let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
256 assert_eq!(json["pending_count"], 0);
257 assert_eq!(json["ws_pending_count"], 0);
258 }
259
260 #[tokio::test]
261 async fn audit_endpoint_uses_shared_snapshot_shape() {
262 let state = Arc::new(CoreState::new(None).await);
263 state.update_policy_from(
264 AuditActor::Http,
265 "policy".to_string(),
266 ProxyPolicy {
267 transparent_enabled: true,
268 ..Default::default()
269 },
270 );
271 let _ = state
272 .resolve_intercept_with_modifications_from(
273 AuditActor::Probe,
274 "missing-flow:request".to_string(),
275 "drop",
276 None,
277 )
278 .await;
279 let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
280
281 let response = app
282 .oneshot(
283 Request::builder()
284 .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
285 .method(Method::GET)
286 .body(Body::empty())
287 .expect("request should build"),
288 )
289 .await
290 .expect("request should succeed");
291
292 assert_eq!(response.status(), StatusCode::OK);
293 let body = to_bytes(response.into_body(), usize::MAX)
294 .await
295 .expect("body should be readable");
296 let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
297 assert!(json["events"].is_array());
298 assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
299 assert_eq!(json["events"][0]["actor"], "http");
300 assert_eq!(json["events"][0]["kind"], "policy_updated");
301 assert_eq!(json["events"][0]["outcome"], "success");
302 }
303
304 #[tokio::test]
305 async fn prometheus_metrics_endpoint_returns_text_format() {
306 let state = Arc::new(CoreState::new(None).await);
307 let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
308
309 let response = app
310 .oneshot(
311 Request::builder()
312 .uri("/api/v1/metrics/prometheus")
313 .method(Method::GET)
314 .body(Body::empty())
315 .expect("request should build"),
316 )
317 .await
318 .expect("request should succeed");
319
320 assert_eq!(response.status(), StatusCode::OK);
321 let content_type = response
322 .headers()
323 .get(axum::http::header::CONTENT_TYPE)
324 .and_then(|v| v.to_str().ok())
325 .unwrap_or_default();
326 assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
327 let body = to_bytes(response.into_body(), usize::MAX)
328 .await
329 .expect("body should be readable");
330 let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
331 assert!(text.contains("relay_core_flows_total "));
332 assert!(text.contains("relay_core_audit_events_total "));
333 }
334
335 #[tokio::test]
336 async fn flows_endpoint_returns_pagination_metadata() {
337 let state = Arc::new(CoreState::new(None).await);
338 let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
339 let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
340 let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
341 let flow_b_id = flow_b.id.to_string();
342 state.upsert_flow(Box::new(flow_a));
343 state.upsert_flow(Box::new(flow_b));
344 state.upsert_flow(Box::new(flow_c));
345 sleep(Duration::from_millis(30)).await;
346
347 let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
348 let response = app
349 .oneshot(
350 Request::builder()
351 .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
352 .method(Method::GET)
353 .body(Body::empty())
354 .expect("request should build"),
355 )
356 .await
357 .expect("request should succeed");
358
359 assert_eq!(response.status(), StatusCode::OK);
360 let body = to_bytes(response.into_body(), usize::MAX)
361 .await
362 .expect("body should be readable");
363 let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
364 assert_eq!(json["returned"], 1);
365 assert_eq!(json["limit"], 1);
366 assert_eq!(json["offset"], 1);
367 assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
368 assert_eq!(json["items"][0]["id"], flow_b_id);
369 }
370
371 #[tokio::test]
372 async fn status_endpoint_requires_bearer_token_when_configured() {
373 let state = Arc::new(CoreState::new(None).await);
374 let app = build_router(
375 state,
376 Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
377 );
378
379 let unauthorized = app
380 .clone()
381 .oneshot(
382 Request::builder()
383 .uri("/api/v1/status")
384 .method(Method::GET)
385 .body(Body::empty())
386 .expect("request should build"),
387 )
388 .await
389 .expect("request should succeed");
390 assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
391
392 let authorized = app
393 .oneshot(
394 Request::builder()
395 .uri("/api/v1/status")
396 .method(Method::GET)
397 .header("Authorization", "Bearer secret-token")
398 .body(Body::empty())
399 .expect("request should build"),
400 )
401 .await
402 .expect("request should succeed");
403 assert_eq!(authorized.status(), StatusCode::OK);
404 }
405
406 #[tokio::test]
407 async fn cors_is_not_open_by_default() {
408 let state = Arc::new(CoreState::new(None).await);
409 let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
410
411 let response = app
412 .oneshot(
413 Request::builder()
414 .uri("/api/v1/status")
415 .method(Method::GET)
416 .header("Origin", "https://example.com")
417 .body(Body::empty())
418 .expect("request should build"),
419 )
420 .await
421 .expect("request should succeed");
422
423 assert!(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
424 }
425
426 #[tokio::test]
427 async fn cors_allows_explicit_origin_only() {
428 let state = Arc::new(CoreState::new(None).await);
429 let app = build_router(
430 state,
431 Arc::new(
432 HttpApiConfig::new(8082).with_allowed_origins([HeaderValue::from_static(
433 "https://allowed.example",
434 )]),
435 ),
436 );
437
438 let allowed = app
439 .clone()
440 .oneshot(
441 Request::builder()
442 .uri("/api/v1/status")
443 .method(Method::GET)
444 .header("Origin", "https://allowed.example")
445 .body(Body::empty())
446 .expect("request should build"),
447 )
448 .await
449 .expect("request should succeed");
450 assert_eq!(
451 allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
452 Some(&HeaderValue::from_static("https://allowed.example"))
453 );
454
455 let denied = app
456 .oneshot(
457 Request::builder()
458 .uri("/api/v1/status")
459 .method(Method::GET)
460 .header("Origin", "https://denied.example")
461 .body(Body::empty())
462 .expect("request should build"),
463 )
464 .await
465 .expect("request should succeed");
466 assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
467 }
468}