Skip to main content

heliosdb_proxy/graphql/
resolver.rs

1//! Field Resolver
2//!
3//! Resolves GraphQL fields to database values.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use super::{ExecutionContext, GraphQLSchema};
9
10/// Resolver context passed to field resolvers
11#[derive(Debug, Clone)]
12pub struct ResolverContext {
13    /// Execution context
14    pub execution: ExecutionContext,
15    /// Parent value (for nested resolvers)
16    pub parent: Option<serde_json::Value>,
17    /// Field arguments
18    pub arguments: HashMap<String, serde_json::Value>,
19    /// Field path
20    pub path: Vec<String>,
21    /// Schema reference
22    pub schema: Arc<GraphQLSchema>,
23}
24
25impl ResolverContext {
26    /// Create a new resolver context
27    pub fn new(schema: Arc<GraphQLSchema>, execution: ExecutionContext) -> Self {
28        Self {
29            execution,
30            parent: None,
31            schema,
32            arguments: HashMap::new(),
33            path: Vec::new(),
34        }
35    }
36
37    /// Set the parent value
38    pub fn with_parent(mut self, parent: serde_json::Value) -> Self {
39        self.parent = Some(parent);
40        self
41    }
42
43    /// Set arguments
44    pub fn with_arguments(mut self, arguments: HashMap<String, serde_json::Value>) -> Self {
45        self.arguments = arguments;
46        self
47    }
48
49    /// Add to path
50    pub fn push_path(&mut self, segment: impl Into<String>) {
51        self.path.push(segment.into());
52    }
53
54    /// Get an argument value
55    pub fn arg<T: serde::de::DeserializeOwned>(&self, name: &str) -> Option<T> {
56        self.arguments
57            .get(name)
58            .and_then(|v| serde_json::from_value(v.clone()).ok())
59    }
60
61    /// Get a required argument
62    pub fn required_arg<T: serde::de::DeserializeOwned>(
63        &self,
64        name: &str,
65    ) -> Result<T, ResolverError> {
66        self.arg(name)
67            .ok_or_else(|| ResolverError::MissingArgument(name.to_string()))
68    }
69
70    /// Get a field from the parent
71    pub fn parent_field(&self, name: &str) -> Option<&serde_json::Value> {
72        self.parent.as_ref()?.get(name)
73    }
74
75    /// Check if user has a role
76    pub fn has_role(&self, role: &str) -> bool {
77        self.execution.has_role(role)
78    }
79
80    /// Get the current user ID
81    pub fn user_id(&self) -> Option<&str> {
82        self.execution.user_id.as_deref()
83    }
84}
85
86/// Resolver result
87#[derive(Debug, Clone)]
88pub enum ResolverResult {
89    /// Resolved value
90    Value(serde_json::Value),
91    /// Null value
92    Null,
93    /// Error occurred
94    Error(ResolverError),
95    /// Deferred to DataLoader
96    Deferred(String),
97}
98
99impl ResolverResult {
100    /// Create a value result
101    pub fn value(val: impl Into<serde_json::Value>) -> Self {
102        ResolverResult::Value(val.into())
103    }
104
105    /// Create a null result
106    pub fn null() -> Self {
107        ResolverResult::Null
108    }
109
110    /// Create an error result
111    pub fn error(err: impl Into<ResolverError>) -> Self {
112        ResolverResult::Error(err.into())
113    }
114
115    /// Check if result is an error
116    pub fn is_error(&self) -> bool {
117        matches!(self, ResolverResult::Error(_))
118    }
119
120    /// Convert to JSON value
121    pub fn to_json(self) -> serde_json::Value {
122        match self {
123            ResolverResult::Value(v) => v,
124            ResolverResult::Null | ResolverResult::Deferred(_) => serde_json::Value::Null,
125            ResolverResult::Error(_) => serde_json::Value::Null,
126        }
127    }
128}
129
130impl From<serde_json::Value> for ResolverResult {
131    fn from(val: serde_json::Value) -> Self {
132        ResolverResult::Value(val)
133    }
134}
135
136impl<T: Into<serde_json::Value>> From<Option<T>> for ResolverResult {
137    fn from(opt: Option<T>) -> Self {
138        match opt {
139            Some(v) => ResolverResult::Value(v.into()),
140            None => ResolverResult::Null,
141        }
142    }
143}
144
145/// Resolver error
146#[derive(Debug, Clone)]
147pub enum ResolverError {
148    /// Missing required argument
149    MissingArgument(String),
150    /// Invalid argument value
151    InvalidArgument(String, String),
152    /// Field not found
153    FieldNotFound(String),
154    /// Authorization failed
155    Unauthorized(String),
156    /// Database error
157    DatabaseError(String),
158    /// Internal error
159    Internal(String),
160}
161
162impl std::fmt::Display for ResolverError {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        match self {
165            ResolverError::MissingArgument(name) => {
166                write!(f, "Missing required argument: {}", name)
167            }
168            ResolverError::InvalidArgument(name, msg) => {
169                write!(f, "Invalid argument '{}': {}", name, msg)
170            }
171            ResolverError::FieldNotFound(name) => write!(f, "Field not found: {}", name),
172            ResolverError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
173            ResolverError::DatabaseError(msg) => write!(f, "Database error: {}", msg),
174            ResolverError::Internal(msg) => write!(f, "Internal error: {}", msg),
175        }
176    }
177}
178
179impl std::error::Error for ResolverError {}
180
181impl From<String> for ResolverError {
182    fn from(s: String) -> Self {
183        ResolverError::Internal(s)
184    }
185}
186
187impl From<&str> for ResolverError {
188    fn from(s: &str) -> Self {
189        ResolverError::Internal(s.to_string())
190    }
191}
192
193/// Field resolver trait
194pub trait FieldResolver: Send + Sync {
195    /// Resolve the field
196    fn resolve(&self, ctx: &ResolverContext) -> ResolverResult;
197
198    /// Get the field name
199    fn field_name(&self) -> &str;
200
201    /// Get the type name this resolver belongs to
202    fn type_name(&self) -> &str;
203}
204
205/// Default field resolver (extracts from parent)
206#[derive(Debug)]
207pub struct DefaultResolver {
208    /// Type name
209    type_name: String,
210    /// Field name
211    field_name: String,
212    /// Column name in database
213    column_name: String,
214}
215
216impl DefaultResolver {
217    /// Create a new default resolver
218    pub fn new(
219        type_name: impl Into<String>,
220        field_name: impl Into<String>,
221        column_name: impl Into<String>,
222    ) -> Self {
223        Self {
224            type_name: type_name.into(),
225            field_name: field_name.into(),
226            column_name: column_name.into(),
227        }
228    }
229}
230
231impl FieldResolver for DefaultResolver {
232    fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
233        match &ctx.parent {
234            Some(parent) => parent.get(&self.column_name).cloned().into(),
235            None => ResolverResult::Null,
236        }
237    }
238
239    fn field_name(&self) -> &str {
240        &self.field_name
241    }
242
243    fn type_name(&self) -> &str {
244        &self.type_name
245    }
246}
247
248/// Computed field resolver
249pub struct ComputedResolver<F>
250where
251    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
252{
253    type_name: String,
254    field_name: String,
255    resolver_fn: F,
256}
257
258impl<F> ComputedResolver<F>
259where
260    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
261{
262    /// Create a computed resolver
263    pub fn new(
264        type_name: impl Into<String>,
265        field_name: impl Into<String>,
266        resolver_fn: F,
267    ) -> Self {
268        Self {
269            type_name: type_name.into(),
270            field_name: field_name.into(),
271            resolver_fn,
272        }
273    }
274}
275
276impl<F> FieldResolver for ComputedResolver<F>
277where
278    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
279{
280    fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
281        (self.resolver_fn)(ctx)
282    }
283
284    fn field_name(&self) -> &str {
285        &self.field_name
286    }
287
288    fn type_name(&self) -> &str {
289        &self.type_name
290    }
291}
292
293impl<F> std::fmt::Debug for ComputedResolver<F>
294where
295    F: Fn(&ResolverContext) -> ResolverResult + Send + Sync,
296{
297    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298        f.debug_struct("ComputedResolver")
299            .field("type_name", &self.type_name)
300            .field("field_name", &self.field_name)
301            .finish()
302    }
303}
304
305/// Resolver registry
306#[derive(Default)]
307pub struct ResolverRegistry {
308    /// Resolvers by type.field
309    resolvers: HashMap<String, Arc<dyn FieldResolver>>,
310}
311
312impl std::fmt::Debug for ResolverRegistry {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        f.debug_struct("ResolverRegistry")
315            .field("resolvers_count", &self.resolvers.len())
316            .finish()
317    }
318}
319
320impl ResolverRegistry {
321    /// Create a new registry
322    pub fn new() -> Self {
323        Self::default()
324    }
325
326    /// Register a resolver
327    pub fn register(&mut self, resolver: impl FieldResolver + 'static) {
328        let key = format!("{}.{}", resolver.type_name(), resolver.field_name());
329        self.resolvers.insert(key, Arc::new(resolver));
330    }
331
332    /// Get a resolver
333    pub fn get(&self, type_name: &str, field_name: &str) -> Option<Arc<dyn FieldResolver>> {
334        let key = format!("{}.{}", type_name, field_name);
335        self.resolvers.get(&key).cloned()
336    }
337
338    /// Check if a resolver exists
339    pub fn has(&self, type_name: &str, field_name: &str) -> bool {
340        let key = format!("{}.{}", type_name, field_name);
341        self.resolvers.contains_key(&key)
342    }
343
344    /// Get all resolvers for a type
345    pub fn resolvers_for(&self, type_name: &str) -> Vec<Arc<dyn FieldResolver>> {
346        let prefix = format!("{}.", type_name);
347        self.resolvers
348            .iter()
349            .filter(|(k, _)| k.starts_with(&prefix))
350            .map(|(_, v)| v.clone())
351            .collect()
352    }
353}
354
355/// Resolver chain for applying multiple resolvers
356pub struct ResolverChain {
357    /// Resolvers in order
358    resolvers: Vec<Arc<dyn FieldResolver>>,
359}
360
361impl std::fmt::Debug for ResolverChain {
362    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
363        f.debug_struct("ResolverChain")
364            .field("resolvers_count", &self.resolvers.len())
365            .finish()
366    }
367}
368
369impl ResolverChain {
370    /// Create a new chain
371    pub fn new() -> Self {
372        Self {
373            resolvers: Vec::new(),
374        }
375    }
376
377    /// Add a resolver to the chain
378    #[allow(clippy::should_implement_trait)]
379    pub fn add(mut self, resolver: impl FieldResolver + 'static) -> Self {
380        self.resolvers.push(Arc::new(resolver));
381        self
382    }
383
384    /// Resolve through the chain
385    pub fn resolve(&self, ctx: &ResolverContext) -> ResolverResult {
386        for resolver in &self.resolvers {
387            let result = resolver.resolve(ctx);
388            if !matches!(result, ResolverResult::Null) {
389                return result;
390            }
391        }
392        ResolverResult::Null
393    }
394}
395
396impl Default for ResolverChain {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::graphql::introspector::GraphQLSchema;
406
407    fn create_test_context() -> ResolverContext {
408        let schema = Arc::new(GraphQLSchema::new());
409        let execution = ExecutionContext::default();
410        ResolverContext::new(schema, execution)
411    }
412
413    #[test]
414    fn test_resolver_context_args() {
415        let mut args = HashMap::new();
416        args.insert("limit".to_string(), serde_json::json!(10));
417        args.insert("name".to_string(), serde_json::json!("test"));
418
419        let ctx = create_test_context().with_arguments(args);
420
421        assert_eq!(ctx.arg::<i32>("limit"), Some(10));
422        assert_eq!(ctx.arg::<String>("name"), Some("test".to_string()));
423        assert_eq!(ctx.arg::<i32>("missing"), None);
424    }
425
426    #[test]
427    fn test_resolver_context_required_arg() {
428        let mut args = HashMap::new();
429        args.insert("id".to_string(), serde_json::json!("123"));
430
431        let ctx = create_test_context().with_arguments(args);
432
433        assert!(ctx.required_arg::<String>("id").is_ok());
434        assert!(ctx.required_arg::<String>("missing").is_err());
435    }
436
437    #[test]
438    fn test_resolver_context_parent() {
439        let parent = serde_json::json!({
440            "id": "123",
441            "name": "Test"
442        });
443
444        let ctx = create_test_context().with_parent(parent);
445
446        assert_eq!(ctx.parent_field("id"), Some(&serde_json::json!("123")));
447        assert_eq!(ctx.parent_field("missing"), None);
448    }
449
450    #[test]
451    fn test_default_resolver() {
452        let resolver = DefaultResolver::new("User", "name", "name");
453
454        let parent = serde_json::json!({
455            "id": "123",
456            "name": "John"
457        });
458
459        let ctx = create_test_context().with_parent(parent);
460        let result = resolver.resolve(&ctx);
461
462        match result {
463            ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
464            _ => panic!("Expected value"),
465        }
466    }
467
468    #[test]
469    fn test_computed_resolver() {
470        let resolver = ComputedResolver::new("User", "fullName", |ctx| {
471            let first = ctx
472                .parent_field("firstName")
473                .and_then(|v| v.as_str())
474                .unwrap_or("");
475            let last = ctx
476                .parent_field("lastName")
477                .and_then(|v| v.as_str())
478                .unwrap_or("");
479            ResolverResult::value(format!("{} {}", first, last))
480        });
481
482        let parent = serde_json::json!({
483            "firstName": "John",
484            "lastName": "Doe"
485        });
486
487        let ctx = create_test_context().with_parent(parent);
488        let result = resolver.resolve(&ctx);
489
490        match result {
491            ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John Doe")),
492            _ => panic!("Expected value"),
493        }
494    }
495
496    #[test]
497    fn test_resolver_registry() {
498        let mut registry = ResolverRegistry::new();
499
500        registry.register(DefaultResolver::new("User", "id", "id"));
501        registry.register(DefaultResolver::new("User", "name", "name"));
502        registry.register(DefaultResolver::new("Post", "title", "title"));
503
504        assert!(registry.has("User", "id"));
505        assert!(registry.has("User", "name"));
506        assert!(registry.has("Post", "title"));
507        assert!(!registry.has("User", "email"));
508
509        let user_resolvers = registry.resolvers_for("User");
510        assert_eq!(user_resolvers.len(), 2);
511    }
512
513    #[test]
514    fn test_resolver_result_conversions() {
515        let value_result: ResolverResult = serde_json::json!("test").into();
516        assert!(!value_result.is_error());
517
518        let some_result: ResolverResult = Some("test").into();
519        assert!(matches!(some_result, ResolverResult::Value(_)));
520
521        let none_result: ResolverResult = Option::<String>::None.into();
522        assert!(matches!(none_result, ResolverResult::Null));
523    }
524
525    #[test]
526    fn test_resolver_chain() {
527        let chain = ResolverChain::new()
528            .add(DefaultResolver::new("User", "displayName", "display_name"))
529            .add(DefaultResolver::new("User", "displayName", "name"))
530            .add(DefaultResolver::new("User", "displayName", "email"));
531
532        // First resolver returns null, second returns value
533        let parent = serde_json::json!({
534            "name": "John"
535        });
536
537        let ctx = create_test_context().with_parent(parent);
538        let result = chain.resolve(&ctx);
539
540        match result {
541            ResolverResult::Value(v) => assert_eq!(v, serde_json::json!("John")),
542            _ => panic!("Expected value from second resolver"),
543        }
544    }
545
546    #[test]
547    fn test_resolver_error_display() {
548        let err = ResolverError::MissingArgument("id".to_string());
549        assert!(err.to_string().contains("id"));
550
551        let err = ResolverError::Unauthorized("Not authenticated".to_string());
552        assert!(err.to_string().contains("Not authenticated"));
553    }
554
555    #[test]
556    fn test_resolver_context_roles() {
557        let schema = Arc::new(GraphQLSchema::new());
558        let execution = ExecutionContext::default()
559            .with_user("user1")
560            .with_role("admin");
561
562        let ctx = ResolverContext::new(schema, execution);
563
564        assert!(ctx.has_role("admin"));
565        assert!(!ctx.has_role("superuser"));
566        assert_eq!(ctx.user_id(), Some("user1"));
567    }
568
569    #[test]
570    fn test_resolver_result_to_json() {
571        assert_eq!(
572            ResolverResult::value("test").to_json(),
573            serde_json::json!("test")
574        );
575        assert_eq!(ResolverResult::null().to_json(), serde_json::Value::Null);
576        assert_eq!(
577            ResolverResult::error("err").to_json(),
578            serde_json::Value::Null
579        );
580    }
581}