Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23
24use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
25use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
26use super::multipart::rpc_multipart_handler;
27use super::response::{RpcError, RpcResponse};
28use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
29use super::sse::{
30    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
31    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
32};
33use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
34use crate::db::Database;
35use crate::function::FunctionRegistry;
36use crate::mcp::McpToolRegistry;
37use crate::realtime::{Reactor, ReactorConfig};
38
39const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
40const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
41const MAX_MULTIPART_CONCURRENCY: usize = 32;
42
43/// Gateway server configuration.
44#[derive(Debug, Clone)]
45pub struct GatewayConfig {
46    /// Port to listen on.
47    pub port: u16,
48    /// Maximum number of connections.
49    pub max_connections: usize,
50    /// Request timeout in seconds.
51    pub request_timeout_secs: u64,
52    /// Enable CORS.
53    pub cors_enabled: bool,
54    /// Allowed CORS origins.
55    pub cors_origins: Vec<String>,
56    /// Authentication configuration.
57    pub auth: AuthConfig,
58    /// MCP configuration.
59    pub mcp: McpConfig,
60}
61
62impl Default for GatewayConfig {
63    fn default() -> Self {
64        Self {
65            port: 8080,
66            max_connections: 512,
67            request_timeout_secs: 30,
68            cors_enabled: false,
69            cors_origins: Vec::new(),
70            auth: AuthConfig::default(),
71            mcp: McpConfig::default(),
72        }
73    }
74}
75
76/// Health check response.
77#[derive(Debug, Serialize)]
78pub struct HealthResponse {
79    pub status: String,
80    pub version: String,
81}
82
83/// Readiness check response.
84#[derive(Debug, Serialize)]
85pub struct ReadinessResponse {
86    pub ready: bool,
87    pub database: bool,
88    pub reactor: bool,
89    pub version: String,
90}
91
92/// State for readiness check.
93#[derive(Clone)]
94pub struct ReadinessState {
95    db_pool: sqlx::PgPool,
96    reactor: Arc<Reactor>,
97}
98
99/// Gateway HTTP server.
100pub struct GatewayServer {
101    config: GatewayConfig,
102    registry: FunctionRegistry,
103    db: Database,
104    reactor: Arc<Reactor>,
105    job_dispatcher: Option<Arc<dyn JobDispatch>>,
106    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
107    mcp_registry: Option<McpToolRegistry>,
108}
109
110impl GatewayServer {
111    /// Create a new gateway server.
112    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
113        let node_id = NodeId::new();
114        let reactor = Arc::new(Reactor::new(
115            node_id,
116            db.read_pool().clone(),
117            registry.clone(),
118            ReactorConfig::default(),
119        ));
120
121        Self {
122            config,
123            registry,
124            db,
125            reactor,
126            job_dispatcher: None,
127            workflow_dispatcher: None,
128            mcp_registry: None,
129        }
130    }
131
132    /// Set the job dispatcher.
133    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
134        self.job_dispatcher = Some(dispatcher);
135        self
136    }
137
138    /// Set the workflow dispatcher.
139    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
140        self.workflow_dispatcher = Some(dispatcher);
141        self
142    }
143
144    /// Set the MCP tool registry.
145    pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
146        self.mcp_registry = Some(registry);
147        self
148    }
149
150    /// Get a reference to the reactor.
151    pub fn reactor(&self) -> Arc<Reactor> {
152        self.reactor.clone()
153    }
154
155    /// Build the Axum router.
156    pub fn router(&self) -> Router {
157        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
158            self.registry.clone(),
159            self.db.clone(),
160            self.job_dispatcher.clone(),
161            self.workflow_dispatcher.clone(),
162        ));
163
164        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
165
166        // Build CORS layer
167        let cors = if self.config.cors_enabled {
168            if self.config.cors_origins.iter().any(|o| o == "*") {
169                CorsLayer::new()
170                    .allow_origin(Any)
171                    .allow_methods(Any)
172                    .allow_headers(Any)
173            } else {
174                let origins: Vec<_> = self
175                    .config
176                    .cors_origins
177                    .iter()
178                    .filter_map(|o| o.parse().ok())
179                    .collect();
180                CorsLayer::new()
181                    .allow_origin(origins)
182                    .allow_methods(Any)
183                    .allow_headers(Any)
184            }
185        } else {
186            CorsLayer::new()
187        };
188
189        // SSE state for Server-Sent Events
190        let sse_state = Arc::new(SseState::new(
191            self.reactor.clone(),
192            auth_middleware_state.clone(),
193        ));
194
195        // Readiness state for DB + reactor health check
196        let readiness_state = Arc::new(ReadinessState {
197            db_pool: self.db.primary().clone(),
198            reactor: self.reactor.clone(),
199        });
200
201        // Build the main router with middleware
202        let mut main_router = Router::new()
203            // Health check endpoint (liveness)
204            .route("/health", get(health_handler))
205            // Readiness check endpoint (checks DB)
206            .route("/ready", get(readiness_handler).with_state(readiness_state))
207            // RPC endpoint
208            .route("/rpc", post(rpc_handler))
209            // REST-style function endpoint (JSON)
210            .route("/rpc/{function}", post(rpc_function_handler))
211            // Prevent oversized JSON payloads from exhausting memory.
212            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
213            // Add state
214            .with_state(rpc_handler_state.clone());
215
216        // Multipart RPC router (separate state needed for multipart)
217        let multipart_router = Router::new()
218            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
219            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
220            // Cap upload fan-out; each request buffers data in memory.
221            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
222            .with_state(rpc_handler_state);
223
224        // SSE router
225        let sse_router = Router::new()
226            .route("/events", get(sse_handler))
227            .route("/subscribe", post(sse_subscribe_handler))
228            .route("/unsubscribe", post(sse_unsubscribe_handler))
229            .route("/subscribe-job", post(sse_job_subscribe_handler))
230            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
231            .with_state(sse_state);
232
233        let mut mcp_router = Router::new();
234        if self.config.mcp.enabled {
235            let path = self.config.mcp.path.clone();
236            let mcp_state = Arc::new(McpState::new(
237                self.config.mcp.clone(),
238                self.mcp_registry.clone().unwrap_or_default(),
239                self.db.primary().clone(),
240                self.job_dispatcher.clone(),
241                self.workflow_dispatcher.clone(),
242            ));
243            mcp_router = mcp_router.route(
244                &path,
245                post(mcp_post_handler)
246                    .get(mcp_get_handler)
247                    .with_state(mcp_state),
248            );
249        }
250
251        main_router = main_router
252            .merge(multipart_router)
253            .merge(sse_router)
254            .merge(mcp_router);
255
256        // Build middleware stack
257        let service_builder = ServiceBuilder::new()
258            .layer(HandleErrorLayer::new(handle_middleware_error))
259            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
260            .layer(TimeoutLayer::new(Duration::from_secs(
261                self.config.request_timeout_secs,
262            )))
263            .layer(cors.clone())
264            .layer(middleware::from_fn_with_state(
265                auth_middleware_state,
266                auth_middleware,
267            ))
268            .layer(middleware::from_fn(tracing_middleware));
269
270        // Apply the remaining middleware layers
271        main_router.layer(service_builder)
272    }
273
274    /// Get the socket address to bind to.
275    pub fn addr(&self) -> std::net::SocketAddr {
276        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
277    }
278
279    /// Run the server (blocking).
280    pub async fn run(self) -> Result<(), std::io::Error> {
281        let addr = self.addr();
282        let router = self.router();
283
284        // Start the reactor for real-time updates
285        self.reactor
286            .start()
287            .await
288            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
289        tracing::info!("Reactor started for real-time updates");
290
291        tracing::info!("Gateway server listening on {}", addr);
292
293        let listener = tokio::net::TcpListener::bind(addr).await?;
294        axum::serve(listener, router.into_make_service()).await
295    }
296}
297
298/// Health check handler (liveness probe).
299async fn health_handler() -> Json<HealthResponse> {
300    Json(HealthResponse {
301        status: "healthy".to_string(),
302        version: env!("CARGO_PKG_VERSION").to_string(),
303    })
304}
305
306/// Readiness check handler (readiness probe).
307async fn readiness_handler(
308    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
309) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
310    // Check database connectivity
311    let db_ok = sqlx::query("SELECT 1")
312        .fetch_one(&state.db_pool)
313        .await
314        .is_ok();
315
316    // Check reactor health (change listener must be running for real-time updates)
317    let reactor_stats = state.reactor.stats().await;
318    let reactor_ok = reactor_stats.listener_running;
319
320    let ready = db_ok && reactor_ok;
321    let status_code = if ready {
322        axum::http::StatusCode::OK
323    } else {
324        axum::http::StatusCode::SERVICE_UNAVAILABLE
325    };
326
327    (
328        status_code,
329        Json(ReadinessResponse {
330            ready,
331            database: db_ok,
332            reactor: reactor_ok,
333            version: env!("CARGO_PKG_VERSION").to_string(),
334        }),
335    )
336}
337
338async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
339    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
340        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
341    } else {
342        (
343            StatusCode::SERVICE_UNAVAILABLE,
344            "SERVICE_UNAVAILABLE",
345            "Server overloaded",
346        )
347    };
348    (
349        status,
350        Json(RpcResponse::error(RpcError::new(code, message))),
351    )
352        .into_response()
353}
354
355/// Simple tracing middleware that adds TracingState to extensions.
356async fn tracing_middleware(
357    req: axum::extract::Request,
358    next: axum::middleware::Next,
359) -> axum::response::Response {
360    let headers = req.headers();
361
362    let trace_id = headers
363        .get(TRACE_ID_HEADER)
364        .and_then(|v| v.to_str().ok())
365        .map(String::from)
366        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
367
368    let parent_span_id = headers
369        .get(SPAN_ID_HEADER)
370        .and_then(|v| v.to_str().ok())
371        .map(String::from);
372
373    let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
374    if let Some(span_id) = parent_span_id {
375        tracing_state = tracing_state.with_parent_span(span_id);
376    }
377
378    let mut req = req;
379    req.extensions_mut().insert(tracing_state.clone());
380
381    if req
382        .extensions()
383        .get::<forge_core::function::AuthContext>()
384        .is_none()
385    {
386        req.extensions_mut()
387            .insert(forge_core::function::AuthContext::unauthenticated());
388    }
389
390    let mut response = next.run(req).await;
391
392    let elapsed = tracing_state.elapsed();
393    tracing::debug!(
394        trace_id = %trace_id,
395        request_id = %tracing_state.request_id,
396        status = %response.status().as_u16(),
397        duration_ms = %elapsed.as_millis(),
398        "Request completed"
399    );
400
401    if let Ok(val) = trace_id.parse() {
402        response.headers_mut().insert(TRACE_ID_HEADER, val);
403    }
404    if let Ok(val) = tracing_state.request_id.parse() {
405        response.headers_mut().insert(REQUEST_ID_HEADER, val);
406    }
407
408    response
409}
410
411#[cfg(test)]
412#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_gateway_config_default() {
418        let config = GatewayConfig::default();
419        assert_eq!(config.port, 8080);
420        assert_eq!(config.max_connections, 512);
421        assert!(!config.cors_enabled);
422    }
423
424    #[test]
425    fn test_health_response_serialization() {
426        let resp = HealthResponse {
427            status: "healthy".to_string(),
428            version: "0.1.0".to_string(),
429        };
430        let json = serde_json::to_string(&resp).unwrap();
431        assert!(json.contains("healthy"));
432    }
433}