Skip to main content

forge_runtime/gateway/
server.rs

1use std::sync::Arc;
2
3use axum::{
4    Json, Router, middleware,
5    routing::{get, post},
6};
7use serde::Serialize;
8use tower::ServiceBuilder;
9use tower_http::cors::{Any, CorsLayer};
10
11use forge_core::cluster::NodeId;
12use forge_core::function::{JobDispatch, WorkflowDispatch};
13
14use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
15use super::multipart::rpc_multipart_handler;
16use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
17use super::sse::{
18    SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
19    sse_unsubscribe_handler, sse_workflow_subscribe_handler,
20};
21use super::tracing::TracingState;
22use crate::db::Database;
23use crate::function::FunctionRegistry;
24use crate::realtime::{Reactor, ReactorConfig};
25
26/// Gateway server configuration.
27#[derive(Debug, Clone)]
28pub struct GatewayConfig {
29    /// Port to listen on.
30    pub port: u16,
31    /// Maximum number of connections.
32    pub max_connections: usize,
33    /// Request timeout in seconds.
34    pub request_timeout_secs: u64,
35    /// Enable CORS.
36    pub cors_enabled: bool,
37    /// Allowed CORS origins.
38    pub cors_origins: Vec<String>,
39    /// Authentication configuration.
40    pub auth: AuthConfig,
41}
42
43impl Default for GatewayConfig {
44    fn default() -> Self {
45        Self {
46            port: 8080,
47            max_connections: 10000,
48            request_timeout_secs: 30,
49            cors_enabled: true,
50            cors_origins: vec!["*".to_string()],
51            auth: AuthConfig::default(),
52        }
53    }
54}
55
56/// Health check response.
57#[derive(Debug, Serialize)]
58pub struct HealthResponse {
59    pub status: String,
60    pub version: String,
61}
62
63/// Readiness check response.
64#[derive(Debug, Serialize)]
65pub struct ReadinessResponse {
66    pub ready: bool,
67    pub database: bool,
68    pub version: String,
69}
70
71/// State for readiness check.
72#[derive(Clone)]
73pub struct ReadinessState {
74    db_pool: sqlx::PgPool,
75}
76
77/// Gateway HTTP server.
78pub struct GatewayServer {
79    config: GatewayConfig,
80    registry: FunctionRegistry,
81    db: Database,
82    reactor: Arc<Reactor>,
83    job_dispatcher: Option<Arc<dyn JobDispatch>>,
84    workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
85}
86
87impl GatewayServer {
88    /// Create a new gateway server.
89    pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
90        let node_id = NodeId::new();
91        let reactor = Arc::new(Reactor::new(
92            node_id,
93            db.read_pool().clone(),
94            registry.clone(),
95            ReactorConfig::default(),
96        ));
97
98        Self {
99            config,
100            registry,
101            db,
102            reactor,
103            job_dispatcher: None,
104            workflow_dispatcher: None,
105        }
106    }
107
108    /// Set the job dispatcher.
109    pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
110        self.job_dispatcher = Some(dispatcher);
111        self
112    }
113
114    /// Set the workflow dispatcher.
115    pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
116        self.workflow_dispatcher = Some(dispatcher);
117        self
118    }
119
120    /// Get a reference to the reactor.
121    pub fn reactor(&self) -> Arc<Reactor> {
122        self.reactor.clone()
123    }
124
125    /// Build the Axum router.
126    pub fn router(&self) -> Router {
127        let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
128            self.registry.clone(),
129            self.db.clone(),
130            self.job_dispatcher.clone(),
131            self.workflow_dispatcher.clone(),
132        ));
133
134        let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
135
136        // Build CORS layer
137        let cors = if self.config.cors_enabled {
138            if self.config.cors_origins.contains(&"*".to_string()) {
139                CorsLayer::new()
140                    .allow_origin(Any)
141                    .allow_methods(Any)
142                    .allow_headers(Any)
143            } else {
144                let origins: Vec<_> = self
145                    .config
146                    .cors_origins
147                    .iter()
148                    .filter_map(|o| o.parse().ok())
149                    .collect();
150                CorsLayer::new()
151                    .allow_origin(origins)
152                    .allow_methods(Any)
153                    .allow_headers(Any)
154            }
155        } else {
156            CorsLayer::new()
157        };
158
159        // SSE state for Server-Sent Events
160        let sse_state = Arc::new(SseState::new(
161            self.reactor.clone(),
162            auth_middleware_state.clone(),
163        ));
164
165        // Readiness state for DB health check
166        let readiness_state = Arc::new(ReadinessState {
167            db_pool: self.db.primary().clone(),
168        });
169
170        // Build the main router with middleware
171        let mut main_router = Router::new()
172            // Health check endpoint (liveness)
173            .route("/health", get(health_handler))
174            // Readiness check endpoint (checks DB)
175            .route("/ready", get(readiness_handler).with_state(readiness_state))
176            // RPC endpoint
177            .route("/rpc", post(rpc_handler))
178            // REST-style function endpoint (JSON)
179            .route("/rpc/{function}", post(rpc_function_handler))
180            // Add state
181            .with_state(rpc_handler_state.clone());
182
183        // Multipart RPC router (separate state needed for multipart)
184        let multipart_router = Router::new()
185            .route("/rpc/{function}/upload", post(rpc_multipart_handler))
186            .with_state(rpc_handler_state);
187
188        // SSE router
189        let sse_router = Router::new()
190            .route("/events", get(sse_handler))
191            .route("/subscribe", post(sse_subscribe_handler))
192            .route("/unsubscribe", post(sse_unsubscribe_handler))
193            .route("/subscribe-job", post(sse_job_subscribe_handler))
194            .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
195            .with_state(sse_state);
196
197        main_router = main_router.merge(multipart_router).merge(sse_router);
198
199        // Build middleware stack
200        let service_builder = ServiceBuilder::new()
201            .layer(cors.clone())
202            .layer(middleware::from_fn_with_state(
203                auth_middleware_state,
204                auth_middleware,
205            ))
206            .layer(middleware::from_fn(tracing_middleware));
207
208        // Apply the remaining middleware layers
209        main_router.layer(service_builder)
210    }
211
212    /// Get the socket address to bind to.
213    pub fn addr(&self) -> std::net::SocketAddr {
214        std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
215    }
216
217    /// Run the server (blocking).
218    pub async fn run(self) -> Result<(), std::io::Error> {
219        let addr = self.addr();
220        let router = self.router();
221
222        // Start the reactor for real-time updates
223        if let Err(e) = self.reactor.start().await {
224            tracing::error!("Failed to start reactor: {}", e);
225        } else {
226            tracing::info!("Reactor started for real-time updates");
227        }
228
229        tracing::info!("Gateway server listening on {}", addr);
230
231        let listener = tokio::net::TcpListener::bind(addr).await?;
232        axum::serve(listener, router.into_make_service()).await
233    }
234}
235
236/// Health check handler (liveness probe).
237async fn health_handler() -> Json<HealthResponse> {
238    Json(HealthResponse {
239        status: "healthy".to_string(),
240        version: env!("CARGO_PKG_VERSION").to_string(),
241    })
242}
243
244/// Readiness check handler (readiness probe).
245async fn readiness_handler(
246    axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
247) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
248    // Check database connectivity
249    let db_ok = sqlx::query("SELECT 1")
250        .fetch_one(&state.db_pool)
251        .await
252        .is_ok();
253
254    let ready = db_ok;
255    let status_code = if ready {
256        axum::http::StatusCode::OK
257    } else {
258        axum::http::StatusCode::SERVICE_UNAVAILABLE
259    };
260
261    (
262        status_code,
263        Json(ReadinessResponse {
264            ready,
265            database: db_ok,
266            version: env!("CARGO_PKG_VERSION").to_string(),
267        }),
268    )
269}
270
271/// Simple tracing middleware that adds TracingState to extensions.
272async fn tracing_middleware(
273    req: axum::extract::Request,
274    next: axum::middleware::Next,
275) -> axum::response::Response {
276    use axum::http::header::HeaderName;
277
278    // Extract or generate trace ID
279    let trace_id = req
280        .headers()
281        .get(HeaderName::from_static("x-trace-id"))
282        .and_then(|v| v.to_str().ok())
283        .map(String::from)
284        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
285
286    let tracing_state = TracingState::with_trace_id(trace_id.clone());
287
288    let mut req = req;
289    req.extensions_mut().insert(tracing_state.clone());
290
291    // Also insert AuthContext default if not present
292    if req
293        .extensions()
294        .get::<forge_core::function::AuthContext>()
295        .is_none()
296    {
297        req.extensions_mut()
298            .insert(forge_core::function::AuthContext::unauthenticated());
299    }
300
301    let mut response = next.run(req).await;
302
303    // Add trace ID to response headers
304    if let Ok(val) = trace_id.parse() {
305        response.headers_mut().insert("x-trace-id", val);
306    }
307    if let Ok(val) = tracing_state.request_id.parse() {
308        response.headers_mut().insert("x-request-id", val);
309    }
310
311    response
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_gateway_config_default() {
320        let config = GatewayConfig::default();
321        assert_eq!(config.port, 8080);
322        assert_eq!(config.max_connections, 10000);
323        assert!(config.cors_enabled);
324    }
325
326    #[test]
327    fn test_health_response_serialization() {
328        let resp = HealthResponse {
329            status: "healthy".to_string(),
330            version: "0.1.0".to_string(),
331        };
332        let json = serde_json::to_string(&resp).unwrap();
333        assert!(json.contains("healthy"));
334    }
335}