forge_runtime/gateway/
server.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5    Json, Router, middleware,
6    routing::{any, get, post},
7};
8use serde::Serialize;
9use tower::ServiceBuilder;
10use tower_http::cors::{Any, CorsLayer};
11
12use forge_core::cluster::NodeId;
13use forge_core::function::{JobDispatch, WorkflowDispatch};
14
15use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
16use super::metrics::{MetricsState, metrics_middleware};
17use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
18use super::tracing::TracingState;
19use super::websocket::{WsState, ws_handler};
20use crate::function::FunctionRegistry;
21use crate::observability::ObservabilityState;
22use crate::realtime::{Reactor, ReactorConfig};
23
24/// Gateway server configuration.
25#[derive(Debug, Clone)]
26pub struct GatewayConfig {
27    /// Port to listen on.
28    pub port: u16,
29    /// Maximum number of connections.
30    pub max_connections: usize,
31    /// Request timeout in seconds.
32    pub request_timeout_secs: u64,
33    /// Enable CORS.
34    pub cors_enabled: bool,
35    /// Allowed CORS origins.
36    pub cors_origins: Vec<String>,
37    /// Authentication configuration.
38    pub auth: AuthConfig,
39}
40
41impl Default for GatewayConfig {
42    fn default() -> Self {
43        Self {
44            port: 8080,
45            max_connections: 10000,
46            request_timeout_secs: 30,
47            cors_enabled: true,
48            cors_origins: vec!["*".to_string()],
49            auth: AuthConfig::default(),
50        }
51    }
52}
53
54/// Health check response.
55#[derive(Debug, Serialize)]
56pub struct HealthResponse {
57    pub status: String,
58    pub version: String,
59}
60
61/// Readiness check response.
62#[derive(Debug, Serialize)]
63pub struct ReadinessResponse {
64    pub ready: bool,
65    pub database: bool,
66    pub version: String,
67}
68
69/// State for readiness check.
70#[derive(Clone)]
71pub struct ReadinessState {
72    db_pool: sqlx::PgPool,
73}
74
75/// Gateway HTTP server.
76pub struct GatewayServer {
77    config: GatewayConfig,
78    registry: FunctionRegistry,
79    db_pool: sqlx::PgPool,
80    reactor: Arc<Reactor>,
81    observability: Option<ObservabilityState>,
82    job_dispatcher: Option<Arc<dyn JobDispatch>>,
83    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
84}
85
86impl GatewayServer {
87    /// Create a new gateway server.
88    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
89        let node_id = NodeId::new();
90        let reactor = Arc::new(Reactor::new(
91            node_id,
92            db_pool.clone(),
93            registry.clone(),
94            ReactorConfig::default(),
95        ));
96
97        Self {
98            config,
99            registry,
100            db_pool,
101            reactor,
102            observability: None,
103            job_dispatcher: None,
104            workflow_dispatcher: None,
105        }
106    }
107
108    /// Create a new gateway server with observability.
109    pub fn with_observability(
110        config: GatewayConfig,
111        registry: FunctionRegistry,
112        db_pool: sqlx::PgPool,
113        observability: ObservabilityState,
114    ) -> Self {
115        let node_id = NodeId::new();
116        let reactor = Arc::new(Reactor::new(
117            node_id,
118            db_pool.clone(),
119            registry.clone(),
120            ReactorConfig::default(),
121        ));
122
123        Self {
124            config,
125            registry,
126            db_pool,
127            reactor,
128            observability: Some(observability),
129            job_dispatcher: None,
130            workflow_dispatcher: None,
131        }
132    }
133
134    /// Set the job dispatcher.
135    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
136        self.job_dispatcher = Some(dispatcher);
137        self
138    }
139
140    /// Set the workflow dispatcher.
141    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
142        self.workflow_dispatcher = Some(dispatcher);
143        self
144    }
145
146    /// Get a reference to the reactor.
147    pub fn reactor(&self) -> Arc<Reactor> {
148        self.reactor.clone()
149    }
150
151    /// Build the Axum router.
152    pub fn router(&self) -> Router {
153        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
154            self.registry.clone(),
155            self.db_pool.clone(),
156            self.job_dispatcher.clone(),
157            self.workflow_dispatcher.clone(),
158        ));
159
160        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
161
162        // Build CORS layer
163        let cors = if self.config.cors_enabled {
164            if self.config.cors_origins.contains(&"*".to_string()) {
165                CorsLayer::new()
166                    .allow_origin(Any)
167                    .allow_methods(Any)
168                    .allow_headers(Any)
169            } else {
170                let origins: Vec<_> = self
171                    .config
172                    .cors_origins
173                    .iter()
174                    .filter_map(|o| o.parse().ok())
175                    .collect();
176                CorsLayer::new()
177                    .allow_origin(origins)
178                    .allow_methods(Any)
179                    .allow_headers(Any)
180            }
181        } else {
182            CorsLayer::new()
183        };
184
185        // WebSocket state uses the reactor and db_pool for session tracking
186        let node_id = self.reactor.node_id();
187        let ws_state = Arc::new(WsState::new(
188            self.reactor.clone(),
189            self.db_pool.clone(),
190            node_id,
191        ));
192
193        // Readiness state for DB health check
194        let readiness_state = Arc::new(ReadinessState {
195            db_pool: self.db_pool.clone(),
196        });
197
198        // Build the main router with middleware
199        let mut main_router = Router::new()
200            // Health check endpoint (liveness)
201            .route("/health", get(health_handler))
202            // Readiness check endpoint (checks DB)
203            .route("/ready", get(readiness_handler).with_state(readiness_state))
204            // RPC endpoint
205            .route("/rpc", post(rpc_handler))
206            // REST-style function endpoint
207            .route("/rpc/{function}", post(rpc_function_handler))
208            // Add state
209            .with_state(rpc_handler_state);
210
211        // Build middleware stack
212        let service_builder = ServiceBuilder::new()
213            .layer(cors.clone())
214            .layer(middleware::from_fn_with_state(
215                auth_middleware_state,
216                auth_middleware,
217            ))
218            .layer(middleware::from_fn(tracing_middleware));
219
220        // Add metrics middleware if observability is enabled
221        if let Some(ref observability) = self.observability {
222            let metrics_state = Arc::new(MetricsState::new(observability.clone()));
223            main_router = main_router.layer(middleware::from_fn_with_state(
224                metrics_state,
225                metrics_middleware,
226            ));
227        }
228
229        // Apply the remaining middleware layers
230        main_router = main_router.layer(service_builder);
231
232        // WebSocket router without auth middleware (just CORS)
233        let ws_router = Router::new()
234            .route("/ws", any(ws_handler).with_state(ws_state))
235            .layer(cors);
236
237        // Merge routers - WebSocket route is separate from middleware stack
238        main_router.merge(ws_router)
239    }
240
241    /// Get the socket address to bind to.
242    pub fn addr(&self) -> SocketAddr {
243        SocketAddr::from(([0, 0, 0, 0], self.config.port))
244    }
245
246    /// Run the server (blocking).
247    pub async fn run(self) -> Result<(), std::io::Error> {
248        let addr = self.addr();
249        let router = self.router();
250
251        // Start the reactor for real-time updates
252        if let Err(e) = self.reactor.start().await {
253            tracing::error!("Failed to start reactor: {}", e);
254        } else {
255            tracing::info!("Reactor started for real-time updates");
256        }
257
258        tracing::info!("Gateway server listening on {}", addr);
259
260        let listener = tokio::net::TcpListener::bind(addr).await?;
261        axum::serve(listener, router).await
262    }
263}
264
265/// Health check handler (liveness probe).
266async fn health_handler() -> Json<HealthResponse> {
267    Json(HealthResponse {
268        status: "healthy".to_string(),
269        version: env!("CARGO_PKG_VERSION").to_string(),
270    })
271}
272
273/// Readiness check handler (readiness probe).
274async fn readiness_handler(
275    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
276) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
277    // Check database connectivity
278    let db_ok = sqlx::query("SELECT 1")
279        .fetch_one(&state.db_pool)
280        .await
281        .is_ok();
282
283    let ready = db_ok;
284    let status_code = if ready {
285        axum::http::StatusCode::OK
286    } else {
287        axum::http::StatusCode::SERVICE_UNAVAILABLE
288    };
289
290    (
291        status_code,
292        Json(ReadinessResponse {
293            ready,
294            database: db_ok,
295            version: env!("CARGO_PKG_VERSION").to_string(),
296        }),
297    )
298}
299
300/// Simple tracing middleware that adds TracingState to extensions.
301async fn tracing_middleware(
302    req: axum::extract::Request,
303    next: axum::middleware::Next,
304) -> axum::response::Response {
305    use axum::http::header::HeaderName;
306
307    // Extract or generate trace ID
308    let trace_id = req
309        .headers()
310        .get(HeaderName::from_static("x-trace-id"))
311        .and_then(|v| v.to_str().ok())
312        .map(String::from)
313        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
314
315    let tracing_state = TracingState::with_trace_id(trace_id.clone());
316
317    let mut req = req;
318    req.extensions_mut().insert(tracing_state.clone());
319
320    // Also insert AuthContext default if not present
321    if req
322        .extensions()
323        .get::<forge_core::function::AuthContext>()
324        .is_none()
325    {
326        req.extensions_mut()
327            .insert(forge_core::function::AuthContext::unauthenticated());
328    }
329
330    let mut response = next.run(req).await;
331
332    // Add trace ID to response headers
333    if let Ok(val) = trace_id.parse() {
334        response.headers_mut().insert("x-trace-id", val);
335    }
336    if let Ok(val) = tracing_state.request_id.parse() {
337        response.headers_mut().insert("x-request-id", val);
338    }
339
340    response
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_gateway_config_default() {
349        let config = GatewayConfig::default();
350        assert_eq!(config.port, 8080);
351        assert_eq!(config.max_connections, 10000);
352        assert!(config.cors_enabled);
353    }
354
355    #[test]
356    fn test_health_response_serialization() {
357        let resp = HealthResponse {
358            status: "healthy".to_string(),
359            version: "0.1.0".to_string(),
360        };
361        let json = serde_json::to_string(&resp).unwrap();
362        assert!(json.contains("healthy"));
363    }
364}