1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use tracing::Instrument;
24
25use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
26use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
27use super::multipart::rpc_multipart_handler;
28use super::response::{RpcError, RpcResponse};
29use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
30use super::sse::{
31 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
32 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
33};
34use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
35use crate::db::Database;
36use crate::function::FunctionRegistry;
37use crate::mcp::McpToolRegistry;
38use crate::realtime::{Reactor, ReactorConfig};
39
40const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
41const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
42const MAX_MULTIPART_CONCURRENCY: usize = 32;
43
44#[derive(Debug, Clone)]
46pub struct GatewayConfig {
47 pub port: u16,
49 pub max_connections: usize,
51 pub request_timeout_secs: u64,
53 pub cors_enabled: bool,
55 pub cors_origins: Vec<String>,
57 pub auth: AuthConfig,
59 pub mcp: McpConfig,
61 pub quiet_routes: Vec<String>,
63}
64
65impl Default for GatewayConfig {
66 fn default() -> Self {
67 Self {
68 port: 8080,
69 max_connections: 512,
70 request_timeout_secs: 30,
71 cors_enabled: false,
72 cors_origins: Vec::new(),
73 auth: AuthConfig::default(),
74 mcp: McpConfig::default(),
75 quiet_routes: Vec::new(),
76 }
77 }
78}
79
80#[derive(Debug, Serialize)]
82pub struct HealthResponse {
83 pub status: String,
84 pub version: String,
85}
86
87#[derive(Debug, Serialize)]
89pub struct ReadinessResponse {
90 pub ready: bool,
91 pub database: bool,
92 pub reactor: bool,
93 pub version: String,
94}
95
96#[derive(Clone)]
98pub struct ReadinessState {
99 db_pool: sqlx::PgPool,
100 reactor: Arc<Reactor>,
101}
102
103pub struct GatewayServer {
105 config: GatewayConfig,
106 registry: FunctionRegistry,
107 db: Database,
108 reactor: Arc<Reactor>,
109 job_dispatcher: Option<Arc<dyn JobDispatch>>,
110 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
111 mcp_registry: Option<McpToolRegistry>,
112}
113
114impl GatewayServer {
115 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
117 let node_id = NodeId::new();
118 let reactor = Arc::new(Reactor::new(
119 node_id,
120 db.read_pool().clone(),
121 registry.clone(),
122 ReactorConfig::default(),
123 ));
124
125 Self {
126 config,
127 registry,
128 db,
129 reactor,
130 job_dispatcher: None,
131 workflow_dispatcher: None,
132 mcp_registry: None,
133 }
134 }
135
136 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
138 self.job_dispatcher = Some(dispatcher);
139 self
140 }
141
142 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
144 self.workflow_dispatcher = Some(dispatcher);
145 self
146 }
147
148 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
150 self.mcp_registry = Some(registry);
151 self
152 }
153
154 pub fn reactor(&self) -> Arc<Reactor> {
156 self.reactor.clone()
157 }
158
159 pub fn router(&self) -> Router {
161 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
162 self.registry.clone(),
163 self.db.clone(),
164 self.job_dispatcher.clone(),
165 self.workflow_dispatcher.clone(),
166 ));
167
168 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
169
170 let cors = if self.config.cors_enabled {
172 if self.config.cors_origins.iter().any(|o| o == "*") {
173 CorsLayer::new()
174 .allow_origin(Any)
175 .allow_methods(Any)
176 .allow_headers(Any)
177 } else {
178 let origins: Vec<_> = self
179 .config
180 .cors_origins
181 .iter()
182 .filter_map(|o| o.parse().ok())
183 .collect();
184 CorsLayer::new()
185 .allow_origin(origins)
186 .allow_methods(Any)
187 .allow_headers(Any)
188 }
189 } else {
190 CorsLayer::new()
191 };
192
193 let sse_state = Arc::new(SseState::new(
195 self.reactor.clone(),
196 auth_middleware_state.clone(),
197 ));
198
199 let readiness_state = Arc::new(ReadinessState {
201 db_pool: self.db.primary().clone(),
202 reactor: self.reactor.clone(),
203 });
204
205 let mut main_router = Router::new()
207 .route("/health", get(health_handler))
209 .route("/ready", get(readiness_handler).with_state(readiness_state))
211 .route("/rpc", post(rpc_handler))
213 .route("/rpc/{function}", post(rpc_function_handler))
215 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
217 .with_state(rpc_handler_state.clone());
219
220 let multipart_router = Router::new()
222 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
223 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
224 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
226 .with_state(rpc_handler_state);
227
228 let sse_router = Router::new()
230 .route("/events", get(sse_handler))
231 .route("/subscribe", post(sse_subscribe_handler))
232 .route("/unsubscribe", post(sse_unsubscribe_handler))
233 .route("/subscribe-job", post(sse_job_subscribe_handler))
234 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
235 .with_state(sse_state);
236
237 let mut mcp_router = Router::new();
238 if self.config.mcp.enabled {
239 let path = self.config.mcp.path.clone();
240 let mcp_state = Arc::new(McpState::new(
241 self.config.mcp.clone(),
242 self.mcp_registry.clone().unwrap_or_default(),
243 self.db.primary().clone(),
244 self.job_dispatcher.clone(),
245 self.workflow_dispatcher.clone(),
246 ));
247 mcp_router = mcp_router.route(
248 &path,
249 post(mcp_post_handler)
250 .get(mcp_get_handler)
251 .with_state(mcp_state),
252 );
253 }
254
255 main_router = main_router
256 .merge(multipart_router)
257 .merge(sse_router)
258 .merge(mcp_router);
259
260 let service_builder = ServiceBuilder::new()
262 .layer(HandleErrorLayer::new(handle_middleware_error))
263 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
264 .layer(TimeoutLayer::new(Duration::from_secs(
265 self.config.request_timeout_secs,
266 )))
267 .layer(cors.clone())
268 .layer(middleware::from_fn_with_state(
269 auth_middleware_state,
270 auth_middleware,
271 ))
272 .layer(middleware::from_fn_with_state(
273 Arc::new(self.config.quiet_routes.clone()),
274 tracing_middleware,
275 ));
276
277 main_router.layer(service_builder)
279 }
280
281 pub fn addr(&self) -> std::net::SocketAddr {
283 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
284 }
285
286 pub async fn run(self) -> Result<(), std::io::Error> {
288 let addr = self.addr();
289 let router = self.router();
290
291 self.reactor
293 .start()
294 .await
295 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
296 tracing::info!("Reactor started for real-time updates");
297
298 tracing::info!("Gateway server listening on {}", addr);
299
300 let listener = tokio::net::TcpListener::bind(addr).await?;
301 axum::serve(listener, router.into_make_service()).await
302 }
303}
304
305async fn health_handler() -> Json<HealthResponse> {
307 Json(HealthResponse {
308 status: "healthy".to_string(),
309 version: env!("CARGO_PKG_VERSION").to_string(),
310 })
311}
312
313async fn readiness_handler(
315 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
316) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
317 let db_ok = sqlx::query("SELECT 1")
319 .fetch_one(&state.db_pool)
320 .await
321 .is_ok();
322
323 let reactor_stats = state.reactor.stats().await;
325 let reactor_ok = reactor_stats.listener_running;
326
327 let ready = db_ok && reactor_ok;
328 let status_code = if ready {
329 axum::http::StatusCode::OK
330 } else {
331 axum::http::StatusCode::SERVICE_UNAVAILABLE
332 };
333
334 (
335 status_code,
336 Json(ReadinessResponse {
337 ready,
338 database: db_ok,
339 reactor: reactor_ok,
340 version: env!("CARGO_PKG_VERSION").to_string(),
341 }),
342 )
343}
344
345async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
346 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
347 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
348 } else {
349 (
350 StatusCode::SERVICE_UNAVAILABLE,
351 "SERVICE_UNAVAILABLE",
352 "Server overloaded",
353 )
354 };
355 (
356 status,
357 Json(RpcResponse::error(RpcError::new(code, message))),
358 )
359 .into_response()
360}
361
362fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
363 if let Ok(val) = trace_id.parse() {
364 response.headers_mut().insert(TRACE_ID_HEADER, val);
365 }
366 if let Ok(val) = request_id.parse() {
367 response.headers_mut().insert(REQUEST_ID_HEADER, val);
368 }
369}
370
371async fn tracing_middleware(
375 axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
376 req: axum::extract::Request,
377 next: axum::middleware::Next,
378) -> axum::response::Response {
379 let headers = req.headers();
380
381 let trace_id = headers
382 .get(TRACE_ID_HEADER)
383 .and_then(|v| v.to_str().ok())
384 .map(String::from)
385 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
386
387 let parent_span_id = headers
388 .get(SPAN_ID_HEADER)
389 .and_then(|v| v.to_str().ok())
390 .map(String::from);
391
392 let method = req.method().to_string();
393 let path = req.uri().path().to_string();
394
395 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
396 if let Some(span_id) = parent_span_id {
397 tracing_state = tracing_state.with_parent_span(span_id);
398 }
399
400 let mut req = req;
401 req.extensions_mut().insert(tracing_state.clone());
402
403 if req
404 .extensions()
405 .get::<forge_core::function::AuthContext>()
406 .is_none()
407 {
408 req.extensions_mut()
409 .insert(forge_core::function::AuthContext::unauthenticated());
410 }
411
412 let full_path = format!("/_api{}", path);
415 let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
416
417 if is_quiet {
418 let mut response = next.run(req).await;
419 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
420 return response;
421 }
422
423 let span = tracing::info_span!(
424 "http.request",
425 http.method = %method,
426 http.route = %path,
427 http.status_code = tracing::field::Empty,
428 trace_id = %trace_id,
429 request_id = %tracing_state.request_id,
430 );
431
432 let mut response = next.run(req).instrument(span.clone()).await;
433
434 let status = response.status().as_u16();
435 let elapsed = tracing_state.elapsed();
436
437 span.record("http.status_code", status);
438 tracing::info!(parent: &span, duration_ms = elapsed.as_millis() as u64, "Request completed");
439 crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
440
441 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
442 response
443}
444
445#[cfg(test)]
446#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_gateway_config_default() {
452 let config = GatewayConfig::default();
453 assert_eq!(config.port, 8080);
454 assert_eq!(config.max_connections, 512);
455 assert!(!config.cors_enabled);
456 }
457
458 #[test]
459 fn test_health_response_serialization() {
460 let resp = HealthResponse {
461 status: "healthy".to_string(),
462 version: "0.1.0".to_string(),
463 };
464 let json = serde_json::to_string(&resp).unwrap();
465 assert!(json.contains("healthy"));
466 }
467}