Skip to main content

relay_core_http/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Router,
6    extract::State,
7    http::{
8        HeaderValue, Method, Request, StatusCode,
9        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10    },
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13};
14use relay_core_runtime::CoreState;
15#[cfg(feature = "script")]
16use relay_core_runtime::services::ScriptService;
17use relay_core_runtime::services::{
18    AuditService, FlowEventHub, FlowReadService, InterceptService, RuleService,
19    RuntimeStatusService,
20};
21use tower_http::cors::CorsLayer;
22use tower_http::trace::TraceLayer;
23use tracing::info;
24
25use crate::routes;
26
27/// Configuration for the HTTP API server.
28#[derive(Debug, Clone)]
29pub struct HttpApiConfig {
30    /// Address to bind (e.g. `127.0.0.1:8082`)
31    pub addr: SocketAddr,
32    pub bearer_token: Option<String>,
33    pub allowed_origins: Vec<HeaderValue>,
34}
35
36impl HttpApiConfig {
37    pub fn new(port: u16) -> Self {
38        Self {
39            addr: SocketAddr::from(([127, 0, 0, 1], port)),
40            bearer_token: None,
41            allowed_origins: Vec::new(),
42        }
43    }
44
45    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
46        self.bearer_token = Some(token.into());
47        self
48    }
49
50    pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
51        self.allowed_origins = origins.into_iter().collect();
52        self
53    }
54}
55
56/// Shared context for HTTP API handlers, exposing only narrow-capability traits.
57/// Constructed from a `CoreState` instance; handlers never import `CoreState` directly.
58pub struct HttpApiContext {
59    pub flows: Arc<dyn FlowReadService>,
60    pub events: Arc<dyn FlowEventHub>,
61    pub rules: Arc<dyn RuleService>,
62    pub intercepts: Arc<dyn InterceptService>,
63    pub audit: Arc<dyn AuditService>,
64    pub status: Arc<dyn RuntimeStatusService>,
65    #[cfg(feature = "script")]
66    pub scripts: Arc<dyn ScriptService>,
67}
68
69impl HttpApiContext {
70    pub fn new(core: Arc<CoreState>) -> Self {
71        Self {
72            flows: core.clone(),
73            events: core.clone(),
74            rules: core.clone(),
75            intercepts: core.clone(),
76            audit: core.clone(),
77            status: core.clone(),
78            #[cfg(feature = "script")]
79            scripts: core.clone(),
80        }
81    }
82}
83
84/// HTTP API server handle.
85pub struct HttpApiServer {
86    config: HttpApiConfig,
87    state: Arc<CoreState>,
88}
89
90impl HttpApiServer {
91    pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
92        Self { config, state }
93    }
94
95    /// Start the server; resolves when the server exits or an error occurs.
96    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
97        let ctx = Arc::new(HttpApiContext::new(self.state));
98        let config = Arc::new(self.config.clone());
99        let listener = tokio::net::TcpListener::bind(config.addr).await?;
100        info!("relay-core HTTP API listening on {}", config.addr);
101
102        if !config.addr.ip().is_loopback() && config.bearer_token.is_none() {
103            tracing::warn!(
104                "HTTP API is listening on {} without a bearer token configured. \
105                 This exposes the API to the network without authentication. \
106                 Set a bearer token for production use, or use --bind 127.0.0.1 for local-only access.",
107                config.addr
108            );
109        }
110
111        let app = build_router(ctx, config);
112        axum::serve(listener, app).await?;
113        Ok(())
114    }
115}
116
117fn build_router(ctx: Arc<HttpApiContext>, config: Arc<HttpApiConfig>) -> Router {
118    let router = Router::new()
119        .merge(routes::version::router())
120        .merge(routes::metrics::router(ctx.clone()))
121        .merge(routes::flows::router(ctx.clone()))
122        .merge(routes::rules::router(ctx.clone()))
123        .merge(routes::intercepts::router(ctx.clone()))
124        .merge(routes::events::router(ctx.clone()));
125
126    #[cfg(feature = "script")]
127    let router = router.merge(routes::scripts::router(ctx.clone()));
128
129    let router = router
130        .route_layer(middleware::from_fn_with_state(
131            config.clone(),
132            require_bearer_token,
133        ))
134        .layer(TraceLayer::new_for_http());
135
136    if config.allowed_origins.is_empty() {
137        router
138    } else {
139        router.layer(
140            CorsLayer::new()
141                .allow_origin(config.allowed_origins.clone())
142                .allow_methods([
143                    Method::GET,
144                    Method::POST,
145                    Method::PUT,
146                    Method::DELETE,
147                    Method::OPTIONS,
148                ])
149                .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
150        )
151    }
152}
153
154async fn require_bearer_token(
155    State(config): State<Arc<HttpApiConfig>>,
156    request: Request<axum::body::Body>,
157    next: Next,
158) -> Response {
159    if request.method() == Method::OPTIONS {
160        return next.run(request).await;
161    }
162
163    let Some(expected_token) = config.bearer_token.as_deref() else {
164        return next.run(request).await;
165    };
166
167    let expected_value = format!("Bearer {}", expected_token);
168    let is_authorized = request
169        .headers()
170        .get(AUTHORIZATION)
171        .and_then(|value| value.to_str().ok())
172        .map(|value| value == expected_value)
173        .unwrap_or(false);
174
175    if is_authorized {
176        return next.run(request).await;
177    }
178
179    (
180        StatusCode::UNAUTHORIZED,
181        [
182            (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
183            (CONTENT_TYPE, HeaderValue::from_static("application/json")),
184        ],
185        serde_json::json!({
186            "error": "missing_or_invalid_bearer_token"
187        })
188        .to_string(),
189    )
190        .into_response()
191}
192
193#[cfg(test)]
194mod tests {
195    use super::{HttpApiConfig, HttpApiContext, build_router};
196    use axum::{
197        body::{Body, to_bytes},
198        http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
199    };
200    use relay_core_api::flow::Flow;
201    use relay_core_api::policy::ProxyPolicy;
202    use relay_core_runtime::{CoreState, audit::AuditActor};
203    use serde_json::json;
204    use std::sync::Arc;
205    use tokio::time::{Duration, sleep};
206    use tower::ServiceExt;
207
208    fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
209        let flow_id = format!(
210            "00000000-0000-0000-0000-{:012}",
211            (ts as u64) % 1_000_000_000_000
212        );
213        let minute = ((ts / 60_000) % 60) as i64;
214        let second = ((ts / 1_000) % 60) as i64;
215        let millis = (ts % 1_000).abs();
216        let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
217        serde_json::from_value(json!({
218            "id": flow_id,
219            "start_time": start_rfc3339,
220            "end_time": start_rfc3339,
221            "network": {
222                "client_ip": "127.0.0.1",
223                "client_port": 12000,
224                "server_ip": "127.0.0.1",
225                "server_port": 8080,
226                "protocol": "TCP",
227                "tls": false,
228                "tls_version": null,
229                "sni": null
230            },
231            "layer": {
232                "type": "Http",
233                "data": {
234                    "request": {
235                        "method": method,
236                        "url": format!("http://{}{}", host, path),
237                        "version": "HTTP/1.1",
238                        "headers": [],
239                        "cookies": [],
240                        "query": [],
241                        "body": null
242                    },
243                    "response": {
244                        "status": status,
245                        "status_text": "OK",
246                        "version": "HTTP/1.1",
247                        "headers": [],
248                        "cookies": [],
249                        "body": null,
250                        "timing": {
251                            "time_to_first_byte": null,
252                            "time_to_last_byte": null
253                        }
254                    },
255                    "error": null
256                }
257            },
258            "tags": []
259        }))
260        .expect("flow json should deserialize")
261    }
262
263    #[tokio::test]
264    async fn status_endpoint_is_available_without_auth_by_default() {
265        let state = Arc::new(CoreState::new(None).await);
266        let ctx = Arc::new(HttpApiContext::new(state.clone()));
267        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
268
269        let response = app
270            .oneshot(
271                Request::builder()
272                    .uri("/api/v1/status")
273                    .method(Method::GET)
274                    .body(Body::empty())
275                    .expect("request should build"),
276            )
277            .await
278            .expect("request should succeed");
279
280        assert_eq!(response.status(), StatusCode::OK);
281        let body = to_bytes(response.into_body(), usize::MAX)
282            .await
283            .expect("body should be readable");
284        let json: serde_json::Value =
285            serde_json::from_slice(&body).expect("body should be valid json");
286        assert_eq!(json["phase"], "created");
287        assert_eq!(json["running"], false);
288        assert!(json.get("started_at_ms").is_none());
289    }
290
291    #[tokio::test]
292    async fn intercepts_endpoint_uses_shared_snapshot_shape() {
293        let state = Arc::new(CoreState::new(None).await);
294        let ctx = Arc::new(HttpApiContext::new(state.clone()));
295        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
296
297        let response = app
298            .oneshot(
299                Request::builder()
300                    .uri("/api/v1/intercepts")
301                    .method(Method::GET)
302                    .body(Body::empty())
303                    .expect("request should build"),
304            )
305            .await
306            .expect("request should succeed");
307
308        assert_eq!(response.status(), StatusCode::OK);
309        let body = to_bytes(response.into_body(), usize::MAX)
310            .await
311            .expect("body should be readable");
312        let json: serde_json::Value =
313            serde_json::from_slice(&body).expect("body should be valid json");
314        assert_eq!(json["pending_count"], 0);
315        assert_eq!(json["ws_pending_count"], 0);
316    }
317
318    #[tokio::test]
319    async fn audit_endpoint_uses_shared_snapshot_shape() {
320        let state = Arc::new(CoreState::new(None).await);
321        state.update_policy_from(
322            AuditActor::Http,
323            "policy".to_string(),
324            ProxyPolicy {
325                transparent_enabled: true,
326                ..Default::default()
327            },
328        );
329        let _ = state
330            .resolve_intercept_with_modifications_from(
331                AuditActor::Probe,
332                "missing-flow:request".to_string(),
333                "drop",
334                None,
335            )
336            .await;
337        let ctx = Arc::new(HttpApiContext::new(state.clone()));
338        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
339
340        let response = app
341            .oneshot(
342                Request::builder()
343                    .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
344                    .method(Method::GET)
345                    .body(Body::empty())
346                    .expect("request should build"),
347            )
348            .await
349            .expect("request should succeed");
350
351        assert_eq!(response.status(), StatusCode::OK);
352        let body = to_bytes(response.into_body(), usize::MAX)
353            .await
354            .expect("body should be readable");
355        let json: serde_json::Value =
356            serde_json::from_slice(&body).expect("body should be valid json");
357        assert!(json["events"].is_array());
358        assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
359        assert_eq!(json["events"][0]["actor"], "http");
360        assert_eq!(json["events"][0]["kind"], "policy_updated");
361        assert_eq!(json["events"][0]["outcome"], "success");
362    }
363
364    #[tokio::test]
365    async fn prometheus_metrics_endpoint_returns_text_format() {
366        let state = Arc::new(CoreState::new(None).await);
367        let ctx = Arc::new(HttpApiContext::new(state.clone()));
368        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
369
370        let response = app
371            .oneshot(
372                Request::builder()
373                    .uri("/api/v1/metrics/prometheus")
374                    .method(Method::GET)
375                    .body(Body::empty())
376                    .expect("request should build"),
377            )
378            .await
379            .expect("request should succeed");
380
381        assert_eq!(response.status(), StatusCode::OK);
382        let content_type = response
383            .headers()
384            .get(axum::http::header::CONTENT_TYPE)
385            .and_then(|v| v.to_str().ok())
386            .unwrap_or_default();
387        assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
388        let body = to_bytes(response.into_body(), usize::MAX)
389            .await
390            .expect("body should be readable");
391        let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
392        assert!(text.contains("relay_core_flows_total "));
393        assert!(text.contains("relay_core_audit_events_total "));
394    }
395
396    #[tokio::test]
397    async fn flows_endpoint_returns_pagination_metadata() {
398        let state = Arc::new(CoreState::new(None).await);
399        let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
400        let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
401        let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
402        let flow_b_id = flow_b.id.to_string();
403        state.upsert_flow(Box::new(flow_a));
404        state.upsert_flow(Box::new(flow_b));
405        state.upsert_flow(Box::new(flow_c));
406        sleep(Duration::from_millis(30)).await;
407
408        let ctx = Arc::new(HttpApiContext::new(state.clone()));
409        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
410        let response = app
411            .oneshot(
412                Request::builder()
413                    .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
414                    .method(Method::GET)
415                    .body(Body::empty())
416                    .expect("request should build"),
417            )
418            .await
419            .expect("request should succeed");
420
421        assert_eq!(response.status(), StatusCode::OK);
422        let body = to_bytes(response.into_body(), usize::MAX)
423            .await
424            .expect("body should be readable");
425        let json: serde_json::Value =
426            serde_json::from_slice(&body).expect("body should be valid json");
427        assert_eq!(json["returned"], 1);
428        assert_eq!(json["limit"], 1);
429        assert_eq!(json["offset"], 1);
430        assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
431        assert_eq!(json["items"][0]["id"], flow_b_id);
432    }
433
434    #[tokio::test]
435    async fn status_endpoint_requires_bearer_token_when_configured() {
436        let state = Arc::new(CoreState::new(None).await);
437        let ctx = Arc::new(HttpApiContext::new(state.clone()));
438        let app = build_router(
439            ctx,
440            Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
441        );
442
443        let unauthorized = app
444            .clone()
445            .oneshot(
446                Request::builder()
447                    .uri("/api/v1/status")
448                    .method(Method::GET)
449                    .body(Body::empty())
450                    .expect("request should build"),
451            )
452            .await
453            .expect("request should succeed");
454        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
455
456        let authorized = app
457            .oneshot(
458                Request::builder()
459                    .uri("/api/v1/status")
460                    .method(Method::GET)
461                    .header("Authorization", "Bearer secret-token")
462                    .body(Body::empty())
463                    .expect("request should build"),
464            )
465            .await
466            .expect("request should succeed");
467        assert_eq!(authorized.status(), StatusCode::OK);
468    }
469
470    #[tokio::test]
471    async fn cors_is_not_open_by_default() {
472        let state = Arc::new(CoreState::new(None).await);
473        let ctx = Arc::new(HttpApiContext::new(state.clone()));
474        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
475
476        let response = app
477            .oneshot(
478                Request::builder()
479                    .uri("/api/v1/status")
480                    .method(Method::GET)
481                    .header("Origin", "https://example.com")
482                    .body(Body::empty())
483                    .expect("request should build"),
484            )
485            .await
486            .expect("request should succeed");
487
488        assert!(
489            response
490                .headers()
491                .get(ACCESS_CONTROL_ALLOW_ORIGIN)
492                .is_none()
493        );
494    }
495
496    #[tokio::test]
497    async fn cors_allows_explicit_origin_only() {
498        let state = Arc::new(CoreState::new(None).await);
499        let ctx = Arc::new(HttpApiContext::new(state.clone()));
500        let app = build_router(
501            ctx,
502            Arc::new(
503                HttpApiConfig::new(8082)
504                    .with_allowed_origins([HeaderValue::from_static("https://allowed.example")]),
505            ),
506        );
507
508        let allowed = app
509            .clone()
510            .oneshot(
511                Request::builder()
512                    .uri("/api/v1/status")
513                    .method(Method::GET)
514                    .header("Origin", "https://allowed.example")
515                    .body(Body::empty())
516                    .expect("request should build"),
517            )
518            .await
519            .expect("request should succeed");
520        assert_eq!(
521            allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
522            Some(&HeaderValue::from_static("https://allowed.example"))
523        );
524
525        let denied = app
526            .oneshot(
527                Request::builder()
528                    .uri("/api/v1/status")
529                    .method(Method::GET)
530                    .header("Origin", "https://denied.example")
531                    .body(Body::empty())
532                    .expect("request should build"),
533            )
534            .await
535            .expect("request should succeed");
536        assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
537    }
538}