Skip to main content

fraiseql_server/routes/graphql/
handler.rs

1//! GraphQL HTTP handlers and execution logic.
2
3use std::{sync::atomic::Ordering, time::Instant};
4
5use axum::{
6    Json,
7    extract::{Query, State},
8    http::HeaderMap,
9};
10use fraiseql_core::{
11    apq::{ApqMetrics, ApqStorage},
12    db::traits::DatabaseAdapter,
13    security::SecurityContext,
14};
15use tracing::{debug, error, info, warn};
16
17use super::{
18    app_state::AppState,
19    request::{GraphQLGetParams, GraphQLRequest, GraphQLResponse},
20};
21use crate::{
22    error::{ErrorResponse, GraphQLError},
23    extractors::OptionalSecurityContext,
24    tracing_utils,
25};
26
27/// GraphQL HTTP handler for POST requests.
28///
29/// Handles POST requests to the GraphQL endpoint:
30/// 1. Extract W3C trace context from traceparent header (if present)
31/// 2. Validate GraphQL request (depth, complexity)
32/// 3. Parse GraphQL request body
33/// 4. Execute query via Executor with optional `SecurityContext`
34/// 5. Return GraphQL response with proper error formatting
35///
36/// Tracks execution timing and operation name for monitoring.
37/// Provides GraphQL spec-compliant error responses.
38/// Supports W3C Trace Context for distributed tracing.
39/// Supports OIDC authentication for RLS policy evaluation.
40///
41/// # Errors
42///
43/// Returns appropriate HTTP status codes based on error type.
44#[tracing::instrument(skip_all, fields(operation_name))]
45pub async fn graphql_handler<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
46    State(state): State<AppState<A>>,
47    headers: HeaderMap,
48    OptionalSecurityContext(security_context): OptionalSecurityContext,
49    Json(request): Json<GraphQLRequest>,
50) -> Result<GraphQLResponse, ErrorResponse> {
51    // Extract trace context from W3C headers
52    let trace_context = tracing_utils::extract_trace_context(&headers);
53    if trace_context.is_some() {
54        debug!("Extracted W3C trace context from incoming request");
55    }
56
57    if security_context.is_some() {
58        debug!("Authenticated request with security context");
59    }
60
61    execute_graphql_request(state, request, trace_context, security_context, &headers).await
62}
63
64/// GraphQL HTTP handler for GET requests.
65///
66/// Handles GET requests to the GraphQL endpoint per the GraphQL over HTTP spec.
67/// Query parameters:
68/// - `query`: Required, the GraphQL query string (URL-encoded)
69/// - `variables`: Optional, JSON-encoded variables object (URL-encoded)
70/// - `operationName`: Optional, name of the operation to execute
71///
72/// Supports W3C Trace Context via traceparent header for distributed tracing.
73///
74/// Example:
75/// ```text
76/// GET /graphql?query={users{id,name}}&variables={"limit":10}
77/// ```
78///
79/// # Errors
80///
81/// Returns `413 Payload Too Large` (via `ErrorResponse`) when the query string
82/// exceeds `AppState::max_get_query_bytes` (default 100 `KiB`, configurable via
83/// `ServerConfig::max_get_query_bytes`). Returns other HTTP status codes for
84/// additional error conditions.
85///
86/// # Note
87///
88/// Per GraphQL over HTTP spec, GET requests should only be used for queries,
89/// not mutations (which should use POST). This handler does not enforce that
90/// restriction but logs a warning for mutation-like queries.
91#[tracing::instrument(skip_all, fields(operation_name))]
92pub async fn graphql_get_handler<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
93    State(state): State<AppState<A>>,
94    headers: HeaderMap,
95    OptionalSecurityContext(security_context): OptionalSecurityContext,
96    Query(params): Query<GraphQLGetParams>,
97) -> Result<GraphQLResponse, ErrorResponse> {
98    // Reject oversized GET queries early to prevent DoS via query parsing.
99    let max_get_bytes = state.max_get_query_bytes;
100    if params.query.len() > max_get_bytes {
101        return Err(ErrorResponse::from_error(GraphQLError::request(format!(
102            "GET query string exceeds maximum allowed length ({max_get_bytes} bytes)"
103        ))));
104    }
105
106    // Parse variables from JSON string.
107    // Apply the same size cap as the query string — the URL-length limit imposed
108    // by reverse proxies/OS is real but not enforced by axum itself, so we guard
109    // explicitly to prevent parser DoS from a very large variables value.
110    let variables = if let Some(vars_str) = params.variables {
111        if vars_str.len() > max_get_bytes {
112            return Err(ErrorResponse::from_error(GraphQLError::request(format!(
113                "GET variables string exceeds maximum allowed length ({max_get_bytes} bytes)"
114            ))));
115        }
116        match serde_json::from_str::<serde_json::Value>(&vars_str) {
117            Ok(v) => Some(v),
118            Err(e) => {
119                warn!(
120                    error = %e,
121                    variables = %vars_str,
122                    "Failed to parse variables JSON in GET request"
123                );
124                return Err(ErrorResponse::from_error(GraphQLError::request(format!(
125                    "Invalid variables JSON: {e}"
126                ))));
127            },
128        }
129    } else {
130        None
131    };
132
133    // Warn if this looks like a mutation (GET should be for queries only)
134    if params.query.trim_start().starts_with("mutation") {
135        warn!(
136            operation_name = ?params.operation_name,
137            "Mutation sent via GET request - should use POST"
138        );
139    }
140
141    let trace_context = tracing_utils::extract_trace_context(&headers);
142    if trace_context.is_some() {
143        debug!("Extracted W3C trace context from incoming request");
144    }
145
146    let request = GraphQLRequest {
147        query: Some(params.query),
148        variables,
149        operation_name: params.operation_name,
150        extensions: None,
151        document_id: None,
152    };
153
154    if security_context.is_some() {
155        debug!("Authenticated GET request with security context");
156    }
157
158    execute_graphql_request(state, request, trace_context, security_context, &headers).await
159}
160
161/// Extract client IP address from headers.
162///
163/// # Security
164///
165/// Does NOT trust X-Forwarded-For or X-Real-IP headers, as these are trivially
166/// spoofable by attackers to bypass rate limiting. Returns "unknown" as a safe
167/// fallback — callers requiring real IPs should use `ConnectInfo<SocketAddr>`
168/// or `ProxyConfig::extract_client_ip()` with validated proxy chains.
169#[cfg(feature = "auth")]
170pub(crate) fn extract_ip_from_headers(_headers: &HeaderMap) -> String {
171    // SECURITY: Spoofable headers removed. Use ConnectInfo<SocketAddr> or
172    // ProxyConfig::extract_client_ip() for validated IP extraction.
173    "unknown".to_string()
174}
175
176/// Extract the APQ SHA-256 hash from the `extensions.persistedQuery` field, if present.
177pub(crate) fn extract_apq_hash(extensions: Option<&serde_json::Value>) -> Option<&str> {
178    extensions?.get("persistedQuery")?.get("sha256Hash")?.as_str()
179}
180
181/// Extract a trusted document ID from the request.
182///
183/// Supports three formats:
184/// 1. `documentId` (GraphQL over HTTP spec)
185/// 2. `extensions.persistedQuery.sha256Hash` (Apollo APQ format)
186/// 3. `extensions.doc_id` (Relay format)
187fn extract_document_id(request: &GraphQLRequest) -> Option<String> {
188    // 1. Top-level documentId field (GraphQL over HTTP spec)
189    if let Some(ref doc_id) = request.document_id {
190        return Some(doc_id.clone());
191    }
192    // 2. Extensions-based formats
193    if let Some(ext) = request.extensions.as_ref() {
194        // Relay format: extensions.doc_id
195        if let Some(doc_id) = ext.get("doc_id").and_then(|v| v.as_str()) {
196            return Some(doc_id.to_string());
197        }
198        // Apollo APQ format: extensions.persistedQuery.sha256Hash (also used for APQ)
199        if let Some(hash) = ext
200            .get("persistedQuery")
201            .and_then(|pq| pq.get("sha256Hash"))
202            .and_then(|h| h.as_str())
203        {
204            return Some(hash.to_string());
205        }
206    }
207    None
208}
209
210/// Resolve an APQ request: look up or register a persisted query.
211///
212/// Returns the resolved query body, or an error if the query is not found and no body was
213/// provided (the client should resend with the full body).
214///
215/// # Errors
216///
217/// Returns [`ErrorResponse`] if the hash doesn't match the body, or if the
218/// hash is unknown and no query body was provided (client must retry with full body).
219pub(crate) async fn resolve_apq(
220    apq_store: &dyn ApqStorage,
221    apq_metrics: &ApqMetrics,
222    hash: &str,
223    query_body: Option<&str>,
224) -> Result<String, ErrorResponse> {
225    if let Some(body) = query_body {
226        // Hash + body present: verify and register.
227        if !fraiseql_core::apq::verify_hash(body, hash) {
228            apq_metrics.record_error();
229            return Err(ErrorResponse::from_error(GraphQLError::persisted_query_mismatch()));
230        }
231        // Store the query (best-effort; log on failure).
232        if let Err(e) = apq_store.set(hash.to_owned(), body.to_owned()).await {
233            warn!(error = %e, "Failed to store APQ query — proceeding without caching");
234            apq_metrics.record_error();
235        } else {
236            apq_metrics.record_store();
237        }
238        Ok(body.to_owned())
239    } else {
240        // Hash only: look up.
241        match apq_store.get(hash).await {
242            Ok(Some(stored)) => {
243                apq_metrics.record_hit();
244                Ok(stored)
245            },
246            Ok(None) => {
247                apq_metrics.record_miss();
248                Err(ErrorResponse::from_error(GraphQLError::persisted_query_not_found()))
249            },
250            Err(e) => {
251                warn!(error = %e, "APQ store lookup failed — treating as miss");
252                apq_metrics.record_error();
253                Err(ErrorResponse::from_error(GraphQLError::persisted_query_not_found()))
254            },
255        }
256    }
257}
258
259/// Shared GraphQL execution logic for both GET and POST handlers.
260#[tracing::instrument(skip_all, fields(operation_name = request.operation_name.as_deref().unwrap_or("anonymous")))]
261async fn execute_graphql_request<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
262    state: AppState<A>,
263    mut request: GraphQLRequest,
264    #[cfg(feature = "federation")] _trace_context: Option<
265        fraiseql_core::federation::FederationTraceContext,
266    >,
267    #[cfg(not(feature = "federation"))] _trace_context: Option<()>,
268    mut security_context: Option<SecurityContext>,
269    headers: &HeaderMap,
270) -> Result<GraphQLResponse, ErrorResponse> {
271    // API key auth: if configured, try it before falling through to JWT/OIDC.
272    if security_context.is_none() {
273        if let Some(ref api_key_auth) = state.api_key_authenticator {
274            match api_key_auth.authenticate(headers).await {
275                crate::api_key::ApiKeyResult::Authenticated(ctx) => {
276                    debug!("Authenticated via API key");
277                    security_context = Some(*ctx);
278                },
279                crate::api_key::ApiKeyResult::Invalid => {
280                    return Err(ErrorResponse::from_error(GraphQLError::new(
281                        "Invalid API key",
282                        crate::error::ErrorCode::Unauthenticated,
283                    )));
284                },
285                crate::api_key::ApiKeyResult::NotPresent => {
286                    // Fall through to JWT/OIDC (or unauthenticated).
287                },
288            }
289        }
290    }
291
292    // Resolve query body — trusted documents take priority over APQ.
293    // If a trusted document store is configured, resolve the document ID first.
294    if let Some(ref td_store) = state.trusted_docs {
295        let doc_id = extract_document_id(&request);
296        match td_store.resolve(doc_id.as_deref(), request.query.as_deref()).await {
297            Ok(resolved) => {
298                if doc_id.is_some() {
299                    crate::trusted_documents::record_hit();
300                    debug!(document_id = ?doc_id, "Trusted document resolved");
301                }
302                // Replace the query with the resolved body so APQ and execution use it.
303                request.query = Some(resolved);
304            },
305            Err(crate::trusted_documents::TrustedDocumentError::ForbiddenRawQuery) => {
306                crate::trusted_documents::record_rejected();
307                return Err(ErrorResponse::from_error(GraphQLError::forbidden_query()));
308            },
309            Err(crate::trusted_documents::TrustedDocumentError::DocumentNotFound { id }) => {
310                crate::trusted_documents::record_miss();
311                return Err(ErrorResponse::from_error(GraphQLError::document_not_found(&id)));
312            },
313            Err(crate::trusted_documents::TrustedDocumentError::ManifestLoad(msg)) => {
314                error!(error = %msg, "Trusted document manifest error");
315                return Err(ErrorResponse::from_error(GraphQLError::internal(
316                    "Trusted documents unavailable",
317                )));
318            },
319        }
320    }
321
322    // Resolve query body — either from APQ or from the request payload.
323    let query = if let Some(hash) = extract_apq_hash(request.extensions.as_ref()) {
324        if let Some(ref store) = state.apq_store {
325            resolve_apq(store.as_ref(), &state.apq_metrics, hash, request.query.as_deref()).await?
326        } else {
327            // APQ extension present but no store configured — use the body if available.
328            request.query.ok_or_else(|| {
329                ErrorResponse::from_error(GraphQLError::request(
330                    "APQ is not enabled on this server and no query body was provided",
331                ))
332            })?
333        }
334    } else {
335        request
336            .query
337            .ok_or_else(|| ErrorResponse::from_error(GraphQLError::request("No query provided")))?
338    };
339
340    let start_time = Instant::now();
341    let metrics = &state.metrics;
342
343    // Increment total queries counter
344    metrics.queries_total.fetch_add(1, Ordering::Relaxed);
345
346    info!(
347        query_length = query.len(),
348        has_variables = request.variables.is_some(),
349        operation_name = ?request.operation_name,
350        "Executing GraphQL query"
351    );
352
353    // Validate request
354    let validator = &state.validator;
355
356    // Validate query
357    if let Err(e) = validator.validate_query(&query) {
358        error!(
359            error = %e,
360            operation_name = ?request.operation_name,
361            "Query validation failed"
362        );
363        metrics.queries_error.fetch_add(1, Ordering::Relaxed);
364        metrics.validation_errors_total.fetch_add(1, Ordering::Relaxed);
365
366        // Check rate limiting for validation errors
367        #[cfg(feature = "auth")]
368        {
369            let client_ip = extract_ip_from_headers(headers);
370            if state.graphql_rate_limiter.check(&client_ip).is_err() {
371                return Err(ErrorResponse::from_error(GraphQLError::rate_limited(
372                    "Too many validation errors. Please reduce query complexity and try again.",
373                )));
374            }
375        }
376
377        let graphql_error = match e {
378            crate::validation::ComplexityValidationError::QueryTooDeep {
379                max_depth,
380                actual_depth,
381            } => GraphQLError::validation(format!(
382                "Query exceeds maximum depth: {actual_depth} > {max_depth}"
383            )),
384            crate::validation::ComplexityValidationError::QueryTooComplex {
385                max_complexity,
386                actual_complexity,
387            } => GraphQLError::validation(format!(
388                "Query exceeds maximum complexity: {actual_complexity} > {max_complexity}"
389            )),
390            crate::validation::ComplexityValidationError::MalformedQuery(msg) => {
391                metrics.parse_errors_total.fetch_add(1, Ordering::Relaxed);
392                GraphQLError::parse(msg)
393            },
394            crate::validation::ComplexityValidationError::InvalidVariables(msg) => {
395                GraphQLError::request(msg)
396            },
397            crate::validation::ComplexityValidationError::TooManyAliases {
398                max_aliases,
399                actual_aliases,
400            } => GraphQLError::validation(format!(
401                "Query exceeds maximum alias count: {actual_aliases} > {max_aliases}"
402            )),
403            // Reason: non_exhaustive requires catch-all for cross-crate matches
404            _ => GraphQLError::validation("Validation error"),
405        };
406        return Err(ErrorResponse::from_error(graphql_error));
407    }
408
409    // Validate variables
410    if let Err(e) = validator.validate_variables(request.variables.as_ref()) {
411        error!(
412            error = %e,
413            operation_name = ?request.operation_name,
414            "Variables validation failed"
415        );
416        metrics.queries_error.fetch_add(1, Ordering::Relaxed);
417        metrics.validation_errors_total.fetch_add(1, Ordering::Relaxed);
418
419        // Check rate limiting for validation errors
420        #[cfg(feature = "auth")]
421        {
422            let client_ip = extract_ip_from_headers(headers);
423            if state.graphql_rate_limiter.check(&client_ip).is_err() {
424                return Err(ErrorResponse::from_error(GraphQLError::rate_limited(
425                    "Too many validation errors. Please reduce query complexity and try again.",
426                )));
427            }
428        }
429
430        return Err(ErrorResponse::from_error(GraphQLError::request(e.to_string())));
431    }
432
433    // Check federation circuit breaker for _entities queries before execution
434    #[cfg(feature = "federation")]
435    let cb_entity_types: Vec<String> = if fraiseql_core::federation::is_federation_query(&query) {
436        if let Some(ref cb_manager) = state.circuit_breaker {
437            let entity_types = crate::federation::circuit_breaker::extract_entity_types(
438                request.variables.as_ref(),
439            );
440            for entity_type in &entity_types {
441                if let Some(retry_after) = cb_manager.check(entity_type) {
442                    warn!(
443                        entity = %entity_type,
444                        retry_after_secs = retry_after,
445                        "Federation circuit breaker open — rejecting _entities request"
446                    );
447                    metrics.queries_error.fetch_add(1, Ordering::Relaxed);
448                    return Err(ErrorResponse::from_error(GraphQLError::circuit_breaker_open(
449                        entity_type,
450                        retry_after,
451                    )));
452                }
453            }
454            entity_types
455        } else {
456            vec![]
457        }
458    } else {
459        vec![]
460    };
461    #[cfg(not(feature = "federation"))]
462    let _cb_entity_types: Vec<String> = vec![];
463
464    // Resolve tenant key from JWT / X-Tenant-ID header / Host header.
465    let tenant_key = super::TenantKeyResolver::resolve(
466        security_context.as_ref(),
467        headers,
468        state.domain_registry(),
469    )
470    .map_err(|e| {
471        ErrorResponse::from_error(GraphQLError::new(
472            e.to_string(),
473            crate::error::ErrorCode::ValidationError,
474        ))
475    })?;
476
477    // Execute query (defer error propagation to record circuit breaker outcome first)
478    let executor = state.executor_for_tenant(tenant_key.as_deref()).map_err(|e| {
479        ErrorResponse::from_error(GraphQLError::new(
480            e.to_string(),
481            crate::error::ErrorCode::Forbidden,
482        ))
483    })?;
484    let exec_result = if let Some(sec_ctx) = security_context {
485        executor
486            .execute_with_security(&query, request.variables.as_ref(), &sec_ctx)
487            .await
488    } else {
489        executor.execute(&query, request.variables.as_ref()).await
490    };
491
492    // Record circuit breaker outcome for federation entity queries
493    #[cfg(feature = "federation")]
494    if !cb_entity_types.is_empty() {
495        if let Some(ref cb_manager) = state.circuit_breaker {
496            if exec_result.is_ok() {
497                for entity_type in &cb_entity_types {
498                    cb_manager.record_success(entity_type);
499                }
500            } else {
501                for entity_type in &cb_entity_types {
502                    cb_manager.record_failure(entity_type);
503                }
504            }
505        }
506    }
507
508    // Propagate execution errors with metrics
509    let op_name = request.operation_name.as_deref().unwrap_or("");
510    let result = exec_result.map_err(|e| {
511        let elapsed = start_time.elapsed();
512        #[allow(clippy::cast_possible_truncation)]
513        // Reason: microsecond counter cannot exceed u64 in any practical uptime
514        let elapsed_us = elapsed.as_micros() as u64;
515        error!(
516            error = %e,
517            elapsed_ms = elapsed.as_millis(),
518            operation_name = ?request.operation_name,
519            "Query execution failed"
520        );
521        metrics.queries_error.fetch_add(1, Ordering::Relaxed);
522        metrics.execution_errors_total.fetch_add(1, Ordering::Relaxed);
523        // Record duration even for failed queries
524        metrics.queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
525        metrics.operation_metrics.record(op_name, elapsed_us, true);
526        let err = state.error_sanitizer.sanitize(GraphQLError::from_fraiseql_error(&e));
527        ErrorResponse::from_error(err)
528    })?;
529
530    let elapsed = start_time.elapsed();
531    #[allow(clippy::cast_possible_truncation)]
532    // Reason: microsecond counter cannot exceed u64 in any practical uptime
533    let elapsed_us = elapsed.as_micros() as u64;
534
535    // Record successful query metrics
536    metrics.queries_success.fetch_add(1, Ordering::Relaxed);
537    metrics.queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
538    metrics.db_queries_total.fetch_add(1, Ordering::Relaxed);
539    metrics.db_queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
540    metrics.operation_metrics.record(op_name, elapsed_us, false);
541
542    // Record federation-specific metrics for federation queries
543    #[cfg(feature = "federation")]
544    if fraiseql_core::federation::is_federation_query(&query) {
545        metrics.record_entity_resolution(elapsed_us, true);
546    }
547
548    debug!(
549        elapsed_ms = elapsed.as_millis(),
550        operation_name = ?request.operation_name,
551        "Query executed successfully"
552    );
553
554    #[allow(unused_mut)]
555    // Reason: mut is required by decrypt_response(&mut ...) when the secrets feature is enabled
556    let mut response_json = result;
557
558    // Decrypt encrypted fields if field encryption is configured
559    #[cfg(feature = "secrets")]
560    if let Some(ref encryption) = state.field_encryption {
561        if encryption.has_encrypted_fields() {
562            encryption.decrypt_response(&mut response_json).await.map_err(|e| {
563                error!(error = %e, "Field decryption failed");
564                let err = state
565                    .error_sanitizer
566                    .sanitize(GraphQLError::internal("Field decryption failed".to_string()));
567                ErrorResponse::from_error(err)
568            })?;
569        }
570    }
571
572    Ok(GraphQLResponse {
573        body: response_json,
574    })
575}