1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46
47#[derive(Debug, Clone)]
49pub struct GatewayConfig {
50 pub port: u16,
52 pub max_connections: usize,
54 pub sse_max_sessions: usize,
56 pub request_timeout_secs: u64,
58 pub cors_enabled: bool,
60 pub cors_origins: Vec<String>,
62 pub auth: AuthConfig,
64 pub mcp: McpConfig,
66 pub quiet_routes: Vec<String>,
68}
69
70impl Default for GatewayConfig {
71 fn default() -> Self {
72 Self {
73 port: 8080,
74 max_connections: 512,
75 sse_max_sessions: 10_000,
76 request_timeout_secs: 30,
77 cors_enabled: false,
78 cors_origins: Vec::new(),
79 auth: AuthConfig::default(),
80 mcp: McpConfig::default(),
81 quiet_routes: Vec::new(),
82 }
83 }
84}
85
86#[derive(Debug, Serialize)]
88pub struct HealthResponse {
89 pub status: String,
90 pub version: String,
91}
92
93#[derive(Debug, Serialize)]
95pub struct ReadinessResponse {
96 pub ready: bool,
97 pub database: bool,
98 pub reactor: bool,
99 pub version: String,
100}
101
102#[derive(Clone)]
104pub struct ReadinessState {
105 db_pool: sqlx::PgPool,
106 reactor: Arc<Reactor>,
107}
108
109pub struct GatewayServer {
111 config: GatewayConfig,
112 registry: FunctionRegistry,
113 db: Database,
114 reactor: Arc<Reactor>,
115 job_dispatcher: Option<Arc<dyn JobDispatch>>,
116 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
117 mcp_registry: Option<McpToolRegistry>,
118}
119
120impl GatewayServer {
121 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
123 let node_id = NodeId::new();
124 let reactor = Arc::new(Reactor::new(
125 node_id,
126 db.primary().clone(),
127 registry.clone(),
128 ReactorConfig::default(),
129 ));
130
131 Self {
132 config,
133 registry,
134 db,
135 reactor,
136 job_dispatcher: None,
137 workflow_dispatcher: None,
138 mcp_registry: None,
139 }
140 }
141
142 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
144 self.job_dispatcher = Some(dispatcher);
145 self
146 }
147
148 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
150 self.workflow_dispatcher = Some(dispatcher);
151 self
152 }
153
154 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
156 self.mcp_registry = Some(registry);
157 self
158 }
159
160 pub fn reactor(&self) -> Arc<Reactor> {
162 self.reactor.clone()
163 }
164
165 pub fn router(&self) -> Router {
167 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
168 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
169
170 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch_and_issuer(
171 self.registry.clone(),
172 self.db.clone(),
173 self.job_dispatcher.clone(),
174 self.workflow_dispatcher.clone(),
175 token_issuer,
176 ));
177
178 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
179
180 let cors = if self.config.cors_enabled {
182 if self.config.cors_origins.iter().any(|o| o == "*") {
183 CorsLayer::new()
184 .allow_origin(Any)
185 .allow_methods(Any)
186 .allow_headers(Any)
187 } else {
188 let origins: Vec<_> = self
189 .config
190 .cors_origins
191 .iter()
192 .filter_map(|o| o.parse().ok())
193 .collect();
194 CorsLayer::new()
195 .allow_origin(origins)
196 .allow_methods(Any)
197 .allow_headers(Any)
198 }
199 } else {
200 CorsLayer::new()
201 };
202
203 let sse_state = Arc::new(SseState::with_config(
205 self.reactor.clone(),
206 auth_middleware_state.clone(),
207 super::sse::SseConfig {
208 max_sessions: self.config.sse_max_sessions,
209 ..Default::default()
210 },
211 ));
212
213 let readiness_state = Arc::new(ReadinessState {
215 db_pool: self.db.primary().clone(),
216 reactor: self.reactor.clone(),
217 });
218
219 let mut main_router = Router::new()
221 .route("/health", get(health_handler))
223 .route("/ready", get(readiness_handler).with_state(readiness_state))
225 .route("/rpc", post(rpc_handler))
227 .route("/rpc/batch", post(rpc_batch_handler))
229 .route("/rpc/{function}", post(rpc_function_handler))
231 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
233 .with_state(rpc_handler_state.clone());
235
236 let multipart_router = Router::new()
238 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
239 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
240 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
242 .with_state(rpc_handler_state);
243
244 let sse_router = Router::new()
246 .route("/events", get(sse_handler))
247 .route("/subscribe", post(sse_subscribe_handler))
248 .route("/unsubscribe", post(sse_unsubscribe_handler))
249 .route("/subscribe-job", post(sse_job_subscribe_handler))
250 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
251 .with_state(sse_state);
252
253 let mut mcp_router = Router::new();
254 if self.config.mcp.enabled {
255 let path = self.config.mcp.path.clone();
256 let mcp_state = Arc::new(McpState::new(
257 self.config.mcp.clone(),
258 self.mcp_registry.clone().unwrap_or_default(),
259 self.db.primary().clone(),
260 self.job_dispatcher.clone(),
261 self.workflow_dispatcher.clone(),
262 ));
263 mcp_router = mcp_router.route(
264 &path,
265 post(mcp_post_handler)
266 .get(mcp_get_handler)
267 .with_state(mcp_state),
268 );
269 }
270
271 main_router = main_router
272 .merge(multipart_router)
273 .merge(sse_router)
274 .merge(mcp_router);
275
276 let service_builder = ServiceBuilder::new()
278 .layer(HandleErrorLayer::new(handle_middleware_error))
279 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
280 .layer(TimeoutLayer::new(Duration::from_secs(
281 self.config.request_timeout_secs,
282 )))
283 .layer(cors.clone())
284 .layer(middleware::from_fn_with_state(
285 auth_middleware_state,
286 auth_middleware,
287 ))
288 .layer(middleware::from_fn_with_state(
289 Arc::new(self.config.quiet_routes.clone()),
290 tracing_middleware,
291 ));
292
293 main_router.layer(service_builder)
295 }
296
297 pub fn addr(&self) -> std::net::SocketAddr {
299 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
300 }
301
302 pub async fn run(self) -> Result<(), std::io::Error> {
304 let addr = self.addr();
305 let router = self.router();
306
307 self.reactor
309 .start()
310 .await
311 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
312 tracing::info!("Reactor started for real-time updates");
313
314 tracing::info!("Gateway server listening on {}", addr);
315
316 let listener = tokio::net::TcpListener::bind(addr).await?;
317 axum::serve(listener, router.into_make_service()).await
318 }
319}
320
321async fn health_handler() -> Json<HealthResponse> {
323 Json(HealthResponse {
324 status: "healthy".to_string(),
325 version: env!("CARGO_PKG_VERSION").to_string(),
326 })
327}
328
329async fn readiness_handler(
331 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
332) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
333 let db_ok = sqlx::query("SELECT 1")
335 .fetch_one(&state.db_pool)
336 .await
337 .is_ok();
338
339 let reactor_stats = state.reactor.stats().await;
341 let reactor_ok = reactor_stats.listener_running;
342
343 let ready = db_ok && reactor_ok;
344 let status_code = if ready {
345 axum::http::StatusCode::OK
346 } else {
347 axum::http::StatusCode::SERVICE_UNAVAILABLE
348 };
349
350 (
351 status_code,
352 Json(ReadinessResponse {
353 ready,
354 database: db_ok,
355 reactor: reactor_ok,
356 version: env!("CARGO_PKG_VERSION").to_string(),
357 }),
358 )
359}
360
361async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
362 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
363 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
364 } else {
365 (
366 StatusCode::SERVICE_UNAVAILABLE,
367 "SERVICE_UNAVAILABLE",
368 "Server overloaded",
369 )
370 };
371 (
372 status,
373 Json(RpcResponse::error(RpcError::new(code, message))),
374 )
375 .into_response()
376}
377
378fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
379 if let Ok(val) = trace_id.parse() {
380 response.headers_mut().insert(TRACE_ID_HEADER, val);
381 }
382 if let Ok(val) = request_id.parse() {
383 response.headers_mut().insert(REQUEST_ID_HEADER, val);
384 }
385}
386
387struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
389
390impl<'a> Extractor for HeaderExtractor<'a> {
391 fn get(&self, key: &str) -> Option<&str> {
392 self.0.get(key).and_then(|v| v.to_str().ok())
393 }
394
395 fn keys(&self) -> Vec<&str> {
396 self.0.keys().map(|k| k.as_str()).collect()
397 }
398}
399
400async fn tracing_middleware(
406 axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
407 req: axum::extract::Request,
408 next: axum::middleware::Next,
409) -> axum::response::Response {
410 let headers = req.headers();
411
412 let parent_cx =
414 global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
415
416 let trace_id = headers
417 .get(TRACE_ID_HEADER)
418 .and_then(|v| v.to_str().ok())
419 .map(String::from)
420 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
421
422 let parent_span_id = headers
423 .get(SPAN_ID_HEADER)
424 .and_then(|v| v.to_str().ok())
425 .map(String::from);
426
427 let method = req.method().to_string();
428 let path = req.uri().path().to_string();
429
430 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
431 if let Some(span_id) = parent_span_id {
432 tracing_state = tracing_state.with_parent_span(span_id);
433 }
434
435 let mut req = req;
436 req.extensions_mut().insert(tracing_state.clone());
437
438 if req
439 .extensions()
440 .get::<forge_core::function::AuthContext>()
441 .is_none()
442 {
443 req.extensions_mut()
444 .insert(forge_core::function::AuthContext::unauthenticated());
445 }
446
447 let full_path = format!("/_api{}", path);
450 let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
451
452 if is_quiet {
453 let mut response = next.run(req).await;
454 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
455 return response;
456 }
457
458 let span = tracing::info_span!(
459 "http.request",
460 http.method = %method,
461 http.route = %path,
462 http.status_code = tracing::field::Empty,
463 trace_id = %trace_id,
464 request_id = %tracing_state.request_id,
465 );
466
467 span.set_parent(parent_cx);
470
471 let mut response = next.run(req).instrument(span.clone()).await;
472
473 let status = response.status().as_u16();
474 let elapsed = tracing_state.elapsed();
475
476 span.record("http.status_code", status);
477 let duration_ms = elapsed.as_millis() as u64;
478 match status {
479 500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
480 400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
481 200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
482 _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
483 }
484 crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
485
486 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
487 response
488}
489
490#[cfg(test)]
491#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_gateway_config_default() {
497 let config = GatewayConfig::default();
498 assert_eq!(config.port, 8080);
499 assert_eq!(config.max_connections, 512);
500 assert!(!config.cors_enabled);
501 }
502
503 #[test]
504 fn test_health_response_serialization() {
505 let resp = HealthResponse {
506 status: "healthy".to_string(),
507 version: "0.1.0".to_string(),
508 };
509 let json = serde_json::to_string(&resp).unwrap();
510 assert!(json.contains("healthy"));
511 }
512}