Skip to main content

relay_core_http/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Router,
6    extract::State,
7    http::{
8        HeaderValue, Method, Request, StatusCode,
9        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10    },
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13};
14use relay_core_runtime::CoreState;
15#[cfg(feature = "script")]
16use relay_core_runtime::services::ScriptService;
17use relay_core_runtime::services::{
18    AuditService, FlowEventHub, FlowReadService, InterceptService, PolicyService, RuleService,
19    RuntimeStatusService,
20};
21use tower_http::cors::CorsLayer;
22use tower_http::trace::TraceLayer;
23use tracing::info;
24
25use crate::routes;
26
27/// Configuration for the HTTP API server.
28#[derive(Debug, Clone)]
29pub struct HttpApiConfig {
30    /// Address to bind (e.g. `127.0.0.1:8082`)
31    pub addr: SocketAddr,
32    pub bearer_token: Option<String>,
33    pub allowed_origins: Vec<HeaderValue>,
34}
35
36impl HttpApiConfig {
37    pub fn new(port: u16) -> Self {
38        Self {
39            addr: SocketAddr::from(([127, 0, 0, 1], port)),
40            bearer_token: None,
41            allowed_origins: Vec::new(),
42        }
43    }
44
45    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
46        self.bearer_token = Some(token.into());
47        self
48    }
49
50    pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
51        self.allowed_origins = origins.into_iter().collect();
52        self
53    }
54}
55
56/// Shared context for HTTP API handlers, exposing only narrow-capability traits.
57/// Constructed from a `CoreState` instance; handlers never import `CoreState` directly.
58pub struct HttpApiContext {
59    pub flows: Arc<dyn FlowReadService>,
60    pub events: Arc<dyn FlowEventHub>,
61    pub rules: Arc<dyn RuleService>,
62    pub intercepts: Arc<dyn InterceptService>,
63    pub audit: Arc<dyn AuditService>,
64    pub policy: Arc<dyn PolicyService>,
65    pub status: Arc<dyn RuntimeStatusService>,
66    #[cfg(feature = "script")]
67    pub scripts: Arc<dyn ScriptService>,
68}
69
70impl HttpApiContext {
71    pub fn new(core: Arc<CoreState>) -> Self {
72        Self {
73            flows: core.clone(),
74            events: core.clone(),
75            rules: core.clone(),
76            intercepts: core.clone(),
77            audit: core.clone(),
78            policy: core.clone(),
79            status: core.clone(),
80            #[cfg(feature = "script")]
81            scripts: core.clone(),
82        }
83    }
84}
85
86/// HTTP API server handle.
87pub struct HttpApiServer {
88    config: HttpApiConfig,
89    state: Arc<CoreState>,
90}
91
92impl HttpApiServer {
93    pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
94        Self { config, state }
95    }
96
97    /// Start the server; resolves when the server exits or an error occurs.
98    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
99        let ctx = Arc::new(HttpApiContext::new(self.state));
100        let config = Arc::new(self.config.clone());
101        let listener = tokio::net::TcpListener::bind(config.addr).await?;
102        info!("relay-core HTTP API listening on {}", config.addr);
103
104        if !config.addr.ip().is_loopback() && config.bearer_token.is_none() {
105            tracing::warn!(
106                "HTTP API is listening on {} without a bearer token configured. \
107                 This exposes the API to the network without authentication. \
108                 Set a bearer token for production use, or use --bind 127.0.0.1 for local-only access.",
109                config.addr
110            );
111        }
112
113        let app = build_router(ctx, config);
114        axum::serve(listener, app).await?;
115        Ok(())
116    }
117}
118
119fn build_router(ctx: Arc<HttpApiContext>, config: Arc<HttpApiConfig>) -> Router {
120    let router = Router::new()
121        .merge(routes::version::router())
122        .merge(routes::metrics::router(ctx.clone()))
123        .merge(routes::flows::router(ctx.clone()))
124        .merge(routes::rules::router(ctx.clone()))
125        .merge(routes::intercepts::router(ctx.clone()))
126        .merge(routes::events::router(ctx.clone()))
127        .merge(routes::policy::router(ctx.clone()));
128
129    #[cfg(feature = "script")]
130    let router = router.merge(routes::scripts::router(ctx.clone()));
131
132    let router = router
133        .route_layer(middleware::from_fn_with_state(
134            config.clone(),
135            require_bearer_token,
136        ))
137        .layer(TraceLayer::new_for_http());
138
139    if config.allowed_origins.is_empty() {
140        router
141    } else {
142        router.layer(
143            CorsLayer::new()
144                .allow_origin(config.allowed_origins.clone())
145                .allow_methods([
146                    Method::GET,
147                    Method::POST,
148                    Method::PUT,
149                    Method::DELETE,
150                    Method::OPTIONS,
151                ])
152                .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
153        )
154    }
155}
156
157async fn require_bearer_token(
158    State(config): State<Arc<HttpApiConfig>>,
159    request: Request<axum::body::Body>,
160    next: Next,
161) -> Response {
162    if request.method() == Method::OPTIONS {
163        return next.run(request).await;
164    }
165
166    let Some(expected_token) = config.bearer_token.as_deref() else {
167        return next.run(request).await;
168    };
169
170    let expected_value = format!("Bearer {}", expected_token);
171    let is_authorized = request
172        .headers()
173        .get(AUTHORIZATION)
174        .and_then(|value| value.to_str().ok())
175        .map(|value| value == expected_value)
176        .unwrap_or(false);
177
178    if is_authorized {
179        return next.run(request).await;
180    }
181
182    (
183        StatusCode::UNAUTHORIZED,
184        [
185            (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
186            (CONTENT_TYPE, HeaderValue::from_static("application/json")),
187        ],
188        serde_json::json!({
189            "error": "missing_or_invalid_bearer_token"
190        })
191        .to_string(),
192    )
193        .into_response()
194}
195
196#[cfg(test)]
197mod tests {
198    use super::{HttpApiConfig, HttpApiContext, build_router};
199    use axum::{
200        body::{Body, to_bytes},
201        http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
202    };
203    use relay_core_api::flow::Flow;
204    use relay_core_api::policy::ProxyPolicy;
205    use relay_core_runtime::{CoreState, audit::AuditActor};
206    use serde_json::json;
207    use std::sync::Arc;
208    use tokio::time::{Duration, sleep};
209    use tower::ServiceExt;
210
211    fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
212        let flow_id = format!(
213            "00000000-0000-0000-0000-{:012}",
214            (ts as u64) % 1_000_000_000_000
215        );
216        let minute = ((ts / 60_000) % 60) as i64;
217        let second = ((ts / 1_000) % 60) as i64;
218        let millis = (ts % 1_000).abs();
219        let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
220        serde_json::from_value(json!({
221            "id": flow_id,
222            "start_time": start_rfc3339,
223            "end_time": start_rfc3339,
224            "network": {
225                "client_ip": "127.0.0.1",
226                "client_port": 12000,
227                "server_ip": "127.0.0.1",
228                "server_port": 8080,
229                "protocol": "TCP",
230                "tls": false,
231                "tls_version": null,
232                "sni": null
233            },
234            "layer": {
235                "type": "Http",
236                "data": {
237                    "request": {
238                        "method": method,
239                        "url": format!("http://{}{}", host, path),
240                        "version": "HTTP/1.1",
241                        "headers": [],
242                        "cookies": [],
243                        "query": [],
244                        "body": null
245                    },
246                    "response": {
247                        "status": status,
248                        "status_text": "OK",
249                        "version": "HTTP/1.1",
250                        "headers": [],
251                        "cookies": [],
252                        "body": null,
253                        "timing": {
254                            "time_to_first_byte": null,
255                            "time_to_last_byte": null
256                        }
257                    },
258                    "error": null
259                }
260            },
261            "tags": []
262        }))
263        .expect("flow json should deserialize")
264    }
265
266    #[tokio::test]
267    async fn status_endpoint_is_available_without_auth_by_default() {
268        let state = Arc::new(CoreState::new(None).await);
269        let ctx = Arc::new(HttpApiContext::new(state.clone()));
270        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
271
272        let response = app
273            .oneshot(
274                Request::builder()
275                    .uri("/api/v1/status")
276                    .method(Method::GET)
277                    .body(Body::empty())
278                    .expect("request should build"),
279            )
280            .await
281            .expect("request should succeed");
282
283        assert_eq!(response.status(), StatusCode::OK);
284        let body = to_bytes(response.into_body(), usize::MAX)
285            .await
286            .expect("body should be readable");
287        let json: serde_json::Value =
288            serde_json::from_slice(&body).expect("body should be valid json");
289        assert_eq!(json["phase"], "created");
290        assert_eq!(json["running"], false);
291        assert!(json.get("started_at_ms").is_none());
292    }
293
294    #[tokio::test]
295    async fn intercepts_endpoint_uses_shared_snapshot_shape() {
296        let state = Arc::new(CoreState::new(None).await);
297        let ctx = Arc::new(HttpApiContext::new(state.clone()));
298        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
299
300        let response = app
301            .oneshot(
302                Request::builder()
303                    .uri("/api/v1/intercepts")
304                    .method(Method::GET)
305                    .body(Body::empty())
306                    .expect("request should build"),
307            )
308            .await
309            .expect("request should succeed");
310
311        assert_eq!(response.status(), StatusCode::OK);
312        let body = to_bytes(response.into_body(), usize::MAX)
313            .await
314            .expect("body should be readable");
315        let json: serde_json::Value =
316            serde_json::from_slice(&body).expect("body should be valid json");
317        assert_eq!(json["pending_count"], 0);
318        assert_eq!(json["ws_pending_count"], 0);
319    }
320
321    #[tokio::test]
322    async fn audit_endpoint_uses_shared_snapshot_shape() {
323        let state = Arc::new(CoreState::new(None).await);
324        state.update_policy_from(
325            AuditActor::Http,
326            "policy".to_string(),
327            ProxyPolicy {
328                transparent_enabled: true,
329                ..Default::default()
330            },
331        );
332        let _ = state
333            .resolve_intercept_with_modifications_from(
334                AuditActor::Probe,
335                "missing-flow:request".to_string(),
336                "drop",
337                None,
338            )
339            .await;
340        let ctx = Arc::new(HttpApiContext::new(state.clone()));
341        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
342
343        let response = app
344            .oneshot(
345                Request::builder()
346                    .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
347                    .method(Method::GET)
348                    .body(Body::empty())
349                    .expect("request should build"),
350            )
351            .await
352            .expect("request should succeed");
353
354        assert_eq!(response.status(), StatusCode::OK);
355        let body = to_bytes(response.into_body(), usize::MAX)
356            .await
357            .expect("body should be readable");
358        let json: serde_json::Value =
359            serde_json::from_slice(&body).expect("body should be valid json");
360        assert!(json["events"].is_array());
361        assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
362        assert_eq!(json["events"][0]["actor"], "http");
363        assert_eq!(json["events"][0]["kind"], "policy_updated");
364        assert_eq!(json["events"][0]["outcome"], "success");
365    }
366
367    #[tokio::test]
368    async fn prometheus_metrics_endpoint_returns_text_format() {
369        let state = Arc::new(CoreState::new(None).await);
370        let ctx = Arc::new(HttpApiContext::new(state.clone()));
371        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
372
373        let response = app
374            .oneshot(
375                Request::builder()
376                    .uri("/api/v1/metrics/prometheus")
377                    .method(Method::GET)
378                    .body(Body::empty())
379                    .expect("request should build"),
380            )
381            .await
382            .expect("request should succeed");
383
384        assert_eq!(response.status(), StatusCode::OK);
385        let content_type = response
386            .headers()
387            .get(axum::http::header::CONTENT_TYPE)
388            .and_then(|v| v.to_str().ok())
389            .unwrap_or_default();
390        assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
391        let body = to_bytes(response.into_body(), usize::MAX)
392            .await
393            .expect("body should be readable");
394        let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
395        assert!(text.contains("relay_core_flows_total "));
396        assert!(text.contains("relay_core_audit_events_total "));
397    }
398
399    #[tokio::test]
400    async fn flows_endpoint_returns_pagination_metadata() {
401        let state = Arc::new(CoreState::new(None).await);
402        let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
403        let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
404        let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
405        let flow_b_id = flow_b.id.to_string();
406        state.upsert_flow(Box::new(flow_a));
407        state.upsert_flow(Box::new(flow_b));
408        state.upsert_flow(Box::new(flow_c));
409        sleep(Duration::from_millis(30)).await;
410
411        let ctx = Arc::new(HttpApiContext::new(state.clone()));
412        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
413        let response = app
414            .oneshot(
415                Request::builder()
416                    .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
417                    .method(Method::GET)
418                    .body(Body::empty())
419                    .expect("request should build"),
420            )
421            .await
422            .expect("request should succeed");
423
424        assert_eq!(response.status(), StatusCode::OK);
425        let body = to_bytes(response.into_body(), usize::MAX)
426            .await
427            .expect("body should be readable");
428        let json: serde_json::Value =
429            serde_json::from_slice(&body).expect("body should be valid json");
430        assert_eq!(json["returned"], 1);
431        assert_eq!(json["limit"], 1);
432        assert_eq!(json["offset"], 1);
433        assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
434        assert_eq!(json["items"][0]["id"], flow_b_id);
435    }
436
437    #[tokio::test]
438    async fn status_endpoint_requires_bearer_token_when_configured() {
439        let state = Arc::new(CoreState::new(None).await);
440        let ctx = Arc::new(HttpApiContext::new(state.clone()));
441        let app = build_router(
442            ctx,
443            Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
444        );
445
446        let unauthorized = app
447            .clone()
448            .oneshot(
449                Request::builder()
450                    .uri("/api/v1/status")
451                    .method(Method::GET)
452                    .body(Body::empty())
453                    .expect("request should build"),
454            )
455            .await
456            .expect("request should succeed");
457        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
458
459        let authorized = app
460            .oneshot(
461                Request::builder()
462                    .uri("/api/v1/status")
463                    .method(Method::GET)
464                    .header("Authorization", "Bearer secret-token")
465                    .body(Body::empty())
466                    .expect("request should build"),
467            )
468            .await
469            .expect("request should succeed");
470        assert_eq!(authorized.status(), StatusCode::OK);
471    }
472
473    #[tokio::test]
474    async fn cors_is_not_open_by_default() {
475        let state = Arc::new(CoreState::new(None).await);
476        let ctx = Arc::new(HttpApiContext::new(state.clone()));
477        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
478
479        let response = app
480            .oneshot(
481                Request::builder()
482                    .uri("/api/v1/status")
483                    .method(Method::GET)
484                    .header("Origin", "https://example.com")
485                    .body(Body::empty())
486                    .expect("request should build"),
487            )
488            .await
489            .expect("request should succeed");
490
491        assert!(
492            response
493                .headers()
494                .get(ACCESS_CONTROL_ALLOW_ORIGIN)
495                .is_none()
496        );
497    }
498
499    #[tokio::test]
500    async fn cors_allows_explicit_origin_only() {
501        let state = Arc::new(CoreState::new(None).await);
502        let ctx = Arc::new(HttpApiContext::new(state.clone()));
503        let app = build_router(
504            ctx,
505            Arc::new(
506                HttpApiConfig::new(8082)
507                    .with_allowed_origins([HeaderValue::from_static("https://allowed.example")]),
508            ),
509        );
510
511        let allowed = app
512            .clone()
513            .oneshot(
514                Request::builder()
515                    .uri("/api/v1/status")
516                    .method(Method::GET)
517                    .header("Origin", "https://allowed.example")
518                    .body(Body::empty())
519                    .expect("request should build"),
520            )
521            .await
522            .expect("request should succeed");
523        assert_eq!(
524            allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
525            Some(&HeaderValue::from_static("https://allowed.example"))
526        );
527
528        let denied = app
529            .oneshot(
530                Request::builder()
531                    .uri("/api/v1/status")
532                    .method(Method::GET)
533                    .header("Origin", "https://denied.example")
534                    .body(Body::empty())
535                    .expect("request should build"),
536            )
537            .await
538            .expect("request should succeed");
539        assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
540    }
541}