forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2
3use axum::{
4    Json, Router, middleware,
5    routing::{get, post},
6};
7use serde::Serialize;
8use tower::ServiceBuilder;
9use tower_http::cors::{Any, CorsLayer};
10
11use forge_core::cluster::NodeId;
12use forge_core::function::{JobDispatch, WorkflowDispatch};
13
14use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
15use super::metrics::{MetricsState, metrics_middleware};
16use super::multipart::rpc_multipart_handler;
17use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
18use super::sse::{
19    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
20    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
21};
22use super::tracing::TracingState;
23use crate::function::FunctionRegistry;
24use crate::observability::ObservabilityState;
25use crate::realtime::{Reactor, ReactorConfig};
26
27/// Gateway server configuration.
28#[derive(Debug, Clone)]
29pub struct GatewayConfig {
30    /// Port to listen on.
31    pub port: u16,
32    /// Maximum number of connections.
33    pub max_connections: usize,
34    /// Request timeout in seconds.
35    pub request_timeout_secs: u64,
36    /// Enable CORS.
37    pub cors_enabled: bool,
38    /// Allowed CORS origins.
39    pub cors_origins: Vec<String>,
40    /// Authentication configuration.
41    pub auth: AuthConfig,
42}
43
44impl Default for GatewayConfig {
45    fn default() -> Self {
46        Self {
47            port: 8080,
48            max_connections: 10000,
49            request_timeout_secs: 30,
50            cors_enabled: true,
51            cors_origins: vec!["*".to_string()],
52            auth: AuthConfig::default(),
53        }
54    }
55}
56
57/// Health check response.
58#[derive(Debug, Serialize)]
59pub struct HealthResponse {
60    pub status: String,
61    pub version: String,
62}
63
64/// Readiness check response.
65#[derive(Debug, Serialize)]
66pub struct ReadinessResponse {
67    pub ready: bool,
68    pub database: bool,
69    pub version: String,
70}
71
72/// State for readiness check.
73#[derive(Clone)]
74pub struct ReadinessState {
75    db_pool: sqlx::PgPool,
76}
77
78/// Gateway HTTP server.
79pub struct GatewayServer {
80    config: GatewayConfig,
81    registry: FunctionRegistry,
82    db_pool: sqlx::PgPool,
83    reactor: Arc<Reactor>,
84    observability: Option<ObservabilityState>,
85    job_dispatcher: Option<Arc<dyn JobDispatch>>,
86    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
87}
88
89impl GatewayServer {
90    /// Create a new gateway server.
91    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
92        let node_id = NodeId::new();
93        let reactor = Arc::new(Reactor::new(
94            node_id,
95            db_pool.clone(),
96            registry.clone(),
97            ReactorConfig::default(),
98        ));
99
100        Self {
101            config,
102            registry,
103            db_pool,
104            reactor,
105            observability: None,
106            job_dispatcher: None,
107            workflow_dispatcher: None,
108        }
109    }
110
111    /// Create a new gateway server with observability.
112    pub fn with_observability(
113        config: GatewayConfig,
114        registry: FunctionRegistry,
115        db_pool: sqlx::PgPool,
116        observability: ObservabilityState,
117    ) -> Self {
118        let node_id = NodeId::new();
119        let reactor = Arc::new(Reactor::new(
120            node_id,
121            db_pool.clone(),
122            registry.clone(),
123            ReactorConfig::default(),
124        ));
125
126        Self {
127            config,
128            registry,
129            db_pool,
130            reactor,
131            observability: Some(observability),
132            job_dispatcher: None,
133            workflow_dispatcher: None,
134        }
135    }
136
137    /// Set the job dispatcher.
138    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
139        self.job_dispatcher = Some(dispatcher);
140        self
141    }
142
143    /// Set the workflow dispatcher.
144    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
145        self.workflow_dispatcher = Some(dispatcher);
146        self
147    }
148
149    /// Get a reference to the reactor.
150    pub fn reactor(&self) -> Arc<Reactor> {
151        self.reactor.clone()
152    }
153
154    /// Build the Axum router.
155    pub fn router(&self) -> Router {
156        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
157            self.registry.clone(),
158            self.db_pool.clone(),
159            self.job_dispatcher.clone(),
160            self.workflow_dispatcher.clone(),
161        ));
162
163        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
164
165        // Build CORS layer
166        let cors = if self.config.cors_enabled {
167            if self.config.cors_origins.contains(&"*".to_string()) {
168                CorsLayer::new()
169                    .allow_origin(Any)
170                    .allow_methods(Any)
171                    .allow_headers(Any)
172            } else {
173                let origins: Vec<_> = self
174                    .config
175                    .cors_origins
176                    .iter()
177                    .filter_map(|o| o.parse().ok())
178                    .collect();
179                CorsLayer::new()
180                    .allow_origin(origins)
181                    .allow_methods(Any)
182                    .allow_headers(Any)
183            }
184        } else {
185            CorsLayer::new()
186        };
187
188        // SSE state for Server-Sent Events
189        let sse_state = Arc::new(SseState::new(
190            self.reactor.clone(),
191            auth_middleware_state.clone(),
192        ));
193
194        // Readiness state for DB health check
195        let readiness_state = Arc::new(ReadinessState {
196            db_pool: self.db_pool.clone(),
197        });
198
199        // Build the main router with middleware
200        let mut main_router = Router::new()
201            // Health check endpoint (liveness)
202            .route("/health", get(health_handler))
203            // Readiness check endpoint (checks DB)
204            .route("/ready", get(readiness_handler).with_state(readiness_state))
205            // RPC endpoint
206            .route("/rpc", post(rpc_handler))
207            // REST-style function endpoint (JSON)
208            .route("/rpc/{function}", post(rpc_function_handler))
209            // Add state
210            .with_state(rpc_handler_state.clone());
211
212        // Multipart RPC router (separate state needed for multipart)
213        let multipart_router = Router::new()
214            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
215            .with_state(rpc_handler_state);
216
217        // SSE router
218        let sse_router = Router::new()
219            .route("/events", get(sse_handler))
220            .route("/subscribe", post(sse_subscribe_handler))
221            .route("/unsubscribe", post(sse_unsubscribe_handler))
222            .route("/subscribe-job", post(sse_job_subscribe_handler))
223            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
224            .with_state(sse_state);
225
226        main_router = main_router.merge(multipart_router).merge(sse_router);
227
228        // Build middleware stack
229        let service_builder = ServiceBuilder::new()
230            .layer(cors.clone())
231            .layer(middleware::from_fn_with_state(
232                auth_middleware_state,
233                auth_middleware,
234            ))
235            .layer(middleware::from_fn(tracing_middleware));
236
237        // Add metrics middleware if observability is enabled
238        if let Some(ref observability) = self.observability {
239            let metrics_state = Arc::new(MetricsState::new(observability.clone()));
240            main_router = main_router.layer(middleware::from_fn_with_state(
241                metrics_state,
242                metrics_middleware,
243            ));
244        }
245
246        // Apply the remaining middleware layers
247        main_router.layer(service_builder)
248    }
249
250    /// Get the socket address to bind to.
251    pub fn addr(&self) -> std::net::SocketAddr {
252        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
253    }
254
255    /// Run the server (blocking).
256    pub async fn run(self) -> Result<(), std::io::Error> {
257        let addr = self.addr();
258        let router = self.router();
259
260        // Start the reactor for real-time updates
261        if let Err(e) = self.reactor.start().await {
262            tracing::error!("Failed to start reactor: {}", e);
263        } else {
264            tracing::info!("Reactor started for real-time updates");
265        }
266
267        tracing::info!("Gateway server listening on {}", addr);
268
269        let listener = tokio::net::TcpListener::bind(addr).await?;
270        axum::serve(listener, router.into_make_service()).await
271    }
272}
273
274/// Health check handler (liveness probe).
275async fn health_handler() -> Json<HealthResponse> {
276    Json(HealthResponse {
277        status: "healthy".to_string(),
278        version: env!("CARGO_PKG_VERSION").to_string(),
279    })
280}
281
282/// Readiness check handler (readiness probe).
283async fn readiness_handler(
284    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
285) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
286    // Check database connectivity
287    let db_ok = sqlx::query("SELECT 1")
288        .fetch_one(&state.db_pool)
289        .await
290        .is_ok();
291
292    let ready = db_ok;
293    let status_code = if ready {
294        axum::http::StatusCode::OK
295    } else {
296        axum::http::StatusCode::SERVICE_UNAVAILABLE
297    };
298
299    (
300        status_code,
301        Json(ReadinessResponse {
302            ready,
303            database: db_ok,
304            version: env!("CARGO_PKG_VERSION").to_string(),
305        }),
306    )
307}
308
309/// Simple tracing middleware that adds TracingState to extensions.
310async fn tracing_middleware(
311    req: axum::extract::Request,
312    next: axum::middleware::Next,
313) -> axum::response::Response {
314    use axum::http::header::HeaderName;
315
316    // Extract or generate trace ID
317    let trace_id = req
318        .headers()
319        .get(HeaderName::from_static("x-trace-id"))
320        .and_then(|v| v.to_str().ok())
321        .map(String::from)
322        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
323
324    let tracing_state = TracingState::with_trace_id(trace_id.clone());
325
326    let mut req = req;
327    req.extensions_mut().insert(tracing_state.clone());
328
329    // Also insert AuthContext default if not present
330    if req
331        .extensions()
332        .get::<forge_core::function::AuthContext>()
333        .is_none()
334    {
335        req.extensions_mut()
336            .insert(forge_core::function::AuthContext::unauthenticated());
337    }
338
339    let mut response = next.run(req).await;
340
341    // Add trace ID to response headers
342    if let Ok(val) = trace_id.parse() {
343        response.headers_mut().insert("x-trace-id", val);
344    }
345    if let Ok(val) = tracing_state.request_id.parse() {
346        response.headers_mut().insert("x-request-id", val);
347    }
348
349    response
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_gateway_config_default() {
358        let config = GatewayConfig::default();
359        assert_eq!(config.port, 8080);
360        assert_eq!(config.max_connections, 10000);
361        assert!(config.cors_enabled);
362    }
363
364    #[test]
365    fn test_health_response_serialization() {
366        let resp = HealthResponse {
367            status: "healthy".to_string(),
368            version: "0.1.0".to_string(),
369        };
370        let json = serde_json::to_string(&resp).unwrap();
371        assert!(json.contains("healthy"));
372    }
373}