Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5    Json, Router,
6    error_handling::HandleErrorLayer,
7    extract::DefaultBodyLimit,
8    http::StatusCode,
9    middleware,
10    response::IntoResponse,
11    routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::function::{JobDispatch, WorkflowDispatch};
22
23use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
24use super::multipart::rpc_multipart_handler;
25use super::response::{RpcError, RpcResponse};
26use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
27use super::sse::{
28    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
29    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
30};
31use super::tracing::TracingState;
32use crate::db::Database;
33use crate::function::FunctionRegistry;
34use crate::realtime::{Reactor, ReactorConfig};
35
36const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
37const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
38const MAX_MULTIPART_CONCURRENCY: usize = 32;
39
40/// Gateway server configuration.
41#[derive(Debug, Clone)]
42pub struct GatewayConfig {
43    /// Port to listen on.
44    pub port: u16,
45    /// Maximum number of connections.
46    pub max_connections: usize,
47    /// Request timeout in seconds.
48    pub request_timeout_secs: u64,
49    /// Enable CORS.
50    pub cors_enabled: bool,
51    /// Allowed CORS origins.
52    pub cors_origins: Vec<String>,
53    /// Authentication configuration.
54    pub auth: AuthConfig,
55}
56
57impl Default for GatewayConfig {
58    fn default() -> Self {
59        Self {
60            port: 8080,
61            max_connections: 512,
62            request_timeout_secs: 30,
63            cors_enabled: false,
64            cors_origins: Vec::new(),
65            auth: AuthConfig::default(),
66        }
67    }
68}
69
70/// Health check response.
71#[derive(Debug, Serialize)]
72pub struct HealthResponse {
73    pub status: String,
74    pub version: String,
75}
76
77/// Readiness check response.
78#[derive(Debug, Serialize)]
79pub struct ReadinessResponse {
80    pub ready: bool,
81    pub database: bool,
82    pub reactor: bool,
83    pub version: String,
84}
85
86/// State for readiness check.
87#[derive(Clone)]
88pub struct ReadinessState {
89    db_pool: sqlx::PgPool,
90    reactor: Arc<Reactor>,
91}
92
93/// Gateway HTTP server.
94pub struct GatewayServer {
95    config: GatewayConfig,
96    registry: FunctionRegistry,
97    db: Database,
98    reactor: Arc<Reactor>,
99    job_dispatcher: Option<Arc<dyn JobDispatch>>,
100    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
101}
102
103impl GatewayServer {
104    /// Create a new gateway server.
105    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
106        let node_id = NodeId::new();
107        let reactor = Arc::new(Reactor::new(
108            node_id,
109            db.read_pool().clone(),
110            registry.clone(),
111            ReactorConfig::default(),
112        ));
113
114        Self {
115            config,
116            registry,
117            db,
118            reactor,
119            job_dispatcher: None,
120            workflow_dispatcher: None,
121        }
122    }
123
124    /// Set the job dispatcher.
125    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
126        self.job_dispatcher = Some(dispatcher);
127        self
128    }
129
130    /// Set the workflow dispatcher.
131    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
132        self.workflow_dispatcher = Some(dispatcher);
133        self
134    }
135
136    /// Get a reference to the reactor.
137    pub fn reactor(&self) -> Arc<Reactor> {
138        self.reactor.clone()
139    }
140
141    /// Build the Axum router.
142    pub fn router(&self) -> Router {
143        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
144            self.registry.clone(),
145            self.db.clone(),
146            self.job_dispatcher.clone(),
147            self.workflow_dispatcher.clone(),
148        ));
149
150        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
151
152        // Build CORS layer
153        let cors = if self.config.cors_enabled {
154            if self.config.cors_origins.iter().any(|o| o == "*") {
155                CorsLayer::new()
156                    .allow_origin(Any)
157                    .allow_methods(Any)
158                    .allow_headers(Any)
159            } else {
160                let origins: Vec<_> = self
161                    .config
162                    .cors_origins
163                    .iter()
164                    .filter_map(|o| o.parse().ok())
165                    .collect();
166                CorsLayer::new()
167                    .allow_origin(origins)
168                    .allow_methods(Any)
169                    .allow_headers(Any)
170            }
171        } else {
172            CorsLayer::new()
173        };
174
175        // SSE state for Server-Sent Events
176        let sse_state = Arc::new(SseState::new(
177            self.reactor.clone(),
178            auth_middleware_state.clone(),
179        ));
180
181        // Readiness state for DB + reactor health check
182        let readiness_state = Arc::new(ReadinessState {
183            db_pool: self.db.primary().clone(),
184            reactor: self.reactor.clone(),
185        });
186
187        // Build the main router with middleware
188        let mut main_router = Router::new()
189            // Health check endpoint (liveness)
190            .route("/health", get(health_handler))
191            // Readiness check endpoint (checks DB)
192            .route("/ready", get(readiness_handler).with_state(readiness_state))
193            // RPC endpoint
194            .route("/rpc", post(rpc_handler))
195            // REST-style function endpoint (JSON)
196            .route("/rpc/{function}", post(rpc_function_handler))
197            // Prevent oversized JSON payloads from exhausting memory.
198            .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
199            // Add state
200            .with_state(rpc_handler_state.clone());
201
202        // Multipart RPC router (separate state needed for multipart)
203        let multipart_router = Router::new()
204            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
205            .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
206            // Cap upload fan-out; each request buffers data in memory.
207            .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
208            .with_state(rpc_handler_state);
209
210        // SSE router
211        let sse_router = Router::new()
212            .route("/events", get(sse_handler))
213            .route("/subscribe", post(sse_subscribe_handler))
214            .route("/unsubscribe", post(sse_unsubscribe_handler))
215            .route("/subscribe-job", post(sse_job_subscribe_handler))
216            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
217            .with_state(sse_state);
218
219        main_router = main_router.merge(multipart_router).merge(sse_router);
220
221        // Build middleware stack
222        let service_builder = ServiceBuilder::new()
223            .layer(HandleErrorLayer::new(handle_middleware_error))
224            .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
225            .layer(TimeoutLayer::new(Duration::from_secs(
226                self.config.request_timeout_secs,
227            )))
228            .layer(cors.clone())
229            .layer(middleware::from_fn_with_state(
230                auth_middleware_state,
231                auth_middleware,
232            ))
233            .layer(middleware::from_fn(tracing_middleware));
234
235        // Apply the remaining middleware layers
236        main_router.layer(service_builder)
237    }
238
239    /// Get the socket address to bind to.
240    pub fn addr(&self) -> std::net::SocketAddr {
241        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
242    }
243
244    /// Run the server (blocking).
245    pub async fn run(self) -> Result<(), std::io::Error> {
246        let addr = self.addr();
247        let router = self.router();
248
249        // Start the reactor for real-time updates
250        self.reactor
251            .start()
252            .await
253            .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
254        tracing::info!("Reactor started for real-time updates");
255
256        tracing::info!("Gateway server listening on {}", addr);
257
258        let listener = tokio::net::TcpListener::bind(addr).await?;
259        axum::serve(listener, router.into_make_service()).await
260    }
261}
262
263/// Health check handler (liveness probe).
264async fn health_handler() -> Json<HealthResponse> {
265    Json(HealthResponse {
266        status: "healthy".to_string(),
267        version: env!("CARGO_PKG_VERSION").to_string(),
268    })
269}
270
271/// Readiness check handler (readiness probe).
272async fn readiness_handler(
273    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
274) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
275    // Check database connectivity
276    let db_ok = sqlx::query("SELECT 1")
277        .fetch_one(&state.db_pool)
278        .await
279        .is_ok();
280
281    // Check reactor health (change listener must be running for real-time updates)
282    let reactor_stats = state.reactor.stats().await;
283    let reactor_ok = reactor_stats.listener_running;
284
285    let ready = db_ok && reactor_ok;
286    let status_code = if ready {
287        axum::http::StatusCode::OK
288    } else {
289        axum::http::StatusCode::SERVICE_UNAVAILABLE
290    };
291
292    (
293        status_code,
294        Json(ReadinessResponse {
295            ready,
296            database: db_ok,
297            reactor: reactor_ok,
298            version: env!("CARGO_PKG_VERSION").to_string(),
299        }),
300    )
301}
302
303async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
304    let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
305        (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
306    } else {
307        (
308            StatusCode::SERVICE_UNAVAILABLE,
309            "SERVICE_UNAVAILABLE",
310            "Server overloaded",
311        )
312    };
313    (
314        status,
315        Json(RpcResponse::error(RpcError::new(code, message))),
316    )
317        .into_response()
318}
319
320/// Simple tracing middleware that adds TracingState to extensions.
321async fn tracing_middleware(
322    req: axum::extract::Request,
323    next: axum::middleware::Next,
324) -> axum::response::Response {
325    use axum::http::header::HeaderName;
326
327    // Extract or generate trace ID
328    let trace_id = req
329        .headers()
330        .get(HeaderName::from_static("x-trace-id"))
331        .and_then(|v| v.to_str().ok())
332        .map(String::from)
333        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
334
335    let tracing_state = TracingState::with_trace_id(trace_id.clone());
336
337    let mut req = req;
338    req.extensions_mut().insert(tracing_state.clone());
339
340    // Also insert AuthContext default if not present
341    if req
342        .extensions()
343        .get::<forge_core::function::AuthContext>()
344        .is_none()
345    {
346        req.extensions_mut()
347            .insert(forge_core::function::AuthContext::unauthenticated());
348    }
349
350    let mut response = next.run(req).await;
351
352    // Add trace ID to response headers
353    if let Ok(val) = trace_id.parse() {
354        response.headers_mut().insert("x-trace-id", val);
355    }
356    if let Ok(val) = tracing_state.request_id.parse() {
357        response.headers_mut().insert("x-request-id", val);
358    }
359
360    response
361}
362
363#[cfg(test)]
364#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_gateway_config_default() {
370        let config = GatewayConfig::default();
371        assert_eq!(config.port, 8080);
372        assert_eq!(config.max_connections, 512);
373        assert!(!config.cors_enabled);
374    }
375
376    #[test]
377    fn test_health_response_serialization() {
378        let resp = HealthResponse {
379            status: "healthy".to_string(),
380            version: "0.1.0".to_string(),
381        };
382        let json = serde_json::to_string(&resp).unwrap();
383        assert!(json.contains("healthy"));
384    }
385}