Skip to main content

mockforge_graphql/
handlers.rs

1//! GraphQL Handler System
2//!
3//! Provides a flexible handler-based system for GraphQL operations, similar to the WebSocket handler architecture.
4//! Handlers can intercept and customize query, mutation, and subscription resolution.
5
6use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13/// Result type for handler operations
14pub type HandlerResult<T> = Result<T, HandlerError>;
15
16/// Errors that can occur during handler execution
17#[derive(Debug, Error)]
18pub enum HandlerError {
19    /// Error sending response
20    #[error("Send error: {0}")]
21    SendError(String),
22
23    /// JSON serialization/deserialization error
24    #[error("JSON error: {0}")]
25    JsonError(#[from] serde_json::Error),
26
27    /// Operation matching error
28    #[error("Operation error: {0}")]
29    OperationError(String),
30
31    /// Upstream passthrough error
32    #[error("Upstream error: {0}")]
33    UpstreamError(String),
34
35    /// Generic handler error
36    #[error("{0}")]
37    Generic(String),
38}
39
40/// Context for GraphQL handler execution
41pub struct GraphQLContext {
42    /// Operation name (query/mutation name)
43    pub operation_name: Option<String>,
44
45    /// Operation type (query, mutation, subscription)
46    pub operation_type: OperationType,
47
48    /// GraphQL query string
49    pub query: String,
50
51    /// Variables passed to the operation
52    pub variables: Variables,
53
54    /// Request metadata (headers, etc.)
55    pub metadata: HashMap<String, String>,
56
57    /// Custom data storage for handlers
58    pub data: HashMap<String, serde_json::Value>,
59}
60
61/// Type of GraphQL operation
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64    /// Query operation
65    Query,
66    /// Mutation operation
67    Mutation,
68    /// Subscription operation
69    Subscription,
70}
71
72impl GraphQLContext {
73    /// Create a new GraphQL context
74    pub fn new(
75        operation_name: Option<String>,
76        operation_type: OperationType,
77        query: String,
78        variables: Variables,
79    ) -> Self {
80        Self {
81            operation_name,
82            operation_type,
83            query,
84            variables,
85            metadata: HashMap::new(),
86            data: HashMap::new(),
87        }
88    }
89
90    /// Get a variable value
91    pub fn get_variable(&self, name: &str) -> Option<&Value> {
92        self.variables.get(&Name::new(name))
93    }
94
95    /// Set custom data
96    pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97        self.data.insert(key, value);
98    }
99
100    /// Get custom data
101    pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102        self.data.get(key)
103    }
104
105    /// Set metadata
106    pub fn set_metadata(&mut self, key: String, value: String) {
107        self.metadata.insert(key, value);
108    }
109
110    /// Get metadata
111    pub fn get_metadata(&self, key: &str) -> Option<&String> {
112        self.metadata.get(key)
113    }
114}
115
116/// Trait for handling GraphQL operations
117#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119    /// Called before query/mutation execution
120    /// Return None to proceed with default resolution, Some(Response) to override
121    async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122        Ok(None)
123    }
124
125    /// Called after successful query/mutation execution
126    /// Allows modification of the response
127    async fn after_operation(
128        &self,
129        _ctx: &GraphQLContext,
130        response: Response,
131    ) -> HandlerResult<Response> {
132        Ok(response)
133    }
134
135    /// Called when an error occurs
136    async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137        let server_error = ServerError::new(error, None);
138        Ok(Response::from_errors(vec![server_error]))
139    }
140
141    /// Check if this handler should handle the given operation
142    fn handles_operation(
143        &self,
144        operation_name: Option<&str>,
145        _operation_type: &OperationType,
146    ) -> bool {
147        // Default: handle all operations
148        operation_name.is_some()
149    }
150
151    /// Priority of this handler (higher = executes first)
152    fn priority(&self) -> i32 {
153        0
154    }
155}
156
157/// Registry for managing GraphQL handlers
158pub struct HandlerRegistry {
159    handlers: Vec<Arc<dyn GraphQLHandler>>,
160    /// Upstream GraphQL server URL for passthrough
161    upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165    /// Create a new handler registry
166    pub fn new() -> Self {
167        Self {
168            handlers: Vec::new(),
169            upstream_url: None,
170        }
171    }
172
173    /// Create a handler registry with upstream URL
174    pub fn with_upstream(upstream_url: Option<String>) -> Self {
175        Self {
176            handlers: Vec::new(),
177            upstream_url,
178        }
179    }
180
181    /// Register a handler
182    pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183        self.handlers.push(Arc::new(handler));
184        // Sort by priority (highest first)
185        self.handlers.sort_by_key(|b| std::cmp::Reverse(b.priority()));
186    }
187
188    /// Get handlers for a specific operation
189    pub fn get_handlers(
190        &self,
191        operation_name: Option<&str>,
192        operation_type: &OperationType,
193    ) -> Vec<Arc<dyn GraphQLHandler>> {
194        self.handlers
195            .iter()
196            .filter(|h| h.handles_operation(operation_name, operation_type))
197            .cloned()
198            .collect()
199    }
200
201    /// Execute handlers for an operation
202    pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203        let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205        for handler in handlers {
206            if let Some(response) = handler.on_operation(ctx).await? {
207                return Ok(Some(response));
208            }
209        }
210
211        Ok(None)
212    }
213
214    /// Execute after_operation hooks
215    pub async fn after_operation(
216        &self,
217        ctx: &GraphQLContext,
218        mut response: Response,
219    ) -> HandlerResult<Response> {
220        let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222        for handler in handlers {
223            response = handler.after_operation(ctx, response).await?;
224        }
225
226        Ok(response)
227    }
228
229    /// Passthrough request to upstream server
230    pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231        let upstream = self
232            .upstream_url
233            .as_ref()
234            .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236        let client = reqwest::Client::new();
237        let body = json!({
238            "query": request.query.clone(),
239            "variables": request.variables.clone(),
240            "operationName": request.operation_name.clone(),
241        });
242
243        let resp = client
244            .post(upstream)
245            .json(&body)
246            .send()
247            .await
248            .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250        let response_data: serde_json::Value =
251            resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253        // Convert JSON response to GraphQL Response
254        let errors: Vec<ServerError> = response_data
255            .get("errors")
256            .and_then(|e| e.as_array())
257            .map(|arr| {
258                arr.iter()
259                    .map(|e| {
260                        let msg = e
261                            .get("message")
262                            .and_then(|m| m.as_str())
263                            .unwrap_or("Upstream GraphQL error");
264                        ServerError::new(msg.to_string(), None)
265                    })
266                    .collect()
267            })
268            .unwrap_or_default();
269
270        let data = response_data.get("data").map(json_to_graphql_value).unwrap_or(Value::Null);
271
272        let mut response = Response::new(data);
273        response.errors = errors;
274        Ok(response)
275    }
276
277    /// Get upstream URL
278    pub fn upstream_url(&self) -> Option<&str> {
279        self.upstream_url.as_deref()
280    }
281}
282
283impl Default for HandlerRegistry {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289/// Convert a `serde_json::Value` to an `async_graphql::Value`
290fn json_to_graphql_value(json: &serde_json::Value) -> Value {
291    match json {
292        serde_json::Value::Null => Value::Null,
293        serde_json::Value::Bool(b) => Value::Boolean(*b),
294        serde_json::Value::Number(n) => {
295            if let Some(i) = n.as_i64() {
296                Value::Number(i.into())
297            } else if let Some(f) = n.as_f64() {
298                Value::Number(async_graphql::Number::from_f64(f).unwrap_or_else(|| 0i32.into()))
299            } else {
300                Value::Null
301            }
302        }
303        serde_json::Value::String(s) => Value::String(s.clone()),
304        serde_json::Value::Array(arr) => {
305            Value::List(arr.iter().map(json_to_graphql_value).collect())
306        }
307        serde_json::Value::Object(obj) => {
308            let map = obj.iter().map(|(k, v)| (Name::new(k), json_to_graphql_value(v))).collect();
309            Value::Object(map)
310        }
311    }
312}
313
314/// Variable matcher for filtering operations by variable values
315#[derive(Debug, Clone)]
316pub struct VariableMatcher {
317    patterns: HashMap<String, VariablePattern>,
318}
319
320impl VariableMatcher {
321    /// Create a new variable matcher
322    pub fn new() -> Self {
323        Self {
324            patterns: HashMap::new(),
325        }
326    }
327
328    /// Add a pattern for a variable
329    pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
330        self.patterns.insert(name, pattern);
331        self
332    }
333
334    /// Check if variables match the patterns
335    pub fn matches(&self, variables: &Variables) -> bool {
336        for (name, pattern) in &self.patterns {
337            if !pattern.matches(variables.get(&Name::new(name))) {
338                return false;
339            }
340        }
341        true
342    }
343}
344
345impl Default for VariableMatcher {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351/// Pattern for matching variable values
352#[derive(Debug, Clone)]
353pub enum VariablePattern {
354    /// Exact value match
355    Exact(Value),
356    /// Regular expression match (for strings)
357    Regex(String),
358    /// Any value (always matches)
359    Any,
360    /// Value must be present
361    Present,
362    /// Value must be null or absent
363    Null,
364}
365
366impl VariablePattern {
367    /// Check if a value matches this pattern
368    pub fn matches(&self, value: Option<&Value>) -> bool {
369        match (self, value) {
370            (VariablePattern::Any, _) => true,
371            (VariablePattern::Present, Some(_)) => true,
372            (VariablePattern::Present, None) => false,
373            (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
374            (VariablePattern::Null, Some(_)) => false,
375            (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
376            (VariablePattern::Exact(_), None) => false,
377            (VariablePattern::Regex(pattern), Some(Value::String(s))) => {
378                regex::Regex::new(pattern).ok().map(|re| re.is_match(s)).unwrap_or(false)
379            }
380            (VariablePattern::Regex(_), _) => false,
381        }
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    struct TestHandler {
390        operation_name: String,
391    }
392
393    #[async_trait]
394    impl GraphQLHandler for TestHandler {
395        async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
396            if ctx.operation_name.as_deref() == Some(&self.operation_name) {
397                // Return a simple null response for testing
398                Ok(Some(Response::new(Value::Null)))
399            } else {
400                Ok(None)
401            }
402        }
403
404        fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
405            operation_name == Some(&self.operation_name)
406        }
407    }
408
409    #[tokio::test]
410    async fn test_handler_registry_new() {
411        let registry = HandlerRegistry::new();
412        assert_eq!(registry.handlers.len(), 0);
413        assert!(registry.upstream_url.is_none());
414    }
415
416    #[tokio::test]
417    async fn test_handler_registry_with_upstream() {
418        let registry =
419            HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
420        assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
421    }
422
423    #[tokio::test]
424    async fn test_handler_registry_register() {
425        let mut registry = HandlerRegistry::new();
426        let handler = TestHandler {
427            operation_name: "getUser".to_string(),
428        };
429        registry.register(handler);
430        assert_eq!(registry.handlers.len(), 1);
431    }
432
433    #[tokio::test]
434    async fn test_handler_execution() {
435        let mut registry = HandlerRegistry::new();
436        registry.register(TestHandler {
437            operation_name: "getUser".to_string(),
438        });
439
440        let ctx = GraphQLContext::new(
441            Some("getUser".to_string()),
442            OperationType::Query,
443            "query { user { id } }".to_string(),
444            Variables::default(),
445        );
446
447        let result = registry.execute_operation(&ctx).await;
448        assert!(result.is_ok());
449        assert!(result.unwrap().is_some());
450    }
451
452    #[test]
453    fn test_variable_matcher_any() {
454        let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
455
456        let mut vars = Variables::default();
457        vars.insert(Name::new("id"), Value::String("123".to_string()));
458
459        assert!(matcher.matches(&vars));
460    }
461
462    #[test]
463    fn test_variable_matcher_exact() {
464        let matcher = VariableMatcher::new().with_pattern(
465            "id".to_string(),
466            VariablePattern::Exact(Value::String("123".to_string())),
467        );
468
469        let mut vars = Variables::default();
470        vars.insert(Name::new("id"), Value::String("123".to_string()));
471
472        assert!(matcher.matches(&vars));
473
474        let mut vars2 = Variables::default();
475        vars2.insert(Name::new("id"), Value::String("456".to_string()));
476
477        assert!(!matcher.matches(&vars2));
478    }
479
480    #[test]
481    fn test_variable_pattern_present() {
482        assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
483        assert!(!VariablePattern::Present.matches(None));
484    }
485
486    #[test]
487    fn test_variable_pattern_null() {
488        assert!(VariablePattern::Null.matches(None));
489        assert!(VariablePattern::Null.matches(Some(&Value::Null)));
490        assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
491    }
492
493    #[test]
494    fn test_graphql_context_new() {
495        let ctx = GraphQLContext::new(
496            Some("getUser".to_string()),
497            OperationType::Query,
498            "query { user { id } }".to_string(),
499            Variables::default(),
500        );
501
502        assert_eq!(ctx.operation_name, Some("getUser".to_string()));
503        assert_eq!(ctx.operation_type, OperationType::Query);
504    }
505
506    #[test]
507    fn test_graphql_context_metadata() {
508        let mut ctx = GraphQLContext::new(
509            Some("getUser".to_string()),
510            OperationType::Query,
511            "query { user { id } }".to_string(),
512            Variables::default(),
513        );
514
515        ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
516        assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
517    }
518
519    #[test]
520    fn test_graphql_context_data() {
521        let mut ctx = GraphQLContext::new(
522            Some("getUser".to_string()),
523            OperationType::Query,
524            "query { user { id } }".to_string(),
525            Variables::default(),
526        );
527
528        ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
529        assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
530    }
531
532    #[test]
533    fn test_operation_type_eq() {
534        assert_eq!(OperationType::Query, OperationType::Query);
535        assert_ne!(OperationType::Query, OperationType::Mutation);
536        assert_ne!(OperationType::Mutation, OperationType::Subscription);
537    }
538
539    #[test]
540    fn test_operation_type_clone() {
541        let op = OperationType::Query;
542        let cloned = op.clone();
543        assert_eq!(op, cloned);
544    }
545
546    #[test]
547    fn test_handler_error_display() {
548        let err = HandlerError::SendError("test error".to_string());
549        assert!(err.to_string().contains("Send error"));
550
551        let err = HandlerError::OperationError("op error".to_string());
552        assert!(err.to_string().contains("Operation error"));
553
554        let err = HandlerError::UpstreamError("upstream error".to_string());
555        assert!(err.to_string().contains("Upstream error"));
556
557        let err = HandlerError::Generic("generic error".to_string());
558        assert!(err.to_string().contains("generic error"));
559    }
560
561    #[test]
562    fn test_handler_error_from_json() {
563        let json_err = serde_json::from_str::<i32>("not a number").unwrap_err();
564        let err: HandlerError = json_err.into();
565        assert!(matches!(err, HandlerError::JsonError(_)));
566    }
567
568    #[test]
569    fn test_variable_matcher_default() {
570        let matcher = VariableMatcher::default();
571        assert!(matcher.matches(&Variables::default()));
572    }
573
574    #[test]
575    fn test_variable_pattern_regex() {
576        let pattern = VariablePattern::Regex(r"^user-\d+$".to_string());
577        assert!(pattern.matches(Some(&Value::String("user-123".to_string()))));
578        assert!(!pattern.matches(Some(&Value::String("invalid".to_string()))));
579        assert!(!pattern.matches(None));
580    }
581
582    #[test]
583    fn test_variable_matcher_multiple_patterns() {
584        let matcher = VariableMatcher::new()
585            .with_pattern("id".to_string(), VariablePattern::Present)
586            .with_pattern("name".to_string(), VariablePattern::Any);
587
588        let mut vars = Variables::default();
589        vars.insert(Name::new("id"), Value::String("123".to_string()));
590
591        assert!(matcher.matches(&vars));
592    }
593
594    #[test]
595    fn test_variable_matcher_fails_on_missing() {
596        let matcher =
597            VariableMatcher::new().with_pattern("required".to_string(), VariablePattern::Present);
598
599        let vars = Variables::default();
600        assert!(!matcher.matches(&vars));
601    }
602
603    #[test]
604    fn test_graphql_context_get_variable() {
605        let mut vars = Variables::default();
606        vars.insert(Name::new("userId"), Value::String("123".to_string()));
607
608        let ctx = GraphQLContext::new(
609            Some("getUser".to_string()),
610            OperationType::Query,
611            "query { user { id } }".to_string(),
612            vars,
613        );
614
615        assert!(ctx.get_variable("userId").is_some());
616        assert!(ctx.get_variable("nonexistent").is_none());
617    }
618
619    #[test]
620    fn test_handler_registry_default() {
621        let registry = HandlerRegistry::default();
622        assert!(registry.upstream_url().is_none());
623    }
624
625    #[tokio::test]
626    async fn test_handler_registry_no_match() {
627        let mut registry = HandlerRegistry::new();
628        registry.register(TestHandler {
629            operation_name: "getUser".to_string(),
630        });
631
632        let ctx = GraphQLContext::new(
633            Some("getProduct".to_string()),
634            OperationType::Query,
635            "query { product { id } }".to_string(),
636            Variables::default(),
637        );
638
639        let result = registry.execute_operation(&ctx).await;
640        assert!(result.is_ok());
641        assert!(result.unwrap().is_none());
642    }
643
644    #[tokio::test]
645    async fn test_handler_registry_after_operation() {
646        let mut registry = HandlerRegistry::new();
647        registry.register(TestHandler {
648            operation_name: "getUser".to_string(),
649        });
650
651        let ctx = GraphQLContext::new(
652            Some("getUser".to_string()),
653            OperationType::Query,
654            "query { user { id } }".to_string(),
655            Variables::default(),
656        );
657
658        let response = Response::new(Value::Null);
659        let result = registry.after_operation(&ctx, response).await;
660        assert!(result.is_ok());
661    }
662
663    #[test]
664    fn test_handler_registry_get_handlers() {
665        let mut registry = HandlerRegistry::new();
666        registry.register(TestHandler {
667            operation_name: "getUser".to_string(),
668        });
669        registry.register(TestHandler {
670            operation_name: "getProduct".to_string(),
671        });
672
673        let handlers = registry.get_handlers(Some("getUser"), &OperationType::Query);
674        assert_eq!(handlers.len(), 1);
675
676        let handlers = registry.get_handlers(Some("unknown"), &OperationType::Query);
677        assert_eq!(handlers.len(), 0);
678    }
679
680    #[test]
681    fn test_handler_priority() {
682        struct PriorityHandler {
683            priority: i32,
684        }
685
686        #[async_trait]
687        impl GraphQLHandler for PriorityHandler {
688            fn priority(&self) -> i32 {
689                self.priority
690            }
691        }
692
693        let handler = PriorityHandler { priority: 10 };
694        assert_eq!(handler.priority(), 10);
695    }
696
697    #[test]
698    fn test_context_all_operation_types() {
699        let query_ctx = GraphQLContext::new(
700            Some("op".to_string()),
701            OperationType::Query,
702            "query".to_string(),
703            Variables::default(),
704        );
705        assert_eq!(query_ctx.operation_type, OperationType::Query);
706
707        let mutation_ctx = GraphQLContext::new(
708            Some("op".to_string()),
709            OperationType::Mutation,
710            "mutation".to_string(),
711            Variables::default(),
712        );
713        assert_eq!(mutation_ctx.operation_type, OperationType::Mutation);
714
715        let subscription_ctx = GraphQLContext::new(
716            Some("op".to_string()),
717            OperationType::Subscription,
718            "subscription".to_string(),
719            Variables::default(),
720        );
721        assert_eq!(subscription_ctx.operation_type, OperationType::Subscription);
722    }
723
724    #[test]
725    fn test_variable_pattern_debug() {
726        let pattern = VariablePattern::Any;
727        let debug = format!("{:?}", pattern);
728        assert!(debug.contains("Any"));
729    }
730
731    #[test]
732    fn test_variable_matcher_debug() {
733        let matcher = VariableMatcher::new();
734        let debug = format!("{:?}", matcher);
735        assert!(debug.contains("VariableMatcher"));
736    }
737}