Skip to main content

fraiseql_core/security/
rls_policy.rs

1//! Row-Level Security (RLS) Policy Evaluation
2//!
3//! This module provides the trait for evaluating RLS rules at runtime.
4//!
5//! RLS rules are defined in fraiseql.toml at authoring time and compiled into
6//! schema.compiled.json. At runtime, the executor evaluates these rules using
7//! the SecurityContext to determine what rows a user can access.
8//!
9//! # Architecture
10//!
11//! ```text
12//! fraiseql.toml (authoring)
13//!     ├── [[security.policies]]          # Define policies
14//!     └── [[security.rules]]             # Define RLS rules
15//!     ↓
16//! schema.compiled.json (compiled)
17//!     ├── "policies": [...]              # Serialized policies
18//!     └── "rules": [...]                 # Serialized rules
19//!     ↓
20//! Executor.execute_regular_query()       # Runtime
21//!     ├── SecurityContext (user info)
22//!     └── RLSPolicy::evaluate()          # Evaluate rules
23//!     ↓
24//! WHERE clause composition
25//!     └── WhereClause::And([user_where, rls_filter])
26//! ```
27//!
28//! # Example RLS Rules (in fraiseql.toml)
29//!
30//! ```toml
31//! # Users can only read their own posts
32//! [[security.rules]]
33//! name = "own_posts_only"
34//! rule = "user.id == object.author_id"
35//! cacheable = true
36//! cache_ttl_seconds = 300
37//!
38//! # Admins can read everything
39//! [[security.rules]]
40//! name = "admin_can_read_all"
41//! rule = "user.roles includes 'admin'"
42//! cacheable = false
43//! ```
44//!
45//! # Example RLS Policies (in fraiseql.toml)
46//!
47//! ```toml
48//! [[security.policies]]
49//! name = "read_own_posts"
50//! type = "rls"
51//! rules = ["own_posts_only"]
52//! description = "Users can only read their own posts"
53//!
54//! [[security.policies]]
55//! name = "admin_access"
56//! type = "rbac"
57//! roles = ["admin"]
58//! strategy = "any"
59//! description = "Admins have full access"
60//! ```
61
62use std::{sync::Arc, time::SystemTime};
63
64use serde::{Deserialize, Serialize};
65
66use crate::{db::WhereClause, error::Result, security::SecurityContext};
67
68/// Cache entry for RLS policy decisions with TTL support
69#[derive(Debug, Clone)]
70struct CacheEntry {
71    /// The cached RLS evaluation result
72    result:     Option<WhereClause>,
73    /// When this cache entry expires
74    expires_at: SystemTime,
75}
76
77/// Row-Level Security (RLS) policy for runtime evaluation.
78///
79/// Implementations of this trait evaluate compiled RLS rules with the user's
80/// SecurityContext to determine what rows they can access.
81///
82/// # Type Safety
83///
84/// The trait returns `Option<WhereClause>` to support composition:
85/// - `None`: No RLS filter (unrestricted access)
86/// - `Some(clause)`: Filter to apply to the query
87///
88/// The executor composes this with user-provided filters via `WhereClause::And()`.
89pub trait RLSPolicy: Send + Sync {
90    /// Evaluate RLS rules for the given type and security context.
91    ///
92    /// # Arguments
93    ///
94    /// * `context` - Security context with user information and permissions
95    /// * `type_name` - GraphQL type name being accessed (e.g., "Post", "User")
96    ///
97    /// # Returns
98    ///
99    /// - `Ok(Some(clause))`: RLS filter to apply to query
100    /// - `Ok(None)`: No RLS filter (full access)
101    /// - `Err(e)`: Policy evaluation error (access denied)
102    ///
103    /// # Example
104    ///
105    /// ```ignore
106    /// let rls = DefaultRLSPolicy::new(schema);
107    /// let context = SecurityContext { user_id: "u1", roles: vec!["user"] };
108    /// let filter = rls.evaluate(&context, "Post")?;
109    /// // filter is Some(WhereClause::Field { path: ["author_id"], operator: Eq, value: "u1" })
110    /// ```
111    fn evaluate(&self, context: &SecurityContext, type_name: &str) -> Result<Option<WhereClause>>;
112
113    /// Optional: Cache RLS decisions for performance.
114    ///
115    /// The executor may call this to cache policy decisions per user/type
116    /// combination to avoid repeated evaluations.
117    ///
118    /// # Arguments
119    ///
120    /// * `cache_key` - Cache key (typically "user_id:type_name")
121    /// * `result` - The policy evaluation result to cache
122    fn cache_result(&self, _cache_key: &str, _result: &Option<WhereClause>) {
123        // Default: no caching. Implementers can override.
124    }
125}
126
127/// Default RLS policy that enforces tenant isolation and owner-based access.
128///
129/// This is a reference implementation showing how to build RLS policies.
130///
131/// Rules:
132/// 1. Multi-tenant: Filter to rows matching user's tenant_id
133/// 2. Admin bypass: Admins can access all rows in their tenant
134/// 3. Owner-based: Regular users can only access their own rows (author_id == user_id)
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct DefaultRLSPolicy {
137    /// Enable multi-tenant isolation
138    pub enable_tenant_isolation: bool,
139    /// Field name for tenant isolation (default: "tenant_id")
140    pub tenant_field:            String,
141    /// Field name for owner-based access (default: "author_id")
142    pub owner_field:             String,
143}
144
145impl DefaultRLSPolicy {
146    /// Create a new default RLS policy.
147    pub fn new() -> Self {
148        Self {
149            enable_tenant_isolation: true,
150            tenant_field:            "tenant_id".to_string(),
151            owner_field:             "author_id".to_string(),
152        }
153    }
154
155    /// Disable tenant isolation (single-tenant mode).
156    pub fn with_single_tenant(mut self) -> Self {
157        self.enable_tenant_isolation = false;
158        self
159    }
160
161    /// Set custom tenant field name.
162    pub fn with_tenant_field(mut self, field: String) -> Self {
163        self.tenant_field = field;
164        self
165    }
166
167    /// Set custom owner field name.
168    pub fn with_owner_field(mut self, field: String) -> Self {
169        self.owner_field = field;
170        self
171    }
172}
173
174impl Default for DefaultRLSPolicy {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180impl RLSPolicy for DefaultRLSPolicy {
181    fn evaluate(&self, context: &SecurityContext, _type_name: &str) -> Result<Option<WhereClause>> {
182        // Admins bypass RLS
183        if context.is_admin() {
184            return Ok(None);
185        }
186
187        let mut filters = vec![];
188
189        // Rule 1: Multi-tenant isolation
190        if self.enable_tenant_isolation {
191            if let Some(ref tenant_id) = context.tenant_id {
192                filters.push(WhereClause::Field {
193                    path:     vec![self.tenant_field.clone()],
194                    operator: crate::db::WhereOperator::Eq,
195                    value:    serde_json::json!(tenant_id.clone()),
196                });
197            }
198        }
199
200        // Rule 2: Owner-based access (users can only access their own rows)
201        filters.push(WhereClause::Field {
202            path:     vec![self.owner_field.clone()],
203            operator: crate::db::WhereOperator::Eq,
204            value:    serde_json::json!(context.user_id.clone()),
205        });
206
207        // Combine all filters with AND
208        match filters.len() {
209            0 => Ok(None),
210            1 => Ok(Some(filters.into_iter().next().unwrap())),
211            _ => Ok(Some(WhereClause::And(filters))),
212        }
213    }
214}
215
216/// No-op RLS policy that allows all access (for testing or fully open APIs).
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct NoRLSPolicy;
219
220impl RLSPolicy for NoRLSPolicy {
221    fn evaluate(
222        &self,
223        _context: &SecurityContext,
224        _type_name: &str,
225    ) -> Result<Option<WhereClause>> {
226        Ok(None)
227    }
228}
229
230/// Custom RLS policy that can be configured from schema.compiled.json
231///
232/// This allows schema authors to define RLS rules without writing Rust code.
233/// Supports caching of policy evaluation results for performance optimization.
234#[derive(Clone, Serialize, Deserialize)]
235pub struct CompiledRLSPolicy {
236    /// RLS rules indexed by type name
237    pub rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
238    /// Default RLS rule if no type-specific rule exists
239    pub default_rule:  Option<RLSRule>,
240    /// Cache for policy evaluation results (not serialized)
241    #[serde(skip)]
242    cache:             Arc<parking_lot::RwLock<std::collections::HashMap<String, CacheEntry>>>,
243}
244
245impl std::fmt::Debug for CompiledRLSPolicy {
246    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247        f.debug_struct("CompiledRLSPolicy")
248            .field("rules_by_type", &self.rules_by_type)
249            .field("default_rule", &self.default_rule)
250            .field("cache", &"<cached>")
251            .finish()
252    }
253}
254
255impl CompiledRLSPolicy {
256    /// Create a new compiled RLS policy with caching enabled
257    pub fn new(
258        rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
259        default_rule: Option<RLSRule>,
260    ) -> Self {
261        Self {
262            rules_by_type,
263            default_rule,
264            cache: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
265        }
266    }
267}
268
269/// A single RLS rule for a type
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct RLSRule {
272    /// Rule name (for debugging)
273    pub name:              String,
274    /// Expression to evaluate (e.g., "user.id == object.author_id")
275    pub expression:        String,
276    /// Whether this rule result can be cached
277    pub cacheable:         bool,
278    /// Cache TTL in seconds (if cacheable)
279    pub cache_ttl_seconds: Option<u64>,
280}
281
282impl RLSPolicy for CompiledRLSPolicy {
283    fn evaluate(&self, context: &SecurityContext, type_name: &str) -> Result<Option<WhereClause>> {
284        // Admins bypass all RLS (never cache admin access)
285        if context.is_admin() {
286            return Ok(None);
287        }
288
289        // Find rule for type or use default
290        let rule = self
291            .rules_by_type
292            .get(type_name)
293            .and_then(|rules| rules.first())
294            .or(self.default_rule.as_ref());
295
296        if let Some(rule) = rule {
297            // Check cache for cacheable rules
298            let cache_key = if rule.cacheable {
299                Some(format!("{}:{}", context.user_id, type_name))
300            } else {
301                None
302            };
303
304            // Try to retrieve from cache
305            if let Some(ref key) = cache_key {
306                let cache = self.cache.read();
307                if let Some(entry) = cache.get(key) {
308                    if SystemTime::now() < entry.expires_at {
309                        return Ok(entry.result.clone());
310                    }
311                }
312                drop(cache);
313            }
314
315            // Evaluate the RLS expression and generate WHERE clause
316            let result = evaluate_rls_expression(&rule.expression, context)?;
317
318            // Cache the result if rule is cacheable
319            if let Some(key) = cache_key {
320                if let Some(ttl_secs) = rule.cache_ttl_seconds {
321                    let expires_at = SystemTime::now() + std::time::Duration::from_secs(ttl_secs);
322                    let entry = CacheEntry {
323                        result: result.clone(),
324                        expires_at,
325                    };
326                    let mut cache = self.cache.write();
327                    cache.insert(key, entry);
328                }
329            }
330
331            Ok(result)
332        } else {
333            Ok(None)
334        }
335    }
336
337    fn cache_result(&self, cache_key: &str, result: &Option<WhereClause>) {
338        // Direct cache storage with default TTL of 300 seconds
339        let expires_at = SystemTime::now() + std::time::Duration::from_secs(300);
340        let entry = CacheEntry {
341            result: result.clone(),
342            expires_at,
343        };
344        let mut cache = self.cache.write();
345        cache.insert(cache_key.to_string(), entry);
346    }
347}
348
349/// Helper function to evaluate RLS expressions
350///
351/// Supports simple expressions like:
352/// - `user.id == object.author_id` - Equality comparison
353/// - `user.roles includes 'admin'` - Role/array membership
354/// - `user.tenant_id == object.tenant_id` - Tenant isolation
355///
356/// In production, consider using:
357/// - Rhai for dynamic expression evaluation
358/// - WASM for sandboxed custom policies
359/// - A domain-specific language (DSL)
360fn evaluate_rls_expression(
361    expression: &str,
362    context: &SecurityContext,
363) -> Result<Option<WhereClause>> {
364    let expr = expression.trim();
365
366    // Pattern 1: Simple equality - "user.id == object.field_name"
367    if let Some(eq_parts) = expr.split_once("==") {
368        let left = eq_parts.0.trim();
369        let right = eq_parts.1.trim();
370
371        // Left side: user.{field}
372        if let Some(user_field) = left.strip_prefix("user.") {
373            let user_value = extract_user_value(user_field, context);
374
375            // Right side: object.{field} or literal
376            if let Some(object_field) = right.strip_prefix("object.") {
377                // Return a field comparison filter
378                return Ok(Some(WhereClause::Field {
379                    path:     vec![object_field.to_string()],
380                    operator: crate::db::WhereOperator::Eq,
381                    value:    user_value.unwrap_or(serde_json::Value::Null),
382                }));
383            } else if serde_json::from_str::<serde_json::Value>(right).is_ok() {
384                // Literal value comparison
385                return Ok(Some(WhereClause::Field {
386                    path:     vec!["_literal_".to_string()],
387                    operator: crate::db::WhereOperator::Eq,
388                    value:    serde_json::json!(user_value),
389                }));
390            }
391        }
392    }
393
394    // Pattern 2: Membership test - "user.roles includes 'admin'"
395    if expr.contains("includes") {
396        if let Some(includes_parts) = expr.split_once("includes") {
397            let left = includes_parts.0.trim();
398            let right = includes_parts.1.trim().trim_matches(|c| c == '\'' || c == '"');
399
400            if left == "user.roles" && context.has_role(right) {
401                // User has the required role - no RLS filter needed
402                return Ok(None);
403            }
404        }
405    }
406
407    // Pattern 3: Tenant isolation - "user.tenant_id == object.tenant_id"
408    if expr.contains("tenant_id") && expr.contains("==") {
409        if let Some(tenant_id) = &context.tenant_id {
410            return Ok(Some(WhereClause::Field {
411                path:     vec!["tenant_id".to_string()],
412                operator: crate::db::WhereOperator::Eq,
413                value:    serde_json::json!(tenant_id),
414            }));
415        }
416    }
417
418    // If no pattern matched or couldn't evaluate, return None (no filter)
419    // In production, this should probably return an error for unparseable expressions
420    Ok(None)
421}
422
423/// Extract a value from user context by field name
424fn extract_user_value(field: &str, context: &SecurityContext) -> Option<serde_json::Value> {
425    match field {
426        "id" | "user_id" => Some(serde_json::json!(context.user_id)),
427        "tenant_id" => context.tenant_id.as_ref().map(|t| serde_json::json!(t)),
428        "roles" => Some(serde_json::json!(context.roles)),
429        custom => context.get_attribute(custom).cloned(),
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use std::collections::HashMap;
436
437    use super::*;
438
439    #[test]
440    fn test_default_rls_policy_admin_bypass() {
441        let policy = DefaultRLSPolicy::new();
442        let context = SecurityContext {
443            user_id:          "user123".to_string(),
444            roles:            vec!["admin".to_string()],
445            tenant_id:        Some("tenant1".to_string()),
446            scopes:           vec![],
447            attributes:       HashMap::new(),
448            request_id:       "req1".to_string(),
449            ip_address:       None,
450            authenticated_at: chrono::Utc::now(),
451            expires_at:       chrono::Utc::now() + chrono::Duration::hours(1),
452            issuer:           None,
453            audience:         None,
454        };
455
456        let result = policy.evaluate(&context, "Post").unwrap();
457        assert_eq!(result, None, "Admins should bypass RLS");
458    }
459
460    #[test]
461    fn test_default_rls_policy_tenant_isolation() {
462        let policy = DefaultRLSPolicy::new();
463        let context = SecurityContext {
464            user_id:          "user123".to_string(),
465            roles:            vec!["user".to_string()],
466            tenant_id:        Some("tenant1".to_string()),
467            scopes:           vec![],
468            attributes:       HashMap::new(),
469            request_id:       "req1".to_string(),
470            ip_address:       None,
471            authenticated_at: chrono::Utc::now(),
472            expires_at:       chrono::Utc::now() + chrono::Duration::hours(1),
473            issuer:           None,
474            audience:         None,
475        };
476
477        let result = policy.evaluate(&context, "Post").unwrap();
478        assert!(result.is_some(), "Non-admin users should have RLS filter applied");
479    }
480
481    #[test]
482    fn test_no_rls_policy() {
483        let policy = NoRLSPolicy;
484        let context = SecurityContext {
485            user_id:          "user123".to_string(),
486            roles:            vec![],
487            tenant_id:        None,
488            scopes:           vec![],
489            attributes:       HashMap::new(),
490            request_id:       "req1".to_string(),
491            ip_address:       None,
492            authenticated_at: chrono::Utc::now(),
493            expires_at:       chrono::Utc::now() + chrono::Duration::hours(1),
494            issuer:           None,
495            audience:         None,
496        };
497
498        let result = policy.evaluate(&context, "Post").unwrap();
499        assert_eq!(result, None, "NoRLSPolicy should never apply filters");
500    }
501}