forge_runtime/gateway/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    middleware,
6    routing::{any, get, post},
7    Json, Router,
8};
9use serde::Serialize;
10use tower::ServiceBuilder;
11use tower_http::cors::{Any, CorsLayer};
12
13use forge_core::cluster::NodeId;
14use forge_core::function::{JobDispatch, WorkflowDispatch};
15
16use super::auth::{auth_middleware, AuthConfig, AuthMiddleware};
17use super::metrics::{metrics_middleware, MetricsState};
18use super::rpc::{rpc_function_handler, rpc_handler, RpcHandler};
19use super::tracing::TracingState;
20use super::websocket::{ws_handler, WsState};
21use crate::function::FunctionRegistry;
22use crate::observability::ObservabilityState;
23use crate::realtime::{Reactor, ReactorConfig};
24
25/// Gateway server configuration.
26#[derive(Debug, Clone)]
27pub struct GatewayConfig {
28    /// Port to listen on.
29    pub port: u16,
30    /// Maximum number of connections.
31    pub max_connections: usize,
32    /// Request timeout in seconds.
33    pub request_timeout_secs: u64,
34    /// Enable CORS.
35    pub cors_enabled: bool,
36    /// Allowed CORS origins.
37    pub cors_origins: Vec<String>,
38    /// Authentication configuration.
39    pub auth: AuthConfig,
40}
41
42impl Default for GatewayConfig {
43    fn default() -> Self {
44        Self {
45            port: 8080,
46            max_connections: 10000,
47            request_timeout_secs: 30,
48            cors_enabled: true,
49            cors_origins: vec!["*".to_string()],
50            auth: AuthConfig::default(),
51        }
52    }
53}
54
55/// Health check response.
56#[derive(Debug, Serialize)]
57pub struct HealthResponse {
58    pub status: String,
59    pub version: String,
60}
61
62/// Readiness check response.
63#[derive(Debug, Serialize)]
64pub struct ReadinessResponse {
65    pub ready: bool,
66    pub database: bool,
67    pub version: String,
68}
69
70/// State for readiness check.
71#[derive(Clone)]
72pub struct ReadinessState {
73    db_pool: sqlx::PgPool,
74}
75
76/// Gateway HTTP server.
77pub struct GatewayServer {
78    config: GatewayConfig,
79    registry: FunctionRegistry,
80    db_pool: sqlx::PgPool,
81    reactor: Arc<Reactor>,
82    observability: Option<ObservabilityState>,
83    job_dispatcher: Option<Arc<dyn JobDispatch>>,
84    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
85}
86
87impl GatewayServer {
88    /// Create a new gateway server.
89    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
90        let node_id = NodeId::new();
91        let reactor = Arc::new(Reactor::new(
92            node_id,
93            db_pool.clone(),
94            registry.clone(),
95            ReactorConfig::default(),
96        ));
97
98        Self {
99            config,
100            registry,
101            db_pool,
102            reactor,
103            observability: None,
104            job_dispatcher: None,
105            workflow_dispatcher: None,
106        }
107    }
108
109    /// Create a new gateway server with observability.
110    pub fn with_observability(
111        config: GatewayConfig,
112        registry: FunctionRegistry,
113        db_pool: sqlx::PgPool,
114        observability: ObservabilityState,
115    ) -> Self {
116        let node_id = NodeId::new();
117        let reactor = Arc::new(Reactor::new(
118            node_id,
119            db_pool.clone(),
120            registry.clone(),
121            ReactorConfig::default(),
122        ));
123
124        Self {
125            config,
126            registry,
127            db_pool,
128            reactor,
129            observability: Some(observability),
130            job_dispatcher: None,
131            workflow_dispatcher: None,
132        }
133    }
134
135    /// Set the job dispatcher.
136    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
137        self.job_dispatcher = Some(dispatcher);
138        self
139    }
140
141    /// Set the workflow dispatcher.
142    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
143        self.workflow_dispatcher = Some(dispatcher);
144        self
145    }
146
147    /// Get a reference to the reactor.
148    pub fn reactor(&self) -> Arc<Reactor> {
149        self.reactor.clone()
150    }
151
152    /// Build the Axum router.
153    pub fn router(&self) -> Router {
154        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
155            self.registry.clone(),
156            self.db_pool.clone(),
157            self.job_dispatcher.clone(),
158            self.workflow_dispatcher.clone(),
159        ));
160
161        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
162
163        // Build CORS layer
164        let cors = if self.config.cors_enabled {
165            if self.config.cors_origins.contains(&"*".to_string()) {
166                CorsLayer::new()
167                    .allow_origin(Any)
168                    .allow_methods(Any)
169                    .allow_headers(Any)
170            } else {
171                let origins: Vec<_> = self
172                    .config
173                    .cors_origins
174                    .iter()
175                    .filter_map(|o| o.parse().ok())
176                    .collect();
177                CorsLayer::new()
178                    .allow_origin(origins)
179                    .allow_methods(Any)
180                    .allow_headers(Any)
181            }
182        } else {
183            CorsLayer::new()
184        };
185
186        // WebSocket state uses the reactor and db_pool for session tracking
187        let node_id = self.reactor.node_id();
188        let ws_state = Arc::new(WsState::new(
189            self.reactor.clone(),
190            self.db_pool.clone(),
191            node_id,
192        ));
193
194        // Readiness state for DB health check
195        let readiness_state = Arc::new(ReadinessState {
196            db_pool: self.db_pool.clone(),
197        });
198
199        // Build the main router with middleware
200        let mut main_router = Router::new()
201            // Health check endpoint (liveness)
202            .route("/health", get(health_handler))
203            // Readiness check endpoint (checks DB)
204            .route("/ready", get(readiness_handler).with_state(readiness_state))
205            // RPC endpoint
206            .route("/rpc", post(rpc_handler))
207            // REST-style function endpoint
208            .route("/rpc/{function}", post(rpc_function_handler))
209            // Add state
210            .with_state(rpc_handler_state);
211
212        // Build middleware stack
213        let service_builder = ServiceBuilder::new()
214            .layer(cors.clone())
215            .layer(middleware::from_fn_with_state(
216                auth_middleware_state,
217                auth_middleware,
218            ))
219            .layer(middleware::from_fn(tracing_middleware));
220
221        // Add metrics middleware if observability is enabled
222        if let Some(ref observability) = self.observability {
223            let metrics_state = Arc::new(MetricsState::new(observability.clone()));
224            main_router = main_router.layer(middleware::from_fn_with_state(
225                metrics_state,
226                metrics_middleware,
227            ));
228        }
229
230        // Apply the remaining middleware layers
231        main_router = main_router.layer(service_builder);
232
233        // WebSocket router without auth middleware (just CORS)
234        let ws_router = Router::new()
235            .route("/ws", any(ws_handler).with_state(ws_state))
236            .layer(cors);
237
238        // Merge routers - WebSocket route is separate from middleware stack
239        main_router.merge(ws_router)
240    }
241
242    /// Get the socket address to bind to.
243    pub fn addr(&self) -> SocketAddr {
244        SocketAddr::from(([0, 0, 0, 0], self.config.port))
245    }
246
247    /// Run the server (blocking).
248    pub async fn run(self) -> Result<(), std::io::Error> {
249        let addr = self.addr();
250        let router = self.router();
251
252        // Start the reactor for real-time updates
253        if let Err(e) = self.reactor.start().await {
254            tracing::error!("Failed to start reactor: {}", e);
255        } else {
256            tracing::info!("Reactor started for real-time updates");
257        }
258
259        tracing::info!("Gateway server listening on {}", addr);
260
261        let listener = tokio::net::TcpListener::bind(addr).await?;
262        axum::serve(listener, router).await
263    }
264}
265
266/// Health check handler (liveness probe).
267async fn health_handler() -> Json<HealthResponse> {
268    Json(HealthResponse {
269        status: "healthy".to_string(),
270        version: env!("CARGO_PKG_VERSION").to_string(),
271    })
272}
273
274/// Readiness check handler (readiness probe).
275async fn readiness_handler(
276    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
277) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
278    // Check database connectivity
279    let db_ok = sqlx::query("SELECT 1")
280        .fetch_one(&state.db_pool)
281        .await
282        .is_ok();
283
284    let ready = db_ok;
285    let status_code = if ready {
286        axum::http::StatusCode::OK
287    } else {
288        axum::http::StatusCode::SERVICE_UNAVAILABLE
289    };
290
291    (
292        status_code,
293        Json(ReadinessResponse {
294            ready,
295            database: db_ok,
296            version: env!("CARGO_PKG_VERSION").to_string(),
297        }),
298    )
299}
300
301/// Simple tracing middleware that adds TracingState to extensions.
302async fn tracing_middleware(
303    req: axum::extract::Request,
304    next: axum::middleware::Next,
305) -> axum::response::Response {
306    use axum::http::header::HeaderName;
307
308    // Extract or generate trace ID
309    let trace_id = req
310        .headers()
311        .get(HeaderName::from_static("x-trace-id"))
312        .and_then(|v| v.to_str().ok())
313        .map(String::from)
314        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
315
316    let tracing_state = TracingState::with_trace_id(trace_id.clone());
317
318    let mut req = req;
319    req.extensions_mut().insert(tracing_state.clone());
320
321    // Also insert AuthContext default if not present
322    if req
323        .extensions()
324        .get::<forge_core::function::AuthContext>()
325        .is_none()
326    {
327        req.extensions_mut()
328            .insert(forge_core::function::AuthContext::unauthenticated());
329    }
330
331    let mut response = next.run(req).await;
332
333    // Add trace ID to response headers
334    if let Ok(val) = trace_id.parse() {
335        response.headers_mut().insert("x-trace-id", val);
336    }
337    if let Ok(val) = tracing_state.request_id.parse() {
338        response.headers_mut().insert("x-request-id", val);
339    }
340
341    response
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_gateway_config_default() {
350        let config = GatewayConfig::default();
351        assert_eq!(config.port, 8080);
352        assert_eq!(config.max_connections, 10000);
353        assert!(config.cors_enabled);
354    }
355
356    #[test]
357    fn test_health_response_serialization() {
358        let resp = HealthResponse {
359            status: "healthy".to_string(),
360            version: "0.1.0".to_string(),
361        };
362        let json = serde_json::to_string(&resp).unwrap();
363        assert!(json.contains("healthy"));
364    }
365}