use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
pub type HandlerResult<T> = Result<T, HandlerError>;
#[derive(Debug, Error)]
pub enum HandlerError {
#[error("Send error: {0}")]
SendError(String),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("Operation error: {0}")]
OperationError(String),
#[error("Upstream error: {0}")]
UpstreamError(String),
#[error("{0}")]
Generic(String),
}
pub struct GraphQLContext {
pub operation_name: Option<String>,
pub operation_type: OperationType,
pub query: String,
pub variables: Variables,
pub metadata: HashMap<String, String>,
pub data: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OperationType {
Query,
Mutation,
Subscription,
}
impl GraphQLContext {
pub fn new(
operation_name: Option<String>,
operation_type: OperationType,
query: String,
variables: Variables,
) -> Self {
Self {
operation_name,
operation_type,
query,
variables,
metadata: HashMap::new(),
data: HashMap::new(),
}
}
pub fn get_variable(&self, name: &str) -> Option<&Value> {
self.variables.get(&Name::new(name))
}
pub fn set_data(&mut self, key: String, value: serde_json::Value) {
self.data.insert(key, value);
}
pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
self.data.get(key)
}
pub fn set_metadata(&mut self, key: String, value: String) {
self.metadata.insert(key, value);
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.get(key)
}
}
#[async_trait]
pub trait GraphQLHandler: Send + Sync {
async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
Ok(None)
}
async fn after_operation(
&self,
_ctx: &GraphQLContext,
response: Response,
) -> HandlerResult<Response> {
Ok(response)
}
async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
let server_error = ServerError::new(error, None);
Ok(Response::from_errors(vec![server_error]))
}
fn handles_operation(
&self,
operation_name: Option<&str>,
_operation_type: &OperationType,
) -> bool {
operation_name.is_some()
}
fn priority(&self) -> i32 {
0
}
}
pub struct HandlerRegistry {
handlers: Vec<Arc<dyn GraphQLHandler>>,
upstream_url: Option<String>,
}
impl HandlerRegistry {
pub fn new() -> Self {
Self {
handlers: Vec::new(),
upstream_url: None,
}
}
pub fn with_upstream(upstream_url: Option<String>) -> Self {
Self {
handlers: Vec::new(),
upstream_url,
}
}
pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
self.handlers.push(Arc::new(handler));
self.handlers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
}
pub fn get_handlers(
&self,
operation_name: Option<&str>,
operation_type: &OperationType,
) -> Vec<Arc<dyn GraphQLHandler>> {
self.handlers
.iter()
.filter(|h| h.handles_operation(operation_name, operation_type))
.cloned()
.collect()
}
pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
for handler in handlers {
if let Some(response) = handler.on_operation(ctx).await? {
return Ok(Some(response));
}
}
Ok(None)
}
pub async fn after_operation(
&self,
ctx: &GraphQLContext,
mut response: Response,
) -> HandlerResult<Response> {
let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
for handler in handlers {
response = handler.after_operation(ctx, response).await?;
}
Ok(response)
}
pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
let upstream = self
.upstream_url
.as_ref()
.ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
let client = reqwest::Client::new();
let body = json!({
"query": request.query.clone(),
"variables": request.variables.clone(),
"operationName": request.operation_name.clone(),
});
let resp = client
.post(upstream)
.json(&body)
.send()
.await
.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
let response_data: serde_json::Value =
resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
let errors: Vec<ServerError> = response_data
.get("errors")
.and_then(|e| e.as_array())
.map(|arr| {
arr.iter()
.map(|e| {
let msg = e
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("Upstream GraphQL error");
ServerError::new(msg.to_string(), None)
})
.collect()
})
.unwrap_or_default();
let data = response_data.get("data").map(json_to_graphql_value).unwrap_or(Value::Null);
let mut response = Response::new(data);
response.errors = errors;
Ok(response)
}
pub fn upstream_url(&self) -> Option<&str> {
self.upstream_url.as_deref()
}
}
impl Default for HandlerRegistry {
fn default() -> Self {
Self::new()
}
}
fn json_to_graphql_value(json: &serde_json::Value) -> Value {
match json {
serde_json::Value::Null => Value::Null,
serde_json::Value::Bool(b) => Value::Boolean(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Number(i.into())
} else if let Some(f) = n.as_f64() {
Value::Number(async_graphql::Number::from_f64(f).unwrap_or_else(|| 0i32.into()))
} else {
Value::Null
}
}
serde_json::Value::String(s) => Value::String(s.clone()),
serde_json::Value::Array(arr) => {
Value::List(arr.iter().map(json_to_graphql_value).collect())
}
serde_json::Value::Object(obj) => {
let map = obj.iter().map(|(k, v)| (Name::new(k), json_to_graphql_value(v))).collect();
Value::Object(map)
}
}
}
#[derive(Debug, Clone)]
pub struct VariableMatcher {
patterns: HashMap<String, VariablePattern>,
}
impl VariableMatcher {
pub fn new() -> Self {
Self {
patterns: HashMap::new(),
}
}
pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
self.patterns.insert(name, pattern);
self
}
pub fn matches(&self, variables: &Variables) -> bool {
for (name, pattern) in &self.patterns {
if !pattern.matches(variables.get(&Name::new(name))) {
return false;
}
}
true
}
}
impl Default for VariableMatcher {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum VariablePattern {
Exact(Value),
Regex(String),
Any,
Present,
Null,
}
impl VariablePattern {
pub fn matches(&self, value: Option<&Value>) -> bool {
match (self, value) {
(VariablePattern::Any, _) => true,
(VariablePattern::Present, Some(_)) => true,
(VariablePattern::Present, None) => false,
(VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
(VariablePattern::Null, Some(_)) => false,
(VariablePattern::Exact(expected), Some(actual)) => expected == actual,
(VariablePattern::Exact(_), None) => false,
(VariablePattern::Regex(pattern), Some(Value::String(s))) => {
regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
}
(VariablePattern::Regex(_), _) => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHandler {
operation_name: String,
}
#[async_trait]
impl GraphQLHandler for TestHandler {
async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
if ctx.operation_name.as_deref() == Some(&self.operation_name) {
Ok(Some(Response::new(Value::Null)))
} else {
Ok(None)
}
}
fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
operation_name == Some(&self.operation_name)
}
}
#[tokio::test]
async fn test_handler_registry_new() {
let registry = HandlerRegistry::new();
assert_eq!(registry.handlers.len(), 0);
assert!(registry.upstream_url.is_none());
}
#[tokio::test]
async fn test_handler_registry_with_upstream() {
let registry =
HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
}
#[tokio::test]
async fn test_handler_registry_register() {
let mut registry = HandlerRegistry::new();
let handler = TestHandler {
operation_name: "getUser".to_string(),
};
registry.register(handler);
assert_eq!(registry.handlers.len(), 1);
}
#[tokio::test]
async fn test_handler_execution() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler {
operation_name: "getUser".to_string(),
});
let ctx = GraphQLContext::new(
Some("getUser".to_string()),
OperationType::Query,
"query { user { id } }".to_string(),
Variables::default(),
);
let result = registry.execute_operation(&ctx).await;
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}
#[test]
fn test_variable_matcher_any() {
let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
let mut vars = Variables::default();
vars.insert(Name::new("id"), Value::String("123".to_string()));
assert!(matcher.matches(&vars));
}
#[test]
fn test_variable_matcher_exact() {
let matcher = VariableMatcher::new().with_pattern(
"id".to_string(),
VariablePattern::Exact(Value::String("123".to_string())),
);
let mut vars = Variables::default();
vars.insert(Name::new("id"), Value::String("123".to_string()));
assert!(matcher.matches(&vars));
let mut vars2 = Variables::default();
vars2.insert(Name::new("id"), Value::String("456".to_string()));
assert!(!matcher.matches(&vars2));
}
#[test]
fn test_variable_pattern_present() {
assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
assert!(!VariablePattern::Present.matches(None));
}
#[test]
fn test_variable_pattern_null() {
assert!(VariablePattern::Null.matches(None));
assert!(VariablePattern::Null.matches(Some(&Value::Null)));
assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
}
#[test]
fn test_graphql_context_new() {
let ctx = GraphQLContext::new(
Some("getUser".to_string()),
OperationType::Query,
"query { user { id } }".to_string(),
Variables::default(),
);
assert_eq!(ctx.operation_name, Some("getUser".to_string()));
assert_eq!(ctx.operation_type, OperationType::Query);
}
#[test]
fn test_graphql_context_metadata() {
let mut ctx = GraphQLContext::new(
Some("getUser".to_string()),
OperationType::Query,
"query { user { id } }".to_string(),
Variables::default(),
);
ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
}
#[test]
fn test_graphql_context_data() {
let mut ctx = GraphQLContext::new(
Some("getUser".to_string()),
OperationType::Query,
"query { user { id } }".to_string(),
Variables::default(),
);
ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
}
#[test]
fn test_operation_type_eq() {
assert_eq!(OperationType::Query, OperationType::Query);
assert_ne!(OperationType::Query, OperationType::Mutation);
assert_ne!(OperationType::Mutation, OperationType::Subscription);
}
#[test]
fn test_operation_type_clone() {
let op = OperationType::Query;
let cloned = op.clone();
assert_eq!(op, cloned);
}
#[test]
fn test_handler_error_display() {
let err = HandlerError::SendError("test error".to_string());
assert!(err.to_string().contains("Send error"));
let err = HandlerError::OperationError("op error".to_string());
assert!(err.to_string().contains("Operation error"));
let err = HandlerError::UpstreamError("upstream error".to_string());
assert!(err.to_string().contains("Upstream error"));
let err = HandlerError::Generic("generic error".to_string());
assert!(err.to_string().contains("generic error"));
}
#[test]
fn test_handler_error_from_json() {
let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
let err: HandlerError = json_err.into();
assert!(matches!(err, HandlerError::JsonError(_)));
}
#[test]
fn test_variable_matcher_default() {
let matcher = VariableMatcher::default();
assert!(matcher.matches(&Variables::default()));
}
#[test]
fn test_variable_pattern_regex() {
let pattern = VariablePattern::Regex(r"^user-\d+$".to_string());
assert!(pattern.matches(Some(&Value::String("user-123".to_string()))));
assert!(!pattern.matches(Some(&Value::String("invalid".to_string()))));
assert!(!pattern.matches(None));
}
#[test]
fn test_variable_matcher_multiple_patterns() {
let matcher = VariableMatcher::new()
.with_pattern("id".to_string(), VariablePattern::Present)
.with_pattern("name".to_string(), VariablePattern::Any);
let mut vars = Variables::default();
vars.insert(Name::new("id"), Value::String("123".to_string()));
assert!(matcher.matches(&vars));
}
#[test]
fn test_variable_matcher_fails_on_missing() {
let matcher =
VariableMatcher::new().with_pattern("required".to_string(), VariablePattern::Present);
let vars = Variables::default();
assert!(!matcher.matches(&vars));
}
#[test]
fn test_graphql_context_get_variable() {
let mut vars = Variables::default();
vars.insert(Name::new("userId"), Value::String("123".to_string()));
let ctx = GraphQLContext::new(
Some("getUser".to_string()),
OperationType::Query,
"query { user { id } }".to_string(),
vars,
);
assert!(ctx.get_variable("userId").is_some());
assert!(ctx.get_variable("nonexistent").is_none());
}
#[test]
fn test_handler_registry_default() {
let registry = HandlerRegistry::default();
assert!(registry.upstream_url().is_none());
}
#[tokio::test]
async fn test_handler_registry_no_match() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler {
operation_name: "getUser".to_string(),
});
let ctx = GraphQLContext::new(
Some("getProduct".to_string()),
OperationType::Query,
"query { product { id } }".to_string(),
Variables::default(),
);
let result = registry.execute_operation(&ctx).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_handler_registry_after_operation() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler {
operation_name: "getUser".to_string(),
});
let ctx = GraphQLContext::new(
Some("getUser".to_string()),
OperationType::Query,
"query { user { id } }".to_string(),
Variables::default(),
);
let response = Response::new(Value::Null);
let result = registry.after_operation(&ctx, response).await;
assert!(result.is_ok());
}
#[test]
fn test_handler_registry_get_handlers() {
let mut registry = HandlerRegistry::new();
registry.register(TestHandler {
operation_name: "getUser".to_string(),
});
registry.register(TestHandler {
operation_name: "getProduct".to_string(),
});
let handlers = registry.get_handlers(Some("getUser"), &OperationType::Query);
assert_eq!(handlers.len(), 1);
let handlers = registry.get_handlers(Some("unknown"), &OperationType::Query);
assert_eq!(handlers.len(), 0);
}
#[test]
fn test_handler_priority() {
struct PriorityHandler {
priority: i32,
}
#[async_trait]
impl GraphQLHandler for PriorityHandler {
fn priority(&self) -> i32 {
self.priority
}
}
let handler = PriorityHandler { priority: 10 };
assert_eq!(handler.priority(), 10);
}
#[test]
fn test_context_all_operation_types() {
let query_ctx = GraphQLContext::new(
Some("op".to_string()),
OperationType::Query,
"query".to_string(),
Variables::default(),
);
assert_eq!(query_ctx.operation_type, OperationType::Query);
let mutation_ctx = GraphQLContext::new(
Some("op".to_string()),
OperationType::Mutation,
"mutation".to_string(),
Variables::default(),
);
assert_eq!(mutation_ctx.operation_type, OperationType::Mutation);
let subscription_ctx = GraphQLContext::new(
Some("op".to_string()),
OperationType::Subscription,
"subscription".to_string(),
Variables::default(),
);
assert_eq!(subscription_ctx.operation_type, OperationType::Subscription);
}
#[test]
fn test_variable_pattern_debug() {
let pattern = VariablePattern::Any;
let debug = format!("{:?}", pattern);
assert!(debug.contains("Any"));
}
#[test]
fn test_variable_matcher_debug() {
let matcher = VariableMatcher::new();
let debug = format!("{:?}", matcher);
assert!(debug.contains("VariableMatcher"));
}
}