1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
33use super::sse::{
34 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46
47#[derive(Debug, Clone)]
49pub struct GatewayConfig {
50 pub port: u16,
52 pub max_connections: usize,
54 pub request_timeout_secs: u64,
56 pub cors_enabled: bool,
58 pub cors_origins: Vec<String>,
60 pub auth: AuthConfig,
62 pub mcp: McpConfig,
64 pub quiet_routes: Vec<String>,
66}
67
68impl Default for GatewayConfig {
69 fn default() -> Self {
70 Self {
71 port: 8080,
72 max_connections: 512,
73 request_timeout_secs: 30,
74 cors_enabled: false,
75 cors_origins: Vec::new(),
76 auth: AuthConfig::default(),
77 mcp: McpConfig::default(),
78 quiet_routes: Vec::new(),
79 }
80 }
81}
82
83#[derive(Debug, Serialize)]
85pub struct HealthResponse {
86 pub status: String,
87 pub version: String,
88}
89
90#[derive(Debug, Serialize)]
92pub struct ReadinessResponse {
93 pub ready: bool,
94 pub database: bool,
95 pub reactor: bool,
96 pub version: String,
97}
98
99#[derive(Clone)]
101pub struct ReadinessState {
102 db_pool: sqlx::PgPool,
103 reactor: Arc<Reactor>,
104}
105
106pub struct GatewayServer {
108 config: GatewayConfig,
109 registry: FunctionRegistry,
110 db: Database,
111 reactor: Arc<Reactor>,
112 job_dispatcher: Option<Arc<dyn JobDispatch>>,
113 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
114 mcp_registry: Option<McpToolRegistry>,
115}
116
117impl GatewayServer {
118 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
120 let node_id = NodeId::new();
121 let reactor = Arc::new(Reactor::new(
122 node_id,
123 db.read_pool().clone(),
124 registry.clone(),
125 ReactorConfig::default(),
126 ));
127
128 Self {
129 config,
130 registry,
131 db,
132 reactor,
133 job_dispatcher: None,
134 workflow_dispatcher: None,
135 mcp_registry: None,
136 }
137 }
138
139 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
141 self.job_dispatcher = Some(dispatcher);
142 self
143 }
144
145 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
147 self.workflow_dispatcher = Some(dispatcher);
148 self
149 }
150
151 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
153 self.mcp_registry = Some(registry);
154 self
155 }
156
157 pub fn reactor(&self) -> Arc<Reactor> {
159 self.reactor.clone()
160 }
161
162 pub fn router(&self) -> Router {
164 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
165 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
166
167 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch_and_issuer(
168 self.registry.clone(),
169 self.db.clone(),
170 self.job_dispatcher.clone(),
171 self.workflow_dispatcher.clone(),
172 token_issuer,
173 ));
174
175 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
176
177 let cors = if self.config.cors_enabled {
179 if self.config.cors_origins.iter().any(|o| o == "*") {
180 CorsLayer::new()
181 .allow_origin(Any)
182 .allow_methods(Any)
183 .allow_headers(Any)
184 } else {
185 let origins: Vec<_> = self
186 .config
187 .cors_origins
188 .iter()
189 .filter_map(|o| o.parse().ok())
190 .collect();
191 CorsLayer::new()
192 .allow_origin(origins)
193 .allow_methods(Any)
194 .allow_headers(Any)
195 }
196 } else {
197 CorsLayer::new()
198 };
199
200 let sse_state = Arc::new(SseState::new(
202 self.reactor.clone(),
203 auth_middleware_state.clone(),
204 ));
205
206 let readiness_state = Arc::new(ReadinessState {
208 db_pool: self.db.primary().clone(),
209 reactor: self.reactor.clone(),
210 });
211
212 let mut main_router = Router::new()
214 .route("/health", get(health_handler))
216 .route("/ready", get(readiness_handler).with_state(readiness_state))
218 .route("/rpc", post(rpc_handler))
220 .route("/rpc/{function}", post(rpc_function_handler))
222 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
224 .with_state(rpc_handler_state.clone());
226
227 let multipart_router = Router::new()
229 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
230 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
231 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
233 .with_state(rpc_handler_state);
234
235 let sse_router = Router::new()
237 .route("/events", get(sse_handler))
238 .route("/subscribe", post(sse_subscribe_handler))
239 .route("/unsubscribe", post(sse_unsubscribe_handler))
240 .route("/subscribe-job", post(sse_job_subscribe_handler))
241 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
242 .with_state(sse_state);
243
244 let mut mcp_router = Router::new();
245 if self.config.mcp.enabled {
246 let path = self.config.mcp.path.clone();
247 let mcp_state = Arc::new(McpState::new(
248 self.config.mcp.clone(),
249 self.mcp_registry.clone().unwrap_or_default(),
250 self.db.primary().clone(),
251 self.job_dispatcher.clone(),
252 self.workflow_dispatcher.clone(),
253 ));
254 mcp_router = mcp_router.route(
255 &path,
256 post(mcp_post_handler)
257 .get(mcp_get_handler)
258 .with_state(mcp_state),
259 );
260 }
261
262 main_router = main_router
263 .merge(multipart_router)
264 .merge(sse_router)
265 .merge(mcp_router);
266
267 let service_builder = ServiceBuilder::new()
269 .layer(HandleErrorLayer::new(handle_middleware_error))
270 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
271 .layer(TimeoutLayer::new(Duration::from_secs(
272 self.config.request_timeout_secs,
273 )))
274 .layer(cors.clone())
275 .layer(middleware::from_fn_with_state(
276 auth_middleware_state,
277 auth_middleware,
278 ))
279 .layer(middleware::from_fn_with_state(
280 Arc::new(self.config.quiet_routes.clone()),
281 tracing_middleware,
282 ));
283
284 main_router.layer(service_builder)
286 }
287
288 pub fn addr(&self) -> std::net::SocketAddr {
290 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
291 }
292
293 pub async fn run(self) -> Result<(), std::io::Error> {
295 let addr = self.addr();
296 let router = self.router();
297
298 self.reactor
300 .start()
301 .await
302 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
303 tracing::info!("Reactor started for real-time updates");
304
305 tracing::info!("Gateway server listening on {}", addr);
306
307 let listener = tokio::net::TcpListener::bind(addr).await?;
308 axum::serve(listener, router.into_make_service()).await
309 }
310}
311
312async fn health_handler() -> Json<HealthResponse> {
314 Json(HealthResponse {
315 status: "healthy".to_string(),
316 version: env!("CARGO_PKG_VERSION").to_string(),
317 })
318}
319
320async fn readiness_handler(
322 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
323) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
324 let db_ok = sqlx::query("SELECT 1")
326 .fetch_one(&state.db_pool)
327 .await
328 .is_ok();
329
330 let reactor_stats = state.reactor.stats().await;
332 let reactor_ok = reactor_stats.listener_running;
333
334 let ready = db_ok && reactor_ok;
335 let status_code = if ready {
336 axum::http::StatusCode::OK
337 } else {
338 axum::http::StatusCode::SERVICE_UNAVAILABLE
339 };
340
341 (
342 status_code,
343 Json(ReadinessResponse {
344 ready,
345 database: db_ok,
346 reactor: reactor_ok,
347 version: env!("CARGO_PKG_VERSION").to_string(),
348 }),
349 )
350}
351
352async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
353 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
354 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
355 } else {
356 (
357 StatusCode::SERVICE_UNAVAILABLE,
358 "SERVICE_UNAVAILABLE",
359 "Server overloaded",
360 )
361 };
362 (
363 status,
364 Json(RpcResponse::error(RpcError::new(code, message))),
365 )
366 .into_response()
367}
368
369fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
370 if let Ok(val) = trace_id.parse() {
371 response.headers_mut().insert(TRACE_ID_HEADER, val);
372 }
373 if let Ok(val) = request_id.parse() {
374 response.headers_mut().insert(REQUEST_ID_HEADER, val);
375 }
376}
377
378struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
380
381impl<'a> Extractor for HeaderExtractor<'a> {
382 fn get(&self, key: &str) -> Option<&str> {
383 self.0.get(key).and_then(|v| v.to_str().ok())
384 }
385
386 fn keys(&self) -> Vec<&str> {
387 self.0.keys().map(|k| k.as_str()).collect()
388 }
389}
390
391async fn tracing_middleware(
397 axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
398 req: axum::extract::Request,
399 next: axum::middleware::Next,
400) -> axum::response::Response {
401 let headers = req.headers();
402
403 let parent_cx =
405 global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
406
407 let trace_id = headers
408 .get(TRACE_ID_HEADER)
409 .and_then(|v| v.to_str().ok())
410 .map(String::from)
411 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
412
413 let parent_span_id = headers
414 .get(SPAN_ID_HEADER)
415 .and_then(|v| v.to_str().ok())
416 .map(String::from);
417
418 let method = req.method().to_string();
419 let path = req.uri().path().to_string();
420
421 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
422 if let Some(span_id) = parent_span_id {
423 tracing_state = tracing_state.with_parent_span(span_id);
424 }
425
426 let mut req = req;
427 req.extensions_mut().insert(tracing_state.clone());
428
429 if req
430 .extensions()
431 .get::<forge_core::function::AuthContext>()
432 .is_none()
433 {
434 req.extensions_mut()
435 .insert(forge_core::function::AuthContext::unauthenticated());
436 }
437
438 let full_path = format!("/_api{}", path);
441 let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
442
443 if is_quiet {
444 let mut response = next.run(req).await;
445 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
446 return response;
447 }
448
449 let span = tracing::info_span!(
450 "http.request",
451 http.method = %method,
452 http.route = %path,
453 http.status_code = tracing::field::Empty,
454 trace_id = %trace_id,
455 request_id = %tracing_state.request_id,
456 );
457
458 span.set_parent(parent_cx);
461
462 let mut response = next.run(req).instrument(span.clone()).await;
463
464 let status = response.status().as_u16();
465 let elapsed = tracing_state.elapsed();
466
467 span.record("http.status_code", status);
468 let duration_ms = elapsed.as_millis() as u64;
469 match status {
470 500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
471 400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
472 200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
473 _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
474 }
475 crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
476
477 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
478 response
479}
480
481#[cfg(test)]
482#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_gateway_config_default() {
488 let config = GatewayConfig::default();
489 assert_eq!(config.port, 8080);
490 assert_eq!(config.max_connections, 512);
491 assert!(!config.cors_enabled);
492 }
493
494 #[test]
495 fn test_health_response_serialization() {
496 let resp = HealthResponse {
497 status: "healthy".to_string(),
498 version: "0.1.0".to_string(),
499 };
500 let json = serde_json::to_string(&resp).unwrap();
501 assert!(json.contains("healthy"));
502 }
503}