Skip to main content

relay_core_http/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Router,
6    extract::State,
7    http::{
8        HeaderValue, Method, Request, StatusCode,
9        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10    },
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13};
14use relay_core_runtime::CoreState;
15use relay_core_runtime::services::{
16    AuditService, FlowEventHub, FlowReadService, InterceptService, RuleService,
17    RuntimeStatusService,
18};
19use tower_http::cors::CorsLayer;
20use tower_http::trace::TraceLayer;
21use tracing::info;
22
23use crate::routes;
24
25/// Configuration for the HTTP API server.
26#[derive(Debug, Clone)]
27pub struct HttpApiConfig {
28    /// Address to bind (e.g. `127.0.0.1:8082`)
29    pub addr: SocketAddr,
30    pub bearer_token: Option<String>,
31    pub allowed_origins: Vec<HeaderValue>,
32}
33
34impl HttpApiConfig {
35    pub fn new(port: u16) -> Self {
36        Self {
37            addr: SocketAddr::from(([127, 0, 0, 1], port)),
38            bearer_token: None,
39            allowed_origins: Vec::new(),
40        }
41    }
42
43    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
44        self.bearer_token = Some(token.into());
45        self
46    }
47
48    pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
49        self.allowed_origins = origins.into_iter().collect();
50        self
51    }
52}
53
54/// Shared context for HTTP API handlers, exposing only narrow-capability traits.
55/// Constructed from a `CoreState` instance; handlers never import `CoreState` directly.
56pub struct HttpApiContext {
57    pub flows: Arc<dyn FlowReadService>,
58    pub events: Arc<dyn FlowEventHub>,
59    pub rules: Arc<dyn RuleService>,
60    pub intercepts: Arc<dyn InterceptService>,
61    pub audit: Arc<dyn AuditService>,
62    pub status: Arc<dyn RuntimeStatusService>,
63}
64
65impl HttpApiContext {
66    pub fn new(core: Arc<CoreState>) -> Self {
67        Self {
68            flows: core.clone(),
69            events: core.clone(),
70            rules: core.clone(),
71            intercepts: core.clone(),
72            audit: core.clone(),
73            status: core.clone(),
74        }
75    }
76}
77
78/// HTTP API server handle.
79pub struct HttpApiServer {
80    config: HttpApiConfig,
81    state: Arc<CoreState>,
82}
83
84impl HttpApiServer {
85    pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
86        Self { config, state }
87    }
88
89    /// Start the server; resolves when the server exits or an error occurs.
90    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
91        let ctx = Arc::new(HttpApiContext::new(self.state));
92        let config = Arc::new(self.config.clone());
93        let listener = tokio::net::TcpListener::bind(config.addr).await?;
94        info!("relay-core HTTP API listening on {}", config.addr);
95
96        if !config.addr.ip().is_loopback() && config.bearer_token.is_none() {
97            tracing::warn!(
98                "HTTP API is listening on {} without a bearer token configured. \
99                 This exposes the API to the network without authentication. \
100                 Set a bearer token for production use, or use --bind 127.0.0.1 for local-only access.",
101                config.addr
102            );
103        }
104
105        let app = build_router(ctx, config);
106        axum::serve(listener, app).await?;
107        Ok(())
108    }
109}
110
111fn build_router(ctx: Arc<HttpApiContext>, config: Arc<HttpApiConfig>) -> Router {
112    let router = Router::new()
113        .merge(routes::version::router())
114        .merge(routes::metrics::router(ctx.clone()))
115        .merge(routes::flows::router(ctx.clone()))
116        .merge(routes::rules::router(ctx.clone()))
117        .merge(routes::intercepts::router(ctx.clone()))
118        .merge(routes::events::router(ctx))
119        .route_layer(middleware::from_fn_with_state(
120            config.clone(),
121            require_bearer_token,
122        ))
123        .layer(TraceLayer::new_for_http());
124
125    if config.allowed_origins.is_empty() {
126        router
127    } else {
128        router.layer(
129            CorsLayer::new()
130                .allow_origin(config.allowed_origins.clone())
131                .allow_methods([
132                    Method::GET,
133                    Method::POST,
134                    Method::PUT,
135                    Method::DELETE,
136                    Method::OPTIONS,
137                ])
138                .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
139        )
140    }
141}
142
143async fn require_bearer_token(
144    State(config): State<Arc<HttpApiConfig>>,
145    request: Request<axum::body::Body>,
146    next: Next,
147) -> Response {
148    if request.method() == Method::OPTIONS {
149        return next.run(request).await;
150    }
151
152    let Some(expected_token) = config.bearer_token.as_deref() else {
153        return next.run(request).await;
154    };
155
156    let expected_value = format!("Bearer {}", expected_token);
157    let is_authorized = request
158        .headers()
159        .get(AUTHORIZATION)
160        .and_then(|value| value.to_str().ok())
161        .map(|value| value == expected_value)
162        .unwrap_or(false);
163
164    if is_authorized {
165        return next.run(request).await;
166    }
167
168    (
169        StatusCode::UNAUTHORIZED,
170        [
171            (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
172            (CONTENT_TYPE, HeaderValue::from_static("application/json")),
173        ],
174        serde_json::json!({
175            "error": "missing_or_invalid_bearer_token"
176        })
177        .to_string(),
178    )
179        .into_response()
180}
181
182#[cfg(test)]
183mod tests {
184    use super::{HttpApiConfig, HttpApiContext, build_router};
185    use axum::{
186        body::{Body, to_bytes},
187        http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
188    };
189    use relay_core_api::flow::Flow;
190    use relay_core_api::policy::ProxyPolicy;
191    use relay_core_runtime::{CoreState, audit::AuditActor};
192    use serde_json::json;
193    use std::sync::Arc;
194    use tokio::time::{Duration, sleep};
195    use tower::ServiceExt;
196
197    fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
198        let flow_id = format!(
199            "00000000-0000-0000-0000-{:012}",
200            (ts as u64) % 1_000_000_000_000
201        );
202        let minute = ((ts / 60_000) % 60) as i64;
203        let second = ((ts / 1_000) % 60) as i64;
204        let millis = (ts % 1_000).abs();
205        let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
206        serde_json::from_value(json!({
207            "id": flow_id,
208            "start_time": start_rfc3339,
209            "end_time": start_rfc3339,
210            "network": {
211                "client_ip": "127.0.0.1",
212                "client_port": 12000,
213                "server_ip": "127.0.0.1",
214                "server_port": 8080,
215                "protocol": "TCP",
216                "tls": false,
217                "tls_version": null,
218                "sni": null
219            },
220            "layer": {
221                "type": "Http",
222                "data": {
223                    "request": {
224                        "method": method,
225                        "url": format!("http://{}{}", host, path),
226                        "version": "HTTP/1.1",
227                        "headers": [],
228                        "cookies": [],
229                        "query": [],
230                        "body": null
231                    },
232                    "response": {
233                        "status": status,
234                        "status_text": "OK",
235                        "version": "HTTP/1.1",
236                        "headers": [],
237                        "cookies": [],
238                        "body": null,
239                        "timing": {
240                            "time_to_first_byte": null,
241                            "time_to_last_byte": null
242                        }
243                    },
244                    "error": null
245                }
246            },
247            "tags": []
248        }))
249        .expect("flow json should deserialize")
250    }
251
252    #[tokio::test]
253    async fn status_endpoint_is_available_without_auth_by_default() {
254        let state = Arc::new(CoreState::new(None).await);
255        let ctx = Arc::new(HttpApiContext::new(state.clone()));
256        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
257
258        let response = app
259            .oneshot(
260                Request::builder()
261                    .uri("/api/v1/status")
262                    .method(Method::GET)
263                    .body(Body::empty())
264                    .expect("request should build"),
265            )
266            .await
267            .expect("request should succeed");
268
269        assert_eq!(response.status(), StatusCode::OK);
270        let body = to_bytes(response.into_body(), usize::MAX)
271            .await
272            .expect("body should be readable");
273        let json: serde_json::Value =
274            serde_json::from_slice(&body).expect("body should be valid json");
275        assert_eq!(json["phase"], "created");
276        assert_eq!(json["running"], false);
277        assert!(json.get("started_at_ms").is_none());
278    }
279
280    #[tokio::test]
281    async fn intercepts_endpoint_uses_shared_snapshot_shape() {
282        let state = Arc::new(CoreState::new(None).await);
283        let ctx = Arc::new(HttpApiContext::new(state.clone()));
284        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
285
286        let response = app
287            .oneshot(
288                Request::builder()
289                    .uri("/api/v1/intercepts")
290                    .method(Method::GET)
291                    .body(Body::empty())
292                    .expect("request should build"),
293            )
294            .await
295            .expect("request should succeed");
296
297        assert_eq!(response.status(), StatusCode::OK);
298        let body = to_bytes(response.into_body(), usize::MAX)
299            .await
300            .expect("body should be readable");
301        let json: serde_json::Value =
302            serde_json::from_slice(&body).expect("body should be valid json");
303        assert_eq!(json["pending_count"], 0);
304        assert_eq!(json["ws_pending_count"], 0);
305    }
306
307    #[tokio::test]
308    async fn audit_endpoint_uses_shared_snapshot_shape() {
309        let state = Arc::new(CoreState::new(None).await);
310        state.update_policy_from(
311            AuditActor::Http,
312            "policy".to_string(),
313            ProxyPolicy {
314                transparent_enabled: true,
315                ..Default::default()
316            },
317        );
318        let _ = state
319            .resolve_intercept_with_modifications_from(
320                AuditActor::Probe,
321                "missing-flow:request".to_string(),
322                "drop",
323                None,
324            )
325            .await;
326        let ctx = Arc::new(HttpApiContext::new(state.clone()));
327        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
328
329        let response = app
330            .oneshot(
331                Request::builder()
332                    .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
333                    .method(Method::GET)
334                    .body(Body::empty())
335                    .expect("request should build"),
336            )
337            .await
338            .expect("request should succeed");
339
340        assert_eq!(response.status(), StatusCode::OK);
341        let body = to_bytes(response.into_body(), usize::MAX)
342            .await
343            .expect("body should be readable");
344        let json: serde_json::Value =
345            serde_json::from_slice(&body).expect("body should be valid json");
346        assert!(json["events"].is_array());
347        assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
348        assert_eq!(json["events"][0]["actor"], "http");
349        assert_eq!(json["events"][0]["kind"], "policy_updated");
350        assert_eq!(json["events"][0]["outcome"], "success");
351    }
352
353    #[tokio::test]
354    async fn prometheus_metrics_endpoint_returns_text_format() {
355        let state = Arc::new(CoreState::new(None).await);
356        let ctx = Arc::new(HttpApiContext::new(state.clone()));
357        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
358
359        let response = app
360            .oneshot(
361                Request::builder()
362                    .uri("/api/v1/metrics/prometheus")
363                    .method(Method::GET)
364                    .body(Body::empty())
365                    .expect("request should build"),
366            )
367            .await
368            .expect("request should succeed");
369
370        assert_eq!(response.status(), StatusCode::OK);
371        let content_type = response
372            .headers()
373            .get(axum::http::header::CONTENT_TYPE)
374            .and_then(|v| v.to_str().ok())
375            .unwrap_or_default();
376        assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
377        let body = to_bytes(response.into_body(), usize::MAX)
378            .await
379            .expect("body should be readable");
380        let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
381        assert!(text.contains("relay_core_flows_total "));
382        assert!(text.contains("relay_core_audit_events_total "));
383    }
384
385    #[tokio::test]
386    async fn flows_endpoint_returns_pagination_metadata() {
387        let state = Arc::new(CoreState::new(None).await);
388        let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
389        let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
390        let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
391        let flow_b_id = flow_b.id.to_string();
392        state.upsert_flow(Box::new(flow_a));
393        state.upsert_flow(Box::new(flow_b));
394        state.upsert_flow(Box::new(flow_c));
395        sleep(Duration::from_millis(30)).await;
396
397        let ctx = Arc::new(HttpApiContext::new(state.clone()));
398        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
399        let response = app
400            .oneshot(
401                Request::builder()
402                    .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
403                    .method(Method::GET)
404                    .body(Body::empty())
405                    .expect("request should build"),
406            )
407            .await
408            .expect("request should succeed");
409
410        assert_eq!(response.status(), StatusCode::OK);
411        let body = to_bytes(response.into_body(), usize::MAX)
412            .await
413            .expect("body should be readable");
414        let json: serde_json::Value =
415            serde_json::from_slice(&body).expect("body should be valid json");
416        assert_eq!(json["returned"], 1);
417        assert_eq!(json["limit"], 1);
418        assert_eq!(json["offset"], 1);
419        assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
420        assert_eq!(json["items"][0]["id"], flow_b_id);
421    }
422
423    #[tokio::test]
424    async fn status_endpoint_requires_bearer_token_when_configured() {
425        let state = Arc::new(CoreState::new(None).await);
426        let ctx = Arc::new(HttpApiContext::new(state.clone()));
427        let app = build_router(
428            ctx,
429            Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
430        );
431
432        let unauthorized = app
433            .clone()
434            .oneshot(
435                Request::builder()
436                    .uri("/api/v1/status")
437                    .method(Method::GET)
438                    .body(Body::empty())
439                    .expect("request should build"),
440            )
441            .await
442            .expect("request should succeed");
443        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
444
445        let authorized = app
446            .oneshot(
447                Request::builder()
448                    .uri("/api/v1/status")
449                    .method(Method::GET)
450                    .header("Authorization", "Bearer secret-token")
451                    .body(Body::empty())
452                    .expect("request should build"),
453            )
454            .await
455            .expect("request should succeed");
456        assert_eq!(authorized.status(), StatusCode::OK);
457    }
458
459    #[tokio::test]
460    async fn cors_is_not_open_by_default() {
461        let state = Arc::new(CoreState::new(None).await);
462        let ctx = Arc::new(HttpApiContext::new(state.clone()));
463        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
464
465        let response = app
466            .oneshot(
467                Request::builder()
468                    .uri("/api/v1/status")
469                    .method(Method::GET)
470                    .header("Origin", "https://example.com")
471                    .body(Body::empty())
472                    .expect("request should build"),
473            )
474            .await
475            .expect("request should succeed");
476
477        assert!(
478            response
479                .headers()
480                .get(ACCESS_CONTROL_ALLOW_ORIGIN)
481                .is_none()
482        );
483    }
484
485    #[tokio::test]
486    async fn cors_allows_explicit_origin_only() {
487        let state = Arc::new(CoreState::new(None).await);
488        let ctx = Arc::new(HttpApiContext::new(state.clone()));
489        let app = build_router(
490            ctx,
491            Arc::new(
492                HttpApiConfig::new(8082)
493                    .with_allowed_origins([HeaderValue::from_static("https://allowed.example")]),
494            ),
495        );
496
497        let allowed = app
498            .clone()
499            .oneshot(
500                Request::builder()
501                    .uri("/api/v1/status")
502                    .method(Method::GET)
503                    .header("Origin", "https://allowed.example")
504                    .body(Body::empty())
505                    .expect("request should build"),
506            )
507            .await
508            .expect("request should succeed");
509        assert_eq!(
510            allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
511            Some(&HeaderValue::from_static("https://allowed.example"))
512        );
513
514        let denied = app
515            .oneshot(
516                Request::builder()
517                    .uri("/api/v1/status")
518                    .method(Method::GET)
519                    .header("Origin", "https://denied.example")
520                    .body(Body::empty())
521                    .expect("request should build"),
522            )
523            .await
524            .expect("request should succeed");
525        assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
526    }
527}