use std::{collections::HashMap, sync::Arc, time::Duration};
use axum::{
Json, Router,
extract::State,
http::StatusCode,
response::IntoResponse,
routing::{get, post},
};
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::net::TcpListener;
use super::{
config::{GatewayConfig, SubgraphConfig},
merger::{self, MergedResponse, SubgraphResponse},
planner::{self, FieldOwnership, QueryPlan},
};
#[derive(Clone)]
pub struct GatewayState {
pub(crate) client: reqwest::Client,
pub(crate) subgraphs: HashMap<String, SubgraphConfig>,
pub(crate) ownership: Arc<FieldOwnership>,
pub(crate) subgraph_timeout: Duration,
}
#[derive(Debug, Deserialize)]
struct GraphQLRequest {
query: String,
#[serde(default)]
variables: Option<Value>,
#[serde(default, rename = "operationName")]
operation_name: Option<String>,
}
pub async fn serve(config: &GatewayConfig, ownership: FieldOwnership) -> anyhow::Result<()> {
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(config.timeouts.subgraph_request_ms))
.build()?;
let state = GatewayState {
client,
subgraphs: config.subgraphs.clone(),
ownership: Arc::new(ownership),
subgraph_timeout: Duration::from_millis(config.timeouts.subgraph_request_ms),
};
let app = build_router(state);
let listener = TcpListener::bind(&config.listen).await?;
eprintln!("Gateway listening on {}", config.listen);
eprintln!(" POST /graphql — GraphQL endpoint");
eprintln!(" GET /health — Health check");
eprintln!(" GET /ready — Readiness check");
axum::serve(listener, app).await?;
Ok(())
}
pub fn build_router(state: GatewayState) -> Router {
Router::new()
.route("/graphql", post(handle_graphql))
.route("/health", get(handle_health))
.route("/ready", get(handle_ready))
.with_state(state)
}
async fn handle_graphql(
State(state): State<GatewayState>,
Json(request): Json<GraphQLRequest>,
) -> impl IntoResponse {
let root_fields = planner::extract_root_fields(&request.query);
if root_fields.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"errors": [{"message": "Could not extract root fields from query"}]
})),
);
}
let plan = match planner::plan_query(&root_fields, &state.ownership) {
Ok(plan) => plan,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(json!({
"errors": [{"message": e.to_string()}]
})),
);
},
};
let responses = execute_plan(&state, &plan, &request).await;
let merged = merger::merge_responses(&responses);
(StatusCode::OK, Json(merged_to_value(&merged)))
}
async fn execute_plan(
state: &GatewayState,
plan: &QueryPlan,
original: &GraphQLRequest,
) -> Vec<(String, SubgraphResponse)> {
let mut handles = Vec::new();
for fetch in &plan.fetches {
let client = state.client.clone();
let subgraph_name = fetch.subgraph.clone();
let query = if fetch.is_entity_fetch {
fetch.query.clone()
} else {
if plan.fetches.len() == 1 {
original.query.clone()
} else {
fetch.query.clone()
}
};
let variables = original.variables.clone().unwrap_or(Value::Null);
let operation_name = original.operation_name.clone();
let url = state.subgraphs.get(&fetch.subgraph).map(|s| s.url.clone()).unwrap_or_default();
let timeout = state.subgraph_timeout;
handles.push(tokio::spawn(async move {
let result = execute_subgraph_request(
&client,
&url,
&query,
&variables,
operation_name.as_deref(),
timeout,
)
.await;
let response = match result {
Ok(resp) => resp,
Err(e) => SubgraphResponse {
data: None,
errors: vec![merger::GraphQLError {
message: format!("Subgraph '{subgraph_name}' request failed: {e}"),
path: None,
locations: None,
extensions: Some(json!({"code": "SUBGRAPH_REQUEST_FAILED"})),
}],
},
};
(subgraph_name, response)
}));
}
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(pair) => results.push(pair),
Err(e) => {
results.push((
"unknown".to_string(),
SubgraphResponse {
data: None,
errors: vec![merger::GraphQLError {
message: format!("Task join error: {e}"),
path: None,
locations: None,
extensions: None,
}],
},
));
},
}
}
results
}
async fn execute_subgraph_request(
client: &reqwest::Client,
url: &str,
query: &str,
variables: &Value,
operation_name: Option<&str>,
_timeout: Duration,
) -> Result<SubgraphResponse, reqwest::Error> {
let mut body = json!({
"query": query,
"variables": variables,
});
if let Some(op) = operation_name {
body["operationName"] = Value::String(op.to_string());
}
let resp = client.post(url).json(&body).send().await?;
let sg_resp: SubgraphResponse = resp.json().await?;
Ok(sg_resp)
}
fn merged_to_value(merged: &MergedResponse) -> Value {
let mut map = serde_json::Map::new();
map.insert("data".to_string(), merged.data.clone());
if !merged.errors.is_empty() {
map.insert(
"errors".to_string(),
serde_json::to_value(&merged.errors).unwrap_or(Value::Array(Vec::new())),
);
}
Value::Object(map)
}
async fn handle_health() -> impl IntoResponse {
Json(json!({"status": "healthy"}))
}
async fn handle_ready(State(state): State<GatewayState>) -> impl IntoResponse {
let subgraph_count = state.subgraphs.len();
Json(json!({
"status": "ready",
"subgraphs": subgraph_count,
}))
}