1use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13pub type HandlerResult<T> = Result<T, HandlerError>;
15
16#[derive(Debug, Error)]
18pub enum HandlerError {
19 #[error("Send error: {0}")]
21 SendError(String),
22
23 #[error("JSON error: {0}")]
25 JsonError(#[from] serde_json::Error),
26
27 #[error("Operation error: {0}")]
29 OperationError(String),
30
31 #[error("Upstream error: {0}")]
33 UpstreamError(String),
34
35 #[error("{0}")]
37 Generic(String),
38}
39
40pub struct GraphQLContext {
42 pub operation_name: Option<String>,
44
45 pub operation_type: OperationType,
47
48 pub query: String,
50
51 pub variables: Variables,
53
54 pub metadata: HashMap<String, String>,
56
57 pub data: HashMap<String, serde_json::Value>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64 Query,
66 Mutation,
68 Subscription,
70}
71
72impl GraphQLContext {
73 pub fn new(
75 operation_name: Option<String>,
76 operation_type: OperationType,
77 query: String,
78 variables: Variables,
79 ) -> Self {
80 Self {
81 operation_name,
82 operation_type,
83 query,
84 variables,
85 metadata: HashMap::new(),
86 data: HashMap::new(),
87 }
88 }
89
90 pub fn get_variable(&self, name: &str) -> Option<&Value> {
92 self.variables.get(&Name::new(name))
93 }
94
95 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97 self.data.insert(key, value);
98 }
99
100 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102 self.data.get(key)
103 }
104
105 pub fn set_metadata(&mut self, key: String, value: String) {
107 self.metadata.insert(key, value);
108 }
109
110 pub fn get_metadata(&self, key: &str) -> Option<&String> {
112 self.metadata.get(key)
113 }
114}
115
116#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119 async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122 Ok(None)
123 }
124
125 async fn after_operation(
128 &self,
129 _ctx: &GraphQLContext,
130 response: Response,
131 ) -> HandlerResult<Response> {
132 Ok(response)
133 }
134
135 async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137 let server_error = ServerError::new(error, None);
138 Ok(Response::from_errors(vec![server_error]))
139 }
140
141 fn handles_operation(
143 &self,
144 operation_name: Option<&str>,
145 _operation_type: &OperationType,
146 ) -> bool {
147 operation_name.is_some()
149 }
150
151 fn priority(&self) -> i32 {
153 0
154 }
155}
156
157pub struct HandlerRegistry {
159 handlers: Vec<Arc<dyn GraphQLHandler>>,
160 upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165 pub fn new() -> Self {
167 Self {
168 handlers: Vec::new(),
169 upstream_url: None,
170 }
171 }
172
173 pub fn with_upstream(upstream_url: Option<String>) -> Self {
175 Self {
176 handlers: Vec::new(),
177 upstream_url,
178 }
179 }
180
181 pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183 self.handlers.push(Arc::new(handler));
184 self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority()));
186 }
187
188 pub fn get_handlers(
190 &self,
191 operation_name: Option<&str>,
192 operation_type: &OperationType,
193 ) -> Vec<Arc<dyn GraphQLHandler>> {
194 self.handlers
195 .iter()
196 .filter(|h| h.handles_operation(operation_name, operation_type))
197 .cloned()
198 .collect()
199 }
200
201 pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205 for handler in handlers {
206 if let Some(response) = handler.on_operation(ctx).await? {
207 return Ok(Some(response));
208 }
209 }
210
211 Ok(None)
212 }
213
214 pub async fn after_operation(
216 &self,
217 ctx: &GraphQLContext,
218 mut response: Response,
219 ) -> HandlerResult<Response> {
220 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222 for handler in handlers {
223 response = handler.after_operation(ctx, response).await?;
224 }
225
226 Ok(response)
227 }
228
229 pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231 let upstream = self
232 .upstream_url
233 .as_ref()
234 .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236 let client = reqwest::Client::new();
237 let body = json!({
238 "query": request.query.clone(),
239 "variables": request.variables.clone(),
240 "operationName": request.operation_name.clone(),
241 });
242
243 let resp = client
244 .post(upstream)
245 .json(&body)
246 .send()
247 .await
248 .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250 let response_data: serde_json::Value =
251 resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253 let errors: Vec<ServerError> = response_data
255 .get("errors")
256 .and_then(|e| e.as_array())
257 .map(|arr| {
258 arr.iter()
259 .map(|e| {
260 let msg = e
261 .get("message")
262 .and_then(|m| m.as_str())
263 .unwrap_or("Upstream GraphQL error");
264 ServerError::new(msg.to_string(), None)
265 })
266 .collect()
267 })
268 .unwrap_or_default();
269
270 let data = response_data
271 .get("data")
272 .map(|d| json_to_graphql_value(d))
273 .unwrap_or(Value::Null);
274
275 let mut response = Response::new(data);
276 response.errors = errors;
277 Ok(response)
278 }
279
280 pub fn upstream_url(&self) -> Option<&str> {
282 self.upstream_url.as_deref()
283 }
284}
285
286impl Default for HandlerRegistry {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292fn json_to_graphql_value(json: &serde_json::Value) -> Value {
294 match json {
295 serde_json::Value::Null => Value::Null,
296 serde_json::Value::Bool(b) => Value::Boolean(*b),
297 serde_json::Value::Number(n) => {
298 if let Some(i) = n.as_i64() {
299 Value::Number(i.into())
300 } else if let Some(f) = n.as_f64() {
301 Value::Number(async_graphql::Number::from_f64(f).unwrap_or_else(|| 0i32.into()))
302 } else {
303 Value::Null
304 }
305 }
306 serde_json::Value::String(s) => Value::String(s.clone()),
307 serde_json::Value::Array(arr) => {
308 Value::List(arr.iter().map(json_to_graphql_value).collect())
309 }
310 serde_json::Value::Object(obj) => {
311 let map = obj.iter().map(|(k, v)| (Name::new(k), json_to_graphql_value(v))).collect();
312 Value::Object(map)
313 }
314 }
315}
316
317#[derive(Debug, Clone)]
319pub struct VariableMatcher {
320 patterns: HashMap<String, VariablePattern>,
321}
322
323impl VariableMatcher {
324 pub fn new() -> Self {
326 Self {
327 patterns: HashMap::new(),
328 }
329 }
330
331 pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
333 self.patterns.insert(name, pattern);
334 self
335 }
336
337 pub fn matches(&self, variables: &Variables) -> bool {
339 for (name, pattern) in &self.patterns {
340 if !pattern.matches(variables.get(&Name::new(name))) {
341 return false;
342 }
343 }
344 true
345 }
346}
347
348impl Default for VariableMatcher {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354#[derive(Debug, Clone)]
356pub enum VariablePattern {
357 Exact(Value),
359 Regex(String),
361 Any,
363 Present,
365 Null,
367}
368
369impl VariablePattern {
370 pub fn matches(&self, value: Option<&Value>) -> bool {
372 match (self, value) {
373 (VariablePattern::Any, _) => true,
374 (VariablePattern::Present, Some(_)) => true,
375 (VariablePattern::Present, None) => false,
376 (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
377 (VariablePattern::Null, Some(_)) => false,
378 (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
379 (VariablePattern::Exact(_), None) => false,
380 (VariablePattern::Regex(pattern), Some(Value::String(s))) => {
381 regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
382 }
383 (VariablePattern::Regex(_), _) => false,
384 }
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 struct TestHandler {
393 operation_name: String,
394 }
395
396 #[async_trait]
397 impl GraphQLHandler for TestHandler {
398 async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
399 if ctx.operation_name.as_deref() == Some(&self.operation_name) {
400 Ok(Some(Response::new(Value::Null)))
402 } else {
403 Ok(None)
404 }
405 }
406
407 fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
408 operation_name == Some(&self.operation_name)
409 }
410 }
411
412 #[tokio::test]
413 async fn test_handler_registry_new() {
414 let registry = HandlerRegistry::new();
415 assert_eq!(registry.handlers.len(), 0);
416 assert!(registry.upstream_url.is_none());
417 }
418
419 #[tokio::test]
420 async fn test_handler_registry_with_upstream() {
421 let registry =
422 HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
423 assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
424 }
425
426 #[tokio::test]
427 async fn test_handler_registry_register() {
428 let mut registry = HandlerRegistry::new();
429 let handler = TestHandler {
430 operation_name: "getUser".to_string(),
431 };
432 registry.register(handler);
433 assert_eq!(registry.handlers.len(), 1);
434 }
435
436 #[tokio::test]
437 async fn test_handler_execution() {
438 let mut registry = HandlerRegistry::new();
439 registry.register(TestHandler {
440 operation_name: "getUser".to_string(),
441 });
442
443 let ctx = GraphQLContext::new(
444 Some("getUser".to_string()),
445 OperationType::Query,
446 "query { user { id } }".to_string(),
447 Variables::default(),
448 );
449
450 let result = registry.execute_operation(&ctx).await;
451 assert!(result.is_ok());
452 assert!(result.unwrap().is_some());
453 }
454
455 #[test]
456 fn test_variable_matcher_any() {
457 let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
458
459 let mut vars = Variables::default();
460 vars.insert(Name::new("id"), Value::String("123".to_string()));
461
462 assert!(matcher.matches(&vars));
463 }
464
465 #[test]
466 fn test_variable_matcher_exact() {
467 let matcher = VariableMatcher::new().with_pattern(
468 "id".to_string(),
469 VariablePattern::Exact(Value::String("123".to_string())),
470 );
471
472 let mut vars = Variables::default();
473 vars.insert(Name::new("id"), Value::String("123".to_string()));
474
475 assert!(matcher.matches(&vars));
476
477 let mut vars2 = Variables::default();
478 vars2.insert(Name::new("id"), Value::String("456".to_string()));
479
480 assert!(!matcher.matches(&vars2));
481 }
482
483 #[test]
484 fn test_variable_pattern_present() {
485 assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
486 assert!(!VariablePattern::Present.matches(None));
487 }
488
489 #[test]
490 fn test_variable_pattern_null() {
491 assert!(VariablePattern::Null.matches(None));
492 assert!(VariablePattern::Null.matches(Some(&Value::Null)));
493 assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
494 }
495
496 #[test]
497 fn test_graphql_context_new() {
498 let ctx = GraphQLContext::new(
499 Some("getUser".to_string()),
500 OperationType::Query,
501 "query { user { id } }".to_string(),
502 Variables::default(),
503 );
504
505 assert_eq!(ctx.operation_name, Some("getUser".to_string()));
506 assert_eq!(ctx.operation_type, OperationType::Query);
507 }
508
509 #[test]
510 fn test_graphql_context_metadata() {
511 let mut ctx = GraphQLContext::new(
512 Some("getUser".to_string()),
513 OperationType::Query,
514 "query { user { id } }".to_string(),
515 Variables::default(),
516 );
517
518 ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
519 assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
520 }
521
522 #[test]
523 fn test_graphql_context_data() {
524 let mut ctx = GraphQLContext::new(
525 Some("getUser".to_string()),
526 OperationType::Query,
527 "query { user { id } }".to_string(),
528 Variables::default(),
529 );
530
531 ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
532 assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
533 }
534
535 #[test]
536 fn test_operation_type_eq() {
537 assert_eq!(OperationType::Query, OperationType::Query);
538 assert_ne!(OperationType::Query, OperationType::Mutation);
539 assert_ne!(OperationType::Mutation, OperationType::Subscription);
540 }
541
542 #[test]
543 fn test_operation_type_clone() {
544 let op = OperationType::Query;
545 let cloned = op.clone();
546 assert_eq!(op, cloned);
547 }
548
549 #[test]
550 fn test_handler_error_display() {
551 let err = HandlerError::SendError("test error".to_string());
552 assert!(err.to_string().contains("Send error"));
553
554 let err = HandlerError::OperationError("op error".to_string());
555 assert!(err.to_string().contains("Operation error"));
556
557 let err = HandlerError::UpstreamError("upstream error".to_string());
558 assert!(err.to_string().contains("Upstream error"));
559
560 let err = HandlerError::Generic("generic error".to_string());
561 assert!(err.to_string().contains("generic error"));
562 }
563
564 #[test]
565 fn test_handler_error_from_json() {
566 let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
567 let err: HandlerError = json_err.into();
568 assert!(matches!(err, HandlerError::JsonError(_)));
569 }
570
571 #[test]
572 fn test_variable_matcher_default() {
573 let matcher = VariableMatcher::default();
574 assert!(matcher.matches(&Variables::default()));
575 }
576
577 #[test]
578 fn test_variable_pattern_regex() {
579 let pattern = VariablePattern::Regex(r"^user-\d+$".to_string());
580 assert!(pattern.matches(Some(&Value::String("user-123".to_string()))));
581 assert!(!pattern.matches(Some(&Value::String("invalid".to_string()))));
582 assert!(!pattern.matches(None));
583 }
584
585 #[test]
586 fn test_variable_matcher_multiple_patterns() {
587 let matcher = VariableMatcher::new()
588 .with_pattern("id".to_string(), VariablePattern::Present)
589 .with_pattern("name".to_string(), VariablePattern::Any);
590
591 let mut vars = Variables::default();
592 vars.insert(Name::new("id"), Value::String("123".to_string()));
593
594 assert!(matcher.matches(&vars));
595 }
596
597 #[test]
598 fn test_variable_matcher_fails_on_missing() {
599 let matcher =
600 VariableMatcher::new().with_pattern("required".to_string(), VariablePattern::Present);
601
602 let vars = Variables::default();
603 assert!(!matcher.matches(&vars));
604 }
605
606 #[test]
607 fn test_graphql_context_get_variable() {
608 let mut vars = Variables::default();
609 vars.insert(Name::new("userId"), Value::String("123".to_string()));
610
611 let ctx = GraphQLContext::new(
612 Some("getUser".to_string()),
613 OperationType::Query,
614 "query { user { id } }".to_string(),
615 vars,
616 );
617
618 assert!(ctx.get_variable("userId").is_some());
619 assert!(ctx.get_variable("nonexistent").is_none());
620 }
621
622 #[test]
623 fn test_handler_registry_default() {
624 let registry = HandlerRegistry::default();
625 assert!(registry.upstream_url().is_none());
626 }
627
628 #[tokio::test]
629 async fn test_handler_registry_no_match() {
630 let mut registry = HandlerRegistry::new();
631 registry.register(TestHandler {
632 operation_name: "getUser".to_string(),
633 });
634
635 let ctx = GraphQLContext::new(
636 Some("getProduct".to_string()),
637 OperationType::Query,
638 "query { product { id } }".to_string(),
639 Variables::default(),
640 );
641
642 let result = registry.execute_operation(&ctx).await;
643 assert!(result.is_ok());
644 assert!(result.unwrap().is_none());
645 }
646
647 #[tokio::test]
648 async fn test_handler_registry_after_operation() {
649 let mut registry = HandlerRegistry::new();
650 registry.register(TestHandler {
651 operation_name: "getUser".to_string(),
652 });
653
654 let ctx = GraphQLContext::new(
655 Some("getUser".to_string()),
656 OperationType::Query,
657 "query { user { id } }".to_string(),
658 Variables::default(),
659 );
660
661 let response = Response::new(Value::Null);
662 let result = registry.after_operation(&ctx, response).await;
663 assert!(result.is_ok());
664 }
665
666 #[test]
667 fn test_handler_registry_get_handlers() {
668 let mut registry = HandlerRegistry::new();
669 registry.register(TestHandler {
670 operation_name: "getUser".to_string(),
671 });
672 registry.register(TestHandler {
673 operation_name: "getProduct".to_string(),
674 });
675
676 let handlers = registry.get_handlers(Some("getUser"), &OperationType::Query);
677 assert_eq!(handlers.len(), 1);
678
679 let handlers = registry.get_handlers(Some("unknown"), &OperationType::Query);
680 assert_eq!(handlers.len(), 0);
681 }
682
683 #[test]
684 fn test_handler_priority() {
685 struct PriorityHandler {
686 priority: i32,
687 }
688
689 #[async_trait]
690 impl GraphQLHandler for PriorityHandler {
691 fn priority(&self) -> i32 {
692 self.priority
693 }
694 }
695
696 let handler = PriorityHandler { priority: 10 };
697 assert_eq!(handler.priority(), 10);
698 }
699
700 #[test]
701 fn test_context_all_operation_types() {
702 let query_ctx = GraphQLContext::new(
703 Some("op".to_string()),
704 OperationType::Query,
705 "query".to_string(),
706 Variables::default(),
707 );
708 assert_eq!(query_ctx.operation_type, OperationType::Query);
709
710 let mutation_ctx = GraphQLContext::new(
711 Some("op".to_string()),
712 OperationType::Mutation,
713 "mutation".to_string(),
714 Variables::default(),
715 );
716 assert_eq!(mutation_ctx.operation_type, OperationType::Mutation);
717
718 let subscription_ctx = GraphQLContext::new(
719 Some("op".to_string()),
720 OperationType::Subscription,
721 "subscription".to_string(),
722 Variables::default(),
723 );
724 assert_eq!(subscription_ctx.operation_type, OperationType::Subscription);
725 }
726
727 #[test]
728 fn test_variable_pattern_debug() {
729 let pattern = VariablePattern::Any;
730 let debug = format!("{:?}", pattern);
731 assert!(debug.contains("Any"));
732 }
733
734 #[test]
735 fn test_variable_matcher_debug() {
736 let matcher = VariableMatcher::new();
737 let debug = format!("{:?}", matcher);
738 assert!(debug.contains("VariableMatcher"));
739 }
740}