Skip to main content

mockforge_graphql/
handlers.rs

1//! GraphQL Handler System
2//!
3//! Provides a flexible handler-based system for GraphQL operations, similar to the WebSocket handler architecture.
4//! Handlers can intercept and customize query, mutation, and subscription resolution.
5
6use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13/// Result type for handler operations
14pub type HandlerResult<T> = Result<T, HandlerError>;
15
16/// Errors that can occur during handler execution
17#[derive(Debug, Error)]
18pub enum HandlerError {
19    /// Error sending response
20    #[error("Send error: {0}")]
21    SendError(String),
22
23    /// JSON serialization/deserialization error
24    #[error("JSON error: {0}")]
25    JsonError(#[from] serde_json::Error),
26
27    /// Operation matching error
28    #[error("Operation error: {0}")]
29    OperationError(String),
30
31    /// Upstream passthrough error
32    #[error("Upstream error: {0}")]
33    UpstreamError(String),
34
35    /// Generic handler error
36    #[error("{0}")]
37    Generic(String),
38}
39
40/// Context for GraphQL handler execution
41pub struct GraphQLContext {
42    /// Operation name (query/mutation name)
43    pub operation_name: Option<String>,
44
45    /// Operation type (query, mutation, subscription)
46    pub operation_type: OperationType,
47
48    /// GraphQL query string
49    pub query: String,
50
51    /// Variables passed to the operation
52    pub variables: Variables,
53
54    /// Request metadata (headers, etc.)
55    pub metadata: HashMap<String, String>,
56
57    /// Custom data storage for handlers
58    pub data: HashMap<String, serde_json::Value>,
59}
60
61/// Type of GraphQL operation
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64    /// Query operation
65    Query,
66    /// Mutation operation
67    Mutation,
68    /// Subscription operation
69    Subscription,
70}
71
72impl GraphQLContext {
73    /// Create a new GraphQL context
74    pub fn new(
75        operation_name: Option<String>,
76        operation_type: OperationType,
77        query: String,
78        variables: Variables,
79    ) -> Self {
80        Self {
81            operation_name,
82            operation_type,
83            query,
84            variables,
85            metadata: HashMap::new(),
86            data: HashMap::new(),
87        }
88    }
89
90    /// Get a variable value
91    pub fn get_variable(&self, name: &str) -> Option<&Value> {
92        self.variables.get(&Name::new(name))
93    }
94
95    /// Set custom data
96    pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97        self.data.insert(key, value);
98    }
99
100    /// Get custom data
101    pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102        self.data.get(key)
103    }
104
105    /// Set metadata
106    pub fn set_metadata(&mut self, key: String, value: String) {
107        self.metadata.insert(key, value);
108    }
109
110    /// Get metadata
111    pub fn get_metadata(&self, key: &str) -> Option<&String> {
112        self.metadata.get(key)
113    }
114}
115
116/// Trait for handling GraphQL operations
117#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119    /// Called before query/mutation execution
120    /// Return None to proceed with default resolution, Some(Response) to override
121    async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122        Ok(None)
123    }
124
125    /// Called after successful query/mutation execution
126    /// Allows modification of the response
127    async fn after_operation(
128        &self,
129        _ctx: &GraphQLContext,
130        response: Response,
131    ) -> HandlerResult<Response> {
132        Ok(response)
133    }
134
135    /// Called when an error occurs
136    async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137        let server_error = ServerError::new(error, None);
138        Ok(Response::from_errors(vec![server_error]))
139    }
140
141    /// Check if this handler should handle the given operation
142    fn handles_operation(
143        &self,
144        operation_name: Option<&str>,
145        _operation_type: &OperationType,
146    ) -> bool {
147        // Default: handle all operations
148        operation_name.is_some()
149    }
150
151    /// Priority of this handler (higher = executes first)
152    fn priority(&self) -> i32 {
153        0
154    }
155}
156
157/// Registry for managing GraphQL handlers
158pub struct HandlerRegistry {
159    handlers: Vec<Arc<dyn GraphQLHandler>>,
160    /// Upstream GraphQL server URL for passthrough
161    upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165    /// Create a new handler registry
166    pub fn new() -> Self {
167        Self {
168            handlers: Vec::new(),
169            upstream_url: None,
170        }
171    }
172
173    /// Create a handler registry with upstream URL
174    pub fn with_upstream(upstream_url: Option<String>) -> Self {
175        Self {
176            handlers: Vec::new(),
177            upstream_url,
178        }
179    }
180
181    /// Register a handler
182    pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183        self.handlers.push(Arc::new(handler));
184        // Sort by priority (highest first)
185        self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority()));
186    }
187
188    /// Get handlers for a specific operation
189    pub fn get_handlers(
190        &self,
191        operation_name: Option<&str>,
192        operation_type: &OperationType,
193    ) -> Vec<Arc<dyn GraphQLHandler>> {
194        self.handlers
195            .iter()
196            .filter(|h| h.handles_operation(operation_name, operation_type))
197            .cloned()
198            .collect()
199    }
200
201    /// Execute handlers for an operation
202    pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203        let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205        for handler in handlers {
206            if let Some(response) = handler.on_operation(ctx).await? {
207                return Ok(Some(response));
208            }
209        }
210
211        Ok(None)
212    }
213
214    /// Execute after_operation hooks
215    pub async fn after_operation(
216        &self,
217        ctx: &GraphQLContext,
218        mut response: Response,
219    ) -> HandlerResult<Response> {
220        let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222        for handler in handlers {
223            response = handler.after_operation(ctx, response).await?;
224        }
225
226        Ok(response)
227    }
228
229    /// Passthrough request to upstream server
230    pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231        let upstream = self
232            .upstream_url
233            .as_ref()
234            .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236        let client = reqwest::Client::new();
237        let body = json!({
238            "query": request.query.clone(),
239            "variables": request.variables.clone(),
240            "operationName": request.operation_name.clone(),
241        });
242
243        let resp = client
244            .post(upstream)
245            .json(&body)
246            .send()
247            .await
248            .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250        let response_data: serde_json::Value =
251            resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253        // Convert JSON response to GraphQL Response
254        let errors: Vec<ServerError> = response_data
255            .get("errors")
256            .and_then(|e| e.as_array())
257            .map(|arr| {
258                arr.iter()
259                    .map(|e| {
260                        let msg = e
261                            .get("message")
262                            .and_then(|m| m.as_str())
263                            .unwrap_or("Upstream GraphQL error");
264                        ServerError::new(msg.to_string(), None)
265                    })
266                    .collect()
267            })
268            .unwrap_or_default();
269
270        let data = response_data
271            .get("data")
272            .map(|d| json_to_graphql_value(d))
273            .unwrap_or(Value::Null);
274
275        let mut response = Response::new(data);
276        response.errors = errors;
277        Ok(response)
278    }
279
280    /// Get upstream URL
281    pub fn upstream_url(&self) -> Option<&str> {
282        self.upstream_url.as_deref()
283    }
284}
285
286impl Default for HandlerRegistry {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292/// Convert a `serde_json::Value` to an `async_graphql::Value`
293fn json_to_graphql_value(json: &serde_json::Value) -> Value {
294    match json {
295        serde_json::Value::Null => Value::Null,
296        serde_json::Value::Bool(b) => Value::Boolean(*b),
297        serde_json::Value::Number(n) => {
298            if let Some(i) = n.as_i64() {
299                Value::Number(i.into())
300            } else if let Some(f) = n.as_f64() {
301                Value::Number(async_graphql::Number::from_f64(f).unwrap_or_else(|| 0i32.into()))
302            } else {
303                Value::Null
304            }
305        }
306        serde_json::Value::String(s) => Value::String(s.clone()),
307        serde_json::Value::Array(arr) => {
308            Value::List(arr.iter().map(json_to_graphql_value).collect())
309        }
310        serde_json::Value::Object(obj) => {
311            let map = obj.iter().map(|(k, v)| (Name::new(k), json_to_graphql_value(v))).collect();
312            Value::Object(map)
313        }
314    }
315}
316
317/// Variable matcher for filtering operations by variable values
318#[derive(Debug, Clone)]
319pub struct VariableMatcher {
320    patterns: HashMap<String, VariablePattern>,
321}
322
323impl VariableMatcher {
324    /// Create a new variable matcher
325    pub fn new() -> Self {
326        Self {
327            patterns: HashMap::new(),
328        }
329    }
330
331    /// Add a pattern for a variable
332    pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
333        self.patterns.insert(name, pattern);
334        self
335    }
336
337    /// Check if variables match the patterns
338    pub fn matches(&self, variables: &Variables) -> bool {
339        for (name, pattern) in &self.patterns {
340            if !pattern.matches(variables.get(&Name::new(name))) {
341                return false;
342            }
343        }
344        true
345    }
346}
347
348impl Default for VariableMatcher {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354/// Pattern for matching variable values
355#[derive(Debug, Clone)]
356pub enum VariablePattern {
357    /// Exact value match
358    Exact(Value),
359    /// Regular expression match (for strings)
360    Regex(String),
361    /// Any value (always matches)
362    Any,
363    /// Value must be present
364    Present,
365    /// Value must be null or absent
366    Null,
367}
368
369impl VariablePattern {
370    /// Check if a value matches this pattern
371    pub fn matches(&self, value: Option<&Value>) -> bool {
372        match (self, value) {
373            (VariablePattern::Any, _) => true,
374            (VariablePattern::Present, Some(_)) => true,
375            (VariablePattern::Present, None) => false,
376            (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
377            (VariablePattern::Null, Some(_)) => false,
378            (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
379            (VariablePattern::Exact(_), None) => false,
380            (VariablePattern::Regex(pattern), Some(Value::String(s))) => {
381                regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
382            }
383            (VariablePattern::Regex(_), _) => false,
384        }
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    struct TestHandler {
393        operation_name: String,
394    }
395
396    #[async_trait]
397    impl GraphQLHandler for TestHandler {
398        async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
399            if ctx.operation_name.as_deref() == Some(&self.operation_name) {
400                // Return a simple null response for testing
401                Ok(Some(Response::new(Value::Null)))
402            } else {
403                Ok(None)
404            }
405        }
406
407        fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
408            operation_name == Some(&self.operation_name)
409        }
410    }
411
412    #[tokio::test]
413    async fn test_handler_registry_new() {
414        let registry = HandlerRegistry::new();
415        assert_eq!(registry.handlers.len(), 0);
416        assert!(registry.upstream_url.is_none());
417    }
418
419    #[tokio::test]
420    async fn test_handler_registry_with_upstream() {
421        let registry =
422            HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
423        assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
424    }
425
426    #[tokio::test]
427    async fn test_handler_registry_register() {
428        let mut registry = HandlerRegistry::new();
429        let handler = TestHandler {
430            operation_name: "getUser".to_string(),
431        };
432        registry.register(handler);
433        assert_eq!(registry.handlers.len(), 1);
434    }
435
436    #[tokio::test]
437    async fn test_handler_execution() {
438        let mut registry = HandlerRegistry::new();
439        registry.register(TestHandler {
440            operation_name: "getUser".to_string(),
441        });
442
443        let ctx = GraphQLContext::new(
444            Some("getUser".to_string()),
445            OperationType::Query,
446            "query { user { id } }".to_string(),
447            Variables::default(),
448        );
449
450        let result = registry.execute_operation(&ctx).await;
451        assert!(result.is_ok());
452        assert!(result.unwrap().is_some());
453    }
454
455    #[test]
456    fn test_variable_matcher_any() {
457        let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
458
459        let mut vars = Variables::default();
460        vars.insert(Name::new("id"), Value::String("123".to_string()));
461
462        assert!(matcher.matches(&vars));
463    }
464
465    #[test]
466    fn test_variable_matcher_exact() {
467        let matcher = VariableMatcher::new().with_pattern(
468            "id".to_string(),
469            VariablePattern::Exact(Value::String("123".to_string())),
470        );
471
472        let mut vars = Variables::default();
473        vars.insert(Name::new("id"), Value::String("123".to_string()));
474
475        assert!(matcher.matches(&vars));
476
477        let mut vars2 = Variables::default();
478        vars2.insert(Name::new("id"), Value::String("456".to_string()));
479
480        assert!(!matcher.matches(&vars2));
481    }
482
483    #[test]
484    fn test_variable_pattern_present() {
485        assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
486        assert!(!VariablePattern::Present.matches(None));
487    }
488
489    #[test]
490    fn test_variable_pattern_null() {
491        assert!(VariablePattern::Null.matches(None));
492        assert!(VariablePattern::Null.matches(Some(&Value::Null)));
493        assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
494    }
495
496    #[test]
497    fn test_graphql_context_new() {
498        let ctx = GraphQLContext::new(
499            Some("getUser".to_string()),
500            OperationType::Query,
501            "query { user { id } }".to_string(),
502            Variables::default(),
503        );
504
505        assert_eq!(ctx.operation_name, Some("getUser".to_string()));
506        assert_eq!(ctx.operation_type, OperationType::Query);
507    }
508
509    #[test]
510    fn test_graphql_context_metadata() {
511        let mut ctx = GraphQLContext::new(
512            Some("getUser".to_string()),
513            OperationType::Query,
514            "query { user { id } }".to_string(),
515            Variables::default(),
516        );
517
518        ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
519        assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
520    }
521
522    #[test]
523    fn test_graphql_context_data() {
524        let mut ctx = GraphQLContext::new(
525            Some("getUser".to_string()),
526            OperationType::Query,
527            "query { user { id } }".to_string(),
528            Variables::default(),
529        );
530
531        ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
532        assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
533    }
534
535    #[test]
536    fn test_operation_type_eq() {
537        assert_eq!(OperationType::Query, OperationType::Query);
538        assert_ne!(OperationType::Query, OperationType::Mutation);
539        assert_ne!(OperationType::Mutation, OperationType::Subscription);
540    }
541
542    #[test]
543    fn test_operation_type_clone() {
544        let op = OperationType::Query;
545        let cloned = op.clone();
546        assert_eq!(op, cloned);
547    }
548
549    #[test]
550    fn test_handler_error_display() {
551        let err = HandlerError::SendError("test error".to_string());
552        assert!(err.to_string().contains("Send error"));
553
554        let err = HandlerError::OperationError("op error".to_string());
555        assert!(err.to_string().contains("Operation error"));
556
557        let err = HandlerError::UpstreamError("upstream error".to_string());
558        assert!(err.to_string().contains("Upstream error"));
559
560        let err = HandlerError::Generic("generic error".to_string());
561        assert!(err.to_string().contains("generic error"));
562    }
563
564    #[test]
565    fn test_handler_error_from_json() {
566        let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
567        let err: HandlerError = json_err.into();
568        assert!(matches!(err, HandlerError::JsonError(_)));
569    }
570
571    #[test]
572    fn test_variable_matcher_default() {
573        let matcher = VariableMatcher::default();
574        assert!(matcher.matches(&Variables::default()));
575    }
576
577    #[test]
578    fn test_variable_pattern_regex() {
579        let pattern = VariablePattern::Regex(r"^user-\d+$".to_string());
580        assert!(pattern.matches(Some(&Value::String("user-123".to_string()))));
581        assert!(!pattern.matches(Some(&Value::String("invalid".to_string()))));
582        assert!(!pattern.matches(None));
583    }
584
585    #[test]
586    fn test_variable_matcher_multiple_patterns() {
587        let matcher = VariableMatcher::new()
588            .with_pattern("id".to_string(), VariablePattern::Present)
589            .with_pattern("name".to_string(), VariablePattern::Any);
590
591        let mut vars = Variables::default();
592        vars.insert(Name::new("id"), Value::String("123".to_string()));
593
594        assert!(matcher.matches(&vars));
595    }
596
597    #[test]
598    fn test_variable_matcher_fails_on_missing() {
599        let matcher =
600            VariableMatcher::new().with_pattern("required".to_string(), VariablePattern::Present);
601
602        let vars = Variables::default();
603        assert!(!matcher.matches(&vars));
604    }
605
606    #[test]
607    fn test_graphql_context_get_variable() {
608        let mut vars = Variables::default();
609        vars.insert(Name::new("userId"), Value::String("123".to_string()));
610
611        let ctx = GraphQLContext::new(
612            Some("getUser".to_string()),
613            OperationType::Query,
614            "query { user { id } }".to_string(),
615            vars,
616        );
617
618        assert!(ctx.get_variable("userId").is_some());
619        assert!(ctx.get_variable("nonexistent").is_none());
620    }
621
622    #[test]
623    fn test_handler_registry_default() {
624        let registry = HandlerRegistry::default();
625        assert!(registry.upstream_url().is_none());
626    }
627
628    #[tokio::test]
629    async fn test_handler_registry_no_match() {
630        let mut registry = HandlerRegistry::new();
631        registry.register(TestHandler {
632            operation_name: "getUser".to_string(),
633        });
634
635        let ctx = GraphQLContext::new(
636            Some("getProduct".to_string()),
637            OperationType::Query,
638            "query { product { id } }".to_string(),
639            Variables::default(),
640        );
641
642        let result = registry.execute_operation(&ctx).await;
643        assert!(result.is_ok());
644        assert!(result.unwrap().is_none());
645    }
646
647    #[tokio::test]
648    async fn test_handler_registry_after_operation() {
649        let mut registry = HandlerRegistry::new();
650        registry.register(TestHandler {
651            operation_name: "getUser".to_string(),
652        });
653
654        let ctx = GraphQLContext::new(
655            Some("getUser".to_string()),
656            OperationType::Query,
657            "query { user { id } }".to_string(),
658            Variables::default(),
659        );
660
661        let response = Response::new(Value::Null);
662        let result = registry.after_operation(&ctx, response).await;
663        assert!(result.is_ok());
664    }
665
666    #[test]
667    fn test_handler_registry_get_handlers() {
668        let mut registry = HandlerRegistry::new();
669        registry.register(TestHandler {
670            operation_name: "getUser".to_string(),
671        });
672        registry.register(TestHandler {
673            operation_name: "getProduct".to_string(),
674        });
675
676        let handlers = registry.get_handlers(Some("getUser"), &OperationType::Query);
677        assert_eq!(handlers.len(), 1);
678
679        let handlers = registry.get_handlers(Some("unknown"), &OperationType::Query);
680        assert_eq!(handlers.len(), 0);
681    }
682
683    #[test]
684    fn test_handler_priority() {
685        struct PriorityHandler {
686            priority: i32,
687        }
688
689        #[async_trait]
690        impl GraphQLHandler for PriorityHandler {
691            fn priority(&self) -> i32 {
692                self.priority
693            }
694        }
695
696        let handler = PriorityHandler { priority: 10 };
697        assert_eq!(handler.priority(), 10);
698    }
699
700    #[test]
701    fn test_context_all_operation_types() {
702        let query_ctx = GraphQLContext::new(
703            Some("op".to_string()),
704            OperationType::Query,
705            "query".to_string(),
706            Variables::default(),
707        );
708        assert_eq!(query_ctx.operation_type, OperationType::Query);
709
710        let mutation_ctx = GraphQLContext::new(
711            Some("op".to_string()),
712            OperationType::Mutation,
713            "mutation".to_string(),
714            Variables::default(),
715        );
716        assert_eq!(mutation_ctx.operation_type, OperationType::Mutation);
717
718        let subscription_ctx = GraphQLContext::new(
719            Some("op".to_string()),
720            OperationType::Subscription,
721            "subscription".to_string(),
722            Variables::default(),
723        );
724        assert_eq!(subscription_ctx.operation_type, OperationType::Subscription);
725    }
726
727    #[test]
728    fn test_variable_pattern_debug() {
729        let pattern = VariablePattern::Any;
730        let debug = format!("{:?}", pattern);
731        assert!(debug.contains("Any"));
732    }
733
734    #[test]
735    fn test_variable_matcher_debug() {
736        let matcher = VariableMatcher::new();
737        let debug = format!("{:?}", matcher);
738        assert!(debug.contains("VariableMatcher"));
739    }
740}