fraiseql_core/security/rls_policy.rs
1//! Row-Level Security (RLS) Policy Evaluation
2//!
3//! This module provides the trait for evaluating RLS rules at runtime.
4//!
5//! RLS rules are defined in fraiseql.toml at authoring time and compiled into
6//! schema.compiled.json. At runtime, the executor evaluates these rules using
7//! the SecurityContext to determine what rows a user can access.
8//!
9//! # Architecture
10//!
11//! ```text
12//! fraiseql.toml (authoring)
13//! ├── [[security.policies]] # Define policies
14//! └── [[security.rules]] # Define RLS rules
15//! ↓
16//! schema.compiled.json (compiled)
17//! ├── "policies": [...] # Serialized policies
18//! └── "rules": [...] # Serialized rules
19//! ↓
20//! Executor.execute_regular_query() # Runtime
21//! ├── SecurityContext (user info)
22//! └── RLSPolicy::evaluate() # Evaluate rules
23//! ↓
24//! WHERE clause composition
25//! └── WhereClause::And([user_where, rls_filter])
26//! ```
27//!
28//! # Example RLS Rules (in fraiseql.toml)
29//!
30//! ```toml
31//! # Users can only read their own posts
32//! [[security.rules]]
33//! name = "own_posts_only"
34//! rule = "user.id == object.author_id"
35//! cacheable = true
36//! cache_ttl_seconds = 300
37//!
38//! # Admins can read everything
39//! [[security.rules]]
40//! name = "admin_can_read_all"
41//! rule = "user.roles includes 'admin'"
42//! cacheable = false
43//! ```
44//!
45//! # Example RLS Policies (in fraiseql.toml)
46//!
47//! ```toml
48//! [[security.policies]]
49//! name = "read_own_posts"
50//! type = "rls"
51//! rules = ["own_posts_only"]
52//! description = "Users can only read their own posts"
53//!
54//! [[security.policies]]
55//! name = "admin_access"
56//! type = "rbac"
57//! roles = ["admin"]
58//! strategy = "any"
59//! description = "Admins have full access"
60//! ```
61
62use std::{sync::Arc, time::SystemTime};
63
64use serde::{Deserialize, Serialize};
65
66use crate::{db::WhereClause, error::Result, security::SecurityContext};
67
68/// Cache entry for RLS policy decisions with TTL support
69#[derive(Debug, Clone)]
70struct CacheEntry {
71 /// The cached RLS evaluation result
72 result: Option<WhereClause>,
73 /// When this cache entry expires
74 expires_at: SystemTime,
75}
76
77/// Row-Level Security (RLS) policy for runtime evaluation.
78///
79/// Implementations of this trait evaluate compiled RLS rules with the user's
80/// SecurityContext to determine what rows they can access.
81///
82/// # Type Safety
83///
84/// The trait returns `Option<WhereClause>` to support composition:
85/// - `None`: No RLS filter (unrestricted access)
86/// - `Some(clause)`: Filter to apply to the query
87///
88/// The executor composes this with user-provided filters via `WhereClause::And()`.
89pub trait RLSPolicy: Send + Sync {
90 /// Evaluate RLS rules for the given type and security context.
91 ///
92 /// # Arguments
93 ///
94 /// * `context` - Security context with user information and permissions
95 /// * `type_name` - GraphQL type name being accessed (e.g., "Post", "User")
96 ///
97 /// # Returns
98 ///
99 /// - `Ok(Some(clause))`: RLS filter to apply to query
100 /// - `Ok(None)`: No RLS filter (full access)
101 /// - `Err(e)`: Policy evaluation error (access denied)
102 ///
103 /// # Example
104 ///
105 /// ```ignore
106 /// let rls = DefaultRLSPolicy::new(schema);
107 /// let context = SecurityContext { user_id: "u1", roles: vec!["user"] };
108 /// let filter = rls.evaluate(&context, "Post")?;
109 /// // filter is Some(WhereClause::Field { path: ["author_id"], operator: Eq, value: "u1" })
110 /// ```
111 fn evaluate(&self, context: &SecurityContext, type_name: &str) -> Result<Option<WhereClause>>;
112
113 /// Optional: Cache RLS decisions for performance.
114 ///
115 /// The executor may call this to cache policy decisions per user/type
116 /// combination to avoid repeated evaluations.
117 ///
118 /// # Arguments
119 ///
120 /// * `cache_key` - Cache key (typically "user_id:type_name")
121 /// * `result` - The policy evaluation result to cache
122 fn cache_result(&self, _cache_key: &str, _result: &Option<WhereClause>) {
123 // Default: no caching. Implementers can override.
124 }
125}
126
127/// Default RLS policy that enforces tenant isolation and owner-based access.
128///
129/// This is a reference implementation showing how to build RLS policies.
130///
131/// Rules:
132/// 1. Multi-tenant: Filter to rows matching user's tenant_id
133/// 2. Admin bypass: Admins can access all rows in their tenant
134/// 3. Owner-based: Regular users can only access their own rows (author_id == user_id)
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct DefaultRLSPolicy {
137 /// Enable multi-tenant isolation
138 pub enable_tenant_isolation: bool,
139 /// Field name for tenant isolation (default: "tenant_id")
140 pub tenant_field: String,
141 /// Field name for owner-based access (default: "author_id")
142 pub owner_field: String,
143}
144
145impl DefaultRLSPolicy {
146 /// Create a new default RLS policy.
147 pub fn new() -> Self {
148 Self {
149 enable_tenant_isolation: true,
150 tenant_field: "tenant_id".to_string(),
151 owner_field: "author_id".to_string(),
152 }
153 }
154
155 /// Disable tenant isolation (single-tenant mode).
156 pub fn with_single_tenant(mut self) -> Self {
157 self.enable_tenant_isolation = false;
158 self
159 }
160
161 /// Set custom tenant field name.
162 pub fn with_tenant_field(mut self, field: String) -> Self {
163 self.tenant_field = field;
164 self
165 }
166
167 /// Set custom owner field name.
168 pub fn with_owner_field(mut self, field: String) -> Self {
169 self.owner_field = field;
170 self
171 }
172}
173
174impl Default for DefaultRLSPolicy {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180impl RLSPolicy for DefaultRLSPolicy {
181 fn evaluate(&self, context: &SecurityContext, _type_name: &str) -> Result<Option<WhereClause>> {
182 // Admins bypass RLS
183 if context.is_admin() {
184 return Ok(None);
185 }
186
187 let mut filters = vec![];
188
189 // Rule 1: Multi-tenant isolation
190 if self.enable_tenant_isolation {
191 if let Some(ref tenant_id) = context.tenant_id {
192 filters.push(WhereClause::Field {
193 path: vec![self.tenant_field.clone()],
194 operator: crate::db::WhereOperator::Eq,
195 value: serde_json::json!(tenant_id.clone()),
196 });
197 }
198 }
199
200 // Rule 2: Owner-based access (users can only access their own rows)
201 filters.push(WhereClause::Field {
202 path: vec![self.owner_field.clone()],
203 operator: crate::db::WhereOperator::Eq,
204 value: serde_json::json!(context.user_id.clone()),
205 });
206
207 // Combine all filters with AND
208 match filters.len() {
209 0 => Ok(None),
210 1 => Ok(Some(filters.into_iter().next().unwrap())),
211 _ => Ok(Some(WhereClause::And(filters))),
212 }
213 }
214}
215
216/// No-op RLS policy that allows all access (for testing or fully open APIs).
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct NoRLSPolicy;
219
220impl RLSPolicy for NoRLSPolicy {
221 fn evaluate(
222 &self,
223 _context: &SecurityContext,
224 _type_name: &str,
225 ) -> Result<Option<WhereClause>> {
226 Ok(None)
227 }
228}
229
230/// Custom RLS policy that can be configured from schema.compiled.json
231///
232/// This allows schema authors to define RLS rules without writing Rust code.
233/// Supports caching of policy evaluation results for performance optimization.
234#[derive(Clone, Serialize, Deserialize)]
235pub struct CompiledRLSPolicy {
236 /// RLS rules indexed by type name
237 pub rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
238 /// Default RLS rule if no type-specific rule exists
239 pub default_rule: Option<RLSRule>,
240 /// Cache for policy evaluation results (not serialized)
241 #[serde(skip)]
242 cache: Arc<parking_lot::RwLock<std::collections::HashMap<String, CacheEntry>>>,
243}
244
245impl std::fmt::Debug for CompiledRLSPolicy {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 f.debug_struct("CompiledRLSPolicy")
248 .field("rules_by_type", &self.rules_by_type)
249 .field("default_rule", &self.default_rule)
250 .field("cache", &"<cached>")
251 .finish()
252 }
253}
254
255impl CompiledRLSPolicy {
256 /// Create a new compiled RLS policy with caching enabled
257 pub fn new(
258 rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
259 default_rule: Option<RLSRule>,
260 ) -> Self {
261 Self {
262 rules_by_type,
263 default_rule,
264 cache: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
265 }
266 }
267}
268
269/// A single RLS rule for a type
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct RLSRule {
272 /// Rule name (for debugging)
273 pub name: String,
274 /// Expression to evaluate (e.g., "user.id == object.author_id")
275 pub expression: String,
276 /// Whether this rule result can be cached
277 pub cacheable: bool,
278 /// Cache TTL in seconds (if cacheable)
279 pub cache_ttl_seconds: Option<u64>,
280}
281
282impl RLSPolicy for CompiledRLSPolicy {
283 fn evaluate(&self, context: &SecurityContext, type_name: &str) -> Result<Option<WhereClause>> {
284 // Admins bypass all RLS (never cache admin access)
285 if context.is_admin() {
286 return Ok(None);
287 }
288
289 // Find rule for type or use default
290 let rule = self
291 .rules_by_type
292 .get(type_name)
293 .and_then(|rules| rules.first())
294 .or(self.default_rule.as_ref());
295
296 if let Some(rule) = rule {
297 // Check cache for cacheable rules
298 let cache_key = if rule.cacheable {
299 Some(format!("{}:{}", context.user_id, type_name))
300 } else {
301 None
302 };
303
304 // Try to retrieve from cache
305 if let Some(ref key) = cache_key {
306 let cache = self.cache.read();
307 if let Some(entry) = cache.get(key) {
308 if SystemTime::now() < entry.expires_at {
309 return Ok(entry.result.clone());
310 }
311 }
312 drop(cache);
313 }
314
315 // Evaluate the RLS expression and generate WHERE clause
316 let result = evaluate_rls_expression(&rule.expression, context)?;
317
318 // Cache the result if rule is cacheable
319 if let Some(key) = cache_key {
320 if let Some(ttl_secs) = rule.cache_ttl_seconds {
321 let expires_at = SystemTime::now() + std::time::Duration::from_secs(ttl_secs);
322 let entry = CacheEntry {
323 result: result.clone(),
324 expires_at,
325 };
326 let mut cache = self.cache.write();
327 cache.insert(key, entry);
328 }
329 }
330
331 Ok(result)
332 } else {
333 Ok(None)
334 }
335 }
336
337 fn cache_result(&self, cache_key: &str, result: &Option<WhereClause>) {
338 // Direct cache storage with default TTL of 300 seconds
339 let expires_at = SystemTime::now() + std::time::Duration::from_secs(300);
340 let entry = CacheEntry {
341 result: result.clone(),
342 expires_at,
343 };
344 let mut cache = self.cache.write();
345 cache.insert(cache_key.to_string(), entry);
346 }
347}
348
349/// Helper function to evaluate RLS expressions
350///
351/// Supports simple expressions like:
352/// - `user.id == object.author_id` - Equality comparison
353/// - `user.roles includes 'admin'` - Role/array membership
354/// - `user.tenant_id == object.tenant_id` - Tenant isolation
355///
356/// In production, consider using:
357/// - Rhai for dynamic expression evaluation
358/// - WASM for sandboxed custom policies
359/// - A domain-specific language (DSL)
360fn evaluate_rls_expression(
361 expression: &str,
362 context: &SecurityContext,
363) -> Result<Option<WhereClause>> {
364 let expr = expression.trim();
365
366 // Pattern 1: Simple equality - "user.id == object.field_name"
367 if let Some(eq_parts) = expr.split_once("==") {
368 let left = eq_parts.0.trim();
369 let right = eq_parts.1.trim();
370
371 // Left side: user.{field}
372 if let Some(user_field) = left.strip_prefix("user.") {
373 let user_value = extract_user_value(user_field, context);
374
375 // Right side: object.{field} or literal
376 if let Some(object_field) = right.strip_prefix("object.") {
377 // Return a field comparison filter
378 return Ok(Some(WhereClause::Field {
379 path: vec![object_field.to_string()],
380 operator: crate::db::WhereOperator::Eq,
381 value: user_value.unwrap_or(serde_json::Value::Null),
382 }));
383 } else if serde_json::from_str::<serde_json::Value>(right).is_ok() {
384 // Literal value comparison
385 return Ok(Some(WhereClause::Field {
386 path: vec!["_literal_".to_string()],
387 operator: crate::db::WhereOperator::Eq,
388 value: serde_json::json!(user_value),
389 }));
390 }
391 }
392 }
393
394 // Pattern 2: Membership test - "user.roles includes 'admin'"
395 if expr.contains("includes") {
396 if let Some(includes_parts) = expr.split_once("includes") {
397 let left = includes_parts.0.trim();
398 let right = includes_parts.1.trim().trim_matches(|c| c == '\'' || c == '"');
399
400 if left == "user.roles" && context.has_role(right) {
401 // User has the required role - no RLS filter needed
402 return Ok(None);
403 }
404 }
405 }
406
407 // Pattern 3: Tenant isolation - "user.tenant_id == object.tenant_id"
408 if expr.contains("tenant_id") && expr.contains("==") {
409 if let Some(tenant_id) = &context.tenant_id {
410 return Ok(Some(WhereClause::Field {
411 path: vec!["tenant_id".to_string()],
412 operator: crate::db::WhereOperator::Eq,
413 value: serde_json::json!(tenant_id),
414 }));
415 }
416 }
417
418 // If no pattern matched or couldn't evaluate, return None (no filter)
419 // In production, this should probably return an error for unparseable expressions
420 Ok(None)
421}
422
423/// Extract a value from user context by field name
424fn extract_user_value(field: &str, context: &SecurityContext) -> Option<serde_json::Value> {
425 match field {
426 "id" | "user_id" => Some(serde_json::json!(context.user_id)),
427 "tenant_id" => context.tenant_id.as_ref().map(|t| serde_json::json!(t)),
428 "roles" => Some(serde_json::json!(context.roles)),
429 custom => context.get_attribute(custom).cloned(),
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use std::collections::HashMap;
436
437 use super::*;
438
439 #[test]
440 fn test_default_rls_policy_admin_bypass() {
441 let policy = DefaultRLSPolicy::new();
442 let context = SecurityContext {
443 user_id: "user123".to_string(),
444 roles: vec!["admin".to_string()],
445 tenant_id: Some("tenant1".to_string()),
446 scopes: vec![],
447 attributes: HashMap::new(),
448 request_id: "req1".to_string(),
449 ip_address: None,
450 authenticated_at: chrono::Utc::now(),
451 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
452 issuer: None,
453 audience: None,
454 };
455
456 let result = policy.evaluate(&context, "Post").unwrap();
457 assert_eq!(result, None, "Admins should bypass RLS");
458 }
459
460 #[test]
461 fn test_default_rls_policy_tenant_isolation() {
462 let policy = DefaultRLSPolicy::new();
463 let context = SecurityContext {
464 user_id: "user123".to_string(),
465 roles: vec!["user".to_string()],
466 tenant_id: Some("tenant1".to_string()),
467 scopes: vec![],
468 attributes: HashMap::new(),
469 request_id: "req1".to_string(),
470 ip_address: None,
471 authenticated_at: chrono::Utc::now(),
472 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
473 issuer: None,
474 audience: None,
475 };
476
477 let result = policy.evaluate(&context, "Post").unwrap();
478 assert!(result.is_some(), "Non-admin users should have RLS filter applied");
479 }
480
481 #[test]
482 fn test_no_rls_policy() {
483 let policy = NoRLSPolicy;
484 let context = SecurityContext {
485 user_id: "user123".to_string(),
486 roles: vec![],
487 tenant_id: None,
488 scopes: vec![],
489 attributes: HashMap::new(),
490 request_id: "req1".to_string(),
491 ip_address: None,
492 authenticated_at: chrono::Utc::now(),
493 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
494 issuer: None,
495 audience: None,
496 };
497
498 let result = policy.evaluate(&context, "Post").unwrap();
499 assert_eq!(result, None, "NoRLSPolicy should never apply filters");
500 }
501}