Skip to main content

relay_core_http/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Router,
6    extract::State,
7    http::{
8        HeaderValue, Method, Request, StatusCode,
9        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10    },
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13};
14use tower_http::cors::CorsLayer;
15use tower_http::trace::TraceLayer;
16use relay_core_runtime::CoreState;
17use tracing::info;
18
19use crate::routes;
20
21/// Configuration for the HTTP API server.
22#[derive(Debug, Clone)]
23pub struct HttpApiConfig {
24    /// Address to bind (e.g. `127.0.0.1:8082`)
25    pub addr: SocketAddr,
26    pub bearer_token: Option<String>,
27    pub allowed_origins: Vec<HeaderValue>,
28}
29
30impl HttpApiConfig {
31    pub fn new(port: u16) -> Self {
32        Self {
33            addr: SocketAddr::from(([127, 0, 0, 1], port)),
34            bearer_token: None,
35            allowed_origins: Vec::new(),
36        }
37    }
38
39    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
40        self.bearer_token = Some(token.into());
41        self
42    }
43
44    pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
45        self.allowed_origins = origins.into_iter().collect();
46        self
47    }
48}
49
50/// HTTP API server handle.
51pub struct HttpApiServer {
52    config: HttpApiConfig,
53    state: Arc<CoreState>,
54}
55
56impl HttpApiServer {
57    pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
58        Self { config, state }
59    }
60
61    /// Start the server; resolves when the server exits or an error occurs.
62    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
63        let app = build_router(self.state, Arc::new(self.config.clone()));
64        let listener = tokio::net::TcpListener::bind(self.config.addr).await?;
65        info!("relay-core HTTP API listening on {}", self.config.addr);
66        axum::serve(listener, app).await?;
67        Ok(())
68    }
69}
70
71fn build_router(state: Arc<CoreState>, config: Arc<HttpApiConfig>) -> Router {
72    let router = Router::new()
73        .merge(routes::version::router())
74        .merge(routes::metrics::router(state.clone()))
75        .merge(routes::flows::router(state.clone()))
76        .merge(routes::rules::router(state.clone()))
77        .merge(routes::intercepts::router(state.clone()))
78        .merge(routes::events::router(state))
79        .route_layer(middleware::from_fn_with_state(config.clone(), require_bearer_token))
80        .layer(TraceLayer::new_for_http());
81
82    if config.allowed_origins.is_empty() {
83        router
84    } else {
85        router.layer(
86            CorsLayer::new()
87                .allow_origin(config.allowed_origins.clone())
88                .allow_methods([
89                    Method::GET,
90                    Method::POST,
91                    Method::PUT,
92                    Method::DELETE,
93                    Method::OPTIONS,
94                ])
95                .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
96        )
97    }
98}
99
100async fn require_bearer_token(
101    State(config): State<Arc<HttpApiConfig>>,
102    request: Request<axum::body::Body>,
103    next: Next,
104) -> Response {
105    if request.method() == Method::OPTIONS {
106        return next.run(request).await;
107    }
108
109    let Some(expected_token) = config.bearer_token.as_deref() else {
110        return next.run(request).await;
111    };
112
113    let expected_value = format!("Bearer {}", expected_token);
114    let is_authorized = request
115        .headers()
116        .get(AUTHORIZATION)
117        .and_then(|value| value.to_str().ok())
118        .map(|value| value == expected_value)
119        .unwrap_or(false);
120
121    if is_authorized {
122        return next.run(request).await;
123    }
124
125    (
126        StatusCode::UNAUTHORIZED,
127        [
128            (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
129            (CONTENT_TYPE, HeaderValue::from_static("application/json")),
130        ],
131        serde_json::json!({
132            "error": "missing_or_invalid_bearer_token"
133        })
134        .to_string(),
135    )
136        .into_response()
137}
138
139#[cfg(test)]
140mod tests {
141    use super::{HttpApiConfig, build_router};
142    use axum::{
143        body::{Body, to_bytes},
144        http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
145    };
146    use relay_core_api::flow::Flow;
147    use relay_core_api::policy::ProxyPolicy;
148    use relay_core_runtime::{CoreState, audit::AuditActor};
149    use std::sync::Arc;
150    use tokio::time::{Duration, sleep};
151    use tower::ServiceExt;
152    use serde_json::json;
153
154    fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
155        let flow_id = format!(
156            "00000000-0000-0000-0000-{:012}",
157            (ts as u64) % 1_000_000_000_000
158        );
159        let minute = ((ts / 60_000) % 60) as i64;
160        let second = ((ts / 1_000) % 60) as i64;
161        let millis = (ts % 1_000).abs();
162        let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
163        serde_json::from_value(json!({
164            "id": flow_id,
165            "start_time": start_rfc3339,
166            "end_time": start_rfc3339,
167            "network": {
168                "client_ip": "127.0.0.1",
169                "client_port": 12000,
170                "server_ip": "127.0.0.1",
171                "server_port": 8080,
172                "protocol": "TCP",
173                "tls": false,
174                "tls_version": null,
175                "sni": null
176            },
177            "layer": {
178                "type": "Http",
179                "data": {
180                    "request": {
181                        "method": method,
182                        "url": format!("http://{}{}", host, path),
183                        "version": "HTTP/1.1",
184                        "headers": [],
185                        "cookies": [],
186                        "query": [],
187                        "body": null
188                    },
189                    "response": {
190                        "status": status,
191                        "status_text": "OK",
192                        "version": "HTTP/1.1",
193                        "headers": [],
194                        "cookies": [],
195                        "body": null,
196                        "timing": {
197                            "time_to_first_byte": null,
198                            "time_to_last_byte": null
199                        }
200                    },
201                    "error": null
202                }
203            },
204            "tags": []
205        }))
206        .expect("flow json should deserialize")
207    }
208
209    #[tokio::test]
210    async fn status_endpoint_is_available_without_auth_by_default() {
211        let state = Arc::new(CoreState::new(None).await);
212        let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
213
214        let response = app
215            .oneshot(
216                Request::builder()
217                    .uri("/api/v1/status")
218                    .method(Method::GET)
219                    .body(Body::empty())
220                    .expect("request should build"),
221            )
222            .await
223            .expect("request should succeed");
224
225        assert_eq!(response.status(), StatusCode::OK);
226        let body = to_bytes(response.into_body(), usize::MAX)
227            .await
228            .expect("body should be readable");
229        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
230        assert_eq!(json["phase"], "created");
231        assert_eq!(json["running"], false);
232        assert!(json.get("started_at_ms").is_none());
233    }
234
235    #[tokio::test]
236    async fn intercepts_endpoint_uses_shared_snapshot_shape() {
237        let state = Arc::new(CoreState::new(None).await);
238        let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
239
240        let response = app
241            .oneshot(
242                Request::builder()
243                    .uri("/api/v1/intercepts")
244                    .method(Method::GET)
245                    .body(Body::empty())
246                    .expect("request should build"),
247            )
248            .await
249            .expect("request should succeed");
250
251        assert_eq!(response.status(), StatusCode::OK);
252        let body = to_bytes(response.into_body(), usize::MAX)
253            .await
254            .expect("body should be readable");
255        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
256        assert_eq!(json["pending_count"], 0);
257        assert_eq!(json["ws_pending_count"], 0);
258    }
259
260    #[tokio::test]
261    async fn audit_endpoint_uses_shared_snapshot_shape() {
262        let state = Arc::new(CoreState::new(None).await);
263        state.update_policy_from(
264            AuditActor::Http,
265            "policy".to_string(),
266            ProxyPolicy {
267                transparent_enabled: true,
268                ..Default::default()
269            },
270        );
271        let _ = state
272            .resolve_intercept_with_modifications_from(
273                AuditActor::Probe,
274                "missing-flow:request".to_string(),
275                "drop",
276                None,
277            )
278            .await;
279        let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
280
281        let response = app
282            .oneshot(
283                Request::builder()
284                    .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
285                    .method(Method::GET)
286                    .body(Body::empty())
287                    .expect("request should build"),
288            )
289            .await
290            .expect("request should succeed");
291
292        assert_eq!(response.status(), StatusCode::OK);
293        let body = to_bytes(response.into_body(), usize::MAX)
294            .await
295            .expect("body should be readable");
296        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
297        assert!(json["events"].is_array());
298        assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
299        assert_eq!(json["events"][0]["actor"], "http");
300        assert_eq!(json["events"][0]["kind"], "policy_updated");
301        assert_eq!(json["events"][0]["outcome"], "success");
302    }
303
304    #[tokio::test]
305    async fn prometheus_metrics_endpoint_returns_text_format() {
306        let state = Arc::new(CoreState::new(None).await);
307        let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
308
309        let response = app
310            .oneshot(
311                Request::builder()
312                    .uri("/api/v1/metrics/prometheus")
313                    .method(Method::GET)
314                    .body(Body::empty())
315                    .expect("request should build"),
316            )
317            .await
318            .expect("request should succeed");
319
320        assert_eq!(response.status(), StatusCode::OK);
321        let content_type = response
322            .headers()
323            .get(axum::http::header::CONTENT_TYPE)
324            .and_then(|v| v.to_str().ok())
325            .unwrap_or_default();
326        assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
327        let body = to_bytes(response.into_body(), usize::MAX)
328            .await
329            .expect("body should be readable");
330        let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
331        assert!(text.contains("relay_core_flows_total "));
332        assert!(text.contains("relay_core_audit_events_total "));
333    }
334
335    #[tokio::test]
336    async fn flows_endpoint_returns_pagination_metadata() {
337        let state = Arc::new(CoreState::new(None).await);
338        let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
339        let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
340        let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
341        let flow_b_id = flow_b.id.to_string();
342        state.upsert_flow(Box::new(flow_a));
343        state.upsert_flow(Box::new(flow_b));
344        state.upsert_flow(Box::new(flow_c));
345        sleep(Duration::from_millis(30)).await;
346
347        let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
348        let response = app
349            .oneshot(
350                Request::builder()
351                    .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
352                    .method(Method::GET)
353                    .body(Body::empty())
354                    .expect("request should build"),
355            )
356            .await
357            .expect("request should succeed");
358
359        assert_eq!(response.status(), StatusCode::OK);
360        let body = to_bytes(response.into_body(), usize::MAX)
361            .await
362            .expect("body should be readable");
363        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
364        assert_eq!(json["returned"], 1);
365        assert_eq!(json["limit"], 1);
366        assert_eq!(json["offset"], 1);
367        assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
368        assert_eq!(json["items"][0]["id"], flow_b_id);
369    }
370
371    #[tokio::test]
372    async fn status_endpoint_requires_bearer_token_when_configured() {
373        let state = Arc::new(CoreState::new(None).await);
374        let app = build_router(
375            state,
376            Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
377        );
378
379        let unauthorized = app
380            .clone()
381            .oneshot(
382                Request::builder()
383                    .uri("/api/v1/status")
384                    .method(Method::GET)
385                    .body(Body::empty())
386                    .expect("request should build"),
387            )
388            .await
389            .expect("request should succeed");
390        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
391
392        let authorized = app
393            .oneshot(
394                Request::builder()
395                    .uri("/api/v1/status")
396                    .method(Method::GET)
397                    .header("Authorization", "Bearer secret-token")
398                    .body(Body::empty())
399                    .expect("request should build"),
400            )
401            .await
402            .expect("request should succeed");
403        assert_eq!(authorized.status(), StatusCode::OK);
404    }
405
406    #[tokio::test]
407    async fn cors_is_not_open_by_default() {
408        let state = Arc::new(CoreState::new(None).await);
409        let app = build_router(state, Arc::new(HttpApiConfig::new(8082)));
410
411        let response = app
412            .oneshot(
413                Request::builder()
414                    .uri("/api/v1/status")
415                    .method(Method::GET)
416                    .header("Origin", "https://example.com")
417                    .body(Body::empty())
418                    .expect("request should build"),
419            )
420            .await
421            .expect("request should succeed");
422
423        assert!(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
424    }
425
426    #[tokio::test]
427    async fn cors_allows_explicit_origin_only() {
428        let state = Arc::new(CoreState::new(None).await);
429        let app = build_router(
430            state,
431            Arc::new(
432                HttpApiConfig::new(8082).with_allowed_origins([HeaderValue::from_static(
433                    "https://allowed.example",
434                )]),
435            ),
436        );
437
438        let allowed = app
439            .clone()
440            .oneshot(
441                Request::builder()
442                    .uri("/api/v1/status")
443                    .method(Method::GET)
444                    .header("Origin", "https://allowed.example")
445                    .body(Body::empty())
446                    .expect("request should build"),
447            )
448            .await
449            .expect("request should succeed");
450        assert_eq!(
451            allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
452            Some(&HeaderValue::from_static("https://allowed.example"))
453        );
454
455        let denied = app
456            .oneshot(
457                Request::builder()
458                    .uri("/api/v1/status")
459                    .method(Method::GET)
460                    .header("Origin", "https://denied.example")
461                    .body(Body::empty())
462                    .expect("request should build"),
463            )
464            .await
465            .expect("request should succeed");
466        assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
467    }
468}