1use async_graphql::{Name, Request, Response, ServerError, Value, Variables};
7use async_trait::async_trait;
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::Arc;
11use thiserror::Error;
12
13pub type HandlerResult<T> = Result<T, HandlerError>;
15
16#[derive(Debug, Error)]
18pub enum HandlerError {
19 #[error("Send error: {0}")]
21 SendError(String),
22
23 #[error("JSON error: {0}")]
25 JsonError(#[from] serde_json::Error),
26
27 #[error("Operation error: {0}")]
29 OperationError(String),
30
31 #[error("Upstream error: {0}")]
33 UpstreamError(String),
34
35 #[error("{0}")]
37 Generic(String),
38}
39
40pub struct GraphQLContext {
42 pub operation_name: Option<String>,
44
45 pub operation_type: OperationType,
47
48 pub query: String,
50
51 pub variables: Variables,
53
54 pub metadata: HashMap<String, String>,
56
57 pub data: HashMap<String, serde_json::Value>,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub enum OperationType {
64 Query,
66 Mutation,
68 Subscription,
70}
71
72impl GraphQLContext {
73 pub fn new(
75 operation_name: Option<String>,
76 operation_type: OperationType,
77 query: String,
78 variables: Variables,
79 ) -> Self {
80 Self {
81 operation_name,
82 operation_type,
83 query,
84 variables,
85 metadata: HashMap::new(),
86 data: HashMap::new(),
87 }
88 }
89
90 pub fn get_variable(&self, name: &str) -> Option<&Value> {
92 self.variables.get(&Name::new(name))
93 }
94
95 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
97 self.data.insert(key, value);
98 }
99
100 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
102 self.data.get(key)
103 }
104
105 pub fn set_metadata(&mut self, key: String, value: String) {
107 self.metadata.insert(key, value);
108 }
109
110 pub fn get_metadata(&self, key: &str) -> Option<&String> {
112 self.metadata.get(key)
113 }
114}
115
116#[async_trait]
118pub trait GraphQLHandler: Send + Sync {
119 async fn on_operation(&self, _ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
122 Ok(None)
123 }
124
125 async fn after_operation(
128 &self,
129 _ctx: &GraphQLContext,
130 response: Response,
131 ) -> HandlerResult<Response> {
132 Ok(response)
133 }
134
135 async fn on_error(&self, _ctx: &GraphQLContext, error: String) -> HandlerResult<Response> {
137 let server_error = ServerError::new(error, None);
138 Ok(Response::from_errors(vec![server_error]))
139 }
140
141 fn handles_operation(
143 &self,
144 operation_name: Option<&str>,
145 _operation_type: &OperationType,
146 ) -> bool {
147 operation_name.is_some()
149 }
150
151 fn priority(&self) -> i32 {
153 0
154 }
155}
156
157pub struct HandlerRegistry {
159 handlers: Vec<Arc<dyn GraphQLHandler>>,
160 upstream_url: Option<String>,
162}
163
164impl HandlerRegistry {
165 pub fn new() -> Self {
167 Self {
168 handlers: Vec::new(),
169 upstream_url: None,
170 }
171 }
172
173 pub fn with_upstream(upstream_url: Option<String>) -> Self {
175 Self {
176 handlers: Vec::new(),
177 upstream_url,
178 }
179 }
180
181 pub fn register<H: GraphQLHandler + 'static>(&mut self, handler: H) {
183 self.handlers.push(Arc::new(handler));
184 self.handlers.sort_by(|a, b| b.priority().cmp(&a.priority()));
186 }
187
188 pub fn get_handlers(
190 &self,
191 operation_name: Option<&str>,
192 operation_type: &OperationType,
193 ) -> Vec<Arc<dyn GraphQLHandler>> {
194 self.handlers
195 .iter()
196 .filter(|h| h.handles_operation(operation_name, operation_type))
197 .cloned()
198 .collect()
199 }
200
201 pub async fn execute_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
203 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
204
205 for handler in handlers {
206 if let Some(response) = handler.on_operation(ctx).await? {
207 return Ok(Some(response));
208 }
209 }
210
211 Ok(None)
212 }
213
214 pub async fn after_operation(
216 &self,
217 ctx: &GraphQLContext,
218 mut response: Response,
219 ) -> HandlerResult<Response> {
220 let handlers = self.get_handlers(ctx.operation_name.as_deref(), &ctx.operation_type);
221
222 for handler in handlers {
223 response = handler.after_operation(ctx, response).await?;
224 }
225
226 Ok(response)
227 }
228
229 pub async fn passthrough(&self, request: &Request) -> HandlerResult<Response> {
231 let upstream = self
232 .upstream_url
233 .as_ref()
234 .ok_or_else(|| HandlerError::UpstreamError("No upstream URL configured".to_string()))?;
235
236 let client = reqwest::Client::new();
237 let body = json!({
238 "query": request.query.clone(),
239 "variables": request.variables.clone(),
240 "operationName": request.operation_name.clone(),
241 });
242
243 let resp = client
244 .post(upstream)
245 .json(&body)
246 .send()
247 .await
248 .map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
249
250 let response_data: serde_json::Value =
251 resp.json().await.map_err(|e| HandlerError::UpstreamError(e.to_string()))?;
252
253 let has_errors = response_data.get("errors").is_some();
256
257 if has_errors {
260 let error_msg = response_data
261 .get("errors")
262 .and_then(|e| e.as_array())
263 .and_then(|arr| arr.first())
264 .and_then(|e| e.get("message"))
265 .and_then(|m| m.as_str())
266 .unwrap_or("Upstream GraphQL error");
267
268 let server_error = async_graphql::ServerError::new(error_msg.to_string(), None);
269 Ok(Response::from_errors(vec![server_error]))
270 } else {
271 Ok(Response::new(Value::Null))
274 }
275 }
276
277 pub fn upstream_url(&self) -> Option<&str> {
279 self.upstream_url.as_deref()
280 }
281}
282
283impl Default for HandlerRegistry {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct VariableMatcher {
292 patterns: HashMap<String, VariablePattern>,
293}
294
295impl VariableMatcher {
296 pub fn new() -> Self {
298 Self {
299 patterns: HashMap::new(),
300 }
301 }
302
303 pub fn with_pattern(mut self, name: String, pattern: VariablePattern) -> Self {
305 self.patterns.insert(name, pattern);
306 self
307 }
308
309 pub fn matches(&self, variables: &Variables) -> bool {
311 for (name, pattern) in &self.patterns {
312 if !pattern.matches(variables.get(&Name::new(name))) {
313 return false;
314 }
315 }
316 true
317 }
318}
319
320impl Default for VariableMatcher {
321 fn default() -> Self {
322 Self::new()
323 }
324}
325
326#[derive(Debug, Clone)]
328pub enum VariablePattern {
329 Exact(Value),
331 Regex(String),
333 Any,
335 Present,
337 Null,
339}
340
341impl VariablePattern {
342 pub fn matches(&self, value: Option<&Value>) -> bool {
344 match (self, value) {
345 (VariablePattern::Any, _) => true,
346 (VariablePattern::Present, Some(_)) => true,
347 (VariablePattern::Present, None) => false,
348 (VariablePattern::Null, None) | (VariablePattern::Null, Some(Value::Null)) => true,
349 (VariablePattern::Null, Some(_)) => false,
350 (VariablePattern::Exact(expected), Some(actual)) => expected == actual,
351 (VariablePattern::Exact(_), None) => false,
352 (VariablePattern::Regex(pattern), Some(Value::String(s))) => regex::Regex::new(pattern)
353 .ok()
354 .and_then(|re| Some(re.is_match(s)))
355 .unwrap_or(false),
356 (VariablePattern::Regex(_), _) => false,
357 }
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 struct TestHandler {
366 operation_name: String,
367 }
368
369 #[async_trait]
370 impl GraphQLHandler for TestHandler {
371 async fn on_operation(&self, ctx: &GraphQLContext) -> HandlerResult<Option<Response>> {
372 if ctx.operation_name.as_deref() == Some(&self.operation_name) {
373 Ok(Some(Response::new(Value::Null)))
375 } else {
376 Ok(None)
377 }
378 }
379
380 fn handles_operation(&self, operation_name: Option<&str>, _: &OperationType) -> bool {
381 operation_name == Some(&self.operation_name)
382 }
383 }
384
385 #[tokio::test]
386 async fn test_handler_registry_new() {
387 let registry = HandlerRegistry::new();
388 assert_eq!(registry.handlers.len(), 0);
389 assert!(registry.upstream_url.is_none());
390 }
391
392 #[tokio::test]
393 async fn test_handler_registry_with_upstream() {
394 let registry =
395 HandlerRegistry::with_upstream(Some("http://example.com/graphql".to_string()));
396 assert_eq!(registry.upstream_url(), Some("http://example.com/graphql"));
397 }
398
399 #[tokio::test]
400 async fn test_handler_registry_register() {
401 let mut registry = HandlerRegistry::new();
402 let handler = TestHandler {
403 operation_name: "getUser".to_string(),
404 };
405 registry.register(handler);
406 assert_eq!(registry.handlers.len(), 1);
407 }
408
409 #[tokio::test]
410 async fn test_handler_execution() {
411 let mut registry = HandlerRegistry::new();
412 registry.register(TestHandler {
413 operation_name: "getUser".to_string(),
414 });
415
416 let ctx = GraphQLContext::new(
417 Some("getUser".to_string()),
418 OperationType::Query,
419 "query { user { id } }".to_string(),
420 Variables::default(),
421 );
422
423 let result = registry.execute_operation(&ctx).await;
424 assert!(result.is_ok());
425 assert!(result.unwrap().is_some());
426 }
427
428 #[test]
429 fn test_variable_matcher_any() {
430 let matcher = VariableMatcher::new().with_pattern("id".to_string(), VariablePattern::Any);
431
432 let mut vars = Variables::default();
433 vars.insert(Name::new("id"), Value::String("123".to_string()));
434
435 assert!(matcher.matches(&vars));
436 }
437
438 #[test]
439 fn test_variable_matcher_exact() {
440 let matcher = VariableMatcher::new().with_pattern(
441 "id".to_string(),
442 VariablePattern::Exact(Value::String("123".to_string())),
443 );
444
445 let mut vars = Variables::default();
446 vars.insert(Name::new("id"), Value::String("123".to_string()));
447
448 assert!(matcher.matches(&vars));
449
450 let mut vars2 = Variables::default();
451 vars2.insert(Name::new("id"), Value::String("456".to_string()));
452
453 assert!(!matcher.matches(&vars2));
454 }
455
456 #[test]
457 fn test_variable_pattern_present() {
458 assert!(VariablePattern::Present.matches(Some(&Value::String("test".to_string()))));
459 assert!(!VariablePattern::Present.matches(None));
460 }
461
462 #[test]
463 fn test_variable_pattern_null() {
464 assert!(VariablePattern::Null.matches(None));
465 assert!(VariablePattern::Null.matches(Some(&Value::Null)));
466 assert!(!VariablePattern::Null.matches(Some(&Value::String("test".to_string()))));
467 }
468
469 #[test]
470 fn test_graphql_context_new() {
471 let ctx = GraphQLContext::new(
472 Some("getUser".to_string()),
473 OperationType::Query,
474 "query { user { id } }".to_string(),
475 Variables::default(),
476 );
477
478 assert_eq!(ctx.operation_name, Some("getUser".to_string()));
479 assert_eq!(ctx.operation_type, OperationType::Query);
480 }
481
482 #[test]
483 fn test_graphql_context_metadata() {
484 let mut ctx = GraphQLContext::new(
485 Some("getUser".to_string()),
486 OperationType::Query,
487 "query { user { id } }".to_string(),
488 Variables::default(),
489 );
490
491 ctx.set_metadata("Authorization".to_string(), "Bearer token".to_string());
492 assert_eq!(ctx.get_metadata("Authorization"), Some(&"Bearer token".to_string()));
493 }
494
495 #[test]
496 fn test_graphql_context_data() {
497 let mut ctx = GraphQLContext::new(
498 Some("getUser".to_string()),
499 OperationType::Query,
500 "query { user { id } }".to_string(),
501 Variables::default(),
502 );
503
504 ctx.set_data("custom_key".to_string(), json!({"test": "value"}));
505 assert_eq!(ctx.get_data("custom_key"), Some(&json!({"test": "value"})));
506 }
507}