Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46/// Fallback for visitor ID hashing when no JWT secret is configured (dev only).
47const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret";
48
49/// Gateway server configuration.
50#[derive(Debug, Clone)]
51pub struct GatewayConfig {
52    /// Port to listen on.
53    pub port: u16,
54    /// Maximum number of connections.
55    pub max_connections: usize,
56    /// Maximum number of active SSE sessions.
57    pub sse_max_sessions: usize,
58    /// Request timeout in seconds.
59    pub request_timeout_secs: u64,
60    /// Enable CORS.
61    pub cors_enabled: bool,
62    /// Allowed CORS origins.
63    pub cors_origins: Vec<String>,
64    /// Authentication configuration.
65    pub auth: AuthConfig,
66    /// MCP configuration.
67    pub mcp: McpConfig,
68    /// Routes excluded from request logs, metrics, and traces.
69    pub quiet_routes: Vec<String>,
70    /// Token TTL configuration for refresh token management.
71    pub token_ttl: forge_core::AuthTokenTtl,
72    /// Project name (displayed on OAuth consent page).
73    pub project_name: String,
74}
75
76impl Default for GatewayConfig {
77    fn default() -> Self {
78        Self {
79            port: 9081,
80            max_connections: 512,
81            sse_max_sessions: 10_000,
82            request_timeout_secs: 30,
83            cors_enabled: false,
84            cors_origins: Vec::new(),
85            auth: AuthConfig::default(),
86            mcp: McpConfig::default(),
87            quiet_routes: Vec::new(),
88            token_ttl: forge_core::AuthTokenTtl::default(),
89            project_name: "forge-app".to_string(),
90        }
91    }
92}
93
94/// Health check response.
95#[derive(Debug, Serialize)]
96pub struct HealthResponse {
97    pub status: String,
98    pub version: String,
99}
100
101/// Readiness check response.
102#[derive(Debug, Serialize)]
103pub struct ReadinessResponse {
104    pub ready: bool,
105    pub database: bool,
106    pub reactor: bool,
107    pub workflows: bool,
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub blocked_workflow_runs: Option<i64>,
110    pub version: String,
111}
112
113/// State for readiness check.
114#[derive(Clone)]
115pub struct ReadinessState {
116    db_pool: sqlx::PgPool,
117    reactor: Arc<Reactor>,
118}
119
120/// Gateway HTTP server.
121pub struct GatewayServer {
122    config: GatewayConfig,
123    registry: FunctionRegistry,
124    db: Database,
125    reactor: Arc<Reactor>,
126    job_dispatcher: Option<Arc<dyn JobDispatch>>,
127    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
128    mcp_registry: Option<McpToolRegistry>,
129    token_ttl: forge_core::AuthTokenTtl,
130    signals_collector: Option<crate::signals::SignalsCollector>,
131}
132
133impl GatewayServer {
134    /// Create a new gateway server.
135    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
136        let node_id = NodeId::new();
137        let reactor = Arc::new(Reactor::new(
138            node_id,
139            db.primary().clone(),
140            registry.clone(),
141            ReactorConfig::default(),
142        ));
143
144        let token_ttl = config.token_ttl.clone();
145        Self {
146            config,
147            registry,
148            db,
149            reactor,
150            job_dispatcher: None,
151            workflow_dispatcher: None,
152            mcp_registry: None,
153            token_ttl,
154            signals_collector: None,
155        }
156    }
157
158    /// Set the job dispatcher.
159    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
160        self.job_dispatcher = Some(dispatcher);
161        self
162    }
163
164    /// Set the workflow dispatcher.
165    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
166        self.workflow_dispatcher = Some(dispatcher);
167        self
168    }
169
170    /// Set the MCP tool registry.
171    pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
172        self.mcp_registry = Some(registry);
173        self
174    }
175
176    /// Set the signals collector for auto-capturing RPC events and
177    /// registering client signal ingestion endpoints.
178    pub fn with_signals_collector(mut self, collector: crate::signals::SignalsCollector) -> Self {
179        self.signals_collector = Some(collector);
180        self
181    }
182
183    /// Get a reference to the reactor.
184    pub fn reactor(&self) -> Arc<Reactor> {
185        self.reactor.clone()
186    }
187
188    /// Build an OAuth router (bypasses auth middleware). Returns None if OAuth is disabled.
189    pub fn oauth_router(&self) -> Option<(Router, Arc<super::oauth::OAuthState>)> {
190        if !self.config.mcp.oauth {
191            return None;
192        }
193
194        let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
195            .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>)?;
196
197        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
198
199        let jwt_secret = self.config.auth.jwt_secret.clone().unwrap_or_default();
200
201        let oauth_state = Arc::new(super::oauth::OAuthState::new(
202            self.db.primary().clone(),
203            auth_middleware_state,
204            token_issuer,
205            self.token_ttl.access_token_secs,
206            self.token_ttl.refresh_token_days,
207            self.config.auth.is_hmac(),
208            self.config.project_name.clone(),
209            jwt_secret,
210        ));
211
212        let router = Router::new()
213            .route(
214                "/oauth/authorize",
215                get(super::oauth::oauth_authorize_get).post(super::oauth::oauth_authorize_post),
216            )
217            .route("/oauth/token", post(super::oauth::oauth_token))
218            .route("/oauth/register", post(super::oauth::oauth_register))
219            .with_state(oauth_state.clone());
220
221        Some((router, oauth_state))
222    }
223
224    /// Build the Axum router.
225    pub fn router(&self) -> Router {
226        let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
227            .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
228
229        let mut rpc = RpcHandler::with_dispatch_and_issuer(
230            self.registry.clone(),
231            self.db.clone(),
232            self.job_dispatcher.clone(),
233            self.workflow_dispatcher.clone(),
234            token_issuer,
235        );
236        rpc.set_token_ttl(self.token_ttl.clone());
237        if let Some(collector) = &self.signals_collector {
238            let secret = self
239                .config
240                .auth
241                .jwt_secret
242                .clone()
243                .unwrap_or_else(|| DEFAULT_SIGNAL_SECRET.to_string());
244            rpc.set_signals_collector(collector.clone(), secret);
245        }
246        let rpc_handler_state = Arc::new(rpc);
247
248        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
249
250        // Build CORS layer. When specific origins are configured, allow
251        // credentials so the browser accepts cross-origin API responses
252        // (the forge-svelte client sends `credentials: "include"` for
253        // the SSE session cookie). Wildcard methods/headers are incompatible
254        // with credentials per the CORS spec, so we enumerate them.
255        let cors = if self.config.cors_enabled {
256            if self.config.cors_origins.iter().any(|o| o == "*") {
257                // Wildcard origin can't use credentials
258                CorsLayer::new()
259                    .allow_origin(Any)
260                    .allow_methods(Any)
261                    .allow_headers(Any)
262            } else {
263                use axum::http::Method;
264                let origins: Vec<_> = self
265                    .config
266                    .cors_origins
267                    .iter()
268                    .filter_map(|o| o.parse().ok())
269                    .collect();
270                CorsLayer::new()
271                    .allow_origin(origins)
272                    .allow_methods([
273                        Method::GET,
274                        Method::POST,
275                        Method::PUT,
276                        Method::DELETE,
277                        Method::PATCH,
278                        Method::OPTIONS,
279                    ])
280                    .allow_headers([
281                        axum::http::header::CONTENT_TYPE,
282                        axum::http::header::AUTHORIZATION,
283                        axum::http::header::ACCEPT,
284                        axum::http::HeaderName::from_static("x-webhook-signature"),
285                        axum::http::HeaderName::from_static("x-idempotency-key"),
286                        axum::http::HeaderName::from_static("x-correlation-id"),
287                        axum::http::HeaderName::from_static("x-session-id"),
288                        axum::http::HeaderName::from_static("x-forge-platform"),
289                    ])
290                    .allow_credentials(true)
291            }
292        } else {
293            CorsLayer::new()
294        };
295
296        // SSE state for Server-Sent Events
297        let sse_state = Arc::new(SseState::with_config(
298            self.reactor.clone(),
299            auth_middleware_state.clone(),
300            super::sse::SseConfig {
301                max_sessions: self.config.sse_max_sessions,
302                ..Default::default()
303            },
304        ));
305
306        // Readiness state for DB + reactor health check
307        let readiness_state = Arc::new(ReadinessState {
308            db_pool: self.db.primary().clone(),
309            reactor: self.reactor.clone(),
310        });
311
312        // Build the main router with middleware
313        let mut main_router = Router::new()
314            // Health check endpoint (liveness)
315            .route("/health", get(health_handler))
316            // Readiness check endpoint (checks DB)
317            .route("/ready", get(readiness_handler).with_state(readiness_state))
318            // RPC endpoint
319            .route("/rpc", post(rpc_handler))
320            // Batch RPC endpoint
321            .route("/rpc/batch", post(rpc_batch_handler))
322            // REST-style function endpoint (JSON)
323            .route("/rpc/{function}", post(rpc_function_handler))
324            // Prevent oversized JSON payloads from exhausting memory.
325            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
326            // Add state
327            .with_state(rpc_handler_state.clone());
328
329        // Multipart RPC router (separate state needed for multipart)
330        let multipart_router = Router::new()
331            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
332            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
333            // Cap upload fan-out; each request buffers data in memory.
334            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
335            .with_state(rpc_handler_state);
336
337        // SSE router
338        let sse_router = Router::new()
339            .route("/events", get(sse_handler))
340            .route("/subscribe", post(sse_subscribe_handler))
341            .route("/unsubscribe", post(sse_unsubscribe_handler))
342            .route("/subscribe-job", post(sse_job_subscribe_handler))
343            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
344            .with_state(sse_state);
345
346        let mut mcp_router = Router::new();
347        if self.config.mcp.enabled {
348            let path = self.config.mcp.path.clone();
349            let mcp_state = Arc::new(McpState::new(
350                self.config.mcp.clone(),
351                self.mcp_registry.clone().unwrap_or_default(),
352                self.db.primary().clone(),
353                self.job_dispatcher.clone(),
354                self.workflow_dispatcher.clone(),
355            ));
356            mcp_router = mcp_router.route(
357                &path,
358                post(mcp_post_handler)
359                    .get(mcp_get_handler)
360                    .with_state(mcp_state),
361            );
362        }
363
364        // Signal ingestion endpoints (product analytics + diagnostics)
365        let mut signals_router = Router::new();
366        if let Some(collector) = &self.signals_collector {
367            let signals_state = Arc::new(crate::signals::endpoints::SignalsState {
368                collector: collector.clone(),
369                pool: self.db.analytics_pool().clone(),
370                server_secret: self
371                    .config
372                    .auth
373                    .jwt_secret
374                    .clone()
375                    .unwrap_or_else(|| DEFAULT_SIGNAL_SECRET.to_string()),
376            });
377            signals_router = Router::new()
378                .route(
379                    "/signal/event",
380                    post(crate::signals::endpoints::event_handler),
381                )
382                .route(
383                    "/signal/view",
384                    post(crate::signals::endpoints::view_handler),
385                )
386                .route(
387                    "/signal/user",
388                    post(crate::signals::endpoints::user_handler),
389                )
390                .route(
391                    "/signal/report",
392                    post(crate::signals::endpoints::report_handler),
393                )
394                .with_state(signals_state);
395        }
396
397        main_router = main_router
398            .merge(multipart_router)
399            .merge(sse_router)
400            .merge(mcp_router)
401            .merge(signals_router);
402
403        // Build middleware stack
404        let service_builder = ServiceBuilder::new()
405            .layer(HandleErrorLayer::new(handle_middleware_error))
406            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
407            .layer(TimeoutLayer::new(Duration::from_secs(
408                self.config.request_timeout_secs,
409            )))
410            .layer(cors.clone())
411            .layer(middleware::from_fn_with_state(
412                auth_middleware_state,
413                auth_middleware,
414            ))
415            .layer(middleware::from_fn_with_state(
416                Arc::new(self.config.quiet_routes.clone()),
417                tracing_middleware,
418            ));
419
420        // Apply the remaining middleware layers
421        main_router.layer(service_builder)
422    }
423
424    /// Get the socket address to bind to.
425    pub fn addr(&self) -> std::net::SocketAddr {
426        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
427    }
428
429    /// Run the server (blocking).
430    pub async fn run(self) -> Result<(), std::io::Error> {
431        let addr = self.addr();
432        let router = self.router();
433
434        // Start the reactor for real-time updates
435        self.reactor
436            .start()
437            .await
438            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
439        tracing::info!("Reactor started for real-time updates");
440
441        tracing::info!("Gateway server listening on {}", addr);
442
443        let listener = tokio::net::TcpListener::bind(addr).await?;
444        axum::serve(listener, router.into_make_service()).await
445    }
446}
447
448/// Health check handler (liveness probe).
449async fn health_handler() -> Json<HealthResponse> {
450    Json(HealthResponse {
451        status: "healthy".to_string(),
452        version: env!("CARGO_PKG_VERSION").to_string(),
453    })
454}
455
456/// Readiness check handler (readiness probe).
457async fn readiness_handler(
458    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
459) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
460    // Check database connectivity
461    let db_ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
462        .fetch_one(&state.db_pool)
463        .await
464        .is_ok();
465
466    // Check reactor health (change listener must be running for real-time updates)
467    let reactor_stats = state.reactor.stats().await;
468    let reactor_ok = reactor_stats.listener_running;
469
470    // Check for blocked workflow runs (strict mode: unhealthy if any runs are blocked)
471    let (workflows_ok, blocked_count) = if db_ok {
472        match sqlx::query_scalar!(
473            r#"SELECT COUNT(*) as "count!" FROM forge_workflow_runs WHERE status LIKE 'blocked_%'"#,
474        )
475        .fetch_one(&state.db_pool)
476        .await
477        {
478            Ok(count) => (count == 0, if count > 0 { Some(count) } else { None }),
479            Err(_) => (true, None), // if query fails, don't block on this check
480        }
481    } else {
482        (true, None)
483    };
484
485    let ready = db_ok && reactor_ok && workflows_ok;
486    let status_code = if ready {
487        axum::http::StatusCode::OK
488    } else {
489        axum::http::StatusCode::SERVICE_UNAVAILABLE
490    };
491
492    (
493        status_code,
494        Json(ReadinessResponse {
495            ready,
496            database: db_ok,
497            reactor: reactor_ok,
498            workflows: workflows_ok,
499            blocked_workflow_runs: blocked_count,
500            version: env!("CARGO_PKG_VERSION").to_string(),
501        }),
502    )
503}
504
505async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
506    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
507        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
508    } else {
509        (
510            StatusCode::SERVICE_UNAVAILABLE,
511            "SERVICE_UNAVAILABLE",
512            "Server overloaded",
513        )
514    };
515    (
516        status,
517        Json(RpcResponse::error(RpcError::new(code, message))),
518    )
519        .into_response()
520}
521
522fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
523    if let Ok(val) = trace_id.parse() {
524        response.headers_mut().insert(TRACE_ID_HEADER, val);
525    }
526    if let Ok(val) = request_id.parse() {
527        response.headers_mut().insert(REQUEST_ID_HEADER, val);
528    }
529}
530
531/// Extracts W3C traceparent context from HTTP headers.
532struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
533
534impl<'a> Extractor for HeaderExtractor<'a> {
535    fn get(&self, key: &str) -> Option<&str> {
536        self.0.get(key).and_then(|v| v.to_str().ok())
537    }
538
539    fn keys(&self) -> Vec<&str> {
540        self.0.keys().map(|k| k.as_str()).collect()
541    }
542}
543
544/// Wraps each request in a span with HTTP semantics and OpenTelemetry
545/// context propagation. Incoming `traceparent` headers are extracted so
546/// that spans join the caller's distributed trace.
547/// Quiet routes skip spans, logs, and metrics to avoid noise from
548/// probes or high-frequency internal endpoints.
549async fn tracing_middleware(
550    axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
551    req: axum::extract::Request,
552    next: axum::middleware::Next,
553) -> axum::response::Response {
554    let headers = req.headers();
555
556    // Extract W3C traceparent from incoming headers for distributed tracing
557    let parent_cx =
558        global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
559
560    let trace_id = headers
561        .get(TRACE_ID_HEADER)
562        .and_then(|v| v.to_str().ok())
563        .map(String::from)
564        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
565
566    let parent_span_id = headers
567        .get(SPAN_ID_HEADER)
568        .and_then(|v| v.to_str().ok())
569        .map(String::from);
570
571    let method = req.method().to_string();
572    let path = req.uri().path().to_string();
573
574    let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
575    if let Some(span_id) = parent_span_id {
576        tracing_state = tracing_state.with_parent_span(span_id);
577    }
578
579    let mut req = req;
580    req.extensions_mut().insert(tracing_state.clone());
581
582    if req
583        .extensions()
584        .get::<forge_core::function::AuthContext>()
585        .is_none()
586    {
587        req.extensions_mut()
588            .insert(forge_core::function::AuthContext::unauthenticated());
589    }
590
591    // Config uses full paths (/_api/health) but axum strips the prefix
592    // for nested routers, so the middleware sees /health not /_api/health.
593    let full_path = format!("/_api{}", path);
594    let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
595
596    if is_quiet {
597        let mut response = next.run(req).await;
598        set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
599        return response;
600    }
601
602    let span = tracing::info_span!(
603        "http.request",
604        http.method = %method,
605        http.route = %path,
606        http.status_code = tracing::field::Empty,
607        trace_id = %trace_id,
608        request_id = %tracing_state.request_id,
609    );
610
611    // Link this span to the incoming distributed trace context so
612    // fn.execute and all downstream spans share the caller's trace ID
613    span.set_parent(parent_cx);
614
615    let mut response = next.run(req).instrument(span.clone()).await;
616
617    let status = response.status().as_u16();
618    let elapsed = tracing_state.elapsed();
619
620    span.record("http.status_code", status);
621    let duration_ms = elapsed.as_millis() as u64;
622    match status {
623        500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
624        400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
625        200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
626        _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
627    }
628    crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
629
630    set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
631    response
632}
633
634#[cfg(test)]
635#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
636mod tests {
637    use super::*;
638
639    #[test]
640    fn test_gateway_config_default() {
641        let config = GatewayConfig::default();
642        assert_eq!(config.port, 9081);
643        assert_eq!(config.max_connections, 512);
644        assert!(!config.cors_enabled);
645    }
646
647    #[test]
648    fn test_health_response_serialization() {
649        let resp = HealthResponse {
650            status: "healthy".to_string(),
651            version: "0.1.0".to_string(),
652        };
653        let json = serde_json::to_string(&resp).unwrap();
654        assert!(json.contains("healthy"));
655    }
656}