Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
33use super::sse::{
34    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46
47/// Gateway server configuration.
48#[derive(Debug, Clone)]
49pub struct GatewayConfig {
50    /// Port to listen on.
51    pub port: u16,
52    /// Maximum number of connections.
53    pub max_connections: usize,
54    /// Request timeout in seconds.
55    pub request_timeout_secs: u64,
56    /// Enable CORS.
57    pub cors_enabled: bool,
58    /// Allowed CORS origins.
59    pub cors_origins: Vec<String>,
60    /// Authentication configuration.
61    pub auth: AuthConfig,
62    /// MCP configuration.
63    pub mcp: McpConfig,
64    /// Routes excluded from request logs, metrics, and traces.
65    pub quiet_routes: Vec<String>,
66}
67
68impl Default for GatewayConfig {
69    fn default() -> Self {
70        Self {
71            port: 8080,
72            max_connections: 512,
73            request_timeout_secs: 30,
74            cors_enabled: false,
75            cors_origins: Vec::new(),
76            auth: AuthConfig::default(),
77            mcp: McpConfig::default(),
78            quiet_routes: Vec::new(),
79        }
80    }
81}
82
83/// Health check response.
84#[derive(Debug, Serialize)]
85pub struct HealthResponse {
86    pub status: String,
87    pub version: String,
88}
89
90/// Readiness check response.
91#[derive(Debug, Serialize)]
92pub struct ReadinessResponse {
93    pub ready: bool,
94    pub database: bool,
95    pub reactor: bool,
96    pub version: String,
97}
98
99/// State for readiness check.
100#[derive(Clone)]
101pub struct ReadinessState {
102    db_pool: sqlx::PgPool,
103    reactor: Arc<Reactor>,
104}
105
106/// Gateway HTTP server.
107pub struct GatewayServer {
108    config: GatewayConfig,
109    registry: FunctionRegistry,
110    db: Database,
111    reactor: Arc<Reactor>,
112    job_dispatcher: Option<Arc<dyn JobDispatch>>,
113    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
114    mcp_registry: Option<McpToolRegistry>,
115}
116
117impl GatewayServer {
118    /// Create a new gateway server.
119    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
120        let node_id = NodeId::new();
121        let reactor = Arc::new(Reactor::new(
122            node_id,
123            db.read_pool().clone(),
124            registry.clone(),
125            ReactorConfig::default(),
126        ));
127
128        Self {
129            config,
130            registry,
131            db,
132            reactor,
133            job_dispatcher: None,
134            workflow_dispatcher: None,
135            mcp_registry: None,
136        }
137    }
138
139    /// Set the job dispatcher.
140    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
141        self.job_dispatcher = Some(dispatcher);
142        self
143    }
144
145    /// Set the workflow dispatcher.
146    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
147        self.workflow_dispatcher = Some(dispatcher);
148        self
149    }
150
151    /// Set the MCP tool registry.
152    pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
153        self.mcp_registry = Some(registry);
154        self
155    }
156
157    /// Get a reference to the reactor.
158    pub fn reactor(&self) -> Arc<Reactor> {
159        self.reactor.clone()
160    }
161
162    /// Build the Axum router.
163    pub fn router(&self) -> Router {
164        let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
165            .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
166
167        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch_and_issuer(
168            self.registry.clone(),
169            self.db.clone(),
170            self.job_dispatcher.clone(),
171            self.workflow_dispatcher.clone(),
172            token_issuer,
173        ));
174
175        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
176
177        // Build CORS layer
178        let cors = if self.config.cors_enabled {
179            if self.config.cors_origins.iter().any(|o| o == "*") {
180                CorsLayer::new()
181                    .allow_origin(Any)
182                    .allow_methods(Any)
183                    .allow_headers(Any)
184            } else {
185                let origins: Vec<_> = self
186                    .config
187                    .cors_origins
188                    .iter()
189                    .filter_map(|o| o.parse().ok())
190                    .collect();
191                CorsLayer::new()
192                    .allow_origin(origins)
193                    .allow_methods(Any)
194                    .allow_headers(Any)
195            }
196        } else {
197            CorsLayer::new()
198        };
199
200        // SSE state for Server-Sent Events
201        let sse_state = Arc::new(SseState::new(
202            self.reactor.clone(),
203            auth_middleware_state.clone(),
204        ));
205
206        // Readiness state for DB + reactor health check
207        let readiness_state = Arc::new(ReadinessState {
208            db_pool: self.db.primary().clone(),
209            reactor: self.reactor.clone(),
210        });
211
212        // Build the main router with middleware
213        let mut main_router = Router::new()
214            // Health check endpoint (liveness)
215            .route("/health", get(health_handler))
216            // Readiness check endpoint (checks DB)
217            .route("/ready", get(readiness_handler).with_state(readiness_state))
218            // RPC endpoint
219            .route("/rpc", post(rpc_handler))
220            // REST-style function endpoint (JSON)
221            .route("/rpc/{function}", post(rpc_function_handler))
222            // Prevent oversized JSON payloads from exhausting memory.
223            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
224            // Add state
225            .with_state(rpc_handler_state.clone());
226
227        // Multipart RPC router (separate state needed for multipart)
228        let multipart_router = Router::new()
229            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
230            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
231            // Cap upload fan-out; each request buffers data in memory.
232            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
233            .with_state(rpc_handler_state);
234
235        // SSE router
236        let sse_router = Router::new()
237            .route("/events", get(sse_handler))
238            .route("/subscribe", post(sse_subscribe_handler))
239            .route("/unsubscribe", post(sse_unsubscribe_handler))
240            .route("/subscribe-job", post(sse_job_subscribe_handler))
241            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
242            .with_state(sse_state);
243
244        let mut mcp_router = Router::new();
245        if self.config.mcp.enabled {
246            let path = self.config.mcp.path.clone();
247            let mcp_state = Arc::new(McpState::new(
248                self.config.mcp.clone(),
249                self.mcp_registry.clone().unwrap_or_default(),
250                self.db.primary().clone(),
251                self.job_dispatcher.clone(),
252                self.workflow_dispatcher.clone(),
253            ));
254            mcp_router = mcp_router.route(
255                &path,
256                post(mcp_post_handler)
257                    .get(mcp_get_handler)
258                    .with_state(mcp_state),
259            );
260        }
261
262        main_router = main_router
263            .merge(multipart_router)
264            .merge(sse_router)
265            .merge(mcp_router);
266
267        // Build middleware stack
268        let service_builder = ServiceBuilder::new()
269            .layer(HandleErrorLayer::new(handle_middleware_error))
270            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
271            .layer(TimeoutLayer::new(Duration::from_secs(
272                self.config.request_timeout_secs,
273            )))
274            .layer(cors.clone())
275            .layer(middleware::from_fn_with_state(
276                auth_middleware_state,
277                auth_middleware,
278            ))
279            .layer(middleware::from_fn_with_state(
280                Arc::new(self.config.quiet_routes.clone()),
281                tracing_middleware,
282            ));
283
284        // Apply the remaining middleware layers
285        main_router.layer(service_builder)
286    }
287
288    /// Get the socket address to bind to.
289    pub fn addr(&self) -> std::net::SocketAddr {
290        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
291    }
292
293    /// Run the server (blocking).
294    pub async fn run(self) -> Result<(), std::io::Error> {
295        let addr = self.addr();
296        let router = self.router();
297
298        // Start the reactor for real-time updates
299        self.reactor
300            .start()
301            .await
302            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
303        tracing::info!("Reactor started for real-time updates");
304
305        tracing::info!("Gateway server listening on {}", addr);
306
307        let listener = tokio::net::TcpListener::bind(addr).await?;
308        axum::serve(listener, router.into_make_service()).await
309    }
310}
311
312/// Health check handler (liveness probe).
313async fn health_handler() -> Json<HealthResponse> {
314    Json(HealthResponse {
315        status: "healthy".to_string(),
316        version: env!("CARGO_PKG_VERSION").to_string(),
317    })
318}
319
320/// Readiness check handler (readiness probe).
321async fn readiness_handler(
322    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
323) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
324    // Check database connectivity
325    let db_ok = sqlx::query("SELECT 1")
326        .fetch_one(&state.db_pool)
327        .await
328        .is_ok();
329
330    // Check reactor health (change listener must be running for real-time updates)
331    let reactor_stats = state.reactor.stats().await;
332    let reactor_ok = reactor_stats.listener_running;
333
334    let ready = db_ok && reactor_ok;
335    let status_code = if ready {
336        axum::http::StatusCode::OK
337    } else {
338        axum::http::StatusCode::SERVICE_UNAVAILABLE
339    };
340
341    (
342        status_code,
343        Json(ReadinessResponse {
344            ready,
345            database: db_ok,
346            reactor: reactor_ok,
347            version: env!("CARGO_PKG_VERSION").to_string(),
348        }),
349    )
350}
351
352async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
353    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
354        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
355    } else {
356        (
357            StatusCode::SERVICE_UNAVAILABLE,
358            "SERVICE_UNAVAILABLE",
359            "Server overloaded",
360        )
361    };
362    (
363        status,
364        Json(RpcResponse::error(RpcError::new(code, message))),
365    )
366        .into_response()
367}
368
369fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
370    if let Ok(val) = trace_id.parse() {
371        response.headers_mut().insert(TRACE_ID_HEADER, val);
372    }
373    if let Ok(val) = request_id.parse() {
374        response.headers_mut().insert(REQUEST_ID_HEADER, val);
375    }
376}
377
378/// Extracts W3C traceparent context from HTTP headers.
379struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
380
381impl<'a> Extractor for HeaderExtractor<'a> {
382    fn get(&self, key: &str) -> Option<&str> {
383        self.0.get(key).and_then(|v| v.to_str().ok())
384    }
385
386    fn keys(&self) -> Vec<&str> {
387        self.0.keys().map(|k| k.as_str()).collect()
388    }
389}
390
391/// Wraps each request in a span with HTTP semantics and OpenTelemetry
392/// context propagation. Incoming `traceparent` headers are extracted so
393/// that spans join the caller's distributed trace.
394/// Quiet routes skip spans, logs, and metrics to avoid noise from
395/// probes or high-frequency internal endpoints.
396async fn tracing_middleware(
397    axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
398    req: axum::extract::Request,
399    next: axum::middleware::Next,
400) -> axum::response::Response {
401    let headers = req.headers();
402
403    // Extract W3C traceparent from incoming headers for distributed tracing
404    let parent_cx =
405        global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
406
407    let trace_id = headers
408        .get(TRACE_ID_HEADER)
409        .and_then(|v| v.to_str().ok())
410        .map(String::from)
411        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
412
413    let parent_span_id = headers
414        .get(SPAN_ID_HEADER)
415        .and_then(|v| v.to_str().ok())
416        .map(String::from);
417
418    let method = req.method().to_string();
419    let path = req.uri().path().to_string();
420
421    let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
422    if let Some(span_id) = parent_span_id {
423        tracing_state = tracing_state.with_parent_span(span_id);
424    }
425
426    let mut req = req;
427    req.extensions_mut().insert(tracing_state.clone());
428
429    if req
430        .extensions()
431        .get::<forge_core::function::AuthContext>()
432        .is_none()
433    {
434        req.extensions_mut()
435            .insert(forge_core::function::AuthContext::unauthenticated());
436    }
437
438    // Config uses full paths (/_api/health) but axum strips the prefix
439    // for nested routers, so the middleware sees /health not /_api/health.
440    let full_path = format!("/_api{}", path);
441    let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
442
443    if is_quiet {
444        let mut response = next.run(req).await;
445        set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
446        return response;
447    }
448
449    let span = tracing::info_span!(
450        "http.request",
451        http.method = %method,
452        http.route = %path,
453        http.status_code = tracing::field::Empty,
454        trace_id = %trace_id,
455        request_id = %tracing_state.request_id,
456    );
457
458    // Link this span to the incoming distributed trace context so
459    // fn.execute and all downstream spans share the caller's trace ID
460    span.set_parent(parent_cx);
461
462    let mut response = next.run(req).instrument(span.clone()).await;
463
464    let status = response.status().as_u16();
465    let elapsed = tracing_state.elapsed();
466
467    span.record("http.status_code", status);
468    let duration_ms = elapsed.as_millis() as u64;
469    match status {
470        500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
471        400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
472        200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
473        _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
474    }
475    crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
476
477    set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
478    response
479}
480
481#[cfg(test)]
482#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_gateway_config_default() {
488        let config = GatewayConfig::default();
489        assert_eq!(config.port, 8080);
490        assert_eq!(config.max_connections, 512);
491        assert!(!config.cors_enabled);
492    }
493
494    #[test]
495    fn test_health_response_serialization() {
496        let resp = HealthResponse {
497            status: "healthy".to_string(),
498            version: "0.1.0".to_string(),
499        };
500        let json = serde_json::to_string(&resp).unwrap();
501        assert!(json.contains("healthy"));
502    }
503}