1use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13pub type HandlerResult<T> = Result<T, HandlerError>;
15
16#[derive(Debug, Error)]
18pub enum HandlerError {
19 #[error("Send error: {0}")]
21 SendError(String),
22
23 #[error("JSON error: {0}")]
25 JsonError(#[from] serde_json::Error),
26
27 #[error("Operation error: {0}")]
29 OperationError(String),
30
31 #[error("Upstream error: {0}")]
33 UpstreamError(String),
34
35 #[error("{0}")]
37 Generic(String),
38}
39
40pub struct GraphQLContext {
42 pub operation_name: Option<String>,
44
45 pub operation_type: OperationType,
47
48 pub query: String,
50
51 pub variables: Variables,
53
54 pub metadata: HashMap<String, String>,
56
57 pub data: HashMap<String, serde_json::Value>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64 Query,
66 Mutation,
68 Subscription,
70}
71
72impl GraphQLContext {
73 pub fn new(
75 operation_name: Option<String>,
76 operation_type: OperationType,
77 query: String,
78 variables: Variables,
79 ) -> Self {
80 Self {
81 operation_name,
82 operation_type,
83 query,
84 variables,
85 metadata: HashMap::new(),
86 data: HashMap::new(),
87 }
88 }
89
90 pub fn get_variable(&self, name: &str) -> Option<&Value> {
92 self.variables.get(&Name::new(name))
93 }
94
95 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97 self.data.insert(key, value);
98 }
99
100 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102 self.data.get(key)
103 }
104
105 pub fn set_metadata(&mut self, key: String, value: String) {
107 self.metadata.insert(key, value);
108 }
109
110 pub fn get_metadata(&self, key: &str) -> Option<&String> {
112 self.metadata.get(key)
113 }
114}
115
116#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119 async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122 Ok(None)
123 }
124
125 async fn after_operation(
128 &self,
129 _ctx: &GraphQLContext,
130 response: Response,
131 ) -> HandlerResult<Response> {
132 Ok(response)
133 }
134
135 async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137 let server_error = ServerError::new(error, None);
138 Ok(Response::from_errors(vec![server_error]))
139 }
140
141 fn handles_operation(
143 &self,
144 operation_name: Option<&str>,
145 _operation_type: &OperationType,
146 ) -> bool {
147 operation_name.is_some()
149 }
150
151 fn priority(&self) -> i32 {
153 0
154 }
155}
156
157pub struct HandlerRegistry {
159 handlers: Vec<Arc<dyn GraphQLHandler>>,
160 upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165 pub fn new() -> Self {
167 Self {
168 handlers: Vec::new(),
169 upstream_url: None,
170 }
171 }
172
173 pub fn with_upstream(upstream_url: Option<String>) -> Self {
175 Self {
176 handlers: Vec::new(),
177 upstream_url,
178 }
179 }
180
181 pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183 self.handlers.push(Arc::new(handler));
184 self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority()));
186 }
187
188 pub fn get_handlers(
190 &self,
191 operation_name: Option<&str>,
192 operation_type: &OperationType,
193 ) -> Vec<Arc<dyn GraphQLHandler>> {
194 self.handlers
195 .iter()
196 .filter(|h| h.handles_operation(operation_name, operation_type))
197 .cloned()
198 .collect()
199 }
200
201 pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205 for handler in handlers {
206 if let Some(response) = handler.on_operation(ctx).await? {
207 return Ok(Some(response));
208 }
209 }
210
211 Ok(None)
212 }
213
214 pub async fn after_operation(
216 &self,
217 ctx: &GraphQLContext,
218 mut response: Response,
219 ) -> HandlerResult<Response> {
220 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222 for handler in handlers {
223 response = handler.after_operation(ctx, response).await?;
224 }
225
226 Ok(response)
227 }
228
229 pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231 let upstream = self
232 .upstream_url
233 .as_ref()
234 .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236 let client = reqwest::Client::new();
237 let body = json!({
238 "query": request.query.clone(),
239 "variables": request.variables.clone(),
240 "operationName": request.operation_name.clone(),
241 });
242
243 let resp = client
244 .post(upstream)
245 .json(&body)
246 .send()
247 .await
248 .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250 let response_data: serde_json::Value =
251 resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253 let has_errors = response_data.get("errors").is_some();
256
257 if has_errors {
260 let error_msg = response_data
261 .get("errors")
262 .and_then(|e| e.as_array())
263 .and_then(|arr| arr.first())
264 .and_then(|e| e.get("message"))
265 .and_then(|m| m.as_str())
266 .unwrap_or("Upstream GraphQL error");
267
268 let server_error = async_graphql::ServerError::new(error_msg.to_string(), None);
269 Ok(Response::from_errors(vec![server_error]))
270 } else {
271 Ok(Response::new(Value::Null))
274 }
275 }
276
277 pub fn upstream_url(&self) -> Option<&str> {
279 self.upstream_url.as_deref()
280 }
281}
282
283impl Default for HandlerRegistry {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct VariableMatcher {
292 patterns: HashMap<String, VariablePattern>,
293}
294
295impl VariableMatcher {
296 pub fn new() -> Self {
298 Self {
299 patterns: HashMap::new(),
300 }
301 }
302
303 pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
305 self.patterns.insert(name, pattern);
306 self
307 }
308
309 pub fn matches(&self, variables: &Variables) -> bool {
311 for (name, pattern) in &self.patterns {
312 if !pattern.matches(variables.get(&Name::new(name))) {
313 return false;
314 }
315 }
316 true
317 }
318}
319
320impl Default for VariableMatcher {
321 fn default() -> Self {
322 Self::new()
323 }
324}
325
326#[derive(Debug, Clone)]
328pub enum VariablePattern {
329 Exact(Value),
331 Regex(String),
333 Any,
335 Present,
337 Null,
339}
340
341impl VariablePattern {
342 pub fn matches(&self, value: Option<&Value>) -> bool {
344 match (self, value) {
345 (VariablePattern::Any, _) => true,
346 (VariablePattern::Present, Some(_)) => true,
347 (VariablePattern::Present, None) => false,
348 (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
349 (VariablePattern::Null, Some(_)) => false,
350 (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
351 (VariablePattern::Exact(_), None) => false,
352 (VariablePattern::Regex(pattern), Some(Value::String(s))) => {
353 regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
354 }
355 (VariablePattern::Regex(_), _) => false,
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 struct TestHandler {
365 operation_name: String,
366 }
367
368 #[async_trait]
369 impl GraphQLHandler for TestHandler {
370 async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
371 if ctx.operation_name.as_deref() == Some(&self.operation_name) {
372 Ok(Some(Response::new(Value::Null)))
374 } else {
375 Ok(None)
376 }
377 }
378
379 fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
380 operation_name == Some(&self.operation_name)
381 }
382 }
383
384 #[tokio::test]
385 async fn test_handler_registry_new() {
386 let registry = HandlerRegistry::new();
387 assert_eq!(registry.handlers.len(), 0);
388 assert!(registry.upstream_url.is_none());
389 }
390
391 #[tokio::test]
392 async fn test_handler_registry_with_upstream() {
393 let registry =
394 HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
395 assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
396 }
397
398 #[tokio::test]
399 async fn test_handler_registry_register() {
400 let mut registry = HandlerRegistry::new();
401 let handler = TestHandler {
402 operation_name: "getUser".to_string(),
403 };
404 registry.register(handler);
405 assert_eq!(registry.handlers.len(), 1);
406 }
407
408 #[tokio::test]
409 async fn test_handler_execution() {
410 let mut registry = HandlerRegistry::new();
411 registry.register(TestHandler {
412 operation_name: "getUser".to_string(),
413 });
414
415 let ctx = GraphQLContext::new(
416 Some("getUser".to_string()),
417 OperationType::Query,
418 "query { user { id } }".to_string(),
419 Variables::default(),
420 );
421
422 let result = registry.execute_operation(&ctx).await;
423 assert!(result.is_ok());
424 assert!(result.unwrap().is_some());
425 }
426
427 #[test]
428 fn test_variable_matcher_any() {
429 let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
430
431 let mut vars = Variables::default();
432 vars.insert(Name::new("id"), Value::String("123".to_string()));
433
434 assert!(matcher.matches(&vars));
435 }
436
437 #[test]
438 fn test_variable_matcher_exact() {
439 let matcher = VariableMatcher::new().with_pattern(
440 "id".to_string(),
441 VariablePattern::Exact(Value::String("123".to_string())),
442 );
443
444 let mut vars = Variables::default();
445 vars.insert(Name::new("id"), Value::String("123".to_string()));
446
447 assert!(matcher.matches(&vars));
448
449 let mut vars2 = Variables::default();
450 vars2.insert(Name::new("id"), Value::String("456".to_string()));
451
452 assert!(!matcher.matches(&vars2));
453 }
454
455 #[test]
456 fn test_variable_pattern_present() {
457 assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
458 assert!(!VariablePattern::Present.matches(None));
459 }
460
461 #[test]
462 fn test_variable_pattern_null() {
463 assert!(VariablePattern::Null.matches(None));
464 assert!(VariablePattern::Null.matches(Some(&Value::Null)));
465 assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
466 }
467
468 #[test]
469 fn test_graphql_context_new() {
470 let ctx = GraphQLContext::new(
471 Some("getUser".to_string()),
472 OperationType::Query,
473 "query { user { id } }".to_string(),
474 Variables::default(),
475 );
476
477 assert_eq!(ctx.operation_name, Some("getUser".to_string()));
478 assert_eq!(ctx.operation_type, OperationType::Query);
479 }
480
481 #[test]
482 fn test_graphql_context_metadata() {
483 let mut ctx = GraphQLContext::new(
484 Some("getUser".to_string()),
485 OperationType::Query,
486 "query { user { id } }".to_string(),
487 Variables::default(),
488 );
489
490 ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
491 assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
492 }
493
494 #[test]
495 fn test_graphql_context_data() {
496 let mut ctx = GraphQLContext::new(
497 Some("getUser".to_string()),
498 OperationType::Query,
499 "query { user { id } }".to_string(),
500 Variables::default(),
501 );
502
503 ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
504 assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
505 }
506
507 #[test]
508 fn test_operation_type_eq() {
509 assert_eq!(OperationType::Query, OperationType::Query);
510 assert_ne!(OperationType::Query, OperationType::Mutation);
511 assert_ne!(OperationType::Mutation, OperationType::Subscription);
512 }
513
514 #[test]
515 fn test_operation_type_clone() {
516 let op = OperationType::Query;
517 let cloned = op.clone();
518 assert_eq!(op, cloned);
519 }
520
521 #[test]
522 fn test_handler_error_display() {
523 let err = HandlerError::SendError("test error".to_string());
524 assert!(err.to_string().contains("Send error"));
525
526 let err = HandlerError::OperationError("op error".to_string());
527 assert!(err.to_string().contains("Operation error"));
528
529 let err = HandlerError::UpstreamError("upstream error".to_string());
530 assert!(err.to_string().contains("Upstream error"));
531
532 let err = HandlerError::Generic("generic error".to_string());
533 assert!(err.to_string().contains("generic error"));
534 }
535
536 #[test]
537 fn test_handler_error_from_json() {
538 let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
539 let err: HandlerError = json_err.into();
540 assert!(matches!(err, HandlerError::JsonError(_)));
541 }
542
543 #[test]
544 fn test_variable_matcher_default() {
545 let matcher = VariableMatcher::default();
546 assert!(matcher.matches(&Variables::default()));
547 }
548
549 #[test]
550 fn test_variable_pattern_regex() {
551 let pattern = VariablePattern::Regex(r"^user-\d+$".to_string());
552 assert!(pattern.matches(Some(&Value::String("user-123".to_string()))));
553 assert!(!pattern.matches(Some(&Value::String("invalid".to_string()))));
554 assert!(!pattern.matches(None));
555 }
556
557 #[test]
558 fn test_variable_matcher_multiple_patterns() {
559 let matcher = VariableMatcher::new()
560 .with_pattern("id".to_string(), VariablePattern::Present)
561 .with_pattern("name".to_string(), VariablePattern::Any);
562
563 let mut vars = Variables::default();
564 vars.insert(Name::new("id"), Value::String("123".to_string()));
565
566 assert!(matcher.matches(&vars));
567 }
568
569 #[test]
570 fn test_variable_matcher_fails_on_missing() {
571 let matcher =
572 VariableMatcher::new().with_pattern("required".to_string(), VariablePattern::Present);
573
574 let vars = Variables::default();
575 assert!(!matcher.matches(&vars));
576 }
577
578 #[test]
579 fn test_graphql_context_get_variable() {
580 let mut vars = Variables::default();
581 vars.insert(Name::new("userId"), Value::String("123".to_string()));
582
583 let ctx = GraphQLContext::new(
584 Some("getUser".to_string()),
585 OperationType::Query,
586 "query { user { id } }".to_string(),
587 vars,
588 );
589
590 assert!(ctx.get_variable("userId").is_some());
591 assert!(ctx.get_variable("nonexistent").is_none());
592 }
593
594 #[test]
595 fn test_handler_registry_default() {
596 let registry = HandlerRegistry::default();
597 assert!(registry.upstream_url().is_none());
598 }
599
600 #[tokio::test]
601 async fn test_handler_registry_no_match() {
602 let mut registry = HandlerRegistry::new();
603 registry.register(TestHandler {
604 operation_name: "getUser".to_string(),
605 });
606
607 let ctx = GraphQLContext::new(
608 Some("getProduct".to_string()),
609 OperationType::Query,
610 "query { product { id } }".to_string(),
611 Variables::default(),
612 );
613
614 let result = registry.execute_operation(&ctx).await;
615 assert!(result.is_ok());
616 assert!(result.unwrap().is_none());
617 }
618
619 #[tokio::test]
620 async fn test_handler_registry_after_operation() {
621 let mut registry = HandlerRegistry::new();
622 registry.register(TestHandler {
623 operation_name: "getUser".to_string(),
624 });
625
626 let ctx = GraphQLContext::new(
627 Some("getUser".to_string()),
628 OperationType::Query,
629 "query { user { id } }".to_string(),
630 Variables::default(),
631 );
632
633 let response = Response::new(Value::Null);
634 let result = registry.after_operation(&ctx, response).await;
635 assert!(result.is_ok());
636 }
637
638 #[test]
639 fn test_handler_registry_get_handlers() {
640 let mut registry = HandlerRegistry::new();
641 registry.register(TestHandler {
642 operation_name: "getUser".to_string(),
643 });
644 registry.register(TestHandler {
645 operation_name: "getProduct".to_string(),
646 });
647
648 let handlers = registry.get_handlers(Some("getUser"), &OperationType::Query);
649 assert_eq!(handlers.len(), 1);
650
651 let handlers = registry.get_handlers(Some("unknown"), &OperationType::Query);
652 assert_eq!(handlers.len(), 0);
653 }
654
655 #[test]
656 fn test_handler_priority() {
657 struct PriorityHandler {
658 priority: i32,
659 }
660
661 #[async_trait]
662 impl GraphQLHandler for PriorityHandler {
663 fn priority(&self) -> i32 {
664 self.priority
665 }
666 }
667
668 let handler = PriorityHandler { priority: 10 };
669 assert_eq!(handler.priority(), 10);
670 }
671
672 #[test]
673 fn test_context_all_operation_types() {
674 let query_ctx = GraphQLContext::new(
675 Some("op".to_string()),
676 OperationType::Query,
677 "query".to_string(),
678 Variables::default(),
679 );
680 assert_eq!(query_ctx.operation_type, OperationType::Query);
681
682 let mutation_ctx = GraphQLContext::new(
683 Some("op".to_string()),
684 OperationType::Mutation,
685 "mutation".to_string(),
686 Variables::default(),
687 );
688 assert_eq!(mutation_ctx.operation_type, OperationType::Mutation);
689
690 let subscription_ctx = GraphQLContext::new(
691 Some("op".to_string()),
692 OperationType::Subscription,
693 "subscription".to_string(),
694 Variables::default(),
695 );
696 assert_eq!(subscription_ctx.operation_type, OperationType::Subscription);
697 }
698
699 #[test]
700 fn test_variable_pattern_debug() {
701 let pattern = VariablePattern::Any;
702 let debug = format!("{:?}", pattern);
703 assert!(debug.contains("Any"));
704 }
705
706 #[test]
707 fn test_variable_matcher_debug() {
708 let matcher = VariableMatcher::new();
709 let debug = format!("{:?}", matcher);
710 assert!(debug.contains("VariableMatcher"));
711 }
712}