1use std::collections::HashMap;
6use std::sync::Arc;
7
8use super::{ExecutionContext, GraphQLSchema};
9
10#[derive(Debug, Clone)]
12pub struct ResolverContext {
13 pub execution: ExecutionContext,
15 pub parent: Option<serde_json::Value>,
17 pub arguments: HashMap<String, serde_json::Value>,
19 pub path: Vec<String>,
21 pub schema: Arc<GraphQLSchema>,
23}
24
25impl ResolverContext {
26 pub fn new(schema: Arc<GraphQLSchema>, execution: ExecutionContext) -> Self {
28 Self {
29 execution,
30 parent: None,
31 schema,
32 arguments: HashMap::new(),
33 path: Vec::new(),
34 }
35 }
36
37 pub fn with_parent(mut self, parent: serde_json::Value) -> Self {
39 self.parent = Some(parent);
40 self
41 }
42
43 pub fn with_arguments(mut self, arguments: HashMap<String, serde_json::Value>) -> Self {
45 self.arguments = arguments;
46 self
47 }
48
49 pub fn push_path(&mut self, segment: impl Into<String>) {
51 self.path.push(segment.into());
52 }
53
54 pub fn arg<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
56 self.arguments
57 .get(name)
58 .and_then(|v| serde_json::from_value(v.clone()).ok())
59 }
60
61 pub fn required_arg<T: serde::de::DeserializeOwned>(
63 &self,
64 name: &str,
65 ) -> Result<T, ResolverError> {
66 self.arg(name)
67 .ok_or_else(|| ResolverError::MissingArgument(name.to_string()))
68 }
69
70 pub fn parent_field(&self, name: &str) -> Option<&serde_json::Value> {
72 self.parent.as_ref()?.get(name)
73 }
74
75 pub fn has_role(&self, role: &str) -> bool {
77 self.execution.has_role(role)
78 }
79
80 pub fn user_id(&self) -> Option<&str> {
82 self.execution.user_id.as_deref()
83 }
84}
85
86#[derive(Debug, Clone)]
88pub enum ResolverResult {
89 Value(serde_json::Value),
91 Null,
93 Error(ResolverError),
95 Deferred(String),
97}
98
99impl ResolverResult {
100 pub fn value(val: impl Into<serde_json::Value>) -> Self {
102 ResolverResult::Value(val.into())
103 }
104
105 pub fn null() -> Self {
107 ResolverResult::Null
108 }
109
110 pub fn error(err: impl Into<ResolverError>) -> Self {
112 ResolverResult::Error(err.into())
113 }
114
115 pub fn is_error(&self) -> bool {
117 matches!(self, ResolverResult::Error(_))
118 }
119
120 pub fn to_json(self) -> serde_json::Value {
122 match self {
123 ResolverResult::Value(v) => v,
124 ResolverResult::Null | ResolverResult::Deferred(_) => serde_json::Value::Null,
125 ResolverResult::Error(_) => serde_json::Value::Null,
126 }
127 }
128}
129
130impl From<serde_json::Value> for ResolverResult {
131 fn from(val: serde_json::Value) -> Self {
132 ResolverResult::Value(val)
133 }
134}
135
136impl<T: Into<serde_json::Value>> From<Option<T>> for ResolverResult {
137 fn from(opt: Option<T>) -> Self {
138 match opt {
139 Some(v) => ResolverResult::Value(v.into()),
140 None => ResolverResult::Null,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub enum ResolverError {
148 MissingArgument(String),
150 InvalidArgument(String, String),
152 FieldNotFound(String),
154 Unauthorized(String),
156 DatabaseError(String),
158 Internal(String),
160}
161
162impl std::fmt::Display for ResolverError {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 match self {
165 ResolverError::MissingArgument(name) => {
166 write!(f, "Missing required argument: {}", name)
167 }
168 ResolverError::InvalidArgument(name, msg) => {
169 write!(f, "Invalid argument '{}': {}", name, msg)
170 }
171 ResolverError::FieldNotFound(name) => write!(f, "Field not found: {}", name),
172 ResolverError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
173 ResolverError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
174 ResolverError::Internal(msg) => write!(f, "Internal error: {}", msg),
175 }
176 }
177}
178
179impl std::error::Error for ResolverError {}
180
181impl From<String> for ResolverError {
182 fn from(s: String) -> Self {
183 ResolverError::Internal(s)
184 }
185}
186
187impl From<&str> for ResolverError {
188 fn from(s: &str) -> Self {
189 ResolverError::Internal(s.to_string())
190 }
191}
192
193pub trait FieldResolver: Send + Sync {
195 fn resolve(&self, ctx: &ResolverContext) -> ResolverResult;
197
198 fn field_name(&self) -> &str;
200
201 fn type_name(&self) -> &str;
203}
204
205#[derive(Debug)]
207pub struct DefaultResolver {
208 type_name: String,
210 field_name: String,
212 column_name: String,
214}
215
216impl DefaultResolver {
217 pub fn new(
219 type_name: impl Into<String>,
220 field_name: impl Into<String>,
221 column_name: impl Into<String>,
222 ) -> Self {
223 Self {
224 type_name: type_name.into(),
225 field_name: field_name.into(),
226 column_name: column_name.into(),
227 }
228 }
229}
230
231impl FieldResolver for DefaultResolver {
232 fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
233 match &ctx.parent {
234 Some(parent) => parent.get(&self.column_name).cloned().into(),
235 None => ResolverResult::Null,
236 }
237 }
238
239 fn field_name(&self) -> &str {
240 &self.field_name
241 }
242
243 fn type_name(&self) -> &str {
244 &self.type_name
245 }
246}
247
248pub struct ComputedResolver<F>
250where
251 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
252{
253 type_name: String,
254 field_name: String,
255 resolver_fn: F,
256}
257
258impl<F> ComputedResolver<F>
259where
260 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
261{
262 pub fn new(
264 type_name: impl Into<String>,
265 field_name: impl Into<String>,
266 resolver_fn: F,
267 ) -> Self {
268 Self {
269 type_name: type_name.into(),
270 field_name: field_name.into(),
271 resolver_fn,
272 }
273 }
274}
275
276impl<F> FieldResolver for ComputedResolver<F>
277where
278 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
279{
280 fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
281 (self.resolver_fn)(ctx)
282 }
283
284 fn field_name(&self) -> &str {
285 &self.field_name
286 }
287
288 fn type_name(&self) -> &str {
289 &self.type_name
290 }
291}
292
293impl<F> std::fmt::Debug for ComputedResolver<F>
294where
295 F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
296{
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("ComputedResolver")
299 .field("type_name", &self.type_name)
300 .field("field_name", &self.field_name)
301 .finish()
302 }
303}
304
305#[derive(Default)]
307pub struct ResolverRegistry {
308 resolvers: HashMap<String, Arc<dyn FieldResolver>>,
310}
311
312impl std::fmt::Debug for ResolverRegistry {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.debug_struct("ResolverRegistry")
315 .field("resolvers_count", &self.resolvers.len())
316 .finish()
317 }
318}
319
320impl ResolverRegistry {
321 pub fn new() -> Self {
323 Self::default()
324 }
325
326 pub fn register(&mut self, resolver: impl FieldResolver + 'static) {
328 let key = format!("{}.{}", resolver.type_name(), resolver.field_name());
329 self.resolvers.insert(key, Arc::new(resolver));
330 }
331
332 pub fn get(&self, type_name: &str, field_name: &str) -> Option<Arc<dyn FieldResolver>> {
334 let key = format!("{}.{}", type_name, field_name);
335 self.resolvers.get(&key).cloned()
336 }
337
338 pub fn has(&self, type_name: &str, field_name: &str) -> bool {
340 let key = format!("{}.{}", type_name, field_name);
341 self.resolvers.contains_key(&key)
342 }
343
344 pub fn resolvers_for(&self, type_name: &str) -> Vec<Arc<dyn FieldResolver>> {
346 let prefix = format!("{}.", type_name);
347 self.resolvers
348 .iter()
349 .filter(|(k, _)| k.starts_with(&prefix))
350 .map(|(_, v)| v.clone())
351 .collect()
352 }
353}
354
355pub struct ResolverChain {
357 resolvers: Vec<Arc<dyn FieldResolver>>,
359}
360
361impl std::fmt::Debug for ResolverChain {
362 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363 f.debug_struct("ResolverChain")
364 .field("resolvers_count", &self.resolvers.len())
365 .finish()
366 }
367}
368
369impl ResolverChain {
370 pub fn new() -> Self {
372 Self {
373 resolvers: Vec::new(),
374 }
375 }
376
377 #[allow(clippy::should_implement_trait)]
379 pub fn add(mut self, resolver: impl FieldResolver + 'static) -> Self {
380 self.resolvers.push(Arc::new(resolver));
381 self
382 }
383
384 pub fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
386 for resolver in &self.resolvers {
387 let result = resolver.resolve(ctx);
388 if !matches!(result, ResolverResult::Null) {
389 return result;
390 }
391 }
392 ResolverResult::Null
393 }
394}
395
396impl Default for ResolverChain {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::graphql::introspector::GraphQLSchema;
406
407 fn create_test_context() -> ResolverContext {
408 let schema = Arc::new(GraphQLSchema::new());
409 let execution = ExecutionContext::default();
410 ResolverContext::new(schema, execution)
411 }
412
413 #[test]
414 fn test_resolver_context_args() {
415 let mut args = HashMap::new();
416 args.insert("limit".to_string(), serde_json::json!(10));
417 args.insert("name".to_string(), serde_json::json!("test"));
418
419 let ctx = create_test_context().with_arguments(args);
420
421 assert_eq!(ctx.arg::<i32>("limit"), Some(10));
422 assert_eq!(ctx.arg::<String>("name"), Some("test".to_string()));
423 assert_eq!(ctx.arg::<i32>("missing"), None);
424 }
425
426 #[test]
427 fn test_resolver_context_required_arg() {
428 let mut args = HashMap::new();
429 args.insert("id".to_string(), serde_json::json!("123"));
430
431 let ctx = create_test_context().with_arguments(args);
432
433 assert!(ctx.required_arg::<String>("id").is_ok());
434 assert!(ctx.required_arg::<String>("missing").is_err());
435 }
436
437 #[test]
438 fn test_resolver_context_parent() {
439 let parent = serde_json::json!({
440 "id": "123",
441 "name": "Test"
442 });
443
444 let ctx = create_test_context().with_parent(parent);
445
446 assert_eq!(ctx.parent_field("id"), Some(&serde_json::json!("123")));
447 assert_eq!(ctx.parent_field("missing"), None);
448 }
449
450 #[test]
451 fn test_default_resolver() {
452 let resolver = DefaultResolver::new("User", "name", "name");
453
454 let parent = serde_json::json!({
455 "id": "123",
456 "name": "John"
457 });
458
459 let ctx = create_test_context().with_parent(parent);
460 let result = resolver.resolve(&ctx);
461
462 match result {
463 ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
464 _ => panic!("Expected value"),
465 }
466 }
467
468 #[test]
469 fn test_computed_resolver() {
470 let resolver = ComputedResolver::new("User", "fullName", |ctx| {
471 let first = ctx
472 .parent_field("firstName")
473 .and_then(|v| v.as_str())
474 .unwrap_or("");
475 let last = ctx
476 .parent_field("lastName")
477 .and_then(|v| v.as_str())
478 .unwrap_or("");
479 ResolverResult::value(format!("{} {}", first, last))
480 });
481
482 let parent = serde_json::json!({
483 "firstName": "John",
484 "lastName": "Doe"
485 });
486
487 let ctx = create_test_context().with_parent(parent);
488 let result = resolver.resolve(&ctx);
489
490 match result {
491 ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John Doe")),
492 _ => panic!("Expected value"),
493 }
494 }
495
496 #[test]
497 fn test_resolver_registry() {
498 let mut registry = ResolverRegistry::new();
499
500 registry.register(DefaultResolver::new("User", "id", "id"));
501 registry.register(DefaultResolver::new("User", "name", "name"));
502 registry.register(DefaultResolver::new("Post", "title", "title"));
503
504 assert!(registry.has("User", "id"));
505 assert!(registry.has("User", "name"));
506 assert!(registry.has("Post", "title"));
507 assert!(!registry.has("User", "email"));
508
509 let user_resolvers = registry.resolvers_for("User");
510 assert_eq!(user_resolvers.len(), 2);
511 }
512
513 #[test]
514 fn test_resolver_result_conversions() {
515 let value_result: ResolverResult = serde_json::json!("test").into();
516 assert!(!value_result.is_error());
517
518 let some_result: ResolverResult = Some("test").into();
519 assert!(matches!(some_result, ResolverResult::Value(_)));
520
521 let none_result: ResolverResult = Option::<String>::None.into();
522 assert!(matches!(none_result, ResolverResult::Null));
523 }
524
525 #[test]
526 fn test_resolver_chain() {
527 let chain = ResolverChain::new()
528 .add(DefaultResolver::new("User", "displayName", "display_name"))
529 .add(DefaultResolver::new("User", "displayName", "name"))
530 .add(DefaultResolver::new("User", "displayName", "email"));
531
532 let parent = serde_json::json!({
534 "name": "John"
535 });
536
537 let ctx = create_test_context().with_parent(parent);
538 let result = chain.resolve(&ctx);
539
540 match result {
541 ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
542 _ => panic!("Expected value from second resolver"),
543 }
544 }
545
546 #[test]
547 fn test_resolver_error_display() {
548 let err = ResolverError::MissingArgument("id".to_string());
549 assert!(err.to_string().contains("id"));
550
551 let err = ResolverError::Unauthorized("Not authenticated".to_string());
552 assert!(err.to_string().contains("Not authenticated"));
553 }
554
555 #[test]
556 fn test_resolver_context_roles() {
557 let schema = Arc::new(GraphQLSchema::new());
558 let execution = ExecutionContext::default()
559 .with_user("user1")
560 .with_role("admin");
561
562 let ctx = ResolverContext::new(schema, execution);
563
564 assert!(ctx.has_role("admin"));
565 assert!(!ctx.has_role("superuser"));
566 assert_eq!(ctx.user_id(), Some("user1"));
567 }
568
569 #[test]
570 fn test_resolver_result_to_json() {
571 assert_eq!(
572 ResolverResult::value("test").to_json(),
573 serde_json::json!("test")
574 );
575 assert_eq!(ResolverResult::null().to_json(), serde_json::Value::Null);
576 assert_eq!(
577 ResolverResult::error("err").to_json(),
578 serde_json::Value::Null
579 );
580 }
581}