Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use tracing::Instrument;
24
25use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
26use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
27use super::multipart::rpc_multipart_handler;
28use super::response::{RpcError, RpcResponse};
29use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
30use super::sse::{
31    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
32    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
33};
34use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
35use crate::db::Database;
36use crate::function::FunctionRegistry;
37use crate::mcp::McpToolRegistry;
38use crate::realtime::{Reactor, ReactorConfig};
39
40const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
41const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
42const MAX_MULTIPART_CONCURRENCY: usize = 32;
43
44/// Gateway server configuration.
45#[derive(Debug, Clone)]
46pub struct GatewayConfig {
47    /// Port to listen on.
48    pub port: u16,
49    /// Maximum number of connections.
50    pub max_connections: usize,
51    /// Request timeout in seconds.
52    pub request_timeout_secs: u64,
53    /// Enable CORS.
54    pub cors_enabled: bool,
55    /// Allowed CORS origins.
56    pub cors_origins: Vec<String>,
57    /// Authentication configuration.
58    pub auth: AuthConfig,
59    /// MCP configuration.
60    pub mcp: McpConfig,
61    /// Routes excluded from request logs, metrics, and traces.
62    pub quiet_routes: Vec<String>,
63}
64
65impl Default for GatewayConfig {
66    fn default() -> Self {
67        Self {
68            port: 8080,
69            max_connections: 512,
70            request_timeout_secs: 30,
71            cors_enabled: false,
72            cors_origins: Vec::new(),
73            auth: AuthConfig::default(),
74            mcp: McpConfig::default(),
75            quiet_routes: Vec::new(),
76        }
77    }
78}
79
80/// Health check response.
81#[derive(Debug, Serialize)]
82pub struct HealthResponse {
83    pub status: String,
84    pub version: String,
85}
86
87/// Readiness check response.
88#[derive(Debug, Serialize)]
89pub struct ReadinessResponse {
90    pub ready: bool,
91    pub database: bool,
92    pub reactor: bool,
93    pub version: String,
94}
95
96/// State for readiness check.
97#[derive(Clone)]
98pub struct ReadinessState {
99    db_pool: sqlx::PgPool,
100    reactor: Arc<Reactor>,
101}
102
103/// Gateway HTTP server.
104pub struct GatewayServer {
105    config: GatewayConfig,
106    registry: FunctionRegistry,
107    db: Database,
108    reactor: Arc<Reactor>,
109    job_dispatcher: Option<Arc<dyn JobDispatch>>,
110    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
111    mcp_registry: Option<McpToolRegistry>,
112}
113
114impl GatewayServer {
115    /// Create a new gateway server.
116    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
117        let node_id = NodeId::new();
118        let reactor = Arc::new(Reactor::new(
119            node_id,
120            db.read_pool().clone(),
121            registry.clone(),
122            ReactorConfig::default(),
123        ));
124
125        Self {
126            config,
127            registry,
128            db,
129            reactor,
130            job_dispatcher: None,
131            workflow_dispatcher: None,
132            mcp_registry: None,
133        }
134    }
135
136    /// Set the job dispatcher.
137    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
138        self.job_dispatcher = Some(dispatcher);
139        self
140    }
141
142    /// Set the workflow dispatcher.
143    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
144        self.workflow_dispatcher = Some(dispatcher);
145        self
146    }
147
148    /// Set the MCP tool registry.
149    pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
150        self.mcp_registry = Some(registry);
151        self
152    }
153
154    /// Get a reference to the reactor.
155    pub fn reactor(&self) -> Arc<Reactor> {
156        self.reactor.clone()
157    }
158
159    /// Build the Axum router.
160    pub fn router(&self) -> Router {
161        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
162            self.registry.clone(),
163            self.db.clone(),
164            self.job_dispatcher.clone(),
165            self.workflow_dispatcher.clone(),
166        ));
167
168        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
169
170        // Build CORS layer
171        let cors = if self.config.cors_enabled {
172            if self.config.cors_origins.iter().any(|o| o == "*") {
173                CorsLayer::new()
174                    .allow_origin(Any)
175                    .allow_methods(Any)
176                    .allow_headers(Any)
177            } else {
178                let origins: Vec<_> = self
179                    .config
180                    .cors_origins
181                    .iter()
182                    .filter_map(|o| o.parse().ok())
183                    .collect();
184                CorsLayer::new()
185                    .allow_origin(origins)
186                    .allow_methods(Any)
187                    .allow_headers(Any)
188            }
189        } else {
190            CorsLayer::new()
191        };
192
193        // SSE state for Server-Sent Events
194        let sse_state = Arc::new(SseState::new(
195            self.reactor.clone(),
196            auth_middleware_state.clone(),
197        ));
198
199        // Readiness state for DB + reactor health check
200        let readiness_state = Arc::new(ReadinessState {
201            db_pool: self.db.primary().clone(),
202            reactor: self.reactor.clone(),
203        });
204
205        // Build the main router with middleware
206        let mut main_router = Router::new()
207            // Health check endpoint (liveness)
208            .route("/health", get(health_handler))
209            // Readiness check endpoint (checks DB)
210            .route("/ready", get(readiness_handler).with_state(readiness_state))
211            // RPC endpoint
212            .route("/rpc", post(rpc_handler))
213            // REST-style function endpoint (JSON)
214            .route("/rpc/{function}", post(rpc_function_handler))
215            // Prevent oversized JSON payloads from exhausting memory.
216            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
217            // Add state
218            .with_state(rpc_handler_state.clone());
219
220        // Multipart RPC router (separate state needed for multipart)
221        let multipart_router = Router::new()
222            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
223            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
224            // Cap upload fan-out; each request buffers data in memory.
225            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
226            .with_state(rpc_handler_state);
227
228        // SSE router
229        let sse_router = Router::new()
230            .route("/events", get(sse_handler))
231            .route("/subscribe", post(sse_subscribe_handler))
232            .route("/unsubscribe", post(sse_unsubscribe_handler))
233            .route("/subscribe-job", post(sse_job_subscribe_handler))
234            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
235            .with_state(sse_state);
236
237        let mut mcp_router = Router::new();
238        if self.config.mcp.enabled {
239            let path = self.config.mcp.path.clone();
240            let mcp_state = Arc::new(McpState::new(
241                self.config.mcp.clone(),
242                self.mcp_registry.clone().unwrap_or_default(),
243                self.db.primary().clone(),
244                self.job_dispatcher.clone(),
245                self.workflow_dispatcher.clone(),
246            ));
247            mcp_router = mcp_router.route(
248                &path,
249                post(mcp_post_handler)
250                    .get(mcp_get_handler)
251                    .with_state(mcp_state),
252            );
253        }
254
255        main_router = main_router
256            .merge(multipart_router)
257            .merge(sse_router)
258            .merge(mcp_router);
259
260        // Build middleware stack
261        let service_builder = ServiceBuilder::new()
262            .layer(HandleErrorLayer::new(handle_middleware_error))
263            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
264            .layer(TimeoutLayer::new(Duration::from_secs(
265                self.config.request_timeout_secs,
266            )))
267            .layer(cors.clone())
268            .layer(middleware::from_fn_with_state(
269                auth_middleware_state,
270                auth_middleware,
271            ))
272            .layer(middleware::from_fn_with_state(
273                Arc::new(self.config.quiet_routes.clone()),
274                tracing_middleware,
275            ));
276
277        // Apply the remaining middleware layers
278        main_router.layer(service_builder)
279    }
280
281    /// Get the socket address to bind to.
282    pub fn addr(&self) -> std::net::SocketAddr {
283        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
284    }
285
286    /// Run the server (blocking).
287    pub async fn run(self) -> Result<(), std::io::Error> {
288        let addr = self.addr();
289        let router = self.router();
290
291        // Start the reactor for real-time updates
292        self.reactor
293            .start()
294            .await
295            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
296        tracing::info!("Reactor started for real-time updates");
297
298        tracing::info!("Gateway server listening on {}", addr);
299
300        let listener = tokio::net::TcpListener::bind(addr).await?;
301        axum::serve(listener, router.into_make_service()).await
302    }
303}
304
305/// Health check handler (liveness probe).
306async fn health_handler() -> Json<HealthResponse> {
307    Json(HealthResponse {
308        status: "healthy".to_string(),
309        version: env!("CARGO_PKG_VERSION").to_string(),
310    })
311}
312
313/// Readiness check handler (readiness probe).
314async fn readiness_handler(
315    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
316) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
317    // Check database connectivity
318    let db_ok = sqlx::query("SELECT 1")
319        .fetch_one(&state.db_pool)
320        .await
321        .is_ok();
322
323    // Check reactor health (change listener must be running for real-time updates)
324    let reactor_stats = state.reactor.stats().await;
325    let reactor_ok = reactor_stats.listener_running;
326
327    let ready = db_ok && reactor_ok;
328    let status_code = if ready {
329        axum::http::StatusCode::OK
330    } else {
331        axum::http::StatusCode::SERVICE_UNAVAILABLE
332    };
333
334    (
335        status_code,
336        Json(ReadinessResponse {
337            ready,
338            database: db_ok,
339            reactor: reactor_ok,
340            version: env!("CARGO_PKG_VERSION").to_string(),
341        }),
342    )
343}
344
345async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
346    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
347        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
348    } else {
349        (
350            StatusCode::SERVICE_UNAVAILABLE,
351            "SERVICE_UNAVAILABLE",
352            "Server overloaded",
353        )
354    };
355    (
356        status,
357        Json(RpcResponse::error(RpcError::new(code, message))),
358    )
359        .into_response()
360}
361
362fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
363    if let Ok(val) = trace_id.parse() {
364        response.headers_mut().insert(TRACE_ID_HEADER, val);
365    }
366    if let Ok(val) = request_id.parse() {
367        response.headers_mut().insert(REQUEST_ID_HEADER, val);
368    }
369}
370
371/// Wraps each request in a span with HTTP semantics.
372/// Quiet routes skip spans, logs, and metrics to avoid noise from
373/// probes or high-frequency internal endpoints.
374async fn tracing_middleware(
375    axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
376    req: axum::extract::Request,
377    next: axum::middleware::Next,
378) -> axum::response::Response {
379    let headers = req.headers();
380
381    let trace_id = headers
382        .get(TRACE_ID_HEADER)
383        .and_then(|v| v.to_str().ok())
384        .map(String::from)
385        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
386
387    let parent_span_id = headers
388        .get(SPAN_ID_HEADER)
389        .and_then(|v| v.to_str().ok())
390        .map(String::from);
391
392    let method = req.method().to_string();
393    let path = req.uri().path().to_string();
394
395    let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
396    if let Some(span_id) = parent_span_id {
397        tracing_state = tracing_state.with_parent_span(span_id);
398    }
399
400    let mut req = req;
401    req.extensions_mut().insert(tracing_state.clone());
402
403    if req
404        .extensions()
405        .get::<forge_core::function::AuthContext>()
406        .is_none()
407    {
408        req.extensions_mut()
409            .insert(forge_core::function::AuthContext::unauthenticated());
410    }
411
412    // Config uses full paths (/_api/health) but axum strips the prefix
413    // for nested routers, so the middleware sees /health not /_api/health.
414    let full_path = format!("/_api{}", path);
415    let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
416
417    if is_quiet {
418        let mut response = next.run(req).await;
419        set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
420        return response;
421    }
422
423    let span = tracing::info_span!(
424        "http.request",
425        http.method = %method,
426        http.route = %path,
427        http.status_code = tracing::field::Empty,
428        trace_id = %trace_id,
429        request_id = %tracing_state.request_id,
430    );
431
432    let mut response = next.run(req).instrument(span.clone()).await;
433
434    let status = response.status().as_u16();
435    let elapsed = tracing_state.elapsed();
436
437    span.record("http.status_code", status);
438    tracing::info!(parent: &span, duration_ms = elapsed.as_millis() as u64, "Request completed");
439    crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
440
441    set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
442    response
443}
444
445#[cfg(test)]
446#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
447mod tests {
448    use super::*;
449
450    #[test]
451    fn test_gateway_config_default() {
452        let config = GatewayConfig::default();
453        assert_eq!(config.port, 8080);
454        assert_eq!(config.max_connections, 512);
455        assert!(!config.cors_enabled);
456    }
457
458    #[test]
459    fn test_health_response_serialization() {
460        let resp = HealthResponse {
461            status: "healthy".to_string(),
462            version: "0.1.0".to_string(),
463        };
464        let json = serde_json::to_string(&resp).unwrap();
465        assert!(json.contains("healthy"));
466    }
467}