1use std::collections::HashMap;
6use std::sync::Arc;
7
8use super::{GraphQLSchema, ExecutionContext};
9
10#[derive(Debug, Clone)]
12pub struct ResolverContext {
13 pub execution: ExecutionContext,
15 pub parent: Option<serde_json::Value>,
17 pub arguments: HashMap<String, serde_json::Value>,
19 pub path: Vec<String>,
21 pub schema: Arc<GraphQLSchema>,
23}
24
25impl ResolverContext {
26 pub fn new(schema: Arc<GraphQLSchema>, execution: ExecutionContext) -> Self {
28 Self {
29 execution,
30 parent: None,
31 schema,
32 arguments: HashMap::new(),
33 path: Vec::new(),
34 }
35 }
36
37 pub fn with_parent(mut self, parent: serde_json::Value) -> Self {
39 self.parent = Some(parent);
40 self
41 }
42
43 pub fn with_arguments(mut self, arguments: HashMap<String, serde_json::Value>) -> Self {
45 self.arguments = arguments;
46 self
47 }
48
49 pub fn push_path(&mut self, segment: impl Into<String>) {
51 self.path.push(segment.into());
52 }
53
54 pub fn arg<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
56 self.arguments
57 .get(name)
58 .and_then(|v| serde_json::from_value(v.clone()).ok())
59 }
60
61 pub fn required_arg<T: serde::de::DeserializeOwned>(&self, name: &str) -> Result<T, ResolverError> {
63 self.arg(name)
64 .ok_or_else(|| ResolverError::MissingArgument(name.to_string()))
65 }
66
67 pub fn parent_field(&self, name: &str) -> Option<&serde_json::Value> {
69 self.parent.as_ref()?.get(name)
70 }
71
72 pub fn has_role(&self, role: &str) -> bool {
74 self.execution.has_role(role)
75 }
76
77 pub fn user_id(&self) -> Option<&str> {
79 self.execution.user_id.as_deref()
80 }
81}
82
83#[derive(Debug, Clone)]
85pub enum ResolverResult {
86 Value(serde_json::Value),
88 Null,
90 Error(ResolverError),
92 Deferred(String),
94}
95
96impl ResolverResult {
97 pub fn value(val: impl Into<serde_json::Value>) -> Self {
99 ResolverResult::Value(val.into())
100 }
101
102 pub fn null() -> Self {
104 ResolverResult::Null
105 }
106
107 pub fn error(err: impl Into<ResolverError>) -> Self {
109 ResolverResult::Error(err.into())
110 }
111
112 pub fn is_error(&self) -> bool {
114 matches!(self, ResolverResult::Error(_))
115 }
116
117 pub fn to_json(self) -> serde_json::Value {
119 match self {
120 ResolverResult::Value(v) => v,
121 ResolverResult::Null | ResolverResult::Deferred(_) => serde_json::Value::Null,
122 ResolverResult::Error(_) => serde_json::Value::Null,
123 }
124 }
125}
126
127impl From<serde_json::Value> for ResolverResult {
128 fn from(val: serde_json::Value) -> Self {
129 ResolverResult::Value(val)
130 }
131}
132
133impl<T: Into<serde_json::Value>> From<Option<T>> for ResolverResult {
134 fn from(opt: Option<T>) -> Self {
135 match opt {
136 Some(v) => ResolverResult::Value(v.into()),
137 None => ResolverResult::Null,
138 }
139 }
140}
141
142#[derive(Debug, Clone)]
144pub enum ResolverError {
145 MissingArgument(String),
147 InvalidArgument(String, String),
149 FieldNotFound(String),
151 Unauthorized(String),
153 DatabaseError(String),
155 Internal(String),
157}
158
159impl std::fmt::Display for ResolverError {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 ResolverError::MissingArgument(name) => write!(f, "Missing required argument: {}", name),
163 ResolverError::InvalidArgument(name, msg) => {
164 write!(f, "Invalid argument '{}': {}", name, msg)
165 }
166 ResolverError::FieldNotFound(name) => write!(f, "Field not found: {}", name),
167 ResolverError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
168 ResolverError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
169 ResolverError::Internal(msg) => write!(f, "Internal error: {}", msg),
170 }
171 }
172}
173
174impl std::error::Error for ResolverError {}
175
176impl From<String> for ResolverError {
177 fn from(s: String) -> Self {
178 ResolverError::Internal(s)
179 }
180}
181
182impl From<&str> for ResolverError {
183 fn from(s: &str) -> Self {
184 ResolverError::Internal(s.to_string())
185 }
186}
187
188pub trait FieldResolver: Send + Sync {
190 fn resolve(&self, ctx: &ResolverContext) -> ResolverResult;
192
193 fn field_name(&self) -> &str;
195
196 fn type_name(&self) -> &str;
198}
199
200#[derive(Debug)]
202pub struct DefaultResolver {
203 type_name: String,
205 field_name: String,
207 column_name: String,
209}
210
211impl DefaultResolver {
212 pub fn new(
214 type_name: impl Into<String>,
215 field_name: impl Into<String>,
216 column_name: impl Into<String>,
217 ) -> Self {
218 Self {
219 type_name: type_name.into(),
220 field_name: field_name.into(),
221 column_name: column_name.into(),
222 }
223 }
224}
225
226impl FieldResolver for DefaultResolver {
227 fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
228 match &ctx.parent {
229 Some(parent) => {
230 parent.get(&self.column_name).cloned().into()
231 }
232 None => ResolverResult::Null,
233 }
234 }
235
236 fn field_name(&self) -> &str {
237 &self.field_name
238 }
239
240 fn type_name(&self) -> &str {
241 &self.type_name
242 }
243}
244
245pub struct ComputedResolver<F>
247where
248 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
249{
250 type_name: String,
251 field_name: String,
252 resolver_fn: F,
253}
254
255impl<F> ComputedResolver<F>
256where
257 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
258{
259 pub fn new(
261 type_name: impl Into<String>,
262 field_name: impl Into<String>,
263 resolver_fn: F,
264 ) -> Self {
265 Self {
266 type_name: type_name.into(),
267 field_name: field_name.into(),
268 resolver_fn,
269 }
270 }
271}
272
273impl<F> FieldResolver for ComputedResolver<F>
274where
275 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
276{
277 fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
278 (self.resolver_fn)(ctx)
279 }
280
281 fn field_name(&self) -> &str {
282 &self.field_name
283 }
284
285 fn type_name(&self) -> &str {
286 &self.type_name
287 }
288}
289
290impl<F> std::fmt::Debug for ComputedResolver<F>
291where
292 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
293{
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 f.debug_struct("ComputedResolver")
296 .field("type_name", &self.type_name)
297 .field("field_name", &self.field_name)
298 .finish()
299 }
300}
301
302#[derive(Default)]
304pub struct ResolverRegistry {
305 resolvers: HashMap<String, Arc<dyn FieldResolver>>,
307}
308
309impl std::fmt::Debug for ResolverRegistry {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 f.debug_struct("ResolverRegistry")
312 .field("resolvers_count", &self.resolvers.len())
313 .finish()
314 }
315}
316
317impl ResolverRegistry {
318 pub fn new() -> Self {
320 Self::default()
321 }
322
323 pub fn register(&mut self, resolver: impl FieldResolver + 'static) {
325 let key = format!("{}.{}", resolver.type_name(), resolver.field_name());
326 self.resolvers.insert(key, Arc::new(resolver));
327 }
328
329 pub fn get(&self, type_name: &str, field_name: &str) -> Option<Arc<dyn FieldResolver>> {
331 let key = format!("{}.{}", type_name, field_name);
332 self.resolvers.get(&key).cloned()
333 }
334
335 pub fn has(&self, type_name: &str, field_name: &str) -> bool {
337 let key = format!("{}.{}", type_name, field_name);
338 self.resolvers.contains_key(&key)
339 }
340
341 pub fn resolvers_for(&self, type_name: &str) -> Vec<Arc<dyn FieldResolver>> {
343 let prefix = format!("{}.", type_name);
344 self.resolvers
345 .iter()
346 .filter(|(k, _)| k.starts_with(&prefix))
347 .map(|(_, v)| v.clone())
348 .collect()
349 }
350}
351
352pub struct ResolverChain {
354 resolvers: Vec<Arc<dyn FieldResolver>>,
356}
357
358impl std::fmt::Debug for ResolverChain {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.debug_struct("ResolverChain")
361 .field("resolvers_count", &self.resolvers.len())
362 .finish()
363 }
364}
365
366impl ResolverChain {
367 pub fn new() -> Self {
369 Self {
370 resolvers: Vec::new(),
371 }
372 }
373
374 pub fn add(mut self, resolver: impl FieldResolver + 'static) -> Self {
376 self.resolvers.push(Arc::new(resolver));
377 self
378 }
379
380 pub fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
382 for resolver in &self.resolvers {
383 let result = resolver.resolve(ctx);
384 if !matches!(result, ResolverResult::Null) {
385 return result;
386 }
387 }
388 ResolverResult::Null
389 }
390}
391
392impl Default for ResolverChain {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::graphql::introspector::GraphQLSchema;
402
403 fn create_test_context() -> ResolverContext {
404 let schema = Arc::new(GraphQLSchema::new());
405 let execution = ExecutionContext::default();
406 ResolverContext::new(schema, execution)
407 }
408
409 #[test]
410 fn test_resolver_context_args() {
411 let mut args = HashMap::new();
412 args.insert("limit".to_string(), serde_json::json!(10));
413 args.insert("name".to_string(), serde_json::json!("test"));
414
415 let ctx = create_test_context().with_arguments(args);
416
417 assert_eq!(ctx.arg::<i32>("limit"), Some(10));
418 assert_eq!(ctx.arg::<String>("name"), Some("test".to_string()));
419 assert_eq!(ctx.arg::<i32>("missing"), None);
420 }
421
422 #[test]
423 fn test_resolver_context_required_arg() {
424 let mut args = HashMap::new();
425 args.insert("id".to_string(), serde_json::json!("123"));
426
427 let ctx = create_test_context().with_arguments(args);
428
429 assert!(ctx.required_arg::<String>("id").is_ok());
430 assert!(ctx.required_arg::<String>("missing").is_err());
431 }
432
433 #[test]
434 fn test_resolver_context_parent() {
435 let parent = serde_json::json!({
436 "id": "123",
437 "name": "Test"
438 });
439
440 let ctx = create_test_context().with_parent(parent);
441
442 assert_eq!(ctx.parent_field("id"), Some(&serde_json::json!("123")));
443 assert_eq!(ctx.parent_field("missing"), None);
444 }
445
446 #[test]
447 fn test_default_resolver() {
448 let resolver = DefaultResolver::new("User", "name", "name");
449
450 let parent = serde_json::json!({
451 "id": "123",
452 "name": "John"
453 });
454
455 let ctx = create_test_context().with_parent(parent);
456 let result = resolver.resolve(&ctx);
457
458 match result {
459 ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
460 _ => panic!("Expected value"),
461 }
462 }
463
464 #[test]
465 fn test_computed_resolver() {
466 let resolver = ComputedResolver::new("User", "fullName", |ctx| {
467 let first = ctx.parent_field("firstName")
468 .and_then(|v| v.as_str())
469 .unwrap_or("");
470 let last = ctx.parent_field("lastName")
471 .and_then(|v| v.as_str())
472 .unwrap_or("");
473 ResolverResult::value(format!("{} {}", first, last))
474 });
475
476 let parent = serde_json::json!({
477 "firstName": "John",
478 "lastName": "Doe"
479 });
480
481 let ctx = create_test_context().with_parent(parent);
482 let result = resolver.resolve(&ctx);
483
484 match result {
485 ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John Doe")),
486 _ => panic!("Expected value"),
487 }
488 }
489
490 #[test]
491 fn test_resolver_registry() {
492 let mut registry = ResolverRegistry::new();
493
494 registry.register(DefaultResolver::new("User", "id", "id"));
495 registry.register(DefaultResolver::new("User", "name", "name"));
496 registry.register(DefaultResolver::new("Post", "title", "title"));
497
498 assert!(registry.has("User", "id"));
499 assert!(registry.has("User", "name"));
500 assert!(registry.has("Post", "title"));
501 assert!(!registry.has("User", "email"));
502
503 let user_resolvers = registry.resolvers_for("User");
504 assert_eq!(user_resolvers.len(), 2);
505 }
506
507 #[test]
508 fn test_resolver_result_conversions() {
509 let value_result: ResolverResult = serde_json::json!("test").into();
510 assert!(!value_result.is_error());
511
512 let some_result: ResolverResult = Some("test").into();
513 assert!(matches!(some_result, ResolverResult::Value(_)));
514
515 let none_result: ResolverResult = Option::<String>::None.into();
516 assert!(matches!(none_result, ResolverResult::Null));
517 }
518
519 #[test]
520 fn test_resolver_chain() {
521 let chain = ResolverChain::new()
522 .add(DefaultResolver::new("User", "displayName", "display_name"))
523 .add(DefaultResolver::new("User", "displayName", "name"))
524 .add(DefaultResolver::new("User", "displayName", "email"));
525
526 let parent = serde_json::json!({
528 "name": "John"
529 });
530
531 let ctx = create_test_context().with_parent(parent);
532 let result = chain.resolve(&ctx);
533
534 match result {
535 ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
536 _ => panic!("Expected value from second resolver"),
537 }
538 }
539
540 #[test]
541 fn test_resolver_error_display() {
542 let err = ResolverError::MissingArgument("id".to_string());
543 assert!(err.to_string().contains("id"));
544
545 let err = ResolverError::Unauthorized("Not authenticated".to_string());
546 assert!(err.to_string().contains("Not authenticated"));
547 }
548
549 #[test]
550 fn test_resolver_context_roles() {
551 let schema = Arc::new(GraphQLSchema::new());
552 let execution = ExecutionContext::default()
553 .with_user("user1")
554 .with_role("admin");
555
556 let ctx = ResolverContext::new(schema, execution);
557
558 assert!(ctx.has_role("admin"));
559 assert!(!ctx.has_role("superuser"));
560 assert_eq!(ctx.user_id(), Some("user1"));
561 }
562
563 #[test]
564 fn test_resolver_result_to_json() {
565 assert_eq!(
566 ResolverResult::value("test").to_json(),
567 serde_json::json!("test")
568 );
569 assert_eq!(
570 ResolverResult::null().to_json(),
571 serde_json::Value::Null
572 );
573 assert_eq!(
574 ResolverResult::error("err").to_json(),
575 serde_json::Value::Null
576 );
577 }
578}