Skip to main content

relay_core_http/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Router,
6    extract::State,
7    http::{
8        HeaderValue, Method, Request, StatusCode,
9        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10    },
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13};
14use tower_http::cors::CorsLayer;
15use tower_http::trace::TraceLayer;
16use relay_core_runtime::CoreState;
17use relay_core_runtime::services::{
18    FlowReadService, FlowEventHub, RuleService, InterceptService,
19    AuditService, RuntimeStatusService,
20};
21use tracing::info;
22
23use crate::routes;
24
25/// Configuration for the HTTP API server.
26#[derive(Debug, Clone)]
27pub struct HttpApiConfig {
28    /// Address to bind (e.g. `127.0.0.1:8082`)
29    pub addr: SocketAddr,
30    pub bearer_token: Option<String>,
31    pub allowed_origins: Vec<HeaderValue>,
32}
33
34impl HttpApiConfig {
35    pub fn new(port: u16) -> Self {
36        Self {
37            addr: SocketAddr::from(([127, 0, 0, 1], port)),
38            bearer_token: None,
39            allowed_origins: Vec::new(),
40        }
41    }
42
43    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
44        self.bearer_token = Some(token.into());
45        self
46    }
47
48    pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
49        self.allowed_origins = origins.into_iter().collect();
50        self
51    }
52}
53
54/// Shared context for HTTP API handlers, exposing only narrow-capability traits.
55/// Constructed from a `CoreState` instance; handlers never import `CoreState` directly.
56pub struct HttpApiContext {
57    pub flows: Arc<dyn FlowReadService>,
58    pub events: Arc<dyn FlowEventHub>,
59    pub rules: Arc<dyn RuleService>,
60    pub intercepts: Arc<dyn InterceptService>,
61    pub audit: Arc<dyn AuditService>,
62    pub status: Arc<dyn RuntimeStatusService>,
63}
64
65impl HttpApiContext {
66    pub fn new(core: Arc<CoreState>) -> Self {
67        Self {
68            flows: core.clone(),
69            events: core.clone(),
70            rules: core.clone(),
71            intercepts: core.clone(),
72            audit: core.clone(),
73            status: core.clone(),
74        }
75    }
76}
77
78/// HTTP API server handle.
79pub struct HttpApiServer {
80    config: HttpApiConfig,
81    state: Arc<CoreState>,
82}
83
84impl HttpApiServer {
85    pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
86        Self { config, state }
87    }
88
89    /// Start the server; resolves when the server exits or an error occurs.
90    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
91        let ctx = Arc::new(HttpApiContext::new(self.state));
92        let config = Arc::new(self.config.clone());
93        let listener = tokio::net::TcpListener::bind(config.addr).await?;
94        info!("relay-core HTTP API listening on {}", config.addr);
95
96        if !config.addr.ip().is_loopback() && config.bearer_token.is_none() {
97            tracing::warn!(
98                "HTTP API is listening on {} without a bearer token configured. \
99                 This exposes the API to the network without authentication. \
100                 Set a bearer token for production use, or use --bind 127.0.0.1 for local-only access.",
101                config.addr
102            );
103        }
104
105        let app = build_router(ctx, config);
106        axum::serve(listener, app).await?;
107        Ok(())
108    }
109}
110
111fn build_router(ctx: Arc<HttpApiContext>, config: Arc<HttpApiConfig>) -> Router {
112    let router = Router::new()
113        .merge(routes::version::router())
114        .merge(routes::metrics::router(ctx.clone()))
115        .merge(routes::flows::router(ctx.clone()))
116        .merge(routes::rules::router(ctx.clone()))
117        .merge(routes::intercepts::router(ctx.clone()))
118        .merge(routes::events::router(ctx))
119        .route_layer(middleware::from_fn_with_state(config.clone(), require_bearer_token))
120        .layer(TraceLayer::new_for_http());
121
122    if config.allowed_origins.is_empty() {
123        router
124    } else {
125        router.layer(
126            CorsLayer::new()
127                .allow_origin(config.allowed_origins.clone())
128                .allow_methods([
129                    Method::GET,
130                    Method::POST,
131                    Method::PUT,
132                    Method::DELETE,
133                    Method::OPTIONS,
134                ])
135                .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
136        )
137    }
138}
139
140async fn require_bearer_token(
141    State(config): State<Arc<HttpApiConfig>>,
142    request: Request<axum::body::Body>,
143    next: Next,
144) -> Response {
145    if request.method() == Method::OPTIONS {
146        return next.run(request).await;
147    }
148
149    let Some(expected_token) = config.bearer_token.as_deref() else {
150        return next.run(request).await;
151    };
152
153    let expected_value = format!("Bearer {}", expected_token);
154    let is_authorized = request
155        .headers()
156        .get(AUTHORIZATION)
157        .and_then(|value| value.to_str().ok())
158        .map(|value| value == expected_value)
159        .unwrap_or(false);
160
161    if is_authorized {
162        return next.run(request).await;
163    }
164
165    (
166        StatusCode::UNAUTHORIZED,
167        [
168            (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
169            (CONTENT_TYPE, HeaderValue::from_static("application/json")),
170        ],
171        serde_json::json!({
172            "error": "missing_or_invalid_bearer_token"
173        })
174        .to_string(),
175    )
176        .into_response()
177}
178
179#[cfg(test)]
180mod tests {
181    use super::{HttpApiConfig, HttpApiContext, build_router};
182    use axum::{
183        body::{Body, to_bytes},
184        http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
185    };
186    use relay_core_api::flow::Flow;
187    use relay_core_api::policy::ProxyPolicy;
188    use relay_core_runtime::{CoreState, audit::AuditActor};
189    use std::sync::Arc;
190    use tokio::time::{Duration, sleep};
191    use tower::ServiceExt;
192    use serde_json::json;
193
194    fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
195        let flow_id = format!(
196            "00000000-0000-0000-0000-{:012}",
197            (ts as u64) % 1_000_000_000_000
198        );
199        let minute = ((ts / 60_000) % 60) as i64;
200        let second = ((ts / 1_000) % 60) as i64;
201        let millis = (ts % 1_000).abs();
202        let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
203        serde_json::from_value(json!({
204            "id": flow_id,
205            "start_time": start_rfc3339,
206            "end_time": start_rfc3339,
207            "network": {
208                "client_ip": "127.0.0.1",
209                "client_port": 12000,
210                "server_ip": "127.0.0.1",
211                "server_port": 8080,
212                "protocol": "TCP",
213                "tls": false,
214                "tls_version": null,
215                "sni": null
216            },
217            "layer": {
218                "type": "Http",
219                "data": {
220                    "request": {
221                        "method": method,
222                        "url": format!("http://{}{}", host, path),
223                        "version": "HTTP/1.1",
224                        "headers": [],
225                        "cookies": [],
226                        "query": [],
227                        "body": null
228                    },
229                    "response": {
230                        "status": status,
231                        "status_text": "OK",
232                        "version": "HTTP/1.1",
233                        "headers": [],
234                        "cookies": [],
235                        "body": null,
236                        "timing": {
237                            "time_to_first_byte": null,
238                            "time_to_last_byte": null
239                        }
240                    },
241                    "error": null
242                }
243            },
244            "tags": []
245        }))
246        .expect("flow json should deserialize")
247    }
248
249    #[tokio::test]
250    async fn status_endpoint_is_available_without_auth_by_default() {
251        let state = Arc::new(CoreState::new(None).await);
252        let ctx = Arc::new(HttpApiContext::new(state.clone()));
253        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
254
255        let response = app
256            .oneshot(
257                Request::builder()
258                    .uri("/api/v1/status")
259                    .method(Method::GET)
260                    .body(Body::empty())
261                    .expect("request should build"),
262            )
263            .await
264            .expect("request should succeed");
265
266        assert_eq!(response.status(), StatusCode::OK);
267        let body = to_bytes(response.into_body(), usize::MAX)
268            .await
269            .expect("body should be readable");
270        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
271        assert_eq!(json["phase"], "created");
272        assert_eq!(json["running"], false);
273        assert!(json.get("started_at_ms").is_none());
274    }
275
276    #[tokio::test]
277    async fn intercepts_endpoint_uses_shared_snapshot_shape() {
278        let state = Arc::new(CoreState::new(None).await);
279        let ctx = Arc::new(HttpApiContext::new(state.clone()));
280        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
281
282        let response = app
283            .oneshot(
284                Request::builder()
285                    .uri("/api/v1/intercepts")
286                    .method(Method::GET)
287                    .body(Body::empty())
288                    .expect("request should build"),
289            )
290            .await
291            .expect("request should succeed");
292
293        assert_eq!(response.status(), StatusCode::OK);
294        let body = to_bytes(response.into_body(), usize::MAX)
295            .await
296            .expect("body should be readable");
297        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
298        assert_eq!(json["pending_count"], 0);
299        assert_eq!(json["ws_pending_count"], 0);
300    }
301
302    #[tokio::test]
303    async fn audit_endpoint_uses_shared_snapshot_shape() {
304        let state = Arc::new(CoreState::new(None).await);
305        state.update_policy_from(
306            AuditActor::Http,
307            "policy".to_string(),
308            ProxyPolicy {
309                transparent_enabled: true,
310                ..Default::default()
311            },
312        );
313        let _ = state
314            .resolve_intercept_with_modifications_from(
315                AuditActor::Probe,
316                "missing-flow:request".to_string(),
317                "drop",
318                None,
319            )
320            .await;
321        let ctx = Arc::new(HttpApiContext::new(state.clone()));
322        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
323
324        let response = app
325            .oneshot(
326                Request::builder()
327                    .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
328                    .method(Method::GET)
329                    .body(Body::empty())
330                    .expect("request should build"),
331            )
332            .await
333            .expect("request should succeed");
334
335        assert_eq!(response.status(), StatusCode::OK);
336        let body = to_bytes(response.into_body(), usize::MAX)
337            .await
338            .expect("body should be readable");
339        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
340        assert!(json["events"].is_array());
341        assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
342        assert_eq!(json["events"][0]["actor"], "http");
343        assert_eq!(json["events"][0]["kind"], "policy_updated");
344        assert_eq!(json["events"][0]["outcome"], "success");
345    }
346
347    #[tokio::test]
348    async fn prometheus_metrics_endpoint_returns_text_format() {
349        let state = Arc::new(CoreState::new(None).await);
350        let ctx = Arc::new(HttpApiContext::new(state.clone()));
351        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
352
353        let response = app
354            .oneshot(
355                Request::builder()
356                    .uri("/api/v1/metrics/prometheus")
357                    .method(Method::GET)
358                    .body(Body::empty())
359                    .expect("request should build"),
360            )
361            .await
362            .expect("request should succeed");
363
364        assert_eq!(response.status(), StatusCode::OK);
365        let content_type = response
366            .headers()
367            .get(axum::http::header::CONTENT_TYPE)
368            .and_then(|v| v.to_str().ok())
369            .unwrap_or_default();
370        assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
371        let body = to_bytes(response.into_body(), usize::MAX)
372            .await
373            .expect("body should be readable");
374        let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
375        assert!(text.contains("relay_core_flows_total "));
376        assert!(text.contains("relay_core_audit_events_total "));
377    }
378
379    #[tokio::test]
380    async fn flows_endpoint_returns_pagination_metadata() {
381        let state = Arc::new(CoreState::new(None).await);
382        let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
383        let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
384        let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
385        let flow_b_id = flow_b.id.to_string();
386        state.upsert_flow(Box::new(flow_a));
387        state.upsert_flow(Box::new(flow_b));
388        state.upsert_flow(Box::new(flow_c));
389        sleep(Duration::from_millis(30)).await;
390
391        let ctx = Arc::new(HttpApiContext::new(state.clone()));
392        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
393        let response = app
394            .oneshot(
395                Request::builder()
396                    .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
397                    .method(Method::GET)
398                    .body(Body::empty())
399                    .expect("request should build"),
400            )
401            .await
402            .expect("request should succeed");
403
404        assert_eq!(response.status(), StatusCode::OK);
405        let body = to_bytes(response.into_body(), usize::MAX)
406            .await
407            .expect("body should be readable");
408        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
409        assert_eq!(json["returned"], 1);
410        assert_eq!(json["limit"], 1);
411        assert_eq!(json["offset"], 1);
412        assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
413        assert_eq!(json["items"][0]["id"], flow_b_id);
414    }
415
416    #[tokio::test]
417    async fn status_endpoint_requires_bearer_token_when_configured() {
418        let state = Arc::new(CoreState::new(None).await);
419        let ctx = Arc::new(HttpApiContext::new(state.clone()));
420        let app = build_router(
421            ctx,
422            Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
423        );
424
425        let unauthorized = app
426            .clone()
427            .oneshot(
428                Request::builder()
429                    .uri("/api/v1/status")
430                    .method(Method::GET)
431                    .body(Body::empty())
432                    .expect("request should build"),
433            )
434            .await
435            .expect("request should succeed");
436        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
437
438        let authorized = app
439            .oneshot(
440                Request::builder()
441                    .uri("/api/v1/status")
442                    .method(Method::GET)
443                    .header("Authorization", "Bearer secret-token")
444                    .body(Body::empty())
445                    .expect("request should build"),
446            )
447            .await
448            .expect("request should succeed");
449        assert_eq!(authorized.status(), StatusCode::OK);
450    }
451
452    #[tokio::test]
453    async fn cors_is_not_open_by_default() {
454        let state = Arc::new(CoreState::new(None).await);
455        let ctx = Arc::new(HttpApiContext::new(state.clone()));
456        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
457
458        let response = app
459            .oneshot(
460                Request::builder()
461                    .uri("/api/v1/status")
462                    .method(Method::GET)
463                    .header("Origin", "https://example.com")
464                    .body(Body::empty())
465                    .expect("request should build"),
466            )
467            .await
468            .expect("request should succeed");
469
470        assert!(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
471    }
472
473    #[tokio::test]
474    async fn cors_allows_explicit_origin_only() {
475        let state = Arc::new(CoreState::new(None).await);
476        let ctx = Arc::new(HttpApiContext::new(state.clone()));
477        let app = build_router(
478            ctx,
479            Arc::new(
480                HttpApiConfig::new(8082).with_allowed_origins([HeaderValue::from_static(
481                    "https://allowed.example",
482                )]),
483            ),
484        );
485
486        let allowed = app
487            .clone()
488            .oneshot(
489                Request::builder()
490                    .uri("/api/v1/status")
491                    .method(Method::GET)
492                    .header("Origin", "https://allowed.example")
493                    .body(Body::empty())
494                    .expect("request should build"),
495            )
496            .await
497            .expect("request should succeed");
498        assert_eq!(
499            allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
500            Some(&HeaderValue::from_static("https://allowed.example"))
501        );
502
503        let denied = app
504            .oneshot(
505                Request::builder()
506                    .uri("/api/v1/status")
507                    .method(Method::GET)
508                    .header("Origin", "https://denied.example")
509                    .body(Body::empty())
510                    .expect("request should build"),
511            )
512            .await
513            .expect("request should succeed");
514        assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
515    }
516}