1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23
24use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
25use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
26use super::multipart::rpc_multipart_handler;
27use super::response::{RpcError, RpcResponse};
28use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
29use super::sse::{
30 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
31 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
32};
33use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
34use crate::db::Database;
35use crate::function::FunctionRegistry;
36use crate::mcp::McpToolRegistry;
37use crate::realtime::{Reactor, ReactorConfig};
38
39const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
40const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
41const MAX_MULTIPART_CONCURRENCY: usize = 32;
42
43#[derive(Debug, Clone)]
45pub struct GatewayConfig {
46 pub port: u16,
48 pub max_connections: usize,
50 pub request_timeout_secs: u64,
52 pub cors_enabled: bool,
54 pub cors_origins: Vec<String>,
56 pub auth: AuthConfig,
58 pub mcp: McpConfig,
60}
61
62impl Default for GatewayConfig {
63 fn default() -> Self {
64 Self {
65 port: 8080,
66 max_connections: 512,
67 request_timeout_secs: 30,
68 cors_enabled: false,
69 cors_origins: Vec::new(),
70 auth: AuthConfig::default(),
71 mcp: McpConfig::default(),
72 }
73 }
74}
75
76#[derive(Debug, Serialize)]
78pub struct HealthResponse {
79 pub status: String,
80 pub version: String,
81}
82
83#[derive(Debug, Serialize)]
85pub struct ReadinessResponse {
86 pub ready: bool,
87 pub database: bool,
88 pub reactor: bool,
89 pub version: String,
90}
91
92#[derive(Clone)]
94pub struct ReadinessState {
95 db_pool: sqlx::PgPool,
96 reactor: Arc<Reactor>,
97}
98
99pub struct GatewayServer {
101 config: GatewayConfig,
102 registry: FunctionRegistry,
103 db: Database,
104 reactor: Arc<Reactor>,
105 job_dispatcher: Option<Arc<dyn JobDispatch>>,
106 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
107 mcp_registry: Option<McpToolRegistry>,
108}
109
110impl GatewayServer {
111 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
113 let node_id = NodeId::new();
114 let reactor = Arc::new(Reactor::new(
115 node_id,
116 db.read_pool().clone(),
117 registry.clone(),
118 ReactorConfig::default(),
119 ));
120
121 Self {
122 config,
123 registry,
124 db,
125 reactor,
126 job_dispatcher: None,
127 workflow_dispatcher: None,
128 mcp_registry: None,
129 }
130 }
131
132 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
134 self.job_dispatcher = Some(dispatcher);
135 self
136 }
137
138 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
140 self.workflow_dispatcher = Some(dispatcher);
141 self
142 }
143
144 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
146 self.mcp_registry = Some(registry);
147 self
148 }
149
150 pub fn reactor(&self) -> Arc<Reactor> {
152 self.reactor.clone()
153 }
154
155 pub fn router(&self) -> Router {
157 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
158 self.registry.clone(),
159 self.db.clone(),
160 self.job_dispatcher.clone(),
161 self.workflow_dispatcher.clone(),
162 ));
163
164 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
165
166 let cors = if self.config.cors_enabled {
168 if self.config.cors_origins.iter().any(|o| o == "*") {
169 CorsLayer::new()
170 .allow_origin(Any)
171 .allow_methods(Any)
172 .allow_headers(Any)
173 } else {
174 let origins: Vec<_> = self
175 .config
176 .cors_origins
177 .iter()
178 .filter_map(|o| o.parse().ok())
179 .collect();
180 CorsLayer::new()
181 .allow_origin(origins)
182 .allow_methods(Any)
183 .allow_headers(Any)
184 }
185 } else {
186 CorsLayer::new()
187 };
188
189 let sse_state = Arc::new(SseState::new(
191 self.reactor.clone(),
192 auth_middleware_state.clone(),
193 ));
194
195 let readiness_state = Arc::new(ReadinessState {
197 db_pool: self.db.primary().clone(),
198 reactor: self.reactor.clone(),
199 });
200
201 let mut main_router = Router::new()
203 .route("/health", get(health_handler))
205 .route("/ready", get(readiness_handler).with_state(readiness_state))
207 .route("/rpc", post(rpc_handler))
209 .route("/rpc/{function}", post(rpc_function_handler))
211 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
213 .with_state(rpc_handler_state.clone());
215
216 let multipart_router = Router::new()
218 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
219 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
220 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
222 .with_state(rpc_handler_state);
223
224 let sse_router = Router::new()
226 .route("/events", get(sse_handler))
227 .route("/subscribe", post(sse_subscribe_handler))
228 .route("/unsubscribe", post(sse_unsubscribe_handler))
229 .route("/subscribe-job", post(sse_job_subscribe_handler))
230 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
231 .with_state(sse_state);
232
233 let mut mcp_router = Router::new();
234 if self.config.mcp.enabled {
235 let path = self.config.mcp.path.clone();
236 let mcp_state = Arc::new(McpState::new(
237 self.config.mcp.clone(),
238 self.mcp_registry.clone().unwrap_or_default(),
239 self.db.primary().clone(),
240 self.job_dispatcher.clone(),
241 self.workflow_dispatcher.clone(),
242 ));
243 mcp_router = mcp_router.route(
244 &path,
245 post(mcp_post_handler)
246 .get(mcp_get_handler)
247 .with_state(mcp_state),
248 );
249 }
250
251 main_router = main_router
252 .merge(multipart_router)
253 .merge(sse_router)
254 .merge(mcp_router);
255
256 let service_builder = ServiceBuilder::new()
258 .layer(HandleErrorLayer::new(handle_middleware_error))
259 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
260 .layer(TimeoutLayer::new(Duration::from_secs(
261 self.config.request_timeout_secs,
262 )))
263 .layer(cors.clone())
264 .layer(middleware::from_fn_with_state(
265 auth_middleware_state,
266 auth_middleware,
267 ))
268 .layer(middleware::from_fn(tracing_middleware));
269
270 main_router.layer(service_builder)
272 }
273
274 pub fn addr(&self) -> std::net::SocketAddr {
276 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
277 }
278
279 pub async fn run(self) -> Result<(), std::io::Error> {
281 let addr = self.addr();
282 let router = self.router();
283
284 self.reactor
286 .start()
287 .await
288 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
289 tracing::info!("Reactor started for real-time updates");
290
291 tracing::info!("Gateway server listening on {}", addr);
292
293 let listener = tokio::net::TcpListener::bind(addr).await?;
294 axum::serve(listener, router.into_make_service()).await
295 }
296}
297
298async fn health_handler() -> Json<HealthResponse> {
300 Json(HealthResponse {
301 status: "healthy".to_string(),
302 version: env!("CARGO_PKG_VERSION").to_string(),
303 })
304}
305
306async fn readiness_handler(
308 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
309) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
310 let db_ok = sqlx::query("SELECT 1")
312 .fetch_one(&state.db_pool)
313 .await
314 .is_ok();
315
316 let reactor_stats = state.reactor.stats().await;
318 let reactor_ok = reactor_stats.listener_running;
319
320 let ready = db_ok && reactor_ok;
321 let status_code = if ready {
322 axum::http::StatusCode::OK
323 } else {
324 axum::http::StatusCode::SERVICE_UNAVAILABLE
325 };
326
327 (
328 status_code,
329 Json(ReadinessResponse {
330 ready,
331 database: db_ok,
332 reactor: reactor_ok,
333 version: env!("CARGO_PKG_VERSION").to_string(),
334 }),
335 )
336}
337
338async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
339 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
340 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
341 } else {
342 (
343 StatusCode::SERVICE_UNAVAILABLE,
344 "SERVICE_UNAVAILABLE",
345 "Server overloaded",
346 )
347 };
348 (
349 status,
350 Json(RpcResponse::error(RpcError::new(code, message))),
351 )
352 .into_response()
353}
354
355async fn tracing_middleware(
357 req: axum::extract::Request,
358 next: axum::middleware::Next,
359) -> axum::response::Response {
360 let headers = req.headers();
361
362 let trace_id = headers
363 .get(TRACE_ID_HEADER)
364 .and_then(|v| v.to_str().ok())
365 .map(String::from)
366 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
367
368 let parent_span_id = headers
369 .get(SPAN_ID_HEADER)
370 .and_then(|v| v.to_str().ok())
371 .map(String::from);
372
373 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
374 if let Some(span_id) = parent_span_id {
375 tracing_state = tracing_state.with_parent_span(span_id);
376 }
377
378 let mut req = req;
379 req.extensions_mut().insert(tracing_state.clone());
380
381 if req
382 .extensions()
383 .get::<forge_core::function::AuthContext>()
384 .is_none()
385 {
386 req.extensions_mut()
387 .insert(forge_core::function::AuthContext::unauthenticated());
388 }
389
390 let mut response = next.run(req).await;
391
392 let elapsed = tracing_state.elapsed();
393 tracing::debug!(
394 trace_id = %trace_id,
395 request_id = %tracing_state.request_id,
396 status = %response.status().as_u16(),
397 duration_ms = %elapsed.as_millis(),
398 "Request completed"
399 );
400
401 if let Ok(val) = trace_id.parse() {
402 response.headers_mut().insert(TRACE_ID_HEADER, val);
403 }
404 if let Ok(val) = tracing_state.request_id.parse() {
405 response.headers_mut().insert(REQUEST_ID_HEADER, val);
406 }
407
408 response
409}
410
411#[cfg(test)]
412#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_gateway_config_default() {
418 let config = GatewayConfig::default();
419 assert_eq!(config.port, 8080);
420 assert_eq!(config.max_connections, 512);
421 assert!(!config.cors_enabled);
422 }
423
424 #[test]
425 fn test_health_response_serialization() {
426 let resp = HealthResponse {
427 status: "healthy".to_string(),
428 version: "0.1.0".to_string(),
429 };
430 let json = serde_json::to_string(&resp).unwrap();
431 assert!(json.contains("healthy"));
432 }
433}