1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::function::{JobDispatch, WorkflowDispatch};
22
23use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
24use super::multipart::rpc_multipart_handler;
25use super::response::{RpcError, RpcResponse};
26use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
27use super::sse::{
28 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
29 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
30};
31use super::tracing::TracingState;
32use crate::db::Database;
33use crate::function::FunctionRegistry;
34use crate::realtime::{Reactor, ReactorConfig};
35
36const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
37const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
38const MAX_MULTIPART_CONCURRENCY: usize = 32;
39
40#[derive(Debug, Clone)]
42pub struct GatewayConfig {
43 pub port: u16,
45 pub max_connections: usize,
47 pub request_timeout_secs: u64,
49 pub cors_enabled: bool,
51 pub cors_origins: Vec<String>,
53 pub auth: AuthConfig,
55}
56
57impl Default for GatewayConfig {
58 fn default() -> Self {
59 Self {
60 port: 8080,
61 max_connections: 512,
62 request_timeout_secs: 30,
63 cors_enabled: false,
64 cors_origins: Vec::new(),
65 auth: AuthConfig::default(),
66 }
67 }
68}
69
70#[derive(Debug, Serialize)]
72pub struct HealthResponse {
73 pub status: String,
74 pub version: String,
75}
76
77#[derive(Debug, Serialize)]
79pub struct ReadinessResponse {
80 pub ready: bool,
81 pub database: bool,
82 pub reactor: bool,
83 pub version: String,
84}
85
86#[derive(Clone)]
88pub struct ReadinessState {
89 db_pool: sqlx::PgPool,
90 reactor: Arc<Reactor>,
91}
92
93pub struct GatewayServer {
95 config: GatewayConfig,
96 registry: FunctionRegistry,
97 db: Database,
98 reactor: Arc<Reactor>,
99 job_dispatcher: Option<Arc<dyn JobDispatch>>,
100 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
101}
102
103impl GatewayServer {
104 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
106 let node_id = NodeId::new();
107 let reactor = Arc::new(Reactor::new(
108 node_id,
109 db.read_pool().clone(),
110 registry.clone(),
111 ReactorConfig::default(),
112 ));
113
114 Self {
115 config,
116 registry,
117 db,
118 reactor,
119 job_dispatcher: None,
120 workflow_dispatcher: None,
121 }
122 }
123
124 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
126 self.job_dispatcher = Some(dispatcher);
127 self
128 }
129
130 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
132 self.workflow_dispatcher = Some(dispatcher);
133 self
134 }
135
136 pub fn reactor(&self) -> Arc<Reactor> {
138 self.reactor.clone()
139 }
140
141 pub fn router(&self) -> Router {
143 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
144 self.registry.clone(),
145 self.db.clone(),
146 self.job_dispatcher.clone(),
147 self.workflow_dispatcher.clone(),
148 ));
149
150 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
151
152 let cors = if self.config.cors_enabled {
154 if self.config.cors_origins.iter().any(|o| o == "*") {
155 CorsLayer::new()
156 .allow_origin(Any)
157 .allow_methods(Any)
158 .allow_headers(Any)
159 } else {
160 let origins: Vec<_> = self
161 .config
162 .cors_origins
163 .iter()
164 .filter_map(|o| o.parse().ok())
165 .collect();
166 CorsLayer::new()
167 .allow_origin(origins)
168 .allow_methods(Any)
169 .allow_headers(Any)
170 }
171 } else {
172 CorsLayer::new()
173 };
174
175 let sse_state = Arc::new(SseState::new(
177 self.reactor.clone(),
178 auth_middleware_state.clone(),
179 ));
180
181 let readiness_state = Arc::new(ReadinessState {
183 db_pool: self.db.primary().clone(),
184 reactor: self.reactor.clone(),
185 });
186
187 let mut main_router = Router::new()
189 .route("/health", get(health_handler))
191 .route("/ready", get(readiness_handler).with_state(readiness_state))
193 .route("/rpc", post(rpc_handler))
195 .route("/rpc/{function}", post(rpc_function_handler))
197 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
199 .with_state(rpc_handler_state.clone());
201
202 let multipart_router = Router::new()
204 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
205 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
206 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
208 .with_state(rpc_handler_state);
209
210 let sse_router = Router::new()
212 .route("/events", get(sse_handler))
213 .route("/subscribe", post(sse_subscribe_handler))
214 .route("/unsubscribe", post(sse_unsubscribe_handler))
215 .route("/subscribe-job", post(sse_job_subscribe_handler))
216 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
217 .with_state(sse_state);
218
219 main_router = main_router.merge(multipart_router).merge(sse_router);
220
221 let service_builder = ServiceBuilder::new()
223 .layer(HandleErrorLayer::new(handle_middleware_error))
224 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
225 .layer(TimeoutLayer::new(Duration::from_secs(
226 self.config.request_timeout_secs,
227 )))
228 .layer(cors.clone())
229 .layer(middleware::from_fn_with_state(
230 auth_middleware_state,
231 auth_middleware,
232 ))
233 .layer(middleware::from_fn(tracing_middleware));
234
235 main_router.layer(service_builder)
237 }
238
239 pub fn addr(&self) -> std::net::SocketAddr {
241 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
242 }
243
244 pub async fn run(self) -> Result<(), std::io::Error> {
246 let addr = self.addr();
247 let router = self.router();
248
249 self.reactor
251 .start()
252 .await
253 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
254 tracing::info!("Reactor started for real-time updates");
255
256 tracing::info!("Gateway server listening on {}", addr);
257
258 let listener = tokio::net::TcpListener::bind(addr).await?;
259 axum::serve(listener, router.into_make_service()).await
260 }
261}
262
263async fn health_handler() -> Json<HealthResponse> {
265 Json(HealthResponse {
266 status: "healthy".to_string(),
267 version: env!("CARGO_PKG_VERSION").to_string(),
268 })
269}
270
271async fn readiness_handler(
273 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
274) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
275 let db_ok = sqlx::query("SELECT 1")
277 .fetch_one(&state.db_pool)
278 .await
279 .is_ok();
280
281 let reactor_stats = state.reactor.stats().await;
283 let reactor_ok = reactor_stats.listener_running;
284
285 let ready = db_ok && reactor_ok;
286 let status_code = if ready {
287 axum::http::StatusCode::OK
288 } else {
289 axum::http::StatusCode::SERVICE_UNAVAILABLE
290 };
291
292 (
293 status_code,
294 Json(ReadinessResponse {
295 ready,
296 database: db_ok,
297 reactor: reactor_ok,
298 version: env!("CARGO_PKG_VERSION").to_string(),
299 }),
300 )
301}
302
303async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
304 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
305 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
306 } else {
307 (
308 StatusCode::SERVICE_UNAVAILABLE,
309 "SERVICE_UNAVAILABLE",
310 "Server overloaded",
311 )
312 };
313 (
314 status,
315 Json(RpcResponse::error(RpcError::new(code, message))),
316 )
317 .into_response()
318}
319
320async fn tracing_middleware(
322 req: axum::extract::Request,
323 next: axum::middleware::Next,
324) -> axum::response::Response {
325 use axum::http::header::HeaderName;
326
327 let trace_id = req
329 .headers()
330 .get(HeaderName::from_static("x-trace-id"))
331 .and_then(|v| v.to_str().ok())
332 .map(String::from)
333 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
334
335 let tracing_state = TracingState::with_trace_id(trace_id.clone());
336
337 let mut req = req;
338 req.extensions_mut().insert(tracing_state.clone());
339
340 if req
342 .extensions()
343 .get::<forge_core::function::AuthContext>()
344 .is_none()
345 {
346 req.extensions_mut()
347 .insert(forge_core::function::AuthContext::unauthenticated());
348 }
349
350 let mut response = next.run(req).await;
351
352 if let Ok(val) = trace_id.parse() {
354 response.headers_mut().insert("x-trace-id", val);
355 }
356 if let Ok(val) = tracing_state.request_id.parse() {
357 response.headers_mut().insert("x-request-id", val);
358 }
359
360 response
361}
362
363#[cfg(test)]
364#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_gateway_config_default() {
370 let config = GatewayConfig::default();
371 assert_eq!(config.port, 8080);
372 assert_eq!(config.max_connections, 512);
373 assert!(!config.cors_enabled);
374 }
375
376 #[test]
377 fn test_health_response_serialization() {
378 let resp = HealthResponse {
379 status: "healthy".to_string(),
380 version: "0.1.0".to_string(),
381 };
382 let json = serde_json::to_string(&resp).unwrap();
383 assert!(json.contains("healthy"));
384 }
385}