1use std::sync::Arc;
2
3use axum::{
4 Json, Router, middleware,
5 routing::{get, post},
6};
7use serde::Serialize;
8use tower::ServiceBuilder;
9use tower_http::cors::{Any, CorsLayer};
10
11use forge_core::cluster::NodeId;
12use forge_core::function::{JobDispatch, WorkflowDispatch};
13
14use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
15use super::multipart::rpc_multipart_handler;
16use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
17use super::sse::{
18 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
19 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
20};
21use super::tracing::TracingState;
22use crate::db::Database;
23use crate::function::FunctionRegistry;
24use crate::realtime::{Reactor, ReactorConfig};
25
26#[derive(Debug, Clone)]
28pub struct GatewayConfig {
29 pub port: u16,
31 pub max_connections: usize,
33 pub request_timeout_secs: u64,
35 pub cors_enabled: bool,
37 pub cors_origins: Vec<String>,
39 pub auth: AuthConfig,
41}
42
43impl Default for GatewayConfig {
44 fn default() -> Self {
45 Self {
46 port: 8080,
47 max_connections: 10000,
48 request_timeout_secs: 30,
49 cors_enabled: true,
50 cors_origins: vec!["*".to_string()],
51 auth: AuthConfig::default(),
52 }
53 }
54}
55
56#[derive(Debug, Serialize)]
58pub struct HealthResponse {
59 pub status: String,
60 pub version: String,
61}
62
63#[derive(Debug, Serialize)]
65pub struct ReadinessResponse {
66 pub ready: bool,
67 pub database: bool,
68 pub version: String,
69}
70
71#[derive(Clone)]
73pub struct ReadinessState {
74 db_pool: sqlx::PgPool,
75}
76
77pub struct GatewayServer {
79 config: GatewayConfig,
80 registry: FunctionRegistry,
81 db: Database,
82 reactor: Arc<Reactor>,
83 job_dispatcher: Option<Arc<dyn JobDispatch>>,
84 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
85}
86
87impl GatewayServer {
88 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
90 let node_id = NodeId::new();
91 let reactor = Arc::new(Reactor::new(
92 node_id,
93 db.read_pool().clone(),
94 registry.clone(),
95 ReactorConfig::default(),
96 ));
97
98 Self {
99 config,
100 registry,
101 db,
102 reactor,
103 job_dispatcher: None,
104 workflow_dispatcher: None,
105 }
106 }
107
108 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
110 self.job_dispatcher = Some(dispatcher);
111 self
112 }
113
114 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
116 self.workflow_dispatcher = Some(dispatcher);
117 self
118 }
119
120 pub fn reactor(&self) -> Arc<Reactor> {
122 self.reactor.clone()
123 }
124
125 pub fn router(&self) -> Router {
127 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
128 self.registry.clone(),
129 self.db.clone(),
130 self.job_dispatcher.clone(),
131 self.workflow_dispatcher.clone(),
132 ));
133
134 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
135
136 let cors = if self.config.cors_enabled {
138 if self.config.cors_origins.contains(&"*".to_string()) {
139 CorsLayer::new()
140 .allow_origin(Any)
141 .allow_methods(Any)
142 .allow_headers(Any)
143 } else {
144 let origins: Vec<_> = self
145 .config
146 .cors_origins
147 .iter()
148 .filter_map(|o| o.parse().ok())
149 .collect();
150 CorsLayer::new()
151 .allow_origin(origins)
152 .allow_methods(Any)
153 .allow_headers(Any)
154 }
155 } else {
156 CorsLayer::new()
157 };
158
159 let sse_state = Arc::new(SseState::new(
161 self.reactor.clone(),
162 auth_middleware_state.clone(),
163 ));
164
165 let readiness_state = Arc::new(ReadinessState {
167 db_pool: self.db.primary().clone(),
168 });
169
170 let mut main_router = Router::new()
172 .route("/health", get(health_handler))
174 .route("/ready", get(readiness_handler).with_state(readiness_state))
176 .route("/rpc", post(rpc_handler))
178 .route("/rpc/{function}", post(rpc_function_handler))
180 .with_state(rpc_handler_state.clone());
182
183 let multipart_router = Router::new()
185 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
186 .with_state(rpc_handler_state);
187
188 let sse_router = Router::new()
190 .route("/events", get(sse_handler))
191 .route("/subscribe", post(sse_subscribe_handler))
192 .route("/unsubscribe", post(sse_unsubscribe_handler))
193 .route("/subscribe-job", post(sse_job_subscribe_handler))
194 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
195 .with_state(sse_state);
196
197 main_router = main_router.merge(multipart_router).merge(sse_router);
198
199 let service_builder = ServiceBuilder::new()
201 .layer(cors.clone())
202 .layer(middleware::from_fn_with_state(
203 auth_middleware_state,
204 auth_middleware,
205 ))
206 .layer(middleware::from_fn(tracing_middleware));
207
208 main_router.layer(service_builder)
210 }
211
212 pub fn addr(&self) -> std::net::SocketAddr {
214 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
215 }
216
217 pub async fn run(self) -> Result<(), std::io::Error> {
219 let addr = self.addr();
220 let router = self.router();
221
222 if let Err(e) = self.reactor.start().await {
224 tracing::error!("Failed to start reactor: {}", e);
225 } else {
226 tracing::info!("Reactor started for real-time updates");
227 }
228
229 tracing::info!("Gateway server listening on {}", addr);
230
231 let listener = tokio::net::TcpListener::bind(addr).await?;
232 axum::serve(listener, router.into_make_service()).await
233 }
234}
235
236async fn health_handler() -> Json<HealthResponse> {
238 Json(HealthResponse {
239 status: "healthy".to_string(),
240 version: env!("CARGO_PKG_VERSION").to_string(),
241 })
242}
243
244async fn readiness_handler(
246 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
247) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
248 let db_ok = sqlx::query("SELECT 1")
250 .fetch_one(&state.db_pool)
251 .await
252 .is_ok();
253
254 let ready = db_ok;
255 let status_code = if ready {
256 axum::http::StatusCode::OK
257 } else {
258 axum::http::StatusCode::SERVICE_UNAVAILABLE
259 };
260
261 (
262 status_code,
263 Json(ReadinessResponse {
264 ready,
265 database: db_ok,
266 version: env!("CARGO_PKG_VERSION").to_string(),
267 }),
268 )
269}
270
271async fn tracing_middleware(
273 req: axum::extract::Request,
274 next: axum::middleware::Next,
275) -> axum::response::Response {
276 use axum::http::header::HeaderName;
277
278 let trace_id = req
280 .headers()
281 .get(HeaderName::from_static("x-trace-id"))
282 .and_then(|v| v.to_str().ok())
283 .map(String::from)
284 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
285
286 let tracing_state = TracingState::with_trace_id(trace_id.clone());
287
288 let mut req = req;
289 req.extensions_mut().insert(tracing_state.clone());
290
291 if req
293 .extensions()
294 .get::<forge_core::function::AuthContext>()
295 .is_none()
296 {
297 req.extensions_mut()
298 .insert(forge_core::function::AuthContext::unauthenticated());
299 }
300
301 let mut response = next.run(req).await;
302
303 if let Ok(val) = trace_id.parse() {
305 response.headers_mut().insert("x-trace-id", val);
306 }
307 if let Ok(val) = tracing_state.request_id.parse() {
308 response.headers_mut().insert("x-request-id", val);
309 }
310
311 response
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_gateway_config_default() {
320 let config = GatewayConfig::default();
321 assert_eq!(config.port, 8080);
322 assert_eq!(config.max_connections, 10000);
323 assert!(config.cors_enabled);
324 }
325
326 #[test]
327 fn test_health_response_serialization() {
328 let resp = HealthResponse {
329 status: "healthy".to_string(),
330 version: "0.1.0".to_string(),
331 };
332 let json = serde_json::to_string(&resp).unwrap();
333 assert!(json.contains("healthy"));
334 }
335}