Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46
47/// Gateway server configuration.
48#[derive(Debug, Clone)]
49pub struct GatewayConfig {
50    /// Port to listen on.
51    pub port: u16,
52    /// Maximum number of connections.
53    pub max_connections: usize,
54    /// Maximum number of active SSE sessions.
55    pub sse_max_sessions: usize,
56    /// Request timeout in seconds.
57    pub request_timeout_secs: u64,
58    /// Enable CORS.
59    pub cors_enabled: bool,
60    /// Allowed CORS origins.
61    pub cors_origins: Vec<String>,
62    /// Authentication configuration.
63    pub auth: AuthConfig,
64    /// MCP configuration.
65    pub mcp: McpConfig,
66    /// Routes excluded from request logs, metrics, and traces.
67    pub quiet_routes: Vec<String>,
68}
69
70impl Default for GatewayConfig {
71    fn default() -> Self {
72        Self {
73            port: 8080,
74            max_connections: 512,
75            sse_max_sessions: 10_000,
76            request_timeout_secs: 30,
77            cors_enabled: false,
78            cors_origins: Vec::new(),
79            auth: AuthConfig::default(),
80            mcp: McpConfig::default(),
81            quiet_routes: Vec::new(),
82        }
83    }
84}
85
86/// Health check response.
87#[derive(Debug, Serialize)]
88pub struct HealthResponse {
89    pub status: String,
90    pub version: String,
91}
92
93/// Readiness check response.
94#[derive(Debug, Serialize)]
95pub struct ReadinessResponse {
96    pub ready: bool,
97    pub database: bool,
98    pub reactor: bool,
99    pub version: String,
100}
101
102/// State for readiness check.
103#[derive(Clone)]
104pub struct ReadinessState {
105    db_pool: sqlx::PgPool,
106    reactor: Arc<Reactor>,
107}
108
109/// Gateway HTTP server.
110pub struct GatewayServer {
111    config: GatewayConfig,
112    registry: FunctionRegistry,
113    db: Database,
114    reactor: Arc<Reactor>,
115    job_dispatcher: Option<Arc<dyn JobDispatch>>,
116    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
117    mcp_registry: Option<McpToolRegistry>,
118}
119
120impl GatewayServer {
121    /// Create a new gateway server.
122    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
123        let node_id = NodeId::new();
124        let reactor = Arc::new(Reactor::new(
125            node_id,
126            db.primary().clone(),
127            registry.clone(),
128            ReactorConfig::default(),
129        ));
130
131        Self {
132            config,
133            registry,
134            db,
135            reactor,
136            job_dispatcher: None,
137            workflow_dispatcher: None,
138            mcp_registry: None,
139        }
140    }
141
142    /// Set the job dispatcher.
143    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
144        self.job_dispatcher = Some(dispatcher);
145        self
146    }
147
148    /// Set the workflow dispatcher.
149    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
150        self.workflow_dispatcher = Some(dispatcher);
151        self
152    }
153
154    /// Set the MCP tool registry.
155    pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
156        self.mcp_registry = Some(registry);
157        self
158    }
159
160    /// Get a reference to the reactor.
161    pub fn reactor(&self) -> Arc<Reactor> {
162        self.reactor.clone()
163    }
164
165    /// Build the Axum router.
166    pub fn router(&self) -> Router {
167        let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
168            .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
169
170        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch_and_issuer(
171            self.registry.clone(),
172            self.db.clone(),
173            self.job_dispatcher.clone(),
174            self.workflow_dispatcher.clone(),
175            token_issuer,
176        ));
177
178        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
179
180        // Build CORS layer
181        let cors = if self.config.cors_enabled {
182            if self.config.cors_origins.iter().any(|o| o == "*") {
183                CorsLayer::new()
184                    .allow_origin(Any)
185                    .allow_methods(Any)
186                    .allow_headers(Any)
187            } else {
188                let origins: Vec<_> = self
189                    .config
190                    .cors_origins
191                    .iter()
192                    .filter_map(|o| o.parse().ok())
193                    .collect();
194                CorsLayer::new()
195                    .allow_origin(origins)
196                    .allow_methods(Any)
197                    .allow_headers(Any)
198            }
199        } else {
200            CorsLayer::new()
201        };
202
203        // SSE state for Server-Sent Events
204        let sse_state = Arc::new(SseState::with_config(
205            self.reactor.clone(),
206            auth_middleware_state.clone(),
207            super::sse::SseConfig {
208                max_sessions: self.config.sse_max_sessions,
209                ..Default::default()
210            },
211        ));
212
213        // Readiness state for DB + reactor health check
214        let readiness_state = Arc::new(ReadinessState {
215            db_pool: self.db.primary().clone(),
216            reactor: self.reactor.clone(),
217        });
218
219        // Build the main router with middleware
220        let mut main_router = Router::new()
221            // Health check endpoint (liveness)
222            .route("/health", get(health_handler))
223            // Readiness check endpoint (checks DB)
224            .route("/ready", get(readiness_handler).with_state(readiness_state))
225            // RPC endpoint
226            .route("/rpc", post(rpc_handler))
227            // Batch RPC endpoint
228            .route("/rpc/batch", post(rpc_batch_handler))
229            // REST-style function endpoint (JSON)
230            .route("/rpc/{function}", post(rpc_function_handler))
231            // Prevent oversized JSON payloads from exhausting memory.
232            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
233            // Add state
234            .with_state(rpc_handler_state.clone());
235
236        // Multipart RPC router (separate state needed for multipart)
237        let multipart_router = Router::new()
238            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
239            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
240            // Cap upload fan-out; each request buffers data in memory.
241            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
242            .with_state(rpc_handler_state);
243
244        // SSE router
245        let sse_router = Router::new()
246            .route("/events", get(sse_handler))
247            .route("/subscribe", post(sse_subscribe_handler))
248            .route("/unsubscribe", post(sse_unsubscribe_handler))
249            .route("/subscribe-job", post(sse_job_subscribe_handler))
250            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
251            .with_state(sse_state);
252
253        let mut mcp_router = Router::new();
254        if self.config.mcp.enabled {
255            let path = self.config.mcp.path.clone();
256            let mcp_state = Arc::new(McpState::new(
257                self.config.mcp.clone(),
258                self.mcp_registry.clone().unwrap_or_default(),
259                self.db.primary().clone(),
260                self.job_dispatcher.clone(),
261                self.workflow_dispatcher.clone(),
262            ));
263            mcp_router = mcp_router.route(
264                &path,
265                post(mcp_post_handler)
266                    .get(mcp_get_handler)
267                    .with_state(mcp_state),
268            );
269        }
270
271        main_router = main_router
272            .merge(multipart_router)
273            .merge(sse_router)
274            .merge(mcp_router);
275
276        // Build middleware stack
277        let service_builder = ServiceBuilder::new()
278            .layer(HandleErrorLayer::new(handle_middleware_error))
279            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
280            .layer(TimeoutLayer::new(Duration::from_secs(
281                self.config.request_timeout_secs,
282            )))
283            .layer(cors.clone())
284            .layer(middleware::from_fn_with_state(
285                auth_middleware_state,
286                auth_middleware,
287            ))
288            .layer(middleware::from_fn_with_state(
289                Arc::new(self.config.quiet_routes.clone()),
290                tracing_middleware,
291            ));
292
293        // Apply the remaining middleware layers
294        main_router.layer(service_builder)
295    }
296
297    /// Get the socket address to bind to.
298    pub fn addr(&self) -> std::net::SocketAddr {
299        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
300    }
301
302    /// Run the server (blocking).
303    pub async fn run(self) -> Result<(), std::io::Error> {
304        let addr = self.addr();
305        let router = self.router();
306
307        // Start the reactor for real-time updates
308        self.reactor
309            .start()
310            .await
311            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
312        tracing::info!("Reactor started for real-time updates");
313
314        tracing::info!("Gateway server listening on {}", addr);
315
316        let listener = tokio::net::TcpListener::bind(addr).await?;
317        axum::serve(listener, router.into_make_service()).await
318    }
319}
320
321/// Health check handler (liveness probe).
322async fn health_handler() -> Json<HealthResponse> {
323    Json(HealthResponse {
324        status: "healthy".to_string(),
325        version: env!("CARGO_PKG_VERSION").to_string(),
326    })
327}
328
329/// Readiness check handler (readiness probe).
330async fn readiness_handler(
331    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
332) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
333    // Check database connectivity
334    let db_ok = sqlx::query("SELECT 1")
335        .fetch_one(&state.db_pool)
336        .await
337        .is_ok();
338
339    // Check reactor health (change listener must be running for real-time updates)
340    let reactor_stats = state.reactor.stats().await;
341    let reactor_ok = reactor_stats.listener_running;
342
343    let ready = db_ok && reactor_ok;
344    let status_code = if ready {
345        axum::http::StatusCode::OK
346    } else {
347        axum::http::StatusCode::SERVICE_UNAVAILABLE
348    };
349
350    (
351        status_code,
352        Json(ReadinessResponse {
353            ready,
354            database: db_ok,
355            reactor: reactor_ok,
356            version: env!("CARGO_PKG_VERSION").to_string(),
357        }),
358    )
359}
360
361async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
362    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
363        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
364    } else {
365        (
366            StatusCode::SERVICE_UNAVAILABLE,
367            "SERVICE_UNAVAILABLE",
368            "Server overloaded",
369        )
370    };
371    (
372        status,
373        Json(RpcResponse::error(RpcError::new(code, message))),
374    )
375        .into_response()
376}
377
378fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
379    if let Ok(val) = trace_id.parse() {
380        response.headers_mut().insert(TRACE_ID_HEADER, val);
381    }
382    if let Ok(val) = request_id.parse() {
383        response.headers_mut().insert(REQUEST_ID_HEADER, val);
384    }
385}
386
387/// Extracts W3C traceparent context from HTTP headers.
388struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
389
390impl<'a> Extractor for HeaderExtractor<'a> {
391    fn get(&self, key: &str) -> Option<&str> {
392        self.0.get(key).and_then(|v| v.to_str().ok())
393    }
394
395    fn keys(&self) -> Vec<&str> {
396        self.0.keys().map(|k| k.as_str()).collect()
397    }
398}
399
400/// Wraps each request in a span with HTTP semantics and OpenTelemetry
401/// context propagation. Incoming `traceparent` headers are extracted so
402/// that spans join the caller's distributed trace.
403/// Quiet routes skip spans, logs, and metrics to avoid noise from
404/// probes or high-frequency internal endpoints.
405async fn tracing_middleware(
406    axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
407    req: axum::extract::Request,
408    next: axum::middleware::Next,
409) -> axum::response::Response {
410    let headers = req.headers();
411
412    // Extract W3C traceparent from incoming headers for distributed tracing
413    let parent_cx =
414        global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
415
416    let trace_id = headers
417        .get(TRACE_ID_HEADER)
418        .and_then(|v| v.to_str().ok())
419        .map(String::from)
420        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
421
422    let parent_span_id = headers
423        .get(SPAN_ID_HEADER)
424        .and_then(|v| v.to_str().ok())
425        .map(String::from);
426
427    let method = req.method().to_string();
428    let path = req.uri().path().to_string();
429
430    let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
431    if let Some(span_id) = parent_span_id {
432        tracing_state = tracing_state.with_parent_span(span_id);
433    }
434
435    let mut req = req;
436    req.extensions_mut().insert(tracing_state.clone());
437
438    if req
439        .extensions()
440        .get::<forge_core::function::AuthContext>()
441        .is_none()
442    {
443        req.extensions_mut()
444            .insert(forge_core::function::AuthContext::unauthenticated());
445    }
446
447    // Config uses full paths (/_api/health) but axum strips the prefix
448    // for nested routers, so the middleware sees /health not /_api/health.
449    let full_path = format!("/_api{}", path);
450    let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
451
452    if is_quiet {
453        let mut response = next.run(req).await;
454        set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
455        return response;
456    }
457
458    let span = tracing::info_span!(
459        "http.request",
460        http.method = %method,
461        http.route = %path,
462        http.status_code = tracing::field::Empty,
463        trace_id = %trace_id,
464        request_id = %tracing_state.request_id,
465    );
466
467    // Link this span to the incoming distributed trace context so
468    // fn.execute and all downstream spans share the caller's trace ID
469    span.set_parent(parent_cx);
470
471    let mut response = next.run(req).instrument(span.clone()).await;
472
473    let status = response.status().as_u16();
474    let elapsed = tracing_state.elapsed();
475
476    span.record("http.status_code", status);
477    let duration_ms = elapsed.as_millis() as u64;
478    match status {
479        500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
480        400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
481        200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
482        _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
483    }
484    crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
485
486    set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
487    response
488}
489
490#[cfg(test)]
491#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_gateway_config_default() {
497        let config = GatewayConfig::default();
498        assert_eq!(config.port, 8080);
499        assert_eq!(config.max_connections, 512);
500        assert!(!config.cors_enabled);
501    }
502
503    #[test]
504    fn test_health_response_serialization() {
505        let resp = HealthResponse {
506            status: "healthy".to_string(),
507            version: "0.1.0".to_string(),
508        };
509        let json = serde_json::to_string(&resp).unwrap();
510        assert!(json.contains("healthy"));
511    }
512}