use std::{sync::atomic::Ordering, time::Instant};
use axum::{
Json,
extract::{Query, State},
http::HeaderMap,
};
use fraiseql_core::{
apq::{ApqMetrics, ApqStorage},
db::traits::DatabaseAdapter,
security::SecurityContext,
};
use tracing::{debug, error, info, warn};
use super::{
app_state::AppState,
request::{GraphQLGetParams, GraphQLRequest, GraphQLResponse},
};
use crate::{
error::{ErrorResponse, GraphQLError},
extractors::OptionalSecurityContext,
tracing_utils,
};
#[tracing::instrument(skip_all, fields(operation_name))]
pub async fn graphql_handler<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
State(state): State<AppState<A>>,
headers: HeaderMap,
OptionalSecurityContext(security_context): OptionalSecurityContext,
Json(request): Json<GraphQLRequest>,
) -> Result<GraphQLResponse, ErrorResponse> {
let trace_context = tracing_utils::extract_trace_context(&headers);
if trace_context.is_some() {
debug!("Extracted W3C trace context from incoming request");
}
if security_context.is_some() {
debug!("Authenticated request with security context");
}
execute_graphql_request(state, request, trace_context, security_context, &headers).await
}
#[tracing::instrument(skip_all, fields(operation_name))]
pub async fn graphql_get_handler<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
State(state): State<AppState<A>>,
headers: HeaderMap,
OptionalSecurityContext(security_context): OptionalSecurityContext,
Query(params): Query<GraphQLGetParams>,
) -> Result<GraphQLResponse, ErrorResponse> {
let max_get_bytes = state.max_get_query_bytes;
if params.query.len() > max_get_bytes {
return Err(ErrorResponse::from_error(GraphQLError::request(format!(
"GET query string exceeds maximum allowed length ({max_get_bytes} bytes)"
))));
}
let variables = if let Some(vars_str) = params.variables {
if vars_str.len() > max_get_bytes {
return Err(ErrorResponse::from_error(GraphQLError::request(format!(
"GET variables string exceeds maximum allowed length ({max_get_bytes} bytes)"
))));
}
match serde_json::from_str::<serde_json::Value>(&vars_str) {
Ok(v) => Some(v),
Err(e) => {
warn!(
error = %e,
variables = %vars_str,
"Failed to parse variables JSON in GET request"
);
return Err(ErrorResponse::from_error(GraphQLError::request(format!(
"Invalid variables JSON: {e}"
))));
},
}
} else {
None
};
if params.query.trim_start().starts_with("mutation") {
warn!(
operation_name = ?params.operation_name,
"Mutation sent via GET request - should use POST"
);
}
let trace_context = tracing_utils::extract_trace_context(&headers);
if trace_context.is_some() {
debug!("Extracted W3C trace context from incoming request");
}
let request = GraphQLRequest {
query: Some(params.query),
variables,
operation_name: params.operation_name,
extensions: None,
document_id: None,
};
if security_context.is_some() {
debug!("Authenticated GET request with security context");
}
execute_graphql_request(state, request, trace_context, security_context, &headers).await
}
#[cfg(feature = "auth")]
pub(crate) fn extract_ip_from_headers(_headers: &HeaderMap) -> String {
"unknown".to_string()
}
pub(crate) fn extract_apq_hash(extensions: Option<&serde_json::Value>) -> Option<&str> {
extensions?.get("persistedQuery")?.get("sha256Hash")?.as_str()
}
fn extract_document_id(request: &GraphQLRequest) -> Option<String> {
if let Some(ref doc_id) = request.document_id {
return Some(doc_id.clone());
}
if let Some(ext) = request.extensions.as_ref() {
if let Some(doc_id) = ext.get("doc_id").and_then(|v| v.as_str()) {
return Some(doc_id.to_string());
}
if let Some(hash) = ext
.get("persistedQuery")
.and_then(|pq| pq.get("sha256Hash"))
.and_then(|h| h.as_str())
{
return Some(hash.to_string());
}
}
None
}
pub(crate) async fn resolve_apq(
apq_store: &dyn ApqStorage,
apq_metrics: &ApqMetrics,
hash: &str,
query_body: Option<&str>,
) -> Result<String, ErrorResponse> {
if let Some(body) = query_body {
if !fraiseql_core::apq::verify_hash(body, hash) {
apq_metrics.record_error();
return Err(ErrorResponse::from_error(GraphQLError::persisted_query_mismatch()));
}
if let Err(e) = apq_store.set(hash.to_owned(), body.to_owned()).await {
warn!(error = %e, "Failed to store APQ query — proceeding without caching");
apq_metrics.record_error();
} else {
apq_metrics.record_store();
}
Ok(body.to_owned())
} else {
match apq_store.get(hash).await {
Ok(Some(stored)) => {
apq_metrics.record_hit();
Ok(stored)
},
Ok(None) => {
apq_metrics.record_miss();
Err(ErrorResponse::from_error(GraphQLError::persisted_query_not_found()))
},
Err(e) => {
warn!(error = %e, "APQ store lookup failed — treating as miss");
apq_metrics.record_error();
Err(ErrorResponse::from_error(GraphQLError::persisted_query_not_found()))
},
}
}
}
#[tracing::instrument(skip_all, fields(operation_name = request.operation_name.as_deref().unwrap_or("anonymous")))]
async fn execute_graphql_request<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
state: AppState<A>,
mut request: GraphQLRequest,
#[cfg(feature = "federation")] _trace_context: Option<
fraiseql_core::federation::FederationTraceContext,
>,
#[cfg(not(feature = "federation"))] _trace_context: Option<()>,
mut security_context: Option<SecurityContext>,
headers: &HeaderMap,
) -> Result<GraphQLResponse, ErrorResponse> {
if security_context.is_none() {
if let Some(ref api_key_auth) = state.api_key_authenticator {
match api_key_auth.authenticate(headers).await {
crate::api_key::ApiKeyResult::Authenticated(ctx) => {
debug!("Authenticated via API key");
security_context = Some(*ctx);
},
crate::api_key::ApiKeyResult::Invalid => {
return Err(ErrorResponse::from_error(GraphQLError::new(
"Invalid API key",
crate::error::ErrorCode::Unauthenticated,
)));
},
crate::api_key::ApiKeyResult::NotPresent => {
},
}
}
}
if let Some(ref td_store) = state.trusted_docs {
let doc_id = extract_document_id(&request);
match td_store.resolve(doc_id.as_deref(), request.query.as_deref()).await {
Ok(resolved) => {
if doc_id.is_some() {
crate::trusted_documents::record_hit();
debug!(document_id = ?doc_id, "Trusted document resolved");
}
request.query = Some(resolved);
},
Err(crate::trusted_documents::TrustedDocumentError::ForbiddenRawQuery) => {
crate::trusted_documents::record_rejected();
return Err(ErrorResponse::from_error(GraphQLError::forbidden_query()));
},
Err(crate::trusted_documents::TrustedDocumentError::DocumentNotFound { id }) => {
crate::trusted_documents::record_miss();
return Err(ErrorResponse::from_error(GraphQLError::document_not_found(&id)));
},
Err(crate::trusted_documents::TrustedDocumentError::ManifestLoad(msg)) => {
error!(error = %msg, "Trusted document manifest error");
return Err(ErrorResponse::from_error(GraphQLError::internal(
"Trusted documents unavailable",
)));
},
}
}
let query = if let Some(hash) = extract_apq_hash(request.extensions.as_ref()) {
if let Some(ref store) = state.apq_store {
resolve_apq(store.as_ref(), &state.apq_metrics, hash, request.query.as_deref()).await?
} else {
request.query.ok_or_else(|| {
ErrorResponse::from_error(GraphQLError::request(
"APQ is not enabled on this server and no query body was provided",
))
})?
}
} else {
request
.query
.ok_or_else(|| ErrorResponse::from_error(GraphQLError::request("No query provided")))?
};
let start_time = Instant::now();
let metrics = &state.metrics;
metrics.queries_total.fetch_add(1, Ordering::Relaxed);
info!(
query_length = query.len(),
has_variables = request.variables.is_some(),
operation_name = ?request.operation_name,
"Executing GraphQL query"
);
let validator = &state.validator;
if let Err(e) = validator.validate_query(&query) {
error!(
error = %e,
operation_name = ?request.operation_name,
"Query validation failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.validation_errors_total.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "auth")]
{
let client_ip = extract_ip_from_headers(headers);
if state.graphql_rate_limiter.check(&client_ip).is_err() {
return Err(ErrorResponse::from_error(GraphQLError::rate_limited(
"Too many validation errors. Please reduce query complexity and try again.",
)));
}
}
let graphql_error = match e {
crate::validation::ComplexityValidationError::QueryTooDeep {
max_depth,
actual_depth,
} => GraphQLError::validation(format!(
"Query exceeds maximum depth: {actual_depth} > {max_depth}"
)),
crate::validation::ComplexityValidationError::QueryTooComplex {
max_complexity,
actual_complexity,
} => GraphQLError::validation(format!(
"Query exceeds maximum complexity: {actual_complexity} > {max_complexity}"
)),
crate::validation::ComplexityValidationError::MalformedQuery(msg) => {
metrics.parse_errors_total.fetch_add(1, Ordering::Relaxed);
GraphQLError::parse(msg)
},
crate::validation::ComplexityValidationError::InvalidVariables(msg) => {
GraphQLError::request(msg)
},
crate::validation::ComplexityValidationError::TooManyAliases {
max_aliases,
actual_aliases,
} => GraphQLError::validation(format!(
"Query exceeds maximum alias count: {actual_aliases} > {max_aliases}"
)),
_ => GraphQLError::validation("Validation error"),
};
return Err(ErrorResponse::from_error(graphql_error));
}
if let Err(e) = validator.validate_variables(request.variables.as_ref()) {
error!(
error = %e,
operation_name = ?request.operation_name,
"Variables validation failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.validation_errors_total.fetch_add(1, Ordering::Relaxed);
#[cfg(feature = "auth")]
{
let client_ip = extract_ip_from_headers(headers);
if state.graphql_rate_limiter.check(&client_ip).is_err() {
return Err(ErrorResponse::from_error(GraphQLError::rate_limited(
"Too many validation errors. Please reduce query complexity and try again.",
)));
}
}
return Err(ErrorResponse::from_error(GraphQLError::request(e.to_string())));
}
#[cfg(feature = "federation")]
let cb_entity_types: Vec<String> = if fraiseql_core::federation::is_federation_query(&query) {
if let Some(ref cb_manager) = state.circuit_breaker {
let entity_types = crate::federation::circuit_breaker::extract_entity_types(
request.variables.as_ref(),
);
for entity_type in &entity_types {
if let Some(retry_after) = cb_manager.check(entity_type) {
warn!(
entity = %entity_type,
retry_after_secs = retry_after,
"Federation circuit breaker open — rejecting _entities request"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
return Err(ErrorResponse::from_error(GraphQLError::circuit_breaker_open(
entity_type,
retry_after,
)));
}
}
entity_types
} else {
vec![]
}
} else {
vec![]
};
#[cfg(not(feature = "federation"))]
let _cb_entity_types: Vec<String> = vec![];
let tenant_key = super::TenantKeyResolver::resolve(
security_context.as_ref(),
headers,
state.domain_registry(),
)
.map_err(|e| {
ErrorResponse::from_error(GraphQLError::new(
e.to_string(),
crate::error::ErrorCode::ValidationError,
))
})?;
let executor = state.executor_for_tenant(tenant_key.as_deref()).map_err(|e| {
ErrorResponse::from_error(GraphQLError::new(
e.to_string(),
crate::error::ErrorCode::Forbidden,
))
})?;
let exec_result = if let Some(sec_ctx) = security_context {
executor
.execute_with_security(&query, request.variables.as_ref(), &sec_ctx)
.await
} else {
executor.execute(&query, request.variables.as_ref()).await
};
#[cfg(feature = "federation")]
if !cb_entity_types.is_empty() {
if let Some(ref cb_manager) = state.circuit_breaker {
if exec_result.is_ok() {
for entity_type in &cb_entity_types {
cb_manager.record_success(entity_type);
}
} else {
for entity_type in &cb_entity_types {
cb_manager.record_failure(entity_type);
}
}
}
}
let op_name = request.operation_name.as_deref().unwrap_or("");
let result = exec_result.map_err(|e| {
let elapsed = start_time.elapsed();
#[allow(clippy::cast_possible_truncation)]
let elapsed_us = elapsed.as_micros() as u64;
error!(
error = %e,
elapsed_ms = elapsed.as_millis(),
operation_name = ?request.operation_name,
"Query execution failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.execution_errors_total.fetch_add(1, Ordering::Relaxed);
metrics.queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
metrics.operation_metrics.record(op_name, elapsed_us, true);
let err = state.error_sanitizer.sanitize(GraphQLError::from_fraiseql_error(&e));
ErrorResponse::from_error(err)
})?;
let elapsed = start_time.elapsed();
#[allow(clippy::cast_possible_truncation)]
let elapsed_us = elapsed.as_micros() as u64;
metrics.queries_success.fetch_add(1, Ordering::Relaxed);
metrics.queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
metrics.db_queries_total.fetch_add(1, Ordering::Relaxed);
metrics.db_queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
metrics.operation_metrics.record(op_name, elapsed_us, false);
#[cfg(feature = "federation")]
if fraiseql_core::federation::is_federation_query(&query) {
metrics.record_entity_resolution(elapsed_us, true);
}
debug!(
elapsed_ms = elapsed.as_millis(),
operation_name = ?request.operation_name,
"Query executed successfully"
);
#[allow(unused_mut)]
let mut response_json = result;
#[cfg(feature = "secrets")]
if let Some(ref encryption) = state.field_encryption {
if encryption.has_encrypted_fields() {
encryption.decrypt_response(&mut response_json).await.map_err(|e| {
error!(error = %e, "Field decryption failed");
let err = state
.error_sanitizer
.sanitize(GraphQLError::internal("Field decryption failed".to_string()));
ErrorResponse::from_error(err)
})?;
}
}
Ok(GraphQLResponse {
body: response_json,
})
}