use std::{
sync::{Arc, atomic::Ordering},
time::Instant,
};
use axum::{
Json,
extract::{Query, State},
http::HeaderMap,
response::{IntoResponse, Response},
};
use fraiseql_core::{db::traits::DatabaseAdapter, runtime::Executor, security::SecurityContext};
use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, warn};
use crate::{
auth::rate_limiting::{KeyedRateLimiter, RateLimitConfig},
error::{ErrorResponse, GraphQLError},
extractors::OptionalSecurityContext,
metrics_server::MetricsCollector,
tracing_utils,
validation::RequestValidator,
};
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphQLRequest {
pub query: String,
#[serde(default)]
pub variables: Option<serde_json::Value>,
#[serde(default)]
pub operation_name: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphQLGetParams {
pub query: String,
#[serde(default)]
pub variables: Option<String>,
#[serde(default)]
pub operation_name: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct GraphQLResponse {
#[serde(flatten)]
pub body: serde_json::Value,
}
impl IntoResponse for GraphQLResponse {
fn into_response(self) -> Response {
Json(self.body).into_response()
}
}
#[derive(Clone)]
pub struct AppState<A: DatabaseAdapter> {
pub executor: Arc<Executor<A>>,
pub metrics: Arc<MetricsCollector>,
pub cache: Option<Arc<fraiseql_arrow::cache::QueryCache>>,
pub config: Option<Arc<crate::config::ServerConfig>>,
pub graphql_rate_limiter: Arc<KeyedRateLimiter>,
}
impl<A: DatabaseAdapter> AppState<A> {
#[must_use]
pub fn new(executor: Arc<Executor<A>>) -> Self {
Self {
executor,
metrics: Arc::new(MetricsCollector::new()),
cache: None,
config: None,
graphql_rate_limiter: Arc::new(KeyedRateLimiter::new(
RateLimitConfig::per_ip_standard(),
)),
}
}
#[must_use]
pub fn with_metrics(executor: Arc<Executor<A>>, metrics: Arc<MetricsCollector>) -> Self {
Self {
executor,
metrics,
cache: None,
config: None,
graphql_rate_limiter: Arc::new(KeyedRateLimiter::new(
RateLimitConfig::per_ip_standard(),
)),
}
}
#[must_use]
pub fn with_cache(
executor: Arc<Executor<A>>,
cache: Arc<fraiseql_arrow::cache::QueryCache>,
) -> Self {
Self {
executor,
metrics: Arc::new(MetricsCollector::new()),
cache: Some(cache),
config: None,
graphql_rate_limiter: Arc::new(KeyedRateLimiter::new(
RateLimitConfig::per_ip_standard(),
)),
}
}
#[must_use]
pub fn with_cache_and_config(
executor: Arc<Executor<A>>,
cache: Arc<fraiseql_arrow::cache::QueryCache>,
config: Arc<crate::config::ServerConfig>,
) -> Self {
Self {
executor,
metrics: Arc::new(MetricsCollector::new()),
cache: Some(cache),
config: Some(config),
graphql_rate_limiter: Arc::new(KeyedRateLimiter::new(
RateLimitConfig::per_ip_standard(),
)),
}
}
pub fn cache(&self) -> Option<&Arc<fraiseql_arrow::cache::QueryCache>> {
self.cache.as_ref()
}
pub fn server_config(&self) -> Option<&Arc<crate::config::ServerConfig>> {
self.config.as_ref()
}
pub fn sanitized_config(&self) -> Option<crate::routes::api::types::SanitizedConfig> {
self.config
.as_ref()
.map(|cfg| crate::routes::api::types::SanitizedConfig::from_config(cfg))
}
}
pub async fn graphql_handler<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
State(state): State<AppState<A>>,
headers: HeaderMap,
OptionalSecurityContext(security_context): OptionalSecurityContext,
Json(request): Json<GraphQLRequest>,
) -> Result<GraphQLResponse, ErrorResponse> {
let trace_context = tracing_utils::extract_trace_context(&headers);
if trace_context.is_some() {
debug!("Extracted W3C trace context from incoming request");
}
if security_context.is_some() {
debug!("Authenticated request with security context");
}
execute_graphql_request(state, request, trace_context, security_context, &headers).await
}
pub async fn graphql_get_handler<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
State(state): State<AppState<A>>,
headers: HeaderMap,
Query(params): Query<GraphQLGetParams>,
) -> Result<GraphQLResponse, ErrorResponse> {
let variables = if let Some(vars_str) = params.variables {
match serde_json::from_str::<serde_json::Value>(&vars_str) {
Ok(v) => Some(v),
Err(e) => {
warn!(
error = %e,
variables = %vars_str,
"Failed to parse variables JSON in GET request"
);
return Err(ErrorResponse::from_error(GraphQLError::request(format!(
"Invalid variables JSON: {e}"
))));
},
}
} else {
None
};
if params.query.trim_start().starts_with("mutation") {
warn!(
operation_name = ?params.operation_name,
"Mutation sent via GET request - should use POST"
);
}
let trace_context = tracing_utils::extract_trace_context(&headers);
if trace_context.is_some() {
debug!("Extracted W3C trace context from incoming request");
}
let request = GraphQLRequest {
query: params.query,
variables,
operation_name: params.operation_name,
};
execute_graphql_request(state, request, trace_context, None, &headers).await
}
fn extract_ip_from_headers(_headers: &HeaderMap) -> String {
"unknown".to_string()
}
async fn execute_graphql_request<A: DatabaseAdapter + Clone + Send + Sync + 'static>(
state: AppState<A>,
request: GraphQLRequest,
_trace_context: Option<fraiseql_core::federation::FederationTraceContext>,
security_context: Option<SecurityContext>,
headers: &HeaderMap,
) -> Result<GraphQLResponse, ErrorResponse> {
let start_time = Instant::now();
let metrics = &state.metrics;
metrics.queries_total.fetch_add(1, Ordering::Relaxed);
info!(
query_length = request.query.len(),
has_variables = request.variables.is_some(),
operation_name = ?request.operation_name,
"Executing GraphQL query"
);
let validator = RequestValidator::new();
if let Err(e) = validator.validate_query(&request.query) {
error!(
error = %e,
operation_name = ?request.operation_name,
"Query validation failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.validation_errors_total.fetch_add(1, Ordering::Relaxed);
let client_ip = extract_ip_from_headers(headers);
if state.graphql_rate_limiter.check(&client_ip).is_err() {
return Err(ErrorResponse::from_error(GraphQLError::rate_limited(
"Too many validation errors. Please reduce query complexity and try again.",
)));
}
let graphql_error = match e {
crate::validation::ValidationError::QueryTooDeep {
max_depth,
actual_depth,
} => GraphQLError::validation(format!(
"Query exceeds maximum depth: {actual_depth} > {max_depth}"
)),
crate::validation::ValidationError::QueryTooComplex {
max_complexity,
actual_complexity,
} => GraphQLError::validation(format!(
"Query exceeds maximum complexity: {actual_complexity} > {max_complexity}"
)),
crate::validation::ValidationError::MalformedQuery(msg) => {
metrics.parse_errors_total.fetch_add(1, Ordering::Relaxed);
GraphQLError::parse(msg)
},
crate::validation::ValidationError::InvalidVariables(msg) => GraphQLError::request(msg),
};
return Err(ErrorResponse::from_error(graphql_error));
}
if let Err(e) = validator.validate_variables(request.variables.as_ref()) {
error!(
error = %e,
operation_name = ?request.operation_name,
"Variables validation failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.validation_errors_total.fetch_add(1, Ordering::Relaxed);
let client_ip = extract_ip_from_headers(headers);
if state.graphql_rate_limiter.check(&client_ip).is_err() {
return Err(ErrorResponse::from_error(GraphQLError::rate_limited(
"Too many validation errors. Please reduce query complexity and try again.",
)));
}
return Err(ErrorResponse::from_error(GraphQLError::request(e.to_string())));
}
let result = if let Some(sec_ctx) = security_context {
state
.executor
.execute_with_security(&request.query, request.variables.as_ref(), &sec_ctx)
.await
.map_err(|e| {
let elapsed = start_time.elapsed();
error!(
error = %e,
elapsed_ms = elapsed.as_millis(),
operation_name = ?request.operation_name,
"Query execution failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.execution_errors_total.fetch_add(1, Ordering::Relaxed);
metrics
.queries_duration_us
.fetch_add(elapsed.as_micros() as u64, Ordering::Relaxed);
ErrorResponse::from_error(GraphQLError::execution(&e.to_string()))
})?
} else {
state
.executor
.execute(&request.query, request.variables.as_ref())
.await
.map_err(|e| {
let elapsed = start_time.elapsed();
error!(
error = %e,
elapsed_ms = elapsed.as_millis(),
operation_name = ?request.operation_name,
"Query execution failed"
);
metrics.queries_error.fetch_add(1, Ordering::Relaxed);
metrics.execution_errors_total.fetch_add(1, Ordering::Relaxed);
metrics
.queries_duration_us
.fetch_add(elapsed.as_micros() as u64, Ordering::Relaxed);
ErrorResponse::from_error(GraphQLError::execution(&e.to_string()))
})?
};
let elapsed = start_time.elapsed();
let elapsed_us = elapsed.as_micros() as u64;
metrics.queries_success.fetch_add(1, Ordering::Relaxed);
metrics.queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
metrics.db_queries_total.fetch_add(1, Ordering::Relaxed);
metrics.db_queries_duration_us.fetch_add(elapsed_us, Ordering::Relaxed);
if fraiseql_core::federation::is_federation_query(&request.query) {
metrics.record_entity_resolution(elapsed_us, true);
}
debug!(
response_length = result.len(),
elapsed_ms = elapsed.as_millis(),
operation_name = ?request.operation_name,
"Query executed successfully"
);
let response_json: serde_json::Value = serde_json::from_str(&result).map_err(|e| {
error!(
error = %e,
response_length = result.len(),
"Failed to deserialize executor response"
);
ErrorResponse::from_error(GraphQLError::internal(format!(
"Failed to process response: {e}"
)))
})?;
Ok(GraphQLResponse {
body: response_json,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graphql_request_deserialize() {
let json = r#"{"query": "{ users { id } }"}"#;
let request: GraphQLRequest = serde_json::from_str(json).unwrap();
assert_eq!(request.query, "{ users { id } }");
assert!(request.variables.is_none());
}
#[test]
fn test_graphql_request_with_variables() {
let json = r#"{"query": "query($id: ID!) { user(id: $id) { name } }", "variables": {"id": "123"}}"#;
let request: GraphQLRequest = serde_json::from_str(json).unwrap();
assert!(request.variables.is_some());
}
#[test]
fn test_graphql_get_params_deserialize() {
let params: GraphQLGetParams = serde_json::from_value(serde_json::json!({
"query": "{ users { id } }",
"operationName": "GetUsers"
}))
.unwrap();
assert_eq!(params.query, "{ users { id } }");
assert_eq!(params.operation_name, Some("GetUsers".to_string()));
assert!(params.variables.is_none());
}
#[test]
fn test_graphql_get_params_with_variables() {
let params: GraphQLGetParams = serde_json::from_value(serde_json::json!({
"query": "query($id: ID!) { user(id: $id) { name } }",
"variables": r#"{"id": "123"}"#
}))
.unwrap();
assert!(params.variables.is_some());
let vars_str = params.variables.unwrap();
let vars: serde_json::Value = serde_json::from_str(&vars_str).unwrap();
assert_eq!(vars["id"], "123");
}
#[test]
fn test_graphql_get_params_camel_case() {
let params: GraphQLGetParams = serde_json::from_value(serde_json::json!({
"query": "{ users { id } }",
"operationName": "TestOp"
}))
.unwrap();
assert_eq!(params.operation_name, Some("TestOp".to_string()));
}
#[test]
fn test_appstate_has_cache_field() {
let _note = "AppState<A> includes: executor, metrics, cache, config";
assert!(!_note.is_empty());
}
#[test]
fn test_appstate_has_config_field() {
let _note = "AppState<A>::cache: Option<Arc<QueryCache>>";
assert!(!_note.is_empty());
}
#[test]
fn test_appstate_with_cache_constructor() {
let _note = "AppState::with_cache(executor, cache) -> Self";
assert!(!_note.is_empty());
}
#[test]
fn test_appstate_with_cache_and_config_constructor() {
let _note = "AppState::with_cache_and_config(executor, cache, config) -> Self";
assert!(!_note.is_empty());
}
#[test]
fn test_appstate_cache_accessor() {
let _note = "AppState::cache() -> Option<&Arc<QueryCache>>";
assert!(!_note.is_empty());
}
#[test]
fn test_appstate_server_config_accessor() {
let _note = "AppState::server_config() -> Option<&Arc<ServerConfig>>";
assert!(!_note.is_empty());
}
#[test]
fn test_sanitized_config_from_server_config() {
use crate::routes::api::types::SanitizedConfig;
let config = crate::config::ServerConfig {
port: 8080,
host: "0.0.0.0".to_string(),
workers: Some(4),
tls: None,
limits: None,
};
let sanitized = SanitizedConfig::from_config(&config);
assert_eq!(sanitized.port, 8080, "Port should be preserved");
assert_eq!(sanitized.host, "0.0.0.0", "Host should be preserved");
assert_eq!(sanitized.workers, Some(4), "Workers count should be preserved");
assert!(!sanitized.tls_enabled, "TLS should be false when not configured");
assert!(sanitized.is_sanitized(), "Should be marked as sanitized");
}
#[test]
fn test_sanitized_config_indicates_tls_without_exposing_keys() {
use std::path::PathBuf;
use crate::routes::api::types::SanitizedConfig;
let config = crate::config::ServerConfig {
port: 8080,
host: "localhost".to_string(),
workers: None,
tls: Some(crate::config::TlsConfig {
cert_file: PathBuf::from("/path/to/cert.pem"),
key_file: PathBuf::from("/path/to/key.pem"),
}),
limits: None,
};
let sanitized = SanitizedConfig::from_config(&config);
assert!(sanitized.tls_enabled, "TLS should be true when configured");
let json = serde_json::to_string(&sanitized).unwrap();
assert!(!json.contains("cert"), "Certificate file path should not be exposed");
assert!(!json.contains("key"), "Key file path should not be exposed");
}
#[test]
fn test_sanitized_config_redaction() {
use crate::routes::api::types::SanitizedConfig;
let config1 = crate::config::ServerConfig {
port: 8000,
host: "127.0.0.1".to_string(),
workers: None,
tls: None,
limits: None,
};
let config2 = crate::config::ServerConfig {
port: 8000,
host: "127.0.0.1".to_string(),
workers: None,
tls: Some(crate::config::TlsConfig {
cert_file: std::path::PathBuf::from("secret.cert"),
key_file: std::path::PathBuf::from("secret.key"),
}),
limits: None,
};
let san1 = SanitizedConfig::from_config(&config1);
let san2 = SanitizedConfig::from_config(&config2);
assert_eq!(san1.port, san2.port);
assert_eq!(san1.host, san2.host);
assert!(!san1.tls_enabled);
assert!(san2.tls_enabled);
}
#[test]
fn test_appstate_executor_provides_access_to_schema() {
let _note = "AppState<A>::executor can be queried for schema information";
assert!(!_note.is_empty());
}
#[test]
fn test_schema_access_for_api_endpoints() {
let _note = "API routes can access schema via state.executor for introspection";
assert!(!_note.is_empty());
}
#[test]
fn test_extract_ip_ignores_x_forwarded_for() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
let ip = extract_ip_from_headers(&headers);
assert_eq!(ip, "unknown", "Must not trust X-Forwarded-For header");
}
#[test]
fn test_extract_ip_ignores_x_real_ip() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
let ip = extract_ip_from_headers(&headers);
assert_eq!(ip, "unknown", "Must not trust X-Real-IP header");
}
#[test]
fn test_extract_ip_from_headers_missing() {
let headers = axum::http::HeaderMap::new();
let ip = extract_ip_from_headers(&headers);
assert_eq!(ip, "unknown");
}
#[test]
fn test_extract_ip_ignores_all_spoofable_headers() {
let mut headers = axum::http::HeaderMap::new();
headers.insert("x-forwarded-for", "192.0.2.1".parse().unwrap());
headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
let ip = extract_ip_from_headers(&headers);
assert_eq!(ip, "unknown", "Must not trust any spoofable header");
}
#[test]
fn test_graphql_rate_limiter_is_per_ip() {
let config = RateLimitConfig {
enabled: true,
max_requests: 3,
window_secs: 60,
};
let limiter = KeyedRateLimiter::new(config);
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("10.0.0.1").is_ok());
assert!(limiter.check("10.0.0.1").is_ok());
assert!(limiter.check("10.0.0.1").is_ok());
}
#[test]
fn test_graphql_rate_limiter_enforces_limit() {
let config = RateLimitConfig {
enabled: true,
max_requests: 2,
window_secs: 60,
};
let limiter = KeyedRateLimiter::new(config);
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_err());
}
#[test]
fn test_graphql_rate_limiter_disabled() {
let config = RateLimitConfig {
enabled: false,
max_requests: 1,
window_secs: 60,
};
let limiter = KeyedRateLimiter::new(config);
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_ok());
}
#[test]
fn test_graphql_rate_limiter_window_reset() {
let config = RateLimitConfig {
enabled: true,
max_requests: 1,
window_secs: 0, };
let limiter = KeyedRateLimiter::new(config);
assert!(limiter.check("192.0.2.1").is_ok());
assert!(limiter.check("192.0.2.1").is_ok());
}
}