Skip to main content

fraiseql_core/runtime/executor/
security.rs

1//! Security-aware execution — field access, RBAC filtering, JWT inject resolution,
2//! `execute_with_context()`, `execute_with_security()`, `execute_json()`.
3
4use std::time::Duration;
5
6use super::{Executor, QueryType};
7use crate::{
8    db::traits::DatabaseAdapter,
9    error::{FraiseQLError, Result},
10    runtime::{ExecutionContext, classify_field_access},
11    schema::{SessionVariableSource, SessionVariablesConfig},
12    security::{FieldAccessError, SecurityContext},
13};
14
15/// Resolve session variable mappings against the current security context.
16///
17/// Returns a list of `(name, value)` pairs to inject as PostgreSQL transaction-scoped
18/// session variables via `set_config()`.
19///
20/// Resolution rules:
21/// - [`SessionVariableSource::Jwt`] — looks up the claim in `security_context.attributes`; falls
22///   back to `user_id` for `"sub"` and to `tenant_id` for `"tenant_id"`.  Missing claims are
23///   silently skipped.
24/// - [`SessionVariableSource::Header`] — looks up the header name in `security_context.attributes`.
25///   Missing headers are silently skipped.
26/// - [`SessionVariableSource::Literal`] — uses the fixed value as-is.
27///
28/// When `config.inject_started_at` is `true`, the pair
29/// `("fraiseql.started_at", <RFC 3339 now>)` is **prepended** to the returned list.
30#[must_use]
31pub fn resolve_session_variables(
32    config: &SessionVariablesConfig,
33    security_context: &SecurityContext,
34) -> Vec<(String, String)> {
35    use chrono::Utc;
36
37    let mut vars: Vec<(String, String)> = Vec::new();
38
39    if config.inject_started_at {
40        vars.push(("fraiseql.started_at".to_string(), Utc::now().to_rfc3339()));
41    }
42
43    for mapping in &config.variables {
44        let value: Option<String> = match &mapping.source {
45            SessionVariableSource::Jwt { claim } => {
46                // Check custom attributes first (raw JWT claims forwarded there).
47                // Fall back to well-known SecurityContext fields for `sub`/`user_id`
48                // and `tenant_id` so that schemas that populate only those fields
49                // (not attributes) still work.
50                if let Some(v) = security_context.attributes.get(claim.as_str()) {
51                    Some(if let serde_json::Value::String(s) = v {
52                        s.clone()
53                    } else {
54                        v.to_string()
55                    })
56                } else if claim == "sub" || claim == "user_id" {
57                    Some(security_context.user_id.clone())
58                } else if claim == "tenant_id" {
59                    security_context.tenant_id.clone()
60                } else {
61                    None
62                }
63            },
64            SessionVariableSource::Header { header } => {
65                // HTTP headers are forwarded into attributes
66                security_context.attributes.get(header.as_str()).map(|v| {
67                    if let serde_json::Value::String(s) = v {
68                        s.clone()
69                    } else {
70                        v.to_string()
71                    }
72                })
73            },
74            SessionVariableSource::Literal { value } => Some(value.clone()),
75        };
76        if let Some(v) = value {
77            vars.push((mapping.name.clone(), v));
78        }
79    }
80
81    vars
82}
83
84impl<A: DatabaseAdapter> Executor<A> {
85    /// Validate that user has access to all requested fields.
86    pub(super) fn validate_field_access(
87        &self,
88        query: &str,
89        variables: Option<&serde_json::Value>,
90        user_scopes: &[String],
91        filter: &crate::security::FieldFilter,
92    ) -> Result<()> {
93        // Parse query to get field selections
94        let query_match = self.matcher.match_query(query, variables)?;
95
96        // Get the return type name from the query definition
97        let type_name = &query_match.query_def.return_type;
98
99        // Validate each requested field
100        let field_refs: Vec<&str> = query_match.fields.iter().map(String::as_str).collect();
101        let errors = filter.validate_fields(type_name, &field_refs, user_scopes);
102
103        if errors.is_empty() {
104            Ok(())
105        } else {
106            // Return the first error (could aggregate all errors if desired)
107            let first_error = &errors[0];
108            Err(FraiseQLError::Authorization {
109                message:  first_error.message.clone(),
110                action:   Some("read".to_string()),
111                resource: Some(format!("{}.{}", first_error.type_name, first_error.field_name)),
112            })
113        }
114    }
115
116    /// Execute a GraphQL query with cancellation support via `ExecutionContext`.
117    ///
118    /// This method allows graceful cancellation of long-running queries through a
119    /// cancellation token. If the token is cancelled during execution, the query
120    /// returns a `FraiseQLError::Cancelled` error.
121    ///
122    /// # Arguments
123    ///
124    /// * `query` - GraphQL query string
125    /// * `variables` - Query variables (optional)
126    /// * `ctx` - `ExecutionContext` with cancellation token
127    ///
128    /// # Returns
129    ///
130    /// GraphQL response as JSON string, or error if cancelled or execution fails
131    ///
132    /// # Errors
133    ///
134    /// * [`FraiseQLError::Cancelled`] — the cancellation token was triggered before or during
135    ///   execution.
136    /// * Propagates any error from the underlying [`execute`](Self::execute) call.
137    ///
138    /// # Example
139    ///
140    /// ```no_run
141    /// // Requires: a live database adapter and running tokio runtime.
142    /// // See: tests/integration/ for runnable examples.
143    /// use fraiseql_core::runtime::ExecutionContext;
144    /// use fraiseql_core::error::FraiseQLError;
145    /// use std::time::Duration;
146    ///
147    /// let ctx = ExecutionContext::new("user-query-123".to_string());
148    /// let cancel_token = ctx.cancellation_token().clone();
149    ///
150    /// // Spawn a task to cancel after 5 seconds
151    /// tokio::spawn(async move {
152    ///     tokio::time::sleep(Duration::from_secs(5)).await;
153    ///     cancel_token.cancel();
154    /// });
155    ///
156    /// // let result = executor.execute_with_context(query, None, &ctx).await;
157    /// ```
158    pub async fn execute_with_context(
159        &self,
160        query: &str,
161        variables: Option<&serde_json::Value>,
162        ctx: &ExecutionContext,
163    ) -> Result<serde_json::Value> {
164        // Check if already cancelled before starting
165        if ctx.is_cancelled() {
166            return Err(FraiseQLError::cancelled(
167                ctx.query_id().to_string(),
168                "Query cancelled before execution".to_string(),
169            ));
170        }
171
172        let token = ctx.cancellation_token().clone();
173
174        // Use tokio::select! to race between execution and cancellation
175        tokio::select! {
176            result = self.execute(query, variables) => {
177                result
178            }
179            () = token.cancelled() => {
180                Err(FraiseQLError::cancelled(
181                    ctx.query_id().to_string(),
182                    "Query cancelled during execution".to_string(),
183                ))
184            }
185        }
186    }
187
188    /// Execute a GraphQL query or mutation with a JWT [`SecurityContext`].
189    ///
190    /// This is the **main authenticated entry point** for the executor. It routes the
191    /// incoming request to the appropriate handler based on the query type:
192    ///
193    /// - **Regular queries**: RLS `WHERE` clauses are applied so each user only sees their own
194    ///   rows, as determined by the RLS policy in `RuntimeConfig`.
195    /// - **Mutations**: The security context is forwarded to `execute_mutation_query_with_security`
196    ///   so server-side `inject` parameters (e.g. `jwt:sub`) are resolved from the caller's JWT
197    ///   claims.
198    /// - **Aggregations, window queries, federation, introspection**: Delegated to their respective
199    ///   handlers (security context is not yet applied to these).
200    ///
201    /// If `query_timeout_ms` is non-zero in the `RuntimeConfig`, the entire
202    /// execution is raced against a Tokio deadline and returns
203    /// [`FraiseQLError::Timeout`] when the deadline is exceeded.
204    ///
205    /// # Arguments
206    ///
207    /// * `query` - GraphQL query string (e.g. `"query { posts { id title } }"`)
208    /// * `variables` - Optional JSON object of GraphQL variable values
209    /// * `security_context` - Authenticated user context extracted from a validated JWT
210    ///
211    /// # Returns
212    ///
213    /// A JSON-encoded GraphQL response string on success, conforming to the
214    /// [GraphQL over HTTP](https://graphql.github.io/graphql-over-http/) specification.
215    ///
216    /// # Errors
217    ///
218    /// * [`FraiseQLError::Parse`] — the query string is not valid GraphQL
219    /// * [`FraiseQLError::Validation`] — unknown mutation name, missing `sql_source`, or a mutation
220    ///   requires `inject` params but the security context is absent
221    /// * [`FraiseQLError::Database`] — the underlying adapter returns an error
222    /// * [`FraiseQLError::Timeout`] — execution exceeded `query_timeout_ms`
223    ///
224    /// # Example
225    ///
226    /// ```no_run
227    /// // Requires: a live database adapter and a SecurityContext from authentication.
228    /// // See: tests/integration/ for runnable examples.
229    /// use fraiseql_core::security::SecurityContext;
230    ///
231    /// // let query = r#"query { posts { id title } }"#;
232    /// // Returns a JSON string: {"data":{"posts":[...]}}
233    /// // let result = executor.execute_with_security(query, None, &context).await?;
234    /// ```
235    pub async fn execute_with_security(
236        &self,
237        query: &str,
238        variables: Option<&serde_json::Value>,
239        security_context: &SecurityContext,
240    ) -> Result<serde_json::Value> {
241        // Apply query timeout if configured
242        if self.config.query_timeout_ms > 0 {
243            let timeout_duration = Duration::from_millis(self.config.query_timeout_ms);
244            tokio::time::timeout(
245                timeout_duration,
246                self.execute_with_security_internal(query, variables, security_context),
247            )
248            .await
249            .map_err(|_| {
250                let query_snippet = if query.len() > 100 {
251                    format!("{}...", &query[..100])
252                } else {
253                    query.to_string()
254                };
255                FraiseQLError::Timeout {
256                    timeout_ms: self.config.query_timeout_ms,
257                    query:      Some(query_snippet),
258                }
259            })?
260        } else {
261            self.execute_with_security_internal(query, variables, security_context).await
262        }
263    }
264
265    /// Internal execution logic with security context (called by `execute_with_security` with
266    /// timeout wrapper).
267    async fn execute_with_security_internal(
268        &self,
269        query: &str,
270        variables: Option<&serde_json::Value>,
271        security_context: &SecurityContext,
272    ) -> Result<serde_json::Value> {
273        // 1. Classify query type
274        let query_type = self.classify_query(query)?;
275
276        // 2. Route to appropriate handler (with RLS support for regular queries)
277        match query_type {
278            QueryType::Regular => {
279                self.execute_regular_query_with_security(query, variables, security_context)
280                    .await
281            },
282            // Other query types don't support RLS yet (relay is handled inside
283            // execute_regular_query_with_security)
284            QueryType::Aggregate(query_name) => {
285                self.execute_aggregate_dispatch(&query_name, variables).await
286            },
287            QueryType::Window(query_name) => {
288                self.execute_window_dispatch(&query_name, variables).await
289            },
290            #[cfg(feature = "federation")]
291            QueryType::Federation(query_name) => {
292                self.execute_federation_query(&query_name, query, variables).await
293            },
294            #[cfg(not(feature = "federation"))]
295            QueryType::Federation(_) => {
296                let _ = (query, variables);
297                Err(FraiseQLError::Validation {
298                    message: "Federation is not enabled in this build".to_string(),
299                    path:    None,
300                })
301            },
302            QueryType::IntrospectionSchema => {
303                Ok(self.introspection.schema_response.as_ref().clone())
304            },
305            QueryType::IntrospectionType(type_name) => {
306                Ok(self.introspection.get_type_response(&type_name))
307            },
308            QueryType::Mutation {
309                name,
310                type_selections,
311            } => {
312                self.execute_mutation_query_with_security(
313                    &name,
314                    variables,
315                    Some(security_context),
316                    &type_selections,
317                )
318                .await
319            },
320            QueryType::NodeQuery { selections } => {
321                self.execute_node_query(query, variables, &selections).await
322            },
323        }
324    }
325
326    /// Check if a specific field can be accessed with given scopes.
327    ///
328    /// This is a convenience method for checking field access without executing a query.
329    ///
330    /// # Arguments
331    ///
332    /// * `type_name` - The GraphQL type name
333    /// * `field_name` - The field name
334    /// * `user_scopes` - User's scopes from JWT token
335    ///
336    /// # Returns
337    ///
338    /// `Ok(())` if access is allowed, `Err(FieldAccessError)` if denied
339    ///
340    /// # Errors
341    ///
342    /// Returns `FieldAccessError::AccessDenied` if the user's scopes do not include the
343    /// required scope for the field.
344    pub fn check_field_access(
345        &self,
346        type_name: &str,
347        field_name: &str,
348        user_scopes: &[String],
349    ) -> std::result::Result<(), FieldAccessError> {
350        if let Some(ref filter) = self.config.field_filter {
351            filter.can_access(type_name, field_name, user_scopes)
352        } else {
353            // No filter configured, allow all access
354            Ok(())
355        }
356    }
357
358    /// Apply field-level RBAC filtering to projection fields.
359    ///
360    /// Classifies each requested field against the user's security context:
361    /// - **Allowed**: user has the required scope (or field is public)
362    /// - **Masked**: user lacks scope, but `on_deny = Mask` → field value will be nulled
363    /// - **Rejected**: user lacks scope, `on_deny = Reject` → query fails with FORBIDDEN
364    ///
365    /// # Errors
366    ///
367    /// Returns `FraiseQLError::Forbidden` if any requested field has `on_deny = Reject`
368    /// and the user lacks the required scope.
369    pub(super) fn apply_field_rbac_filtering(
370        &self,
371        return_type: &str,
372        projection_fields: Vec<String>,
373        security_context: &SecurityContext,
374    ) -> Result<super::super::field_filter::FieldAccessResult> {
375        use super::super::field_filter::FieldAccessResult;
376
377        // Try to extract security config from compiled schema
378        if let Some(security_config) = self.schema.security.as_ref() {
379            if let Some(type_def) = self.schema.types.iter().find(|t| t.name == return_type) {
380                return classify_field_access(
381                    security_context,
382                    security_config,
383                    &type_def.fields,
384                    projection_fields,
385                )
386                .map_err(|rejected_field| FraiseQLError::Authorization {
387                    message:  format!(
388                        "Access denied: field '{rejected_field}' on type '{return_type}' \
389                         requires a scope you do not have"
390                    ),
391                    action:   Some("read".to_string()),
392                    resource: Some(format!("{return_type}.{rejected_field}")),
393                });
394            }
395        }
396
397        // No security config or type not found → all fields allowed, none masked
398        Ok(FieldAccessResult {
399            allowed: projection_fields,
400            masked:  Vec::new(),
401        })
402    }
403
404    /// Execute a query and return parsed JSON.
405    ///
406    /// This method is now equivalent to `execute()` since `execute()` already
407    /// returns `serde_json::Value`.
408    ///
409    /// # Errors
410    ///
411    /// Returns any error from `execute()`.
412    #[deprecated(
413        since = "2.2.0",
414        note = "use execute() directly — it now returns Value"
415    )]
416    pub async fn execute_json(
417        &self,
418        query: &str,
419        variables: Option<&serde_json::Value>,
420    ) -> Result<serde_json::Value> {
421        self.execute(query, variables).await
422    }
423}
424
425#[cfg(test)]
426mod session_variable_tests {
427    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
428
429    use chrono::Utc;
430
431    use super::resolve_session_variables;
432    use crate::{
433        schema::{SessionVariableMapping, SessionVariableSource, SessionVariablesConfig},
434        security::SecurityContext,
435    };
436
437    fn make_context() -> SecurityContext {
438        let mut attributes = std::collections::HashMap::new();
439        attributes.insert("tenant_id".to_string(), serde_json::json!("tenant-abc"));
440        attributes.insert("x-tenant-id".to_string(), serde_json::json!("header-tenant"));
441        attributes.insert("region".to_string(), serde_json::json!("eu-west-1"));
442        SecurityContext {
443            user_id: "user-42".to_string(),
444            roles: vec!["admin".to_string()],
445            tenant_id: Some("tenant-123".to_string()),
446            scopes: vec![],
447            attributes,
448            request_id: "req-test".to_string(),
449            ip_address: None,
450            authenticated_at: Utc::now(),
451            expires_at: Utc::now(),
452            issuer: None,
453            audience: None,
454        }
455    }
456
457    #[test]
458    fn resolve_session_variables_jwt_claim() {
459        let ctx = make_context();
460        let config = SessionVariablesConfig {
461            variables:         vec![SessionVariableMapping {
462                name:   "app.tenant_id".to_string(),
463                source: SessionVariableSource::Jwt {
464                    claim: "tenant_id".to_string(),
465                },
466            }],
467            inject_started_at: false,
468        };
469        let vars = resolve_session_variables(&config, &ctx);
470        // tenant_id is in attributes
471        assert_eq!(vars.len(), 1);
472        assert_eq!(vars[0].0, "app.tenant_id");
473        assert_eq!(vars[0].1, "tenant-abc");
474    }
475
476    #[test]
477    fn resolve_session_variables_jwt_well_known_sub() {
478        let ctx = make_context();
479        let config = SessionVariablesConfig {
480            variables:         vec![SessionVariableMapping {
481                name:   "app.user_id".to_string(),
482                source: SessionVariableSource::Jwt {
483                    claim: "sub".to_string(),
484                },
485            }],
486            inject_started_at: false,
487        };
488        let vars = resolve_session_variables(&config, &ctx);
489        assert_eq!(vars.len(), 1);
490        assert_eq!(vars[0].0, "app.user_id");
491        assert_eq!(vars[0].1, "user-42");
492    }
493
494    #[test]
495    fn resolve_session_variables_literal() {
496        let ctx = make_context();
497        let config = SessionVariablesConfig {
498            variables:         vec![SessionVariableMapping {
499                name:   "app.locale".to_string(),
500                source: SessionVariableSource::Literal {
501                    value: "en".to_string(),
502                },
503            }],
504            inject_started_at: false,
505        };
506        let vars = resolve_session_variables(&config, &ctx);
507        assert_eq!(vars.len(), 1);
508        assert_eq!(vars[0].0, "app.locale");
509        assert_eq!(vars[0].1, "en");
510    }
511
512    #[test]
513    fn inject_started_at_prepended() {
514        let ctx = make_context();
515        let config = SessionVariablesConfig {
516            variables:         vec![SessionVariableMapping {
517                name:   "app.locale".to_string(),
518                source: SessionVariableSource::Literal {
519                    value: "en".to_string(),
520                },
521            }],
522            inject_started_at: true,
523        };
524        let vars = resolve_session_variables(&config, &ctx);
525        // started_at must come first
526        assert_eq!(vars.len(), 2);
527        assert_eq!(vars[0].0, "fraiseql.started_at");
528        // Verify it's an ISO 8601 / RFC 3339 string (contains 'T')
529        assert!(vars[0].1.contains('T'), "started_at should be ISO 8601");
530        assert_eq!(vars[1].0, "app.locale");
531    }
532
533    #[test]
534    fn inject_started_at_disabled() {
535        let ctx = make_context();
536        let config = SessionVariablesConfig {
537            variables:         vec![],
538            inject_started_at: false,
539        };
540        let vars = resolve_session_variables(&config, &ctx);
541        assert!(vars.is_empty());
542        assert!(!vars.iter().any(|(k, _)| k == "fraiseql.started_at"));
543    }
544
545    #[test]
546    fn resolve_session_variables_header() {
547        let ctx = make_context();
548        let config = SessionVariablesConfig {
549            variables:         vec![SessionVariableMapping {
550                name:   "app.tenant".to_string(),
551                source: SessionVariableSource::Header {
552                    header: "x-tenant-id".to_string(),
553                },
554            }],
555            inject_started_at: false,
556        };
557        let vars = resolve_session_variables(&config, &ctx);
558        assert_eq!(vars.len(), 1);
559        assert_eq!(vars[0].0, "app.tenant");
560        assert_eq!(vars[0].1, "header-tenant");
561    }
562}