1use std::sync::Arc;
2
3use axum::{
4 Json, Router, middleware,
5 routing::{get, post},
6};
7use serde::Serialize;
8use tower::ServiceBuilder;
9use tower_http::cors::{Any, CorsLayer};
10
11use forge_core::cluster::NodeId;
12use forge_core::function::{JobDispatch, WorkflowDispatch};
13
14use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
15use super::metrics::{MetricsState, metrics_middleware};
16use super::multipart::rpc_multipart_handler;
17use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
18use super::sse::{
19 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
20 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
21};
22use super::tracing::TracingState;
23use crate::function::FunctionRegistry;
24use crate::observability::ObservabilityState;
25use crate::realtime::{Reactor, ReactorConfig};
26
27#[derive(Debug, Clone)]
29pub struct GatewayConfig {
30 pub port: u16,
32 pub max_connections: usize,
34 pub request_timeout_secs: u64,
36 pub cors_enabled: bool,
38 pub cors_origins: Vec<String>,
40 pub auth: AuthConfig,
42}
43
44impl Default for GatewayConfig {
45 fn default() -> Self {
46 Self {
47 port: 8080,
48 max_connections: 10000,
49 request_timeout_secs: 30,
50 cors_enabled: true,
51 cors_origins: vec!["*".to_string()],
52 auth: AuthConfig::default(),
53 }
54 }
55}
56
57#[derive(Debug, Serialize)]
59pub struct HealthResponse {
60 pub status: String,
61 pub version: String,
62}
63
64#[derive(Debug, Serialize)]
66pub struct ReadinessResponse {
67 pub ready: bool,
68 pub database: bool,
69 pub version: String,
70}
71
72#[derive(Clone)]
74pub struct ReadinessState {
75 db_pool: sqlx::PgPool,
76}
77
78pub struct GatewayServer {
80 config: GatewayConfig,
81 registry: FunctionRegistry,
82 db_pool: sqlx::PgPool,
83 reactor: Arc<Reactor>,
84 observability: Option<ObservabilityState>,
85 job_dispatcher: Option<Arc<dyn JobDispatch>>,
86 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
87}
88
89impl GatewayServer {
90 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
92 let node_id = NodeId::new();
93 let reactor = Arc::new(Reactor::new(
94 node_id,
95 db_pool.clone(),
96 registry.clone(),
97 ReactorConfig::default(),
98 ));
99
100 Self {
101 config,
102 registry,
103 db_pool,
104 reactor,
105 observability: None,
106 job_dispatcher: None,
107 workflow_dispatcher: None,
108 }
109 }
110
111 pub fn with_observability(
113 config: GatewayConfig,
114 registry: FunctionRegistry,
115 db_pool: sqlx::PgPool,
116 observability: ObservabilityState,
117 ) -> Self {
118 let node_id = NodeId::new();
119 let reactor = Arc::new(Reactor::new(
120 node_id,
121 db_pool.clone(),
122 registry.clone(),
123 ReactorConfig::default(),
124 ));
125
126 Self {
127 config,
128 registry,
129 db_pool,
130 reactor,
131 observability: Some(observability),
132 job_dispatcher: None,
133 workflow_dispatcher: None,
134 }
135 }
136
137 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
139 self.job_dispatcher = Some(dispatcher);
140 self
141 }
142
143 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
145 self.workflow_dispatcher = Some(dispatcher);
146 self
147 }
148
149 pub fn reactor(&self) -> Arc<Reactor> {
151 self.reactor.clone()
152 }
153
154 pub fn router(&self) -> Router {
156 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
157 self.registry.clone(),
158 self.db_pool.clone(),
159 self.job_dispatcher.clone(),
160 self.workflow_dispatcher.clone(),
161 ));
162
163 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
164
165 let cors = if self.config.cors_enabled {
167 if self.config.cors_origins.contains(&"*".to_string()) {
168 CorsLayer::new()
169 .allow_origin(Any)
170 .allow_methods(Any)
171 .allow_headers(Any)
172 } else {
173 let origins: Vec<_> = self
174 .config
175 .cors_origins
176 .iter()
177 .filter_map(|o| o.parse().ok())
178 .collect();
179 CorsLayer::new()
180 .allow_origin(origins)
181 .allow_methods(Any)
182 .allow_headers(Any)
183 }
184 } else {
185 CorsLayer::new()
186 };
187
188 let sse_state = Arc::new(SseState::new(
190 self.reactor.clone(),
191 auth_middleware_state.clone(),
192 ));
193
194 let readiness_state = Arc::new(ReadinessState {
196 db_pool: self.db_pool.clone(),
197 });
198
199 let mut main_router = Router::new()
201 .route("/health", get(health_handler))
203 .route("/ready", get(readiness_handler).with_state(readiness_state))
205 .route("/rpc", post(rpc_handler))
207 .route("/rpc/{function}", post(rpc_function_handler))
209 .with_state(rpc_handler_state.clone());
211
212 let multipart_router = Router::new()
214 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
215 .with_state(rpc_handler_state);
216
217 let sse_router = Router::new()
219 .route("/events", get(sse_handler))
220 .route("/subscribe", post(sse_subscribe_handler))
221 .route("/unsubscribe", post(sse_unsubscribe_handler))
222 .route("/subscribe-job", post(sse_job_subscribe_handler))
223 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
224 .with_state(sse_state);
225
226 main_router = main_router.merge(multipart_router).merge(sse_router);
227
228 let service_builder = ServiceBuilder::new()
230 .layer(cors.clone())
231 .layer(middleware::from_fn_with_state(
232 auth_middleware_state,
233 auth_middleware,
234 ))
235 .layer(middleware::from_fn(tracing_middleware));
236
237 if let Some(ref observability) = self.observability {
239 let metrics_state = Arc::new(MetricsState::new(observability.clone()));
240 main_router = main_router.layer(middleware::from_fn_with_state(
241 metrics_state,
242 metrics_middleware,
243 ));
244 }
245
246 main_router.layer(service_builder)
248 }
249
250 pub fn addr(&self) -> std::net::SocketAddr {
252 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
253 }
254
255 pub async fn run(self) -> Result<(), std::io::Error> {
257 let addr = self.addr();
258 let router = self.router();
259
260 if let Err(e) = self.reactor.start().await {
262 tracing::error!("Failed to start reactor: {}", e);
263 } else {
264 tracing::info!("Reactor started for real-time updates");
265 }
266
267 tracing::info!("Gateway server listening on {}", addr);
268
269 let listener = tokio::net::TcpListener::bind(addr).await?;
270 axum::serve(listener, router.into_make_service()).await
271 }
272}
273
274async fn health_handler() -> Json<HealthResponse> {
276 Json(HealthResponse {
277 status: "healthy".to_string(),
278 version: env!("CARGO_PKG_VERSION").to_string(),
279 })
280}
281
282async fn readiness_handler(
284 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
285) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
286 let db_ok = sqlx::query("SELECT 1")
288 .fetch_one(&state.db_pool)
289 .await
290 .is_ok();
291
292 let ready = db_ok;
293 let status_code = if ready {
294 axum::http::StatusCode::OK
295 } else {
296 axum::http::StatusCode::SERVICE_UNAVAILABLE
297 };
298
299 (
300 status_code,
301 Json(ReadinessResponse {
302 ready,
303 database: db_ok,
304 version: env!("CARGO_PKG_VERSION").to_string(),
305 }),
306 )
307}
308
309async fn tracing_middleware(
311 req: axum::extract::Request,
312 next: axum::middleware::Next,
313) -> axum::response::Response {
314 use axum::http::header::HeaderName;
315
316 let trace_id = req
318 .headers()
319 .get(HeaderName::from_static("x-trace-id"))
320 .and_then(|v| v.to_str().ok())
321 .map(String::from)
322 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
323
324 let tracing_state = TracingState::with_trace_id(trace_id.clone());
325
326 let mut req = req;
327 req.extensions_mut().insert(tracing_state.clone());
328
329 if req
331 .extensions()
332 .get::<forge_core::function::AuthContext>()
333 .is_none()
334 {
335 req.extensions_mut()
336 .insert(forge_core::function::AuthContext::unauthenticated());
337 }
338
339 let mut response = next.run(req).await;
340
341 if let Ok(val) = trace_id.parse() {
343 response.headers_mut().insert("x-trace-id", val);
344 }
345 if let Ok(val) = tracing_state.request_id.parse() {
346 response.headers_mut().insert("x-request-id", val);
347 }
348
349 response
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_gateway_config_default() {
358 let config = GatewayConfig::default();
359 assert_eq!(config.port, 8080);
360 assert_eq!(config.max_connections, 10000);
361 assert!(config.cors_enabled);
362 }
363
364 #[test]
365 fn test_health_response_serialization() {
366 let resp = HealthResponse {
367 status: "healthy".to_string(),
368 version: "0.1.0".to_string(),
369 };
370 let json = serde_json::to_string(&resp).unwrap();
371 assert!(json.contains("healthy"));
372 }
373}