Skip to main content

relay_core_http/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Router,
6    extract::State,
7    http::{
8        HeaderValue, Method, Request, StatusCode,
9        header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, ORIGIN, WWW_AUTHENTICATE},
10    },
11    middleware::{self, Next},
12    response::{IntoResponse, Response},
13};
14use tower_http::cors::CorsLayer;
15use tower_http::trace::TraceLayer;
16use relay_core_runtime::CoreState;
17use relay_core_runtime::services::{
18    FlowReadService, FlowEventHub, RuleService, InterceptService,
19    AuditService, RuntimeStatusService,
20};
21use tracing::info;
22
23use crate::routes;
24
25/// Configuration for the HTTP API server.
26#[derive(Debug, Clone)]
27pub struct HttpApiConfig {
28    /// Address to bind (e.g. `127.0.0.1:8082`)
29    pub addr: SocketAddr,
30    pub bearer_token: Option<String>,
31    pub allowed_origins: Vec<HeaderValue>,
32}
33
34impl HttpApiConfig {
35    pub fn new(port: u16) -> Self {
36        Self {
37            addr: SocketAddr::from(([127, 0, 0, 1], port)),
38            bearer_token: None,
39            allowed_origins: Vec::new(),
40        }
41    }
42
43    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
44        self.bearer_token = Some(token.into());
45        self
46    }
47
48    pub fn with_allowed_origins(mut self, origins: impl IntoIterator<Item = HeaderValue>) -> Self {
49        self.allowed_origins = origins.into_iter().collect();
50        self
51    }
52}
53
54/// Shared context for HTTP API handlers, exposing only narrow-capability traits.
55/// Constructed from a `CoreState` instance; handlers never import `CoreState` directly.
56pub struct HttpApiContext {
57    pub flows: Arc<dyn FlowReadService>,
58    pub events: Arc<dyn FlowEventHub>,
59    pub rules: Arc<dyn RuleService>,
60    pub intercepts: Arc<dyn InterceptService>,
61    pub audit: Arc<dyn AuditService>,
62    pub status: Arc<dyn RuntimeStatusService>,
63}
64
65impl HttpApiContext {
66    pub fn new(core: Arc<CoreState>) -> Self {
67        Self {
68            flows: core.clone(),
69            events: core.clone(),
70            rules: core.clone(),
71            intercepts: core.clone(),
72            audit: core.clone(),
73            status: core.clone(),
74        }
75    }
76}
77
78/// HTTP API server handle.
79pub struct HttpApiServer {
80    config: HttpApiConfig,
81    state: Arc<CoreState>,
82}
83
84impl HttpApiServer {
85    pub fn new(config: HttpApiConfig, state: Arc<CoreState>) -> Self {
86        Self { config, state }
87    }
88
89    /// Start the server; resolves when the server exits or an error occurs.
90    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
91        let ctx = Arc::new(HttpApiContext::new(self.state));
92        let app = build_router(ctx, Arc::new(self.config.clone()));
93        let listener = tokio::net::TcpListener::bind(self.config.addr).await?;
94        info!("relay-core HTTP API listening on {}", self.config.addr);
95        axum::serve(listener, app).await?;
96        Ok(())
97    }
98}
99
100fn build_router(ctx: Arc<HttpApiContext>, config: Arc<HttpApiConfig>) -> Router {
101    let router = Router::new()
102        .merge(routes::version::router())
103        .merge(routes::metrics::router(ctx.clone()))
104        .merge(routes::flows::router(ctx.clone()))
105        .merge(routes::rules::router(ctx.clone()))
106        .merge(routes::intercepts::router(ctx.clone()))
107        .merge(routes::events::router(ctx))
108        .route_layer(middleware::from_fn_with_state(config.clone(), require_bearer_token))
109        .layer(TraceLayer::new_for_http());
110
111    if config.allowed_origins.is_empty() {
112        router
113    } else {
114        router.layer(
115            CorsLayer::new()
116                .allow_origin(config.allowed_origins.clone())
117                .allow_methods([
118                    Method::GET,
119                    Method::POST,
120                    Method::PUT,
121                    Method::DELETE,
122                    Method::OPTIONS,
123                ])
124                .allow_headers([AUTHORIZATION, CONTENT_TYPE, ACCEPT, ORIGIN]),
125        )
126    }
127}
128
129async fn require_bearer_token(
130    State(config): State<Arc<HttpApiConfig>>,
131    request: Request<axum::body::Body>,
132    next: Next,
133) -> Response {
134    if request.method() == Method::OPTIONS {
135        return next.run(request).await;
136    }
137
138    let Some(expected_token) = config.bearer_token.as_deref() else {
139        return next.run(request).await;
140    };
141
142    let expected_value = format!("Bearer {}", expected_token);
143    let is_authorized = request
144        .headers()
145        .get(AUTHORIZATION)
146        .and_then(|value| value.to_str().ok())
147        .map(|value| value == expected_value)
148        .unwrap_or(false);
149
150    if is_authorized {
151        return next.run(request).await;
152    }
153
154    (
155        StatusCode::UNAUTHORIZED,
156        [
157            (WWW_AUTHENTICATE, HeaderValue::from_static("Bearer")),
158            (CONTENT_TYPE, HeaderValue::from_static("application/json")),
159        ],
160        serde_json::json!({
161            "error": "missing_or_invalid_bearer_token"
162        })
163        .to_string(),
164    )
165        .into_response()
166}
167
168#[cfg(test)]
169mod tests {
170    use super::{HttpApiConfig, HttpApiContext, build_router};
171    use axum::{
172        body::{Body, to_bytes},
173        http::{HeaderValue, Method, Request, StatusCode, header::ACCESS_CONTROL_ALLOW_ORIGIN},
174    };
175    use relay_core_api::flow::Flow;
176    use relay_core_api::policy::ProxyPolicy;
177    use relay_core_runtime::{CoreState, audit::AuditActor};
178    use std::sync::Arc;
179    use tokio::time::{Duration, sleep};
180    use tower::ServiceExt;
181    use serde_json::json;
182
183    fn sample_http_flow(host: &str, path: &str, method: &str, status: u16, ts: i64) -> Flow {
184        let flow_id = format!(
185            "00000000-0000-0000-0000-{:012}",
186            (ts as u64) % 1_000_000_000_000
187        );
188        let minute = ((ts / 60_000) % 60) as i64;
189        let second = ((ts / 1_000) % 60) as i64;
190        let millis = (ts % 1_000).abs();
191        let start_rfc3339 = format!("2023-11-14T22:{:02}:{:02}.{:03}Z", minute, second, millis);
192        serde_json::from_value(json!({
193            "id": flow_id,
194            "start_time": start_rfc3339,
195            "end_time": start_rfc3339,
196            "network": {
197                "client_ip": "127.0.0.1",
198                "client_port": 12000,
199                "server_ip": "127.0.0.1",
200                "server_port": 8080,
201                "protocol": "TCP",
202                "tls": false,
203                "tls_version": null,
204                "sni": null
205            },
206            "layer": {
207                "type": "Http",
208                "data": {
209                    "request": {
210                        "method": method,
211                        "url": format!("http://{}{}", host, path),
212                        "version": "HTTP/1.1",
213                        "headers": [],
214                        "cookies": [],
215                        "query": [],
216                        "body": null
217                    },
218                    "response": {
219                        "status": status,
220                        "status_text": "OK",
221                        "version": "HTTP/1.1",
222                        "headers": [],
223                        "cookies": [],
224                        "body": null,
225                        "timing": {
226                            "time_to_first_byte": null,
227                            "time_to_last_byte": null
228                        }
229                    },
230                    "error": null
231                }
232            },
233            "tags": []
234        }))
235        .expect("flow json should deserialize")
236    }
237
238    #[tokio::test]
239    async fn status_endpoint_is_available_without_auth_by_default() {
240        let state = Arc::new(CoreState::new(None).await);
241        let ctx = Arc::new(HttpApiContext::new(state.clone()));
242        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
243
244        let response = app
245            .oneshot(
246                Request::builder()
247                    .uri("/api/v1/status")
248                    .method(Method::GET)
249                    .body(Body::empty())
250                    .expect("request should build"),
251            )
252            .await
253            .expect("request should succeed");
254
255        assert_eq!(response.status(), StatusCode::OK);
256        let body = to_bytes(response.into_body(), usize::MAX)
257            .await
258            .expect("body should be readable");
259        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
260        assert_eq!(json["phase"], "created");
261        assert_eq!(json["running"], false);
262        assert!(json.get("started_at_ms").is_none());
263    }
264
265    #[tokio::test]
266    async fn intercepts_endpoint_uses_shared_snapshot_shape() {
267        let state = Arc::new(CoreState::new(None).await);
268        let ctx = Arc::new(HttpApiContext::new(state.clone()));
269        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
270
271        let response = app
272            .oneshot(
273                Request::builder()
274                    .uri("/api/v1/intercepts")
275                    .method(Method::GET)
276                    .body(Body::empty())
277                    .expect("request should build"),
278            )
279            .await
280            .expect("request should succeed");
281
282        assert_eq!(response.status(), StatusCode::OK);
283        let body = to_bytes(response.into_body(), usize::MAX)
284            .await
285            .expect("body should be readable");
286        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
287        assert_eq!(json["pending_count"], 0);
288        assert_eq!(json["ws_pending_count"], 0);
289    }
290
291    #[tokio::test]
292    async fn audit_endpoint_uses_shared_snapshot_shape() {
293        let state = Arc::new(CoreState::new(None).await);
294        state.update_policy_from(
295            AuditActor::Http,
296            "policy".to_string(),
297            ProxyPolicy {
298                transparent_enabled: true,
299                ..Default::default()
300            },
301        );
302        let _ = state
303            .resolve_intercept_with_modifications_from(
304                AuditActor::Probe,
305                "missing-flow:request".to_string(),
306                "drop",
307                None,
308            )
309            .await;
310        let ctx = Arc::new(HttpApiContext::new(state.clone()));
311        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
312
313        let response = app
314            .oneshot(
315                Request::builder()
316                    .uri("/api/v1/audit?actor=http&kind=policy_updated&outcome=success&limit=1")
317                    .method(Method::GET)
318                    .body(Body::empty())
319                    .expect("request should build"),
320            )
321            .await
322            .expect("request should succeed");
323
324        assert_eq!(response.status(), StatusCode::OK);
325        let body = to_bytes(response.into_body(), usize::MAX)
326            .await
327            .expect("body should be readable");
328        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
329        assert!(json["events"].is_array());
330        assert_eq!(json["events"].as_array().map(|v| v.len()), Some(1));
331        assert_eq!(json["events"][0]["actor"], "http");
332        assert_eq!(json["events"][0]["kind"], "policy_updated");
333        assert_eq!(json["events"][0]["outcome"], "success");
334    }
335
336    #[tokio::test]
337    async fn prometheus_metrics_endpoint_returns_text_format() {
338        let state = Arc::new(CoreState::new(None).await);
339        let ctx = Arc::new(HttpApiContext::new(state.clone()));
340        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
341
342        let response = app
343            .oneshot(
344                Request::builder()
345                    .uri("/api/v1/metrics/prometheus")
346                    .method(Method::GET)
347                    .body(Body::empty())
348                    .expect("request should build"),
349            )
350            .await
351            .expect("request should succeed");
352
353        assert_eq!(response.status(), StatusCode::OK);
354        let content_type = response
355            .headers()
356            .get(axum::http::header::CONTENT_TYPE)
357            .and_then(|v| v.to_str().ok())
358            .unwrap_or_default();
359        assert_eq!(content_type, "text/plain; version=0.0.4; charset=utf-8");
360        let body = to_bytes(response.into_body(), usize::MAX)
361            .await
362            .expect("body should be readable");
363        let text = String::from_utf8(body.to_vec()).expect("prometheus body should be utf-8");
364        assert!(text.contains("relay_core_flows_total "));
365        assert!(text.contains("relay_core_audit_events_total "));
366    }
367
368    #[tokio::test]
369    async fn flows_endpoint_returns_pagination_metadata() {
370        let state = Arc::new(CoreState::new(None).await);
371        let flow_a = sample_http_flow("api.example.com", "/a", "GET", 200, 1_700_000_001_000);
372        let flow_b = sample_http_flow("api.example.com", "/b", "POST", 201, 1_700_000_002_000);
373        let flow_c = sample_http_flow("api.example.com", "/c", "GET", 500, 1_700_000_003_000);
374        let flow_b_id = flow_b.id.to_string();
375        state.upsert_flow(Box::new(flow_a));
376        state.upsert_flow(Box::new(flow_b));
377        state.upsert_flow(Box::new(flow_c));
378        sleep(Duration::from_millis(30)).await;
379
380        let ctx = Arc::new(HttpApiContext::new(state.clone()));
381        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
382        let response = app
383            .oneshot(
384                Request::builder()
385                    .uri("/api/v1/flows?host=api.example.com&limit=1&offset=1")
386                    .method(Method::GET)
387                    .body(Body::empty())
388                    .expect("request should build"),
389            )
390            .await
391            .expect("request should succeed");
392
393        assert_eq!(response.status(), StatusCode::OK);
394        let body = to_bytes(response.into_body(), usize::MAX)
395            .await
396            .expect("body should be readable");
397        let json: serde_json::Value = serde_json::from_slice(&body).expect("body should be valid json");
398        assert_eq!(json["returned"], 1);
399        assert_eq!(json["limit"], 1);
400        assert_eq!(json["offset"], 1);
401        assert_eq!(json["items"].as_array().map(|v| v.len()), Some(1));
402        assert_eq!(json["items"][0]["id"], flow_b_id);
403    }
404
405    #[tokio::test]
406    async fn status_endpoint_requires_bearer_token_when_configured() {
407        let state = Arc::new(CoreState::new(None).await);
408        let ctx = Arc::new(HttpApiContext::new(state.clone()));
409        let app = build_router(
410            ctx,
411            Arc::new(HttpApiConfig::new(8082).with_bearer_token("secret-token")),
412        );
413
414        let unauthorized = app
415            .clone()
416            .oneshot(
417                Request::builder()
418                    .uri("/api/v1/status")
419                    .method(Method::GET)
420                    .body(Body::empty())
421                    .expect("request should build"),
422            )
423            .await
424            .expect("request should succeed");
425        assert_eq!(unauthorized.status(), StatusCode::UNAUTHORIZED);
426
427        let authorized = app
428            .oneshot(
429                Request::builder()
430                    .uri("/api/v1/status")
431                    .method(Method::GET)
432                    .header("Authorization", "Bearer secret-token")
433                    .body(Body::empty())
434                    .expect("request should build"),
435            )
436            .await
437            .expect("request should succeed");
438        assert_eq!(authorized.status(), StatusCode::OK);
439    }
440
441    #[tokio::test]
442    async fn cors_is_not_open_by_default() {
443        let state = Arc::new(CoreState::new(None).await);
444        let ctx = Arc::new(HttpApiContext::new(state.clone()));
445        let app = build_router(ctx, Arc::new(HttpApiConfig::new(8082)));
446
447        let response = app
448            .oneshot(
449                Request::builder()
450                    .uri("/api/v1/status")
451                    .method(Method::GET)
452                    .header("Origin", "https://example.com")
453                    .body(Body::empty())
454                    .expect("request should build"),
455            )
456            .await
457            .expect("request should succeed");
458
459        assert!(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
460    }
461
462    #[tokio::test]
463    async fn cors_allows_explicit_origin_only() {
464        let state = Arc::new(CoreState::new(None).await);
465        let ctx = Arc::new(HttpApiContext::new(state.clone()));
466        let app = build_router(
467            ctx,
468            Arc::new(
469                HttpApiConfig::new(8082).with_allowed_origins([HeaderValue::from_static(
470                    "https://allowed.example",
471                )]),
472            ),
473        );
474
475        let allowed = app
476            .clone()
477            .oneshot(
478                Request::builder()
479                    .uri("/api/v1/status")
480                    .method(Method::GET)
481                    .header("Origin", "https://allowed.example")
482                    .body(Body::empty())
483                    .expect("request should build"),
484            )
485            .await
486            .expect("request should succeed");
487        assert_eq!(
488            allowed.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN),
489            Some(&HeaderValue::from_static("https://allowed.example"))
490        );
491
492        let denied = app
493            .oneshot(
494                Request::builder()
495                    .uri("/api/v1/status")
496                    .method(Method::GET)
497                    .header("Origin", "https://denied.example")
498                    .body(Body::empty())
499                    .expect("request should build"),
500            )
501            .await
502            .expect("request should succeed");
503        assert!(denied.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN).is_none());
504    }
505}