1use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13pub type HandlerResult<T> = Result<T, HandlerError>;
15
16#[derive(Debug, Error)]
18pub enum HandlerError {
19 #[error("Send error: {0}")]
21 SendError(String),
22
23 #[error("JSON error: {0}")]
25 JsonError(#[from] serde_json::Error),
26
27 #[error("Operation error: {0}")]
29 OperationError(String),
30
31 #[error("Upstream error: {0}")]
33 UpstreamError(String),
34
35 #[error("{0}")]
37 Generic(String),
38}
39
40pub struct GraphQLContext {
42 pub operation_name: Option<String>,
44
45 pub operation_type: OperationType,
47
48 pub query: String,
50
51 pub variables: Variables,
53
54 pub metadata: HashMap<String, String>,
56
57 pub data: HashMap<String, serde_json::Value>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64 Query,
66 Mutation,
68 Subscription,
70}
71
72impl GraphQLContext {
73 pub fn new(
75 operation_name: Option<String>,
76 operation_type: OperationType,
77 query: String,
78 variables: Variables,
79 ) -> Self {
80 Self {
81 operation_name,
82 operation_type,
83 query,
84 variables,
85 metadata: HashMap::new(),
86 data: HashMap::new(),
87 }
88 }
89
90 pub fn get_variable(&self, name: &str) -> Option<&Value> {
92 self.variables.get(&Name::new(name))
93 }
94
95 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97 self.data.insert(key, value);
98 }
99
100 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102 self.data.get(key)
103 }
104
105 pub fn set_metadata(&mut self, key: String, value: String) {
107 self.metadata.insert(key, value);
108 }
109
110 pub fn get_metadata(&self, key: &str) -> Option<&String> {
112 self.metadata.get(key)
113 }
114}
115
116#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119 async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122 Ok(None)
123 }
124
125 async fn after_operation(
128 &self,
129 _ctx: &GraphQLContext,
130 response: Response,
131 ) -> HandlerResult<Response> {
132 Ok(response)
133 }
134
135 async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137 let server_error = ServerError::new(error, None);
138 Ok(Response::from_errors(vec![server_error]))
139 }
140
141 fn handles_operation(
143 &self,
144 operation_name: Option<&str>,
145 _operation_type: &OperationType,
146 ) -> bool {
147 operation_name.is_some()
149 }
150
151 fn priority(&self) -> i32 {
153 0
154 }
155}
156
157pub struct HandlerRegistry {
159 handlers: Vec<Arc<dyn GraphQLHandler>>,
160 upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165 pub fn new() -> Self {
167 Self {
168 handlers: Vec::new(),
169 upstream_url: None,
170 }
171 }
172
173 pub fn with_upstream(upstream_url: Option<String>) -> Self {
175 Self {
176 handlers: Vec::new(),
177 upstream_url,
178 }
179 }
180
181 pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183 self.handlers.push(Arc::new(handler));
184 self.handlers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
186 }
187
188 pub fn get_handlers(
190 &self,
191 operation_name: Option<&str>,
192 operation_type: &OperationType,
193 ) -> Vec<Arc<dyn GraphQLHandler>> {
194 self.handlers
195 .iter()
196 .filter(|h| h.handles_operation(operation_name, operation_type))
197 .cloned()
198 .collect()
199 }
200
201 pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205 for handler in handlers {
206 if let Some(response) = handler.on_operation(ctx).await? {
207 return Ok(Some(response));
208 }
209 }
210
211 Ok(None)
212 }
213
214 pub async fn after_operation(
216 &self,
217 ctx: &GraphQLContext,
218 mut response: Response,
219 ) -> HandlerResult<Response> {
220 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222 for handler in handlers {
223 response = handler.after_operation(ctx, response).await?;
224 }
225
226 Ok(response)
227 }
228
229 pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231 let upstream = self
232 .upstream_url
233 .as_ref()
234 .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236 let client = reqwest::Client::new();
237 let body = json!({
238 "query": request.query.clone(),
239 "variables": request.variables.clone(),
240 "operationName": request.operation_name.clone(),
241 });
242
243 let resp = client
244 .post(upstream)
245 .json(&body)
246 .send()
247 .await
248 .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250 let response_data: serde_json::Value =
251 resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253 let errors: Vec<ServerError> = response_data
255 .get("errors")
256 .and_then(|e| e.as_array())
257 .map(|arr| {
258 arr.iter()
259 .map(|e| {
260 let msg = e
261 .get("message")
262 .and_then(|m| m.as_str())
263 .unwrap_or("Upstream GraphQL error");
264 ServerError::new(msg.to_string(), None)
265 })
266 .collect()
267 })
268 .unwrap_or_default();
269
270 let data = response_data.get("data").map(json_to_graphql_value).unwrap_or(Value::Null);
271
272 let mut response = Response::new(data);
273 response.errors = errors;
274 Ok(response)
275 }
276
277 pub fn upstream_url(&self) -> Option<&str> {
279 self.upstream_url.as_deref()
280 }
281}
282
283impl Default for HandlerRegistry {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289fn json_to_graphql_value(json: &serde_json::Value) -> Value {
291 match json {
292 serde_json::Value::Null => Value::Null,
293 serde_json::Value::Bool(b) => Value::Boolean(*b),
294 serde_json::Value::Number(n) => {
295 if let Some(i) = n.as_i64() {
296 Value::Number(i.into())
297 } else if let Some(f) = n.as_f64() {
298 Value::Number(async_graphql::Number::from_f64(f).unwrap_or_else(|| 0i32.into()))
299 } else {
300 Value::Null
301 }
302 }
303 serde_json::Value::String(s) => Value::String(s.clone()),
304 serde_json::Value::Array(arr) => {
305 Value::List(arr.iter().map(json_to_graphql_value).collect())
306 }
307 serde_json::Value::Object(obj) => {
308 let map = obj.iter().map(|(k, v)| (Name::new(k), json_to_graphql_value(v))).collect();
309 Value::Object(map)
310 }
311 }
312}
313
314#[derive(Debug, Clone)]
316pub struct VariableMatcher {
317 patterns: HashMap<String, VariablePattern>,
318}
319
320impl VariableMatcher {
321 pub fn new() -> Self {
323 Self {
324 patterns: HashMap::new(),
325 }
326 }
327
328 pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
330 self.patterns.insert(name, pattern);
331 self
332 }
333
334 pub fn matches(&self, variables: &Variables) -> bool {
336 for (name, pattern) in &self.patterns {
337 if !pattern.matches(variables.get(&Name::new(name))) {
338 return false;
339 }
340 }
341 true
342 }
343}
344
345impl Default for VariableMatcher {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351#[derive(Debug, Clone)]
353pub enum VariablePattern {
354 Exact(Value),
356 Regex(String),
358 Any,
360 Present,
362 Null,
364}
365
366impl VariablePattern {
367 pub fn matches(&self, value: Option<&Value>) -> bool {
369 match (self, value) {
370 (VariablePattern::Any, _) => true,
371 (VariablePattern::Present, Some(_)) => true,
372 (VariablePattern::Present, None) => false,
373 (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
374 (VariablePattern::Null, Some(_)) => false,
375 (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
376 (VariablePattern::Exact(_), None) => false,
377 (VariablePattern::Regex(pattern), Some(Value::String(s))) => {
378 regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
379 }
380 (VariablePattern::Regex(_), _) => false,
381 }
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 struct TestHandler {
390 operation_name: String,
391 }
392
393 #[async_trait]
394 impl GraphQLHandler for TestHandler {
395 async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
396 if ctx.operation_name.as_deref() == Some(&self.operation_name) {
397 Ok(Some(Response::new(Value::Null)))
399 } else {
400 Ok(None)
401 }
402 }
403
404 fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
405 operation_name == Some(&self.operation_name)
406 }
407 }
408
409 #[tokio::test]
410 async fn test_handler_registry_new() {
411 let registry = HandlerRegistry::new();
412 assert_eq!(registry.handlers.len(), 0);
413 assert!(registry.upstream_url.is_none());
414 }
415
416 #[tokio::test]
417 async fn test_handler_registry_with_upstream() {
418 let registry =
419 HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
420 assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
421 }
422
423 #[tokio::test]
424 async fn test_handler_registry_register() {
425 let mut registry = HandlerRegistry::new();
426 let handler = TestHandler {
427 operation_name: "getUser".to_string(),
428 };
429 registry.register(handler);
430 assert_eq!(registry.handlers.len(), 1);
431 }
432
433 #[tokio::test]
434 async fn test_handler_execution() {
435 let mut registry = HandlerRegistry::new();
436 registry.register(TestHandler {
437 operation_name: "getUser".to_string(),
438 });
439
440 let ctx = GraphQLContext::new(
441 Some("getUser".to_string()),
442 OperationType::Query,
443 "query { user { id } }".to_string(),
444 Variables::default(),
445 );
446
447 let result = registry.execute_operation(&ctx).await;
448 assert!(result.is_ok());
449 assert!(result.unwrap().is_some());
450 }
451
452 #[test]
453 fn test_variable_matcher_any() {
454 let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
455
456 let mut vars = Variables::default();
457 vars.insert(Name::new("id"), Value::String("123".to_string()));
458
459 assert!(matcher.matches(&vars));
460 }
461
462 #[test]
463 fn test_variable_matcher_exact() {
464 let matcher = VariableMatcher::new().with_pattern(
465 "id".to_string(),
466 VariablePattern::Exact(Value::String("123".to_string())),
467 );
468
469 let mut vars = Variables::default();
470 vars.insert(Name::new("id"), Value::String("123".to_string()));
471
472 assert!(matcher.matches(&vars));
473
474 let mut vars2 = Variables::default();
475 vars2.insert(Name::new("id"), Value::String("456".to_string()));
476
477 assert!(!matcher.matches(&vars2));
478 }
479
480 #[test]
481 fn test_variable_pattern_present() {
482 assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
483 assert!(!VariablePattern::Present.matches(None));
484 }
485
486 #[test]
487 fn test_variable_pattern_null() {
488 assert!(VariablePattern::Null.matches(None));
489 assert!(VariablePattern::Null.matches(Some(&Value::Null)));
490 assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
491 }
492
493 #[test]
494 fn test_graphql_context_new() {
495 let ctx = GraphQLContext::new(
496 Some("getUser".to_string()),
497 OperationType::Query,
498 "query { user { id } }".to_string(),
499 Variables::default(),
500 );
501
502 assert_eq!(ctx.operation_name, Some("getUser".to_string()));
503 assert_eq!(ctx.operation_type, OperationType::Query);
504 }
505
506 #[test]
507 fn test_graphql_context_metadata() {
508 let mut ctx = GraphQLContext::new(
509 Some("getUser".to_string()),
510 OperationType::Query,
511 "query { user { id } }".to_string(),
512 Variables::default(),
513 );
514
515 ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
516 assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
517 }
518
519 #[test]
520 fn test_graphql_context_data() {
521 let mut ctx = GraphQLContext::new(
522 Some("getUser".to_string()),
523 OperationType::Query,
524 "query { user { id } }".to_string(),
525 Variables::default(),
526 );
527
528 ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
529 assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
530 }
531
532 #[test]
533 fn test_operation_type_eq() {
534 assert_eq!(OperationType::Query, OperationType::Query);
535 assert_ne!(OperationType::Query, OperationType::Mutation);
536 assert_ne!(OperationType::Mutation, OperationType::Subscription);
537 }
538
539 #[test]
540 fn test_operation_type_clone() {
541 let op = OperationType::Query;
542 let cloned = op.clone();
543 assert_eq!(op, cloned);
544 }
545
546 #[test]
547 fn test_handler_error_display() {
548 let err = HandlerError::SendError("test error".to_string());
549 assert!(err.to_string().contains("Send error"));
550
551 let err = HandlerError::OperationError("op error".to_string());
552 assert!(err.to_string().contains("Operation error"));
553
554 let err = HandlerError::UpstreamError("upstream error".to_string());
555 assert!(err.to_string().contains("Upstream error"));
556
557 let err = HandlerError::Generic("generic error".to_string());
558 assert!(err.to_string().contains("generic error"));
559 }
560
561 #[test]
562 fn test_handler_error_from_json() {
563 let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
564 let err: HandlerError = json_err.into();
565 assert!(matches!(err, HandlerError::JsonError(_)));
566 }
567
568 #[test]
569 fn test_variable_matcher_default() {
570 let matcher = VariableMatcher::default();
571 assert!(matcher.matches(&Variables::default()));
572 }
573
574 #[test]
575 fn test_variable_pattern_regex() {
576 let pattern = VariablePattern::Regex(r"^user-\d+$".to_string());
577 assert!(pattern.matches(Some(&Value::String("user-123".to_string()))));
578 assert!(!pattern.matches(Some(&Value::String("invalid".to_string()))));
579 assert!(!pattern.matches(None));
580 }
581
582 #[test]
583 fn test_variable_matcher_multiple_patterns() {
584 let matcher = VariableMatcher::new()
585 .with_pattern("id".to_string(), VariablePattern::Present)
586 .with_pattern("name".to_string(), VariablePattern::Any);
587
588 let mut vars = Variables::default();
589 vars.insert(Name::new("id"), Value::String("123".to_string()));
590
591 assert!(matcher.matches(&vars));
592 }
593
594 #[test]
595 fn test_variable_matcher_fails_on_missing() {
596 let matcher =
597 VariableMatcher::new().with_pattern("required".to_string(), VariablePattern::Present);
598
599 let vars = Variables::default();
600 assert!(!matcher.matches(&vars));
601 }
602
603 #[test]
604 fn test_graphql_context_get_variable() {
605 let mut vars = Variables::default();
606 vars.insert(Name::new("userId"), Value::String("123".to_string()));
607
608 let ctx = GraphQLContext::new(
609 Some("getUser".to_string()),
610 OperationType::Query,
611 "query { user { id } }".to_string(),
612 vars,
613 );
614
615 assert!(ctx.get_variable("userId").is_some());
616 assert!(ctx.get_variable("nonexistent").is_none());
617 }
618
619 #[test]
620 fn test_handler_registry_default() {
621 let registry = HandlerRegistry::default();
622 assert!(registry.upstream_url().is_none());
623 }
624
625 #[tokio::test]
626 async fn test_handler_registry_no_match() {
627 let mut registry = HandlerRegistry::new();
628 registry.register(TestHandler {
629 operation_name: "getUser".to_string(),
630 });
631
632 let ctx = GraphQLContext::new(
633 Some("getProduct".to_string()),
634 OperationType::Query,
635 "query { product { id } }".to_string(),
636 Variables::default(),
637 );
638
639 let result = registry.execute_operation(&ctx).await;
640 assert!(result.is_ok());
641 assert!(result.unwrap().is_none());
642 }
643
644 #[tokio::test]
645 async fn test_handler_registry_after_operation() {
646 let mut registry = HandlerRegistry::new();
647 registry.register(TestHandler {
648 operation_name: "getUser".to_string(),
649 });
650
651 let ctx = GraphQLContext::new(
652 Some("getUser".to_string()),
653 OperationType::Query,
654 "query { user { id } }".to_string(),
655 Variables::default(),
656 );
657
658 let response = Response::new(Value::Null);
659 let result = registry.after_operation(&ctx, response).await;
660 assert!(result.is_ok());
661 }
662
663 #[test]
664 fn test_handler_registry_get_handlers() {
665 let mut registry = HandlerRegistry::new();
666 registry.register(TestHandler {
667 operation_name: "getUser".to_string(),
668 });
669 registry.register(TestHandler {
670 operation_name: "getProduct".to_string(),
671 });
672
673 let handlers = registry.get_handlers(Some("getUser"), &OperationType::Query);
674 assert_eq!(handlers.len(), 1);
675
676 let handlers = registry.get_handlers(Some("unknown"), &OperationType::Query);
677 assert_eq!(handlers.len(), 0);
678 }
679
680 #[test]
681 fn test_handler_priority() {
682 struct PriorityHandler {
683 priority: i32,
684 }
685
686 #[async_trait]
687 impl GraphQLHandler for PriorityHandler {
688 fn priority(&self) -> i32 {
689 self.priority
690 }
691 }
692
693 let handler = PriorityHandler { priority: 10 };
694 assert_eq!(handler.priority(), 10);
695 }
696
697 #[test]
698 fn test_context_all_operation_types() {
699 let query_ctx = GraphQLContext::new(
700 Some("op".to_string()),
701 OperationType::Query,
702 "query".to_string(),
703 Variables::default(),
704 );
705 assert_eq!(query_ctx.operation_type, OperationType::Query);
706
707 let mutation_ctx = GraphQLContext::new(
708 Some("op".to_string()),
709 OperationType::Mutation,
710 "mutation".to_string(),
711 Variables::default(),
712 );
713 assert_eq!(mutation_ctx.operation_type, OperationType::Mutation);
714
715 let subscription_ctx = GraphQLContext::new(
716 Some("op".to_string()),
717 OperationType::Subscription,
718 "subscription".to_string(),
719 Variables::default(),
720 );
721 assert_eq!(subscription_ctx.operation_type, OperationType::Subscription);
722 }
723
724 #[test]
725 fn test_variable_pattern_debug() {
726 let pattern = VariablePattern::Any;
727 let debug = format!("{:?}", pattern);
728 assert!(debug.contains("Any"));
729 }
730
731 #[test]
732 fn test_variable_matcher_debug() {
733 let matcher = VariableMatcher::new();
734 let debug = format!("{:?}", matcher);
735 assert!(debug.contains("VariableMatcher"));
736 }
737}