1use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13pub type HandlerResult<T> = Result<T, HandlerError>;
15
16#[derive(Debug, Error)]
18pub enum HandlerError {
19 #[error("Send error: {0}")]
21 SendError(String),
22
23 #[error("JSON error: {0}")]
25 JsonError(#[from] serde_json::Error),
26
27 #[error("Operation error: {0}")]
29 OperationError(String),
30
31 #[error("Upstream error: {0}")]
33 UpstreamError(String),
34
35 #[error("{0}")]
37 Generic(String),
38}
39
40pub struct GraphQLContext {
42 pub operation_name: Option<String>,
44
45 pub operation_type: OperationType,
47
48 pub query: String,
50
51 pub variables: Variables,
53
54 pub metadata: HashMap<String, String>,
56
57 pub data: HashMap<String, serde_json::Value>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64 Query,
66 Mutation,
68 Subscription,
70}
71
72impl GraphQLContext {
73 pub fn new(
75 operation_name: Option<String>,
76 operation_type: OperationType,
77 query: String,
78 variables: Variables,
79 ) -> Self {
80 Self {
81 operation_name,
82 operation_type,
83 query,
84 variables,
85 metadata: HashMap::new(),
86 data: HashMap::new(),
87 }
88 }
89
90 pub fn get_variable(&self, name: &str) -> Option<&Value> {
92 self.variables.get(&Name::new(name))
93 }
94
95 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97 self.data.insert(key, value);
98 }
99
100 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102 self.data.get(key)
103 }
104
105 pub fn set_metadata(&mut self, key: String, value: String) {
107 self.metadata.insert(key, value);
108 }
109
110 pub fn get_metadata(&self, key: &str) -> Option<&String> {
112 self.metadata.get(key)
113 }
114}
115
116#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119 async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122 Ok(None)
123 }
124
125 async fn after_operation(
128 &self,
129 _ctx: &GraphQLContext,
130 response: Response,
131 ) -> HandlerResult<Response> {
132 Ok(response)
133 }
134
135 async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137 let server_error = ServerError::new(error, None);
138 Ok(Response::from_errors(vec![server_error]))
139 }
140
141 fn handles_operation(
143 &self,
144 operation_name: Option<&str>,
145 _operation_type: &OperationType,
146 ) -> bool {
147 operation_name.is_some()
149 }
150
151 fn priority(&self) -> i32 {
153 0
154 }
155}
156
157pub struct HandlerRegistry {
159 handlers: Vec<Arc<dyn GraphQLHandler>>,
160 upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165 pub fn new() -> Self {
167 Self {
168 handlers: Vec::new(),
169 upstream_url: None,
170 }
171 }
172
173 pub fn with_upstream(upstream_url: Option<String>) -> Self {
175 Self {
176 handlers: Vec::new(),
177 upstream_url,
178 }
179 }
180
181 pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183 self.handlers.push(Arc::new(handler));
184 self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority()));
186 }
187
188 pub fn get_handlers(
190 &self,
191 operation_name: Option<&str>,
192 operation_type: &OperationType,
193 ) -> Vec<Arc<dyn GraphQLHandler>> {
194 self.handlers
195 .iter()
196 .filter(|h| h.handles_operation(operation_name, operation_type))
197 .cloned()
198 .collect()
199 }
200
201 pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205 for handler in handlers {
206 if let Some(response) = handler.on_operation(ctx).await? {
207 return Ok(Some(response));
208 }
209 }
210
211 Ok(None)
212 }
213
214 pub async fn after_operation(
216 &self,
217 ctx: &GraphQLContext,
218 mut response: Response,
219 ) -> HandlerResult<Response> {
220 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222 for handler in handlers {
223 response = handler.after_operation(ctx, response).await?;
224 }
225
226 Ok(response)
227 }
228
229 pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231 let upstream = self
232 .upstream_url
233 .as_ref()
234 .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236 let client = reqwest::Client::new();
237 let body = json!({
238 "query": request.query.clone(),
239 "variables": request.variables.clone(),
240 "operationName": request.operation_name.clone(),
241 });
242
243 let resp = client
244 .post(upstream)
245 .json(&body)
246 .send()
247 .await
248 .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250 let response_data: serde_json::Value =
251 resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253 let has_errors = response_data.get("errors").is_some();
256
257 if has_errors {
260 let error_msg = response_data
261 .get("errors")
262 .and_then(|e| e.as_array())
263 .and_then(|arr| arr.first())
264 .and_then(|e| e.get("message"))
265 .and_then(|m| m.as_str())
266 .unwrap_or("Upstream GraphQL error");
267
268 let server_error = async_graphql::ServerError::new(error_msg.to_string(), None);
269 Ok(Response::from_errors(vec![server_error]))
270 } else {
271 Ok(Response::new(Value::Null))
274 }
275 }
276
277 pub fn upstream_url(&self) -> Option<&str> {
279 self.upstream_url.as_deref()
280 }
281}
282
283impl Default for HandlerRegistry {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct VariableMatcher {
292 patterns: HashMap<String, VariablePattern>,
293}
294
295impl VariableMatcher {
296 pub fn new() -> Self {
298 Self {
299 patterns: HashMap::new(),
300 }
301 }
302
303 pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
305 self.patterns.insert(name, pattern);
306 self
307 }
308
309 pub fn matches(&self, variables: &Variables) -> bool {
311 for (name, pattern) in &self.patterns {
312 if !pattern.matches(variables.get(&Name::new(name))) {
313 return false;
314 }
315 }
316 true
317 }
318}
319
320impl Default for VariableMatcher {
321 fn default() -> Self {
322 Self::new()
323 }
324}
325
326#[derive(Debug, Clone)]
328pub enum VariablePattern {
329 Exact(Value),
331 Regex(String),
333 Any,
335 Present,
337 Null,
339}
340
341impl VariablePattern {
342 pub fn matches(&self, value: Option<&Value>) -> bool {
344 match (self, value) {
345 (VariablePattern::Any, _) => true,
346 (VariablePattern::Present, Some(_)) => true,
347 (VariablePattern::Present, None) => false,
348 (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
349 (VariablePattern::Null, Some(_)) => false,
350 (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
351 (VariablePattern::Exact(_), None) => false,
352 (VariablePattern::Regex(pattern), Some(Value::String(s))) => {
353 regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
354 }
355 (VariablePattern::Regex(_), _) => false,
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 struct TestHandler {
365 operation_name: String,
366 }
367
368 #[async_trait]
369 impl GraphQLHandler for TestHandler {
370 async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
371 if ctx.operation_name.as_deref() == Some(&self.operation_name) {
372 Ok(Some(Response::new(Value::Null)))
374 } else {
375 Ok(None)
376 }
377 }
378
379 fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
380 operation_name == Some(&self.operation_name)
381 }
382 }
383
384 #[tokio::test]
385 async fn test_handler_registry_new() {
386 let registry = HandlerRegistry::new();
387 assert_eq!(registry.handlers.len(), 0);
388 assert!(registry.upstream_url.is_none());
389 }
390
391 #[tokio::test]
392 async fn test_handler_registry_with_upstream() {
393 let registry =
394 HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
395 assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
396 }
397
398 #[tokio::test]
399 async fn test_handler_registry_register() {
400 let mut registry = HandlerRegistry::new();
401 let handler = TestHandler {
402 operation_name: "getUser".to_string(),
403 };
404 registry.register(handler);
405 assert_eq!(registry.handlers.len(), 1);
406 }
407
408 #[tokio::test]
409 async fn test_handler_execution() {
410 let mut registry = HandlerRegistry::new();
411 registry.register(TestHandler {
412 operation_name: "getUser".to_string(),
413 });
414
415 let ctx = GraphQLContext::new(
416 Some("getUser".to_string()),
417 OperationType::Query,
418 "query { user { id } }".to_string(),
419 Variables::default(),
420 );
421
422 let result = registry.execute_operation(&ctx).await;
423 assert!(result.is_ok());
424 assert!(result.unwrap().is_some());
425 }
426
427 #[test]
428 fn test_variable_matcher_any() {
429 let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
430
431 let mut vars = Variables::default();
432 vars.insert(Name::new("id"), Value::String("123".to_string()));
433
434 assert!(matcher.matches(&vars));
435 }
436
437 #[test]
438 fn test_variable_matcher_exact() {
439 let matcher = VariableMatcher::new().with_pattern(
440 "id".to_string(),
441 VariablePattern::Exact(Value::String("123".to_string())),
442 );
443
444 let mut vars = Variables::default();
445 vars.insert(Name::new("id"), Value::String("123".to_string()));
446
447 assert!(matcher.matches(&vars));
448
449 let mut vars2 = Variables::default();
450 vars2.insert(Name::new("id"), Value::String("456".to_string()));
451
452 assert!(!matcher.matches(&vars2));
453 }
454
455 #[test]
456 fn test_variable_pattern_present() {
457 assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
458 assert!(!VariablePattern::Present.matches(None));
459 }
460
461 #[test]
462 fn test_variable_pattern_null() {
463 assert!(VariablePattern::Null.matches(None));
464 assert!(VariablePattern::Null.matches(Some(&Value::Null)));
465 assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
466 }
467
468 #[test]
469 fn test_graphql_context_new() {
470 let ctx = GraphQLContext::new(
471 Some("getUser".to_string()),
472 OperationType::Query,
473 "query { user { id } }".to_string(),
474 Variables::default(),
475 );
476
477 assert_eq!(ctx.operation_name, Some("getUser".to_string()));
478 assert_eq!(ctx.operation_type, OperationType::Query);
479 }
480
481 #[test]
482 fn test_graphql_context_metadata() {
483 let mut ctx = GraphQLContext::new(
484 Some("getUser".to_string()),
485 OperationType::Query,
486 "query { user { id } }".to_string(),
487 Variables::default(),
488 );
489
490 ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
491 assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
492 }
493
494 #[test]
495 fn test_graphql_context_data() {
496 let mut ctx = GraphQLContext::new(
497 Some("getUser".to_string()),
498 OperationType::Query,
499 "query { user { id } }".to_string(),
500 Variables::default(),
501 );
502
503 ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
504 assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
505 }
506}