Skip to main content

heliosdb_proxy/graphql/
resolver.rs

1//! Field Resolver
2//!
3//! Resolves GraphQL fields to database values.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use super::{GraphQLSchema, ExecutionContext};
9
10/// Resolver context passed to field resolvers
11#[derive(Debug, Clone)]
12pub struct ResolverContext {
13    /// Execution context
14    pub execution: ExecutionContext,
15    /// Parent value (for nested resolvers)
16    pub parent: Option<serde_json::Value>,
17    /// Field arguments
18    pub arguments: HashMap<String, serde_json::Value>,
19    /// Field path
20    pub path: Vec<String>,
21    /// Schema reference
22    pub schema: Arc<GraphQLSchema>,
23}
24
25impl ResolverContext {
26    /// Create a new resolver context
27    pub fn new(schema: Arc<GraphQLSchema>, execution: ExecutionContext) -> Self {
28        Self {
29            execution,
30            parent: None,
31            schema,
32            arguments: HashMap::new(),
33            path: Vec::new(),
34        }
35    }
36
37    /// Set the parent value
38    pub fn with_parent(mut self, parent: serde_json::Value) -> Self {
39        self.parent = Some(parent);
40        self
41    }
42
43    /// Set arguments
44    pub fn with_arguments(mut self, arguments: HashMap<String, serde_json::Value>) -> Self {
45        self.arguments = arguments;
46        self
47    }
48
49    /// Add to path
50    pub fn push_path(&mut self, segment: impl Into<String>) {
51        self.path.push(segment.into());
52    }
53
54    /// Get an argument value
55    pub fn arg<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
56        self.arguments
57            .get(name)
58            .and_then(|v| serde_json::from_value(v.clone()).ok())
59    }
60
61    /// Get a required argument
62    pub fn required_arg<T: serde::de::DeserializeOwned>(&self, name: &str) -> Result<T, ResolverError> {
63        self.arg(name)
64            .ok_or_else(|| ResolverError::MissingArgument(name.to_string()))
65    }
66
67    /// Get a field from the parent
68    pub fn parent_field(&self, name: &str) -> Option<&serde_json::Value> {
69        self.parent.as_ref()?.get(name)
70    }
71
72    /// Check if user has a role
73    pub fn has_role(&self, role: &str) -> bool {
74        self.execution.has_role(role)
75    }
76
77    /// Get the current user ID
78    pub fn user_id(&self) -> Option<&str> {
79        self.execution.user_id.as_deref()
80    }
81}
82
83/// Resolver result
84#[derive(Debug, Clone)]
85pub enum ResolverResult {
86    /// Resolved value
87    Value(serde_json::Value),
88    /// Null value
89    Null,
90    /// Error occurred
91    Error(ResolverError),
92    /// Deferred to DataLoader
93    Deferred(String),
94}
95
96impl ResolverResult {
97    /// Create a value result
98    pub fn value(val: impl Into<serde_json::Value>) -> Self {
99        ResolverResult::Value(val.into())
100    }
101
102    /// Create a null result
103    pub fn null() -> Self {
104        ResolverResult::Null
105    }
106
107    /// Create an error result
108    pub fn error(err: impl Into<ResolverError>) -> Self {
109        ResolverResult::Error(err.into())
110    }
111
112    /// Check if result is an error
113    pub fn is_error(&self) -> bool {
114        matches!(self, ResolverResult::Error(_))
115    }
116
117    /// Convert to JSON value
118    pub fn to_json(self) -> serde_json::Value {
119        match self {
120            ResolverResult::Value(v) => v,
121            ResolverResult::Null | ResolverResult::Deferred(_) => serde_json::Value::Null,
122            ResolverResult::Error(_) => serde_json::Value::Null,
123        }
124    }
125}
126
127impl From<serde_json::Value> for ResolverResult {
128    fn from(val: serde_json::Value) -> Self {
129        ResolverResult::Value(val)
130    }
131}
132
133impl<T: Into<serde_json::Value>> From<Option<T>> for ResolverResult {
134    fn from(opt: Option<T>) -> Self {
135        match opt {
136            Some(v) => ResolverResult::Value(v.into()),
137            None => ResolverResult::Null,
138        }
139    }
140}
141
142/// Resolver error
143#[derive(Debug, Clone)]
144pub enum ResolverError {
145    /// Missing required argument
146    MissingArgument(String),
147    /// Invalid argument value
148    InvalidArgument(String, String),
149    /// Field not found
150    FieldNotFound(String),
151    /// Authorization failed
152    Unauthorized(String),
153    /// Database error
154    DatabaseError(String),
155    /// Internal error
156    Internal(String),
157}
158
159impl std::fmt::Display for ResolverError {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            ResolverError::MissingArgument(name) => write!(f, "Missing required argument: {}", name),
163            ResolverError::InvalidArgument(name, msg) => {
164                write!(f, "Invalid argument '{}': {}", name, msg)
165            }
166            ResolverError::FieldNotFound(name) => write!(f, "Field not found: {}", name),
167            ResolverError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
168            ResolverError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
169            ResolverError::Internal(msg) => write!(f, "Internal error: {}", msg),
170        }
171    }
172}
173
174impl std::error::Error for ResolverError {}
175
176impl From<String> for ResolverError {
177    fn from(s: String) -> Self {
178        ResolverError::Internal(s)
179    }
180}
181
182impl From<&str> for ResolverError {
183    fn from(s: &str) -> Self {
184        ResolverError::Internal(s.to_string())
185    }
186}
187
188/// Field resolver trait
189pub trait FieldResolver: Send + Sync {
190    /// Resolve the field
191    fn resolve(&self, ctx: &ResolverContext) -> ResolverResult;
192
193    /// Get the field name
194    fn field_name(&self) -> &str;
195
196    /// Get the type name this resolver belongs to
197    fn type_name(&self) -> &str;
198}
199
200/// Default field resolver (extracts from parent)
201#[derive(Debug)]
202pub struct DefaultResolver {
203    /// Type name
204    type_name: String,
205    /// Field name
206    field_name: String,
207    /// Column name in database
208    column_name: String,
209}
210
211impl DefaultResolver {
212    /// Create a new default resolver
213    pub fn new(
214        type_name: impl Into<String>,
215        field_name: impl Into<String>,
216        column_name: impl Into<String>,
217    ) -> Self {
218        Self {
219            type_name: type_name.into(),
220            field_name: field_name.into(),
221            column_name: column_name.into(),
222        }
223    }
224}
225
226impl FieldResolver for DefaultResolver {
227    fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
228        match &ctx.parent {
229            Some(parent) => {
230                parent.get(&self.column_name).cloned().into()
231            }
232            None => ResolverResult::Null,
233        }
234    }
235
236    fn field_name(&self) -> &str {
237        &self.field_name
238    }
239
240    fn type_name(&self) -> &str {
241        &self.type_name
242    }
243}
244
245/// Computed field resolver
246pub struct ComputedResolver<F>
247where
248    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
249{
250    type_name: String,
251    field_name: String,
252    resolver_fn: F,
253}
254
255impl<F> ComputedResolver<F>
256where
257    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
258{
259    /// Create a computed resolver
260    pub fn new(
261        type_name: impl Into<String>,
262        field_name: impl Into<String>,
263        resolver_fn: F,
264    ) -> Self {
265        Self {
266            type_name: type_name.into(),
267            field_name: field_name.into(),
268            resolver_fn,
269        }
270    }
271}
272
273impl<F> FieldResolver for ComputedResolver<F>
274where
275    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
276{
277    fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
278        (self.resolver_fn)(ctx)
279    }
280
281    fn field_name(&self) -> &str {
282        &self.field_name
283    }
284
285    fn type_name(&self) -> &str {
286        &self.type_name
287    }
288}
289
290impl<F> std::fmt::Debug for ComputedResolver<F>
291where
292    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
293{
294    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295        f.debug_struct("ComputedResolver")
296            .field("type_name", &self.type_name)
297            .field("field_name", &self.field_name)
298            .finish()
299    }
300}
301
302/// Resolver registry
303#[derive(Default)]
304pub struct ResolverRegistry {
305    /// Resolvers by type.field
306    resolvers: HashMap<String, Arc<dyn FieldResolver>>,
307}
308
309impl std::fmt::Debug for ResolverRegistry {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        f.debug_struct("ResolverRegistry")
312            .field("resolvers_count", &self.resolvers.len())
313            .finish()
314    }
315}
316
317impl ResolverRegistry {
318    /// Create a new registry
319    pub fn new() -> Self {
320        Self::default()
321    }
322
323    /// Register a resolver
324    pub fn register(&mut self, resolver: impl FieldResolver + 'static) {
325        let key = format!("{}.{}", resolver.type_name(), resolver.field_name());
326        self.resolvers.insert(key, Arc::new(resolver));
327    }
328
329    /// Get a resolver
330    pub fn get(&self, type_name: &str, field_name: &str) -> Option<Arc<dyn FieldResolver>> {
331        let key = format!("{}.{}", type_name, field_name);
332        self.resolvers.get(&key).cloned()
333    }
334
335    /// Check if a resolver exists
336    pub fn has(&self, type_name: &str, field_name: &str) -> bool {
337        let key = format!("{}.{}", type_name, field_name);
338        self.resolvers.contains_key(&key)
339    }
340
341    /// Get all resolvers for a type
342    pub fn resolvers_for(&self, type_name: &str) -> Vec<Arc<dyn FieldResolver>> {
343        let prefix = format!("{}.", type_name);
344        self.resolvers
345            .iter()
346            .filter(|(k, _)| k.starts_with(&prefix))
347            .map(|(_, v)| v.clone())
348            .collect()
349    }
350}
351
352/// Resolver chain for applying multiple resolvers
353pub struct ResolverChain {
354    /// Resolvers in order
355    resolvers: Vec<Arc<dyn FieldResolver>>,
356}
357
358impl std::fmt::Debug for ResolverChain {
359    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360        f.debug_struct("ResolverChain")
361            .field("resolvers_count", &self.resolvers.len())
362            .finish()
363    }
364}
365
366impl ResolverChain {
367    /// Create a new chain
368    pub fn new() -> Self {
369        Self {
370            resolvers: Vec::new(),
371        }
372    }
373
374    /// Add a resolver to the chain
375    pub fn add(mut self, resolver: impl FieldResolver + 'static) -> Self {
376        self.resolvers.push(Arc::new(resolver));
377        self
378    }
379
380    /// Resolve through the chain
381    pub fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
382        for resolver in &self.resolvers {
383            let result = resolver.resolve(ctx);
384            if !matches!(result, ResolverResult::Null) {
385                return result;
386            }
387        }
388        ResolverResult::Null
389    }
390}
391
392impl Default for ResolverChain {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::graphql::introspector::GraphQLSchema;
402
403    fn create_test_context() -> ResolverContext {
404        let schema = Arc::new(GraphQLSchema::new());
405        let execution = ExecutionContext::default();
406        ResolverContext::new(schema, execution)
407    }
408
409    #[test]
410    fn test_resolver_context_args() {
411        let mut args = HashMap::new();
412        args.insert("limit".to_string(), serde_json::json!(10));
413        args.insert("name".to_string(), serde_json::json!("test"));
414
415        let ctx = create_test_context().with_arguments(args);
416
417        assert_eq!(ctx.arg::<i32>("limit"), Some(10));
418        assert_eq!(ctx.arg::<String>("name"), Some("test".to_string()));
419        assert_eq!(ctx.arg::<i32>("missing"), None);
420    }
421
422    #[test]
423    fn test_resolver_context_required_arg() {
424        let mut args = HashMap::new();
425        args.insert("id".to_string(), serde_json::json!("123"));
426
427        let ctx = create_test_context().with_arguments(args);
428
429        assert!(ctx.required_arg::<String>("id").is_ok());
430        assert!(ctx.required_arg::<String>("missing").is_err());
431    }
432
433    #[test]
434    fn test_resolver_context_parent() {
435        let parent = serde_json::json!({
436            "id": "123",
437            "name": "Test"
438        });
439
440        let ctx = create_test_context().with_parent(parent);
441
442        assert_eq!(ctx.parent_field("id"), Some(&serde_json::json!("123")));
443        assert_eq!(ctx.parent_field("missing"), None);
444    }
445
446    #[test]
447    fn test_default_resolver() {
448        let resolver = DefaultResolver::new("User", "name", "name");
449
450        let parent = serde_json::json!({
451            "id": "123",
452            "name": "John"
453        });
454
455        let ctx = create_test_context().with_parent(parent);
456        let result = resolver.resolve(&ctx);
457
458        match result {
459            ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
460            _ => panic!("Expected value"),
461        }
462    }
463
464    #[test]
465    fn test_computed_resolver() {
466        let resolver = ComputedResolver::new("User", "fullName", |ctx| {
467            let first = ctx.parent_field("firstName")
468                .and_then(|v| v.as_str())
469                .unwrap_or("");
470            let last = ctx.parent_field("lastName")
471                .and_then(|v| v.as_str())
472                .unwrap_or("");
473            ResolverResult::value(format!("{} {}", first, last))
474        });
475
476        let parent = serde_json::json!({
477            "firstName": "John",
478            "lastName": "Doe"
479        });
480
481        let ctx = create_test_context().with_parent(parent);
482        let result = resolver.resolve(&ctx);
483
484        match result {
485            ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John Doe")),
486            _ => panic!("Expected value"),
487        }
488    }
489
490    #[test]
491    fn test_resolver_registry() {
492        let mut registry = ResolverRegistry::new();
493
494        registry.register(DefaultResolver::new("User", "id", "id"));
495        registry.register(DefaultResolver::new("User", "name", "name"));
496        registry.register(DefaultResolver::new("Post", "title", "title"));
497
498        assert!(registry.has("User", "id"));
499        assert!(registry.has("User", "name"));
500        assert!(registry.has("Post", "title"));
501        assert!(!registry.has("User", "email"));
502
503        let user_resolvers = registry.resolvers_for("User");
504        assert_eq!(user_resolvers.len(), 2);
505    }
506
507    #[test]
508    fn test_resolver_result_conversions() {
509        let value_result: ResolverResult = serde_json::json!("test").into();
510        assert!(!value_result.is_error());
511
512        let some_result: ResolverResult = Some("test").into();
513        assert!(matches!(some_result, ResolverResult::Value(_)));
514
515        let none_result: ResolverResult = Option::<String>::None.into();
516        assert!(matches!(none_result, ResolverResult::Null));
517    }
518
519    #[test]
520    fn test_resolver_chain() {
521        let chain = ResolverChain::new()
522            .add(DefaultResolver::new("User", "displayName", "display_name"))
523            .add(DefaultResolver::new("User", "displayName", "name"))
524            .add(DefaultResolver::new("User", "displayName", "email"));
525
526        // First resolver returns null, second returns value
527        let parent = serde_json::json!({
528            "name": "John"
529        });
530
531        let ctx = create_test_context().with_parent(parent);
532        let result = chain.resolve(&ctx);
533
534        match result {
535            ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
536            _ => panic!("Expected value from second resolver"),
537        }
538    }
539
540    #[test]
541    fn test_resolver_error_display() {
542        let err = ResolverError::MissingArgument("id".to_string());
543        assert!(err.to_string().contains("id"));
544
545        let err = ResolverError::Unauthorized("Not authenticated".to_string());
546        assert!(err.to_string().contains("Not authenticated"));
547    }
548
549    #[test]
550    fn test_resolver_context_roles() {
551        let schema = Arc::new(GraphQLSchema::new());
552        let execution = ExecutionContext::default()
553            .with_user("user1")
554            .with_role("admin");
555
556        let ctx = ResolverContext::new(schema, execution);
557
558        assert!(ctx.has_role("admin"));
559        assert!(!ctx.has_role("superuser"));
560        assert_eq!(ctx.user_id(), Some("user1"));
561    }
562
563    #[test]
564    fn test_resolver_result_to_json() {
565        assert_eq!(
566            ResolverResult::value("test").to_json(),
567            serde_json::json!("test")
568        );
569        assert_eq!(
570            ResolverResult::null().to_json(),
571            serde_json::Value::Null
572        );
573        assert_eq!(
574            ResolverResult::error("err").to_json(),
575            serde_json::Value::Null
576        );
577    }
578}