1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46
47#[derive(Debug, Clone)]
49pub struct GatewayConfig {
50 pub port: u16,
52 pub max_connections: usize,
54 pub sse_max_sessions: usize,
56 pub request_timeout_secs: u64,
58 pub cors_enabled: bool,
60 pub cors_origins: Vec<String>,
62 pub auth: AuthConfig,
64 pub mcp: McpConfig,
66 pub quiet_routes: Vec<String>,
68 pub token_ttl: forge_core::AuthTokenTtl,
70 pub project_name: String,
72}
73
74impl Default for GatewayConfig {
75 fn default() -> Self {
76 Self {
77 port: 9081,
78 max_connections: 512,
79 sse_max_sessions: 10_000,
80 request_timeout_secs: 30,
81 cors_enabled: false,
82 cors_origins: Vec::new(),
83 auth: AuthConfig::default(),
84 mcp: McpConfig::default(),
85 quiet_routes: Vec::new(),
86 token_ttl: forge_core::AuthTokenTtl::default(),
87 project_name: "forge-app".to_string(),
88 }
89 }
90}
91
92#[derive(Debug, Serialize)]
94pub struct HealthResponse {
95 pub status: String,
96 pub version: String,
97}
98
99#[derive(Debug, Serialize)]
101pub struct ReadinessResponse {
102 pub ready: bool,
103 pub database: bool,
104 pub reactor: bool,
105 pub version: String,
106}
107
108#[derive(Clone)]
110pub struct ReadinessState {
111 db_pool: sqlx::PgPool,
112 reactor: Arc<Reactor>,
113}
114
115pub struct GatewayServer {
117 config: GatewayConfig,
118 registry: FunctionRegistry,
119 db: Database,
120 reactor: Arc<Reactor>,
121 job_dispatcher: Option<Arc<dyn JobDispatch>>,
122 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
123 mcp_registry: Option<McpToolRegistry>,
124 token_ttl: forge_core::AuthTokenTtl,
125}
126
127impl GatewayServer {
128 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
130 let node_id = NodeId::new();
131 let reactor = Arc::new(Reactor::new(
132 node_id,
133 db.primary().clone(),
134 registry.clone(),
135 ReactorConfig::default(),
136 ));
137
138 let token_ttl = config.token_ttl.clone();
139 Self {
140 config,
141 registry,
142 db,
143 reactor,
144 job_dispatcher: None,
145 workflow_dispatcher: None,
146 mcp_registry: None,
147 token_ttl,
148 }
149 }
150
151 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
153 self.job_dispatcher = Some(dispatcher);
154 self
155 }
156
157 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
159 self.workflow_dispatcher = Some(dispatcher);
160 self
161 }
162
163 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
165 self.mcp_registry = Some(registry);
166 self
167 }
168
169 pub fn reactor(&self) -> Arc<Reactor> {
171 self.reactor.clone()
172 }
173
174 pub fn oauth_router(&self) -> Option<(Router, Arc<super::oauth::OAuthState>)> {
176 if !self.config.mcp.oauth {
177 return None;
178 }
179
180 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
181 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>)?;
182
183 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
184
185 let jwt_secret = self.config.auth.jwt_secret.clone().unwrap_or_default();
186
187 let oauth_state = Arc::new(super::oauth::OAuthState::new(
188 self.db.primary().clone(),
189 auth_middleware_state,
190 token_issuer,
191 self.token_ttl.access_token_secs,
192 self.token_ttl.refresh_token_days,
193 self.config.auth.is_hmac(),
194 self.config.project_name.clone(),
195 jwt_secret,
196 ));
197
198 let router = Router::new()
199 .route(
200 "/oauth/authorize",
201 get(super::oauth::oauth_authorize_get).post(super::oauth::oauth_authorize_post),
202 )
203 .route("/oauth/token", post(super::oauth::oauth_token))
204 .route("/oauth/register", post(super::oauth::oauth_register))
205 .with_state(oauth_state.clone());
206
207 Some((router, oauth_state))
208 }
209
210 pub fn router(&self) -> Router {
212 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
213 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
214
215 let mut rpc = RpcHandler::with_dispatch_and_issuer(
216 self.registry.clone(),
217 self.db.clone(),
218 self.job_dispatcher.clone(),
219 self.workflow_dispatcher.clone(),
220 token_issuer,
221 );
222 rpc.set_token_ttl(self.token_ttl.clone());
223 let rpc_handler_state = Arc::new(rpc);
224
225 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
226
227 let cors = if self.config.cors_enabled {
231 if self.config.cors_origins.iter().any(|o| o == "*") {
232 CorsLayer::new()
234 .allow_origin(Any)
235 .allow_methods(Any)
236 .allow_headers(Any)
237 } else {
238 let origins: Vec<_> = self
239 .config
240 .cors_origins
241 .iter()
242 .filter_map(|o| o.parse().ok())
243 .collect();
244 use axum::http::Method;
245 CorsLayer::new()
246 .allow_origin(origins)
247 .allow_methods([
248 Method::GET,
249 Method::POST,
250 Method::PUT,
251 Method::DELETE,
252 Method::PATCH,
253 Method::OPTIONS,
254 ])
255 .allow_headers([
256 axum::http::header::CONTENT_TYPE,
257 axum::http::header::AUTHORIZATION,
258 axum::http::header::ACCEPT,
259 ])
260 .allow_credentials(true)
261 }
262 } else {
263 CorsLayer::new()
264 };
265
266 let sse_state = Arc::new(SseState::with_config(
268 self.reactor.clone(),
269 auth_middleware_state.clone(),
270 super::sse::SseConfig {
271 max_sessions: self.config.sse_max_sessions,
272 ..Default::default()
273 },
274 ));
275
276 let readiness_state = Arc::new(ReadinessState {
278 db_pool: self.db.primary().clone(),
279 reactor: self.reactor.clone(),
280 });
281
282 let mut main_router = Router::new()
284 .route("/health", get(health_handler))
286 .route("/ready", get(readiness_handler).with_state(readiness_state))
288 .route("/rpc", post(rpc_handler))
290 .route("/rpc/batch", post(rpc_batch_handler))
292 .route("/rpc/{function}", post(rpc_function_handler))
294 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
296 .with_state(rpc_handler_state.clone());
298
299 let multipart_router = Router::new()
301 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
302 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
303 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
305 .with_state(rpc_handler_state);
306
307 let sse_router = Router::new()
309 .route("/events", get(sse_handler))
310 .route("/subscribe", post(sse_subscribe_handler))
311 .route("/unsubscribe", post(sse_unsubscribe_handler))
312 .route("/subscribe-job", post(sse_job_subscribe_handler))
313 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
314 .with_state(sse_state);
315
316 let mut mcp_router = Router::new();
317 if self.config.mcp.enabled {
318 let path = self.config.mcp.path.clone();
319 let mcp_state = Arc::new(McpState::new(
320 self.config.mcp.clone(),
321 self.mcp_registry.clone().unwrap_or_default(),
322 self.db.primary().clone(),
323 self.job_dispatcher.clone(),
324 self.workflow_dispatcher.clone(),
325 ));
326 mcp_router = mcp_router.route(
327 &path,
328 post(mcp_post_handler)
329 .get(mcp_get_handler)
330 .with_state(mcp_state),
331 );
332 }
333
334 main_router = main_router
335 .merge(multipart_router)
336 .merge(sse_router)
337 .merge(mcp_router);
338
339 let service_builder = ServiceBuilder::new()
341 .layer(HandleErrorLayer::new(handle_middleware_error))
342 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
343 .layer(TimeoutLayer::new(Duration::from_secs(
344 self.config.request_timeout_secs,
345 )))
346 .layer(cors.clone())
347 .layer(middleware::from_fn_with_state(
348 auth_middleware_state,
349 auth_middleware,
350 ))
351 .layer(middleware::from_fn_with_state(
352 Arc::new(self.config.quiet_routes.clone()),
353 tracing_middleware,
354 ));
355
356 main_router.layer(service_builder)
358 }
359
360 pub fn addr(&self) -> std::net::SocketAddr {
362 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
363 }
364
365 pub async fn run(self) -> Result<(), std::io::Error> {
367 let addr = self.addr();
368 let router = self.router();
369
370 self.reactor
372 .start()
373 .await
374 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
375 tracing::info!("Reactor started for real-time updates");
376
377 tracing::info!("Gateway server listening on {}", addr);
378
379 let listener = tokio::net::TcpListener::bind(addr).await?;
380 axum::serve(listener, router.into_make_service()).await
381 }
382}
383
384async fn health_handler() -> Json<HealthResponse> {
386 Json(HealthResponse {
387 status: "healthy".to_string(),
388 version: env!("CARGO_PKG_VERSION").to_string(),
389 })
390}
391
392async fn readiness_handler(
394 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
395) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
396 let db_ok = sqlx::query("SELECT 1")
398 .fetch_one(&state.db_pool)
399 .await
400 .is_ok();
401
402 let reactor_stats = state.reactor.stats().await;
404 let reactor_ok = reactor_stats.listener_running;
405
406 let ready = db_ok && reactor_ok;
407 let status_code = if ready {
408 axum::http::StatusCode::OK
409 } else {
410 axum::http::StatusCode::SERVICE_UNAVAILABLE
411 };
412
413 (
414 status_code,
415 Json(ReadinessResponse {
416 ready,
417 database: db_ok,
418 reactor: reactor_ok,
419 version: env!("CARGO_PKG_VERSION").to_string(),
420 }),
421 )
422}
423
424async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
425 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
426 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
427 } else {
428 (
429 StatusCode::SERVICE_UNAVAILABLE,
430 "SERVICE_UNAVAILABLE",
431 "Server overloaded",
432 )
433 };
434 (
435 status,
436 Json(RpcResponse::error(RpcError::new(code, message))),
437 )
438 .into_response()
439}
440
441fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
442 if let Ok(val) = trace_id.parse() {
443 response.headers_mut().insert(TRACE_ID_HEADER, val);
444 }
445 if let Ok(val) = request_id.parse() {
446 response.headers_mut().insert(REQUEST_ID_HEADER, val);
447 }
448}
449
450struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
452
453impl<'a> Extractor for HeaderExtractor<'a> {
454 fn get(&self, key: &str) -> Option<&str> {
455 self.0.get(key).and_then(|v| v.to_str().ok())
456 }
457
458 fn keys(&self) -> Vec<&str> {
459 self.0.keys().map(|k| k.as_str()).collect()
460 }
461}
462
463async fn tracing_middleware(
469 axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
470 req: axum::extract::Request,
471 next: axum::middleware::Next,
472) -> axum::response::Response {
473 let headers = req.headers();
474
475 let parent_cx =
477 global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
478
479 let trace_id = headers
480 .get(TRACE_ID_HEADER)
481 .and_then(|v| v.to_str().ok())
482 .map(String::from)
483 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
484
485 let parent_span_id = headers
486 .get(SPAN_ID_HEADER)
487 .and_then(|v| v.to_str().ok())
488 .map(String::from);
489
490 let method = req.method().to_string();
491 let path = req.uri().path().to_string();
492
493 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
494 if let Some(span_id) = parent_span_id {
495 tracing_state = tracing_state.with_parent_span(span_id);
496 }
497
498 let mut req = req;
499 req.extensions_mut().insert(tracing_state.clone());
500
501 if req
502 .extensions()
503 .get::<forge_core::function::AuthContext>()
504 .is_none()
505 {
506 req.extensions_mut()
507 .insert(forge_core::function::AuthContext::unauthenticated());
508 }
509
510 let full_path = format!("/_api{}", path);
513 let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
514
515 if is_quiet {
516 let mut response = next.run(req).await;
517 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
518 return response;
519 }
520
521 let span = tracing::info_span!(
522 "http.request",
523 http.method = %method,
524 http.route = %path,
525 http.status_code = tracing::field::Empty,
526 trace_id = %trace_id,
527 request_id = %tracing_state.request_id,
528 );
529
530 span.set_parent(parent_cx);
533
534 let mut response = next.run(req).instrument(span.clone()).await;
535
536 let status = response.status().as_u16();
537 let elapsed = tracing_state.elapsed();
538
539 span.record("http.status_code", status);
540 let duration_ms = elapsed.as_millis() as u64;
541 match status {
542 500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
543 400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
544 200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
545 _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
546 }
547 crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
548
549 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
550 response
551}
552
553#[cfg(test)]
554#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_gateway_config_default() {
560 let config = GatewayConfig::default();
561 assert_eq!(config.port, 9081);
562 assert_eq!(config.max_connections, 512);
563 assert!(!config.cors_enabled);
564 }
565
566 #[test]
567 fn test_health_response_serialization() {
568 let resp = HealthResponse {
569 status: "healthy".to_string(),
570 version: "0.1.0".to_string(),
571 };
572 let json = serde_json::to_string(&resp).unwrap();
573 assert!(json.contains("healthy"));
574 }
575}