Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46
47/// Gateway server configuration.
48#[derive(Debug, Clone)]
49pub struct GatewayConfig {
50    /// Port to listen on.
51    pub port: u16,
52    /// Maximum number of connections.
53    pub max_connections: usize,
54    /// Maximum number of active SSE sessions.
55    pub sse_max_sessions: usize,
56    /// Request timeout in seconds.
57    pub request_timeout_secs: u64,
58    /// Enable CORS.
59    pub cors_enabled: bool,
60    /// Allowed CORS origins.
61    pub cors_origins: Vec<String>,
62    /// Authentication configuration.
63    pub auth: AuthConfig,
64    /// MCP configuration.
65    pub mcp: McpConfig,
66    /// Routes excluded from request logs, metrics, and traces.
67    pub quiet_routes: Vec<String>,
68    /// Token TTL configuration for refresh token management.
69    pub token_ttl: forge_core::AuthTokenTtl,
70    /// Project name (displayed on OAuth consent page).
71    pub project_name: String,
72}
73
74impl Default for GatewayConfig {
75    fn default() -> Self {
76        Self {
77            port: 9081,
78            max_connections: 512,
79            sse_max_sessions: 10_000,
80            request_timeout_secs: 30,
81            cors_enabled: false,
82            cors_origins: Vec::new(),
83            auth: AuthConfig::default(),
84            mcp: McpConfig::default(),
85            quiet_routes: Vec::new(),
86            token_ttl: forge_core::AuthTokenTtl::default(),
87            project_name: "forge-app".to_string(),
88        }
89    }
90}
91
92/// Health check response.
93#[derive(Debug, Serialize)]
94pub struct HealthResponse {
95    pub status: String,
96    pub version: String,
97}
98
99/// Readiness check response.
100#[derive(Debug, Serialize)]
101pub struct ReadinessResponse {
102    pub ready: bool,
103    pub database: bool,
104    pub reactor: bool,
105    pub version: String,
106}
107
108/// State for readiness check.
109#[derive(Clone)]
110pub struct ReadinessState {
111    db_pool: sqlx::PgPool,
112    reactor: Arc<Reactor>,
113}
114
115/// Gateway HTTP server.
116pub struct GatewayServer {
117    config: GatewayConfig,
118    registry: FunctionRegistry,
119    db: Database,
120    reactor: Arc<Reactor>,
121    job_dispatcher: Option<Arc<dyn JobDispatch>>,
122    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
123    mcp_registry: Option<McpToolRegistry>,
124    token_ttl: forge_core::AuthTokenTtl,
125}
126
127impl GatewayServer {
128    /// Create a new gateway server.
129    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
130        let node_id = NodeId::new();
131        let reactor = Arc::new(Reactor::new(
132            node_id,
133            db.primary().clone(),
134            registry.clone(),
135            ReactorConfig::default(),
136        ));
137
138        let token_ttl = config.token_ttl.clone();
139        Self {
140            config,
141            registry,
142            db,
143            reactor,
144            job_dispatcher: None,
145            workflow_dispatcher: None,
146            mcp_registry: None,
147            token_ttl,
148        }
149    }
150
151    /// Set the job dispatcher.
152    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
153        self.job_dispatcher = Some(dispatcher);
154        self
155    }
156
157    /// Set the workflow dispatcher.
158    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
159        self.workflow_dispatcher = Some(dispatcher);
160        self
161    }
162
163    /// Set the MCP tool registry.
164    pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
165        self.mcp_registry = Some(registry);
166        self
167    }
168
169    /// Get a reference to the reactor.
170    pub fn reactor(&self) -> Arc<Reactor> {
171        self.reactor.clone()
172    }
173
174    /// Build an OAuth router (bypasses auth middleware). Returns None if OAuth is disabled.
175    pub fn oauth_router(&self) -> Option<(Router, Arc<super::oauth::OAuthState>)> {
176        if !self.config.mcp.oauth {
177            return None;
178        }
179
180        let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
181            .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>)?;
182
183        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
184
185        let jwt_secret = self.config.auth.jwt_secret.clone().unwrap_or_default();
186
187        let oauth_state = Arc::new(super::oauth::OAuthState::new(
188            self.db.primary().clone(),
189            auth_middleware_state,
190            token_issuer,
191            self.token_ttl.access_token_secs,
192            self.token_ttl.refresh_token_days,
193            self.config.auth.is_hmac(),
194            self.config.project_name.clone(),
195            jwt_secret,
196        ));
197
198        let router = Router::new()
199            .route(
200                "/oauth/authorize",
201                get(super::oauth::oauth_authorize_get).post(super::oauth::oauth_authorize_post),
202            )
203            .route("/oauth/token", post(super::oauth::oauth_token))
204            .route("/oauth/register", post(super::oauth::oauth_register))
205            .with_state(oauth_state.clone());
206
207        Some((router, oauth_state))
208    }
209
210    /// Build the Axum router.
211    pub fn router(&self) -> Router {
212        let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
213            .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
214
215        let mut rpc = RpcHandler::with_dispatch_and_issuer(
216            self.registry.clone(),
217            self.db.clone(),
218            self.job_dispatcher.clone(),
219            self.workflow_dispatcher.clone(),
220            token_issuer,
221        );
222        rpc.set_token_ttl(self.token_ttl.clone());
223        let rpc_handler_state = Arc::new(rpc);
224
225        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
226
227        // Build CORS layer. When MCP OAuth is enabled and specific origins
228        // are configured, allow credentials so the browser accepts Set-Cookie
229        // on cross-origin API responses (needed for forge_session cookie).
230        let cors = if self.config.cors_enabled {
231            if self.config.cors_origins.iter().any(|o| o == "*") {
232                // Wildcard origin can't use credentials
233                CorsLayer::new()
234                    .allow_origin(Any)
235                    .allow_methods(Any)
236                    .allow_headers(Any)
237            } else {
238                let origins: Vec<_> = self
239                    .config
240                    .cors_origins
241                    .iter()
242                    .filter_map(|o| o.parse().ok())
243                    .collect();
244                use axum::http::Method;
245                CorsLayer::new()
246                    .allow_origin(origins)
247                    .allow_methods([
248                        Method::GET,
249                        Method::POST,
250                        Method::PUT,
251                        Method::DELETE,
252                        Method::PATCH,
253                        Method::OPTIONS,
254                    ])
255                    .allow_headers([
256                        axum::http::header::CONTENT_TYPE,
257                        axum::http::header::AUTHORIZATION,
258                        axum::http::header::ACCEPT,
259                    ])
260                    .allow_credentials(true)
261            }
262        } else {
263            CorsLayer::new()
264        };
265
266        // SSE state for Server-Sent Events
267        let sse_state = Arc::new(SseState::with_config(
268            self.reactor.clone(),
269            auth_middleware_state.clone(),
270            super::sse::SseConfig {
271                max_sessions: self.config.sse_max_sessions,
272                ..Default::default()
273            },
274        ));
275
276        // Readiness state for DB + reactor health check
277        let readiness_state = Arc::new(ReadinessState {
278            db_pool: self.db.primary().clone(),
279            reactor: self.reactor.clone(),
280        });
281
282        // Build the main router with middleware
283        let mut main_router = Router::new()
284            // Health check endpoint (liveness)
285            .route("/health", get(health_handler))
286            // Readiness check endpoint (checks DB)
287            .route("/ready", get(readiness_handler).with_state(readiness_state))
288            // RPC endpoint
289            .route("/rpc", post(rpc_handler))
290            // Batch RPC endpoint
291            .route("/rpc/batch", post(rpc_batch_handler))
292            // REST-style function endpoint (JSON)
293            .route("/rpc/{function}", post(rpc_function_handler))
294            // Prevent oversized JSON payloads from exhausting memory.
295            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
296            // Add state
297            .with_state(rpc_handler_state.clone());
298
299        // Multipart RPC router (separate state needed for multipart)
300        let multipart_router = Router::new()
301            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
302            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
303            // Cap upload fan-out; each request buffers data in memory.
304            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
305            .with_state(rpc_handler_state);
306
307        // SSE router
308        let sse_router = Router::new()
309            .route("/events", get(sse_handler))
310            .route("/subscribe", post(sse_subscribe_handler))
311            .route("/unsubscribe", post(sse_unsubscribe_handler))
312            .route("/subscribe-job", post(sse_job_subscribe_handler))
313            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
314            .with_state(sse_state);
315
316        let mut mcp_router = Router::new();
317        if self.config.mcp.enabled {
318            let path = self.config.mcp.path.clone();
319            let mcp_state = Arc::new(McpState::new(
320                self.config.mcp.clone(),
321                self.mcp_registry.clone().unwrap_or_default(),
322                self.db.primary().clone(),
323                self.job_dispatcher.clone(),
324                self.workflow_dispatcher.clone(),
325            ));
326            mcp_router = mcp_router.route(
327                &path,
328                post(mcp_post_handler)
329                    .get(mcp_get_handler)
330                    .with_state(mcp_state),
331            );
332        }
333
334        main_router = main_router
335            .merge(multipart_router)
336            .merge(sse_router)
337            .merge(mcp_router);
338
339        // Build middleware stack
340        let service_builder = ServiceBuilder::new()
341            .layer(HandleErrorLayer::new(handle_middleware_error))
342            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
343            .layer(TimeoutLayer::new(Duration::from_secs(
344                self.config.request_timeout_secs,
345            )))
346            .layer(cors.clone())
347            .layer(middleware::from_fn_with_state(
348                auth_middleware_state,
349                auth_middleware,
350            ))
351            .layer(middleware::from_fn_with_state(
352                Arc::new(self.config.quiet_routes.clone()),
353                tracing_middleware,
354            ));
355
356        // Apply the remaining middleware layers
357        main_router.layer(service_builder)
358    }
359
360    /// Get the socket address to bind to.
361    pub fn addr(&self) -> std::net::SocketAddr {
362        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
363    }
364
365    /// Run the server (blocking).
366    pub async fn run(self) -> Result<(), std::io::Error> {
367        let addr = self.addr();
368        let router = self.router();
369
370        // Start the reactor for real-time updates
371        self.reactor
372            .start()
373            .await
374            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
375        tracing::info!("Reactor started for real-time updates");
376
377        tracing::info!("Gateway server listening on {}", addr);
378
379        let listener = tokio::net::TcpListener::bind(addr).await?;
380        axum::serve(listener, router.into_make_service()).await
381    }
382}
383
384/// Health check handler (liveness probe).
385async fn health_handler() -> Json<HealthResponse> {
386    Json(HealthResponse {
387        status: "healthy".to_string(),
388        version: env!("CARGO_PKG_VERSION").to_string(),
389    })
390}
391
392/// Readiness check handler (readiness probe).
393async fn readiness_handler(
394    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
395) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
396    // Check database connectivity
397    let db_ok = sqlx::query("SELECT 1")
398        .fetch_one(&state.db_pool)
399        .await
400        .is_ok();
401
402    // Check reactor health (change listener must be running for real-time updates)
403    let reactor_stats = state.reactor.stats().await;
404    let reactor_ok = reactor_stats.listener_running;
405
406    let ready = db_ok && reactor_ok;
407    let status_code = if ready {
408        axum::http::StatusCode::OK
409    } else {
410        axum::http::StatusCode::SERVICE_UNAVAILABLE
411    };
412
413    (
414        status_code,
415        Json(ReadinessResponse {
416            ready,
417            database: db_ok,
418            reactor: reactor_ok,
419            version: env!("CARGO_PKG_VERSION").to_string(),
420        }),
421    )
422}
423
424async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
425    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
426        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
427    } else {
428        (
429            StatusCode::SERVICE_UNAVAILABLE,
430            "SERVICE_UNAVAILABLE",
431            "Server overloaded",
432        )
433    };
434    (
435        status,
436        Json(RpcResponse::error(RpcError::new(code, message))),
437    )
438        .into_response()
439}
440
441fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
442    if let Ok(val) = trace_id.parse() {
443        response.headers_mut().insert(TRACE_ID_HEADER, val);
444    }
445    if let Ok(val) = request_id.parse() {
446        response.headers_mut().insert(REQUEST_ID_HEADER, val);
447    }
448}
449
450/// Extracts W3C traceparent context from HTTP headers.
451struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
452
453impl<'a> Extractor for HeaderExtractor<'a> {
454    fn get(&self, key: &str) -> Option<&str> {
455        self.0.get(key).and_then(|v| v.to_str().ok())
456    }
457
458    fn keys(&self) -> Vec<&str> {
459        self.0.keys().map(|k| k.as_str()).collect()
460    }
461}
462
463/// Wraps each request in a span with HTTP semantics and OpenTelemetry
464/// context propagation. Incoming `traceparent` headers are extracted so
465/// that spans join the caller's distributed trace.
466/// Quiet routes skip spans, logs, and metrics to avoid noise from
467/// probes or high-frequency internal endpoints.
468async fn tracing_middleware(
469    axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
470    req: axum::extract::Request,
471    next: axum::middleware::Next,
472) -> axum::response::Response {
473    let headers = req.headers();
474
475    // Extract W3C traceparent from incoming headers for distributed tracing
476    let parent_cx =
477        global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
478
479    let trace_id = headers
480        .get(TRACE_ID_HEADER)
481        .and_then(|v| v.to_str().ok())
482        .map(String::from)
483        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
484
485    let parent_span_id = headers
486        .get(SPAN_ID_HEADER)
487        .and_then(|v| v.to_str().ok())
488        .map(String::from);
489
490    let method = req.method().to_string();
491    let path = req.uri().path().to_string();
492
493    let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
494    if let Some(span_id) = parent_span_id {
495        tracing_state = tracing_state.with_parent_span(span_id);
496    }
497
498    let mut req = req;
499    req.extensions_mut().insert(tracing_state.clone());
500
501    if req
502        .extensions()
503        .get::<forge_core::function::AuthContext>()
504        .is_none()
505    {
506        req.extensions_mut()
507            .insert(forge_core::function::AuthContext::unauthenticated());
508    }
509
510    // Config uses full paths (/_api/health) but axum strips the prefix
511    // for nested routers, so the middleware sees /health not /_api/health.
512    let full_path = format!("/_api{}", path);
513    let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
514
515    if is_quiet {
516        let mut response = next.run(req).await;
517        set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
518        return response;
519    }
520
521    let span = tracing::info_span!(
522        "http.request",
523        http.method = %method,
524        http.route = %path,
525        http.status_code = tracing::field::Empty,
526        trace_id = %trace_id,
527        request_id = %tracing_state.request_id,
528    );
529
530    // Link this span to the incoming distributed trace context so
531    // fn.execute and all downstream spans share the caller's trace ID
532    span.set_parent(parent_cx);
533
534    let mut response = next.run(req).instrument(span.clone()).await;
535
536    let status = response.status().as_u16();
537    let elapsed = tracing_state.elapsed();
538
539    span.record("http.status_code", status);
540    let duration_ms = elapsed.as_millis() as u64;
541    match status {
542        500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
543        400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
544        200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
545        _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
546    }
547    crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
548
549    set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
550    response
551}
552
553#[cfg(test)]
554#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
555mod tests {
556    use super::*;
557
558    #[test]
559    fn test_gateway_config_default() {
560        let config = GatewayConfig::default();
561        assert_eq!(config.port, 9081);
562        assert_eq!(config.max_connections, 512);
563        assert!(!config.cors_enabled);
564    }
565
566    #[test]
567    fn test_health_response_serialization() {
568        let resp = HealthResponse {
569            status: "healthy".to_string(),
570            version: "0.1.0".to_string(),
571        };
572        let json = serde_json::to_string(&resp).unwrap();
573        assert!(json.contains("healthy"));
574    }
575}