forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2
3use axum::{
4    Json, Router, middleware,
5    routing::{any, get, post},
6};
7use serde::Serialize;
8use tower::ServiceBuilder;
9use tower_http::cors::{Any, CorsLayer};
10
11use forge_core::cluster::NodeId;
12use forge_core::function::{JobDispatch, WorkflowDispatch};
13
14use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
15use super::metrics::{MetricsState, metrics_middleware};
16use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
17use super::tracing::TracingState;
18use super::websocket::{WsState, ws_handler};
19use crate::function::FunctionRegistry;
20use crate::observability::ObservabilityState;
21use crate::realtime::{Reactor, ReactorConfig};
22
23/// Gateway server configuration.
24#[derive(Debug, Clone)]
25pub struct GatewayConfig {
26    /// Port to listen on.
27    pub port: u16,
28    /// Maximum number of connections.
29    pub max_connections: usize,
30    /// Request timeout in seconds.
31    pub request_timeout_secs: u64,
32    /// Enable CORS.
33    pub cors_enabled: bool,
34    /// Allowed CORS origins.
35    pub cors_origins: Vec<String>,
36    /// Authentication configuration.
37    pub auth: AuthConfig,
38}
39
40impl Default for GatewayConfig {
41    fn default() -> Self {
42        Self {
43            port: 8080,
44            max_connections: 10000,
45            request_timeout_secs: 30,
46            cors_enabled: true,
47            cors_origins: vec!["*".to_string()],
48            auth: AuthConfig::default(),
49        }
50    }
51}
52
53/// Health check response.
54#[derive(Debug, Serialize)]
55pub struct HealthResponse {
56    pub status: String,
57    pub version: String,
58}
59
60/// Readiness check response.
61#[derive(Debug, Serialize)]
62pub struct ReadinessResponse {
63    pub ready: bool,
64    pub database: bool,
65    pub version: String,
66}
67
68/// State for readiness check.
69#[derive(Clone)]
70pub struct ReadinessState {
71    db_pool: sqlx::PgPool,
72}
73
74/// Gateway HTTP server.
75pub struct GatewayServer {
76    config: GatewayConfig,
77    registry: FunctionRegistry,
78    db_pool: sqlx::PgPool,
79    reactor: Arc<Reactor>,
80    observability: Option<ObservabilityState>,
81    job_dispatcher: Option<Arc<dyn JobDispatch>>,
82    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
83}
84
85impl GatewayServer {
86    /// Create a new gateway server.
87    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
88        let node_id = NodeId::new();
89        let reactor = Arc::new(Reactor::new(
90            node_id,
91            db_pool.clone(),
92            registry.clone(),
93            ReactorConfig::default(),
94        ));
95
96        Self {
97            config,
98            registry,
99            db_pool,
100            reactor,
101            observability: None,
102            job_dispatcher: None,
103            workflow_dispatcher: None,
104        }
105    }
106
107    /// Create a new gateway server with observability.
108    pub fn with_observability(
109        config: GatewayConfig,
110        registry: FunctionRegistry,
111        db_pool: sqlx::PgPool,
112        observability: ObservabilityState,
113    ) -> Self {
114        let node_id = NodeId::new();
115        let reactor = Arc::new(Reactor::new(
116            node_id,
117            db_pool.clone(),
118            registry.clone(),
119            ReactorConfig::default(),
120        ));
121
122        Self {
123            config,
124            registry,
125            db_pool,
126            reactor,
127            observability: Some(observability),
128            job_dispatcher: None,
129            workflow_dispatcher: None,
130        }
131    }
132
133    /// Set the job dispatcher.
134    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
135        self.job_dispatcher = Some(dispatcher);
136        self
137    }
138
139    /// Set the workflow dispatcher.
140    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
141        self.workflow_dispatcher = Some(dispatcher);
142        self
143    }
144
145    /// Get a reference to the reactor.
146    pub fn reactor(&self) -> Arc<Reactor> {
147        self.reactor.clone()
148    }
149
150    /// Build the Axum router.
151    pub fn router(&self) -> Router {
152        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
153            self.registry.clone(),
154            self.db_pool.clone(),
155            self.job_dispatcher.clone(),
156            self.workflow_dispatcher.clone(),
157        ));
158
159        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
160
161        // Build CORS layer
162        let cors = if self.config.cors_enabled {
163            if self.config.cors_origins.contains(&"*".to_string()) {
164                CorsLayer::new()
165                    .allow_origin(Any)
166                    .allow_methods(Any)
167                    .allow_headers(Any)
168            } else {
169                let origins: Vec<_> = self
170                    .config
171                    .cors_origins
172                    .iter()
173                    .filter_map(|o| o.parse().ok())
174                    .collect();
175                CorsLayer::new()
176                    .allow_origin(origins)
177                    .allow_methods(Any)
178                    .allow_headers(Any)
179            }
180        } else {
181            CorsLayer::new()
182        };
183
184        // WebSocket state uses the reactor, db_pool, and auth middleware for session tracking
185        let node_id = self.reactor.node_id();
186        let ws_state = Arc::new(WsState::with_auth(
187            self.reactor.clone(),
188            self.db_pool.clone(),
189            node_id,
190            auth_middleware_state.clone(),
191        ));
192
193        // Readiness state for DB health check
194        let readiness_state = Arc::new(ReadinessState {
195            db_pool: self.db_pool.clone(),
196        });
197
198        // Build the main router with middleware
199        let mut main_router = Router::new()
200            // Health check endpoint (liveness)
201            .route("/health", get(health_handler))
202            // Readiness check endpoint (checks DB)
203            .route("/ready", get(readiness_handler).with_state(readiness_state))
204            // RPC endpoint
205            .route("/rpc", post(rpc_handler))
206            // REST-style function endpoint
207            .route("/rpc/{function}", post(rpc_function_handler))
208            // Add state
209            .with_state(rpc_handler_state);
210
211        // Build middleware stack
212        let service_builder = ServiceBuilder::new()
213            .layer(cors.clone())
214            .layer(middleware::from_fn_with_state(
215                auth_middleware_state,
216                auth_middleware,
217            ))
218            .layer(middleware::from_fn(tracing_middleware));
219
220        // Add metrics middleware if observability is enabled
221        if let Some(ref observability) = self.observability {
222            let metrics_state = Arc::new(MetricsState::new(observability.clone()));
223            main_router = main_router.layer(middleware::from_fn_with_state(
224                metrics_state,
225                metrics_middleware,
226            ));
227        }
228
229        // Apply the remaining middleware layers
230        main_router = main_router.layer(service_builder);
231
232        // WebSocket router without auth middleware (just CORS)
233        let ws_router = Router::new()
234            .route("/ws", any(ws_handler).with_state(ws_state))
235            .layer(cors);
236
237        // Merge routers - WebSocket route is separate from middleware stack
238        main_router.merge(ws_router)
239    }
240
241    /// Get the socket address to bind to.
242    pub fn addr(&self) -> std::net::SocketAddr {
243        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
244    }
245
246    /// Run the server (blocking).
247    pub async fn run(self) -> Result<(), std::io::Error> {
248        let addr = self.addr();
249        let router = self.router();
250
251        // Start the reactor for real-time updates
252        if let Err(e) = self.reactor.start().await {
253            tracing::error!("Failed to start reactor: {}", e);
254        } else {
255            tracing::info!("Reactor started for real-time updates");
256        }
257
258        tracing::info!("Gateway server listening on {}", addr);
259
260        let listener = tokio::net::TcpListener::bind(addr).await?;
261        axum::serve(listener, router.into_make_service()).await
262    }
263}
264
265/// Health check handler (liveness probe).
266async fn health_handler() -> Json<HealthResponse> {
267    Json(HealthResponse {
268        status: "healthy".to_string(),
269        version: env!("CARGO_PKG_VERSION").to_string(),
270    })
271}
272
273/// Readiness check handler (readiness probe).
274async fn readiness_handler(
275    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
276) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
277    // Check database connectivity
278    let db_ok = sqlx::query("SELECT 1")
279        .fetch_one(&state.db_pool)
280        .await
281        .is_ok();
282
283    let ready = db_ok;
284    let status_code = if ready {
285        axum::http::StatusCode::OK
286    } else {
287        axum::http::StatusCode::SERVICE_UNAVAILABLE
288    };
289
290    (
291        status_code,
292        Json(ReadinessResponse {
293            ready,
294            database: db_ok,
295            version: env!("CARGO_PKG_VERSION").to_string(),
296        }),
297    )
298}
299
300/// Simple tracing middleware that adds TracingState to extensions.
301async fn tracing_middleware(
302    req: axum::extract::Request,
303    next: axum::middleware::Next,
304) -> axum::response::Response {
305    use axum::http::header::HeaderName;
306
307    // Extract or generate trace ID
308    let trace_id = req
309        .headers()
310        .get(HeaderName::from_static("x-trace-id"))
311        .and_then(|v| v.to_str().ok())
312        .map(String::from)
313        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
314
315    let tracing_state = TracingState::with_trace_id(trace_id.clone());
316
317    let mut req = req;
318    req.extensions_mut().insert(tracing_state.clone());
319
320    // Also insert AuthContext default if not present
321    if req
322        .extensions()
323        .get::<forge_core::function::AuthContext>()
324        .is_none()
325    {
326        req.extensions_mut()
327            .insert(forge_core::function::AuthContext::unauthenticated());
328    }
329
330    let mut response = next.run(req).await;
331
332    // Add trace ID to response headers
333    if let Ok(val) = trace_id.parse() {
334        response.headers_mut().insert("x-trace-id", val);
335    }
336    if let Ok(val) = tracing_state.request_id.parse() {
337        response.headers_mut().insert("x-request-id", val);
338    }
339
340    response
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_gateway_config_default() {
349        let config = GatewayConfig::default();
350        assert_eq!(config.port, 8080);
351        assert_eq!(config.max_connections, 10000);
352        assert!(config.cors_enabled);
353    }
354
355    #[test]
356    fn test_health_response_serialization() {
357        let resp = HealthResponse {
358            status: "healthy".to_string(),
359            version: "0.1.0".to_string(),
360        };
361        let json = serde_json::to_string(&resp).unwrap();
362        assert!(json.contains("healthy"));
363    }
364}