1use std::sync::Arc;
63
64use serde::{Deserialize, Serialize};
65
66use crate::{
67 db::WhereClause,
68 error::{FraiseQLError, Result},
69 security::SecurityContext,
70 utils::clock::{Clock, SystemClock},
71};
72
73#[derive(Debug, Clone, PartialEq)]
100pub struct RlsWhereClause {
101 inner: WhereClause,
102}
103
104impl RlsWhereClause {
105 pub(crate) const fn new(inner: WhereClause) -> Self {
111 Self { inner }
112 }
113
114 pub const fn as_where_clause(&self) -> &WhereClause {
116 &self.inner
117 }
118
119 pub fn into_where_clause(self) -> WhereClause {
121 self.inner
122 }
123}
124
125#[derive(Debug, Clone)]
127struct CacheEntry {
128 result: Option<WhereClause>,
130 expires_at: u64,
132}
133
134pub trait RLSPolicy: Send + Sync {
147 fn evaluate(
176 &self,
177 context: &SecurityContext,
178 type_name: &str,
179 ) -> Result<Option<RlsWhereClause>>;
180
181 fn cache_result(&self, _cache_key: &str, _result: &Option<WhereClause>) {
191 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct DefaultRLSPolicy {
205 pub enable_tenant_isolation: bool,
207 pub tenant_field: String,
209 pub owner_field: String,
211}
212
213impl DefaultRLSPolicy {
214 pub fn new() -> Self {
216 Self {
217 enable_tenant_isolation: true,
218 tenant_field: "tenant_id".to_string(),
219 owner_field: "author_id".to_string(),
220 }
221 }
222
223 pub const fn with_single_tenant(mut self) -> Self {
225 self.enable_tenant_isolation = false;
226 self
227 }
228
229 pub fn with_tenant_field(mut self, field: String) -> Self {
231 self.tenant_field = field;
232 self
233 }
234
235 pub fn with_owner_field(mut self, field: String) -> Self {
237 self.owner_field = field;
238 self
239 }
240}
241
242impl Default for DefaultRLSPolicy {
243 fn default() -> Self {
244 Self::new()
245 }
246}
247
248impl RLSPolicy for DefaultRLSPolicy {
249 fn evaluate(
250 &self,
251 context: &SecurityContext,
252 _type_name: &str,
253 ) -> Result<Option<RlsWhereClause>> {
254 if context.is_admin() {
256 return Ok(None);
257 }
258
259 let mut filters = vec![];
260
261 if self.enable_tenant_isolation {
263 if let Some(ref tenant_id) = context.tenant_id {
264 filters.push(WhereClause::Field {
265 path: vec![self.tenant_field.clone()],
266 operator: crate::db::WhereOperator::Eq,
267 value: serde_json::json!(tenant_id.clone()),
268 });
269 }
270 }
271
272 filters.push(WhereClause::Field {
274 path: vec![self.owner_field.clone()],
275 operator: crate::db::WhereOperator::Eq,
276 value: serde_json::json!(context.user_id.clone()),
277 });
278
279 let clause = match filters.len() {
281 0 => return Ok(None),
282 1 => filters.into_iter().next().expect("len checked == 1"),
284 _ => WhereClause::And(filters),
285 };
286 Ok(Some(RlsWhereClause::new(clause)))
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct NoRLSPolicy;
293
294impl RLSPolicy for NoRLSPolicy {
295 fn evaluate(
296 &self,
297 _context: &SecurityContext,
298 _type_name: &str,
299 ) -> Result<Option<RlsWhereClause>> {
300 Ok(None)
301 }
302}
303
304fn default_system_clock() -> Arc<dyn Clock> {
307 Arc::new(SystemClock)
308}
309
310#[derive(Clone, Serialize, Deserialize)]
315pub struct CompiledRLSPolicy {
316 pub rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
318 pub default_rule: Option<RLSRule>,
320 #[serde(skip)]
322 cache: Arc<parking_lot::RwLock<std::collections::HashMap<String, CacheEntry>>>,
323 #[serde(skip, default = "default_system_clock")]
325 clock: Arc<dyn Clock>,
326}
327
328impl std::fmt::Debug for CompiledRLSPolicy {
329 #[cfg_attr(test, mutants::skip)]
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333 f.debug_struct("CompiledRLSPolicy")
334 .field("rules_by_type", &self.rules_by_type)
335 .field("default_rule", &self.default_rule)
336 .field("cache", &"<cached>")
337 .field("clock", &"<clock>")
338 .finish()
339 }
340}
341
342impl CompiledRLSPolicy {
343 pub fn new(
345 rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
346 default_rule: Option<RLSRule>,
347 ) -> Self {
348 Self::new_with_clock(rules_by_type, default_rule, Arc::new(SystemClock))
349 }
350
351 pub fn new_with_clock(
353 rules_by_type: std::collections::HashMap<String, Vec<RLSRule>>,
354 default_rule: Option<RLSRule>,
355 clock: Arc<dyn Clock>,
356 ) -> Self {
357 Self {
358 rules_by_type,
359 default_rule,
360 cache: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
361 clock,
362 }
363 }
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct RLSRule {
369 pub name: String,
371 pub expression: String,
373 pub cacheable: bool,
375 pub cache_ttl_seconds: Option<u64>,
377}
378
379impl RLSPolicy for CompiledRLSPolicy {
380 fn evaluate(
381 &self,
382 context: &SecurityContext,
383 type_name: &str,
384 ) -> Result<Option<RlsWhereClause>> {
385 if context.is_admin() {
387 return Ok(None);
388 }
389
390 let rule = self
392 .rules_by_type
393 .get(type_name)
394 .and_then(|rules| rules.first())
395 .or(self.default_rule.as_ref());
396
397 if let Some(rule) = rule {
398 let cache_key = if rule.cacheable {
400 Some(format!("{}:{}", context.user_id, type_name))
401 } else {
402 None
403 };
404
405 if let Some(ref key) = cache_key {
407 let cache = self.cache.read();
408 if let Some(entry) = cache.get(key) {
409 if self.clock.now_secs() < entry.expires_at {
410 return Ok(entry.result.clone().map(RlsWhereClause::new));
412 }
413 }
414 drop(cache);
415 }
416
417 let result: Option<WhereClause> = evaluate_rls_expression(&rule.expression, context)?;
419
420 if let Some(key) = cache_key {
422 if let Some(ttl_secs) = rule.cache_ttl_seconds {
423 let expires_at = self.clock.now_secs() + ttl_secs;
424 let entry = CacheEntry {
425 result: result.clone(),
426 expires_at,
427 };
428 let mut cache = self.cache.write();
429 cache.insert(key, entry);
430 }
431 }
432
433 Ok(result.map(RlsWhereClause::new))
434 } else {
435 Ok(None)
436 }
437 }
438
439 fn cache_result(&self, cache_key: &str, result: &Option<WhereClause>) {
440 let expires_at = self.clock.now_secs() + 300;
442 let entry = CacheEntry {
443 result: result.clone(),
444 expires_at,
445 };
446 let mut cache = self.cache.write();
447 cache.insert(cache_key.to_string(), entry);
448 }
449}
450
451fn evaluate_rls_expression(
463 expression: &str,
464 context: &SecurityContext,
465) -> Result<Option<WhereClause>> {
466 let expr = expression.trim();
467
468 if let Some(eq_parts) = expr.split_once("==") {
470 let left = eq_parts.0.trim();
471 let right = eq_parts.1.trim();
472
473 if let Some(user_field) = left.strip_prefix("user.") {
475 let user_value = extract_user_value(user_field, context);
476
477 if let Some(object_field) = right.strip_prefix("object.") {
479 return Ok(Some(WhereClause::Field {
481 path: vec![object_field.to_string()],
482 operator: crate::db::WhereOperator::Eq,
483 value: user_value.unwrap_or(serde_json::Value::Null),
484 }));
485 } else if serde_json::from_str::<serde_json::Value>(right).is_ok() {
486 return Ok(Some(WhereClause::Field {
488 path: vec!["_literal_".to_string()],
489 operator: crate::db::WhereOperator::Eq,
490 value: serde_json::json!(user_value),
491 }));
492 }
493 }
494 }
495
496 if expr.contains("includes") {
498 if let Some(includes_parts) = expr.split_once("includes") {
499 let left = includes_parts.0.trim();
500 let right = includes_parts.1.trim().trim_matches(|c| c == '\'' || c == '"');
501
502 if left == "user.roles" && context.has_role(right) {
503 return Ok(None);
505 }
506 }
507 }
508
509 if expr.contains("tenant_id") && expr.contains("==") {
511 if let Some(tenant_id) = &context.tenant_id {
512 return Ok(Some(WhereClause::Field {
513 path: vec!["tenant_id".to_string()],
514 operator: crate::db::WhereOperator::Eq,
515 value: serde_json::json!(tenant_id),
516 }));
517 }
518 }
519
520 Err(FraiseQLError::Validation {
522 message: format!("Unrecognised RLS expression: '{expr}'"),
523 path: None,
524 })
525}
526
527fn extract_user_value(field: &str, context: &SecurityContext) -> Option<serde_json::Value> {
529 match field {
530 "id" | "user_id" => Some(serde_json::json!(context.user_id)),
531 "tenant_id" => context.tenant_id.as_ref().map(|t| serde_json::json!(t)),
532 "roles" => Some(serde_json::json!(context.roles)),
533 custom => context.get_attribute(custom).cloned(),
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 #![allow(clippy::unwrap_used)] use std::collections::HashMap;
542
543 use super::*;
544
545 fn make_context(user_id: &str, roles: Vec<&str>, tenant_id: Option<&str>) -> SecurityContext {
548 SecurityContext {
549 user_id: user_id.to_string(),
550 roles: roles.into_iter().map(String::from).collect(),
551 tenant_id: tenant_id.map(String::from),
552 scopes: vec![],
553 attributes: HashMap::new(),
554 request_id: "req1".to_string(),
555 ip_address: None,
556 authenticated_at: chrono::Utc::now(),
557 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
558 issuer: None,
559 audience: None,
560 }
561 }
562
563 fn cacheable_owner_rule() -> RLSRule {
564 RLSRule {
565 name: "owner_only".to_string(),
566 expression: "user.id == object.author_id".to_string(),
567 cacheable: true,
568 cache_ttl_seconds: Some(300),
569 }
570 }
571
572 fn policy_with_rule(rule: RLSRule) -> CompiledRLSPolicy {
573 let mut rules_by_type = std::collections::HashMap::new();
574 rules_by_type.insert("Post".to_string(), vec![rule]);
575 CompiledRLSPolicy::new(rules_by_type, None)
576 }
577
578 fn policy_with_rule_and_clock(
579 rule: RLSRule,
580 clock: std::sync::Arc<dyn crate::utils::clock::Clock>,
581 ) -> CompiledRLSPolicy {
582 let mut rules_by_type = std::collections::HashMap::new();
583 rules_by_type.insert("Post".to_string(), vec![rule]);
584 CompiledRLSPolicy::new_with_clock(rules_by_type, None, clock)
585 }
586
587 #[test]
590 fn test_with_tenant_field_sets_field_name() {
591 let policy = DefaultRLSPolicy::new().with_tenant_field("org_id".to_string());
595 assert_eq!(policy.tenant_field, "org_id", "with_tenant_field must update tenant_field");
596
597 let context = make_context("user1", vec!["viewer"], Some("org42"));
599 let result = policy.evaluate(&context, "Post").unwrap().unwrap();
600 let sql = format!("{:?}", result.into_where_clause());
601 assert!(sql.contains("org_id"), "custom tenant field must appear in WHERE clause: {sql}");
602 assert!(!sql.contains("\"tenant_id\""), "default field name must not appear: {sql}");
603 }
604
605 #[test]
606 fn test_with_owner_field_sets_field_name() {
607 let policy = DefaultRLSPolicy::new().with_owner_field("creator_id".to_string());
610 assert_eq!(policy.owner_field, "creator_id", "with_owner_field must update owner_field");
611
612 let context = make_context("user1", vec!["viewer"], None);
614 let result = policy.evaluate(&context, "Post").unwrap().unwrap();
615 let sql = format!("{:?}", result.into_where_clause());
616 assert!(
617 sql.contains("creator_id"),
618 "custom owner field must appear in WHERE clause: {sql}"
619 );
620 assert!(!sql.contains("author_id"), "default field name must not appear: {sql}");
621 }
622
623 #[test]
624 fn test_default_rls_policy_admin_bypass() {
625 let policy = DefaultRLSPolicy::new();
626 let context = make_context("user123", vec!["admin"], Some("tenant1"));
627 let result = policy.evaluate(&context, "Post").unwrap();
628 assert_eq!(result, None, "Admins should bypass RLS");
629 }
630
631 #[test]
632 fn test_default_rls_policy_tenant_isolation() {
633 let policy = DefaultRLSPolicy::new();
634 let context = make_context("user123", vec!["user"], Some("tenant1"));
635 let result = policy.evaluate(&context, "Post").unwrap();
636 assert!(result.is_some(), "Non-admin users should have RLS filter applied");
637 }
638
639 #[test]
640 fn test_no_rls_policy() {
641 let policy = NoRLSPolicy;
642 let context = make_context("user123", vec![], None);
643 let result = policy.evaluate(&context, "Post").unwrap();
644 assert_eq!(result, None, "NoRLSPolicy should never apply filters");
645 }
646
647 #[test]
650 fn test_compiled_rls_cache_entry_has_correct_ttl() {
651 use crate::utils::clock::ManualClock;
654 let clock = std::sync::Arc::new(ManualClock::new());
655 let t0 = clock.now_secs();
656 let policy = policy_with_rule_and_clock(cacheable_owner_rule(), clock);
658 let context = make_context("user1", vec!["viewer"], Some("t1"));
659
660 policy.evaluate(&context, "Post").unwrap();
662
663 let cache = policy.cache.read();
664 let entry =
665 cache.get("user1:Post").expect("cache should be populated after first evaluate");
666 assert_eq!(entry.expires_at, t0 + 300, "expires_at must be now_secs + ttl_secs (300)");
667 }
668
669 #[test]
670 fn test_compiled_rls_cache_hit_does_not_refresh_expiry() {
671 use crate::utils::clock::ManualClock;
675 let clock = std::sync::Arc::new(ManualClock::new());
676 let t0 = clock.now_secs();
677
678 let policy = policy_with_rule_and_clock(cacheable_owner_rule(), clock.clone());
679 let context = make_context("user1", vec!["viewer"], Some("t1"));
680
681 policy.evaluate(&context, "Post").unwrap();
683 let first_expires_at = policy.cache.read().get("user1:Post").unwrap().expires_at;
684 assert_eq!(first_expires_at, t0 + 300);
685
686 clock.advance(std::time::Duration::from_secs(299));
688
689 policy.evaluate(&context, "Post").unwrap();
691 let second_expires_at = policy.cache.read().get("user1:Post").unwrap().expires_at;
692 assert_eq!(
693 second_expires_at, first_expires_at,
694 "Cache hit must not update expires_at (would indicate a miss + re-cache)"
695 );
696 }
697
698 #[test]
699 fn test_compiled_rls_cache_miss_after_expiry_refreshes_entry() {
700 use crate::utils::clock::ManualClock;
703 let clock = std::sync::Arc::new(ManualClock::new());
704 let t0 = clock.now_secs();
705
706 let policy = policy_with_rule_and_clock(cacheable_owner_rule(), clock.clone());
707 let context = make_context("user1", vec!["viewer"], Some("t1"));
708
709 policy.evaluate(&context, "Post").unwrap();
711
712 clock.advance(std::time::Duration::from_secs(301));
714
715 policy.evaluate(&context, "Post").unwrap();
717 let new_expires = policy.cache.read().get("user1:Post").unwrap().expires_at;
718 assert_eq!(
719 new_expires,
720 t0 + 601,
721 "After TTL expiry, cache must be refreshed with updated expires_at"
722 );
723 }
724
725 #[test]
726 fn test_compiled_rls_cache_expires_exactly_at_ttl_boundary() {
727 use crate::utils::clock::ManualClock;
730 let clock = std::sync::Arc::new(ManualClock::new());
731 let t0 = clock.now_secs();
732
733 let policy = policy_with_rule_and_clock(cacheable_owner_rule(), clock.clone());
734 let context = make_context("user1", vec!["viewer"], Some("t1"));
735
736 policy.evaluate(&context, "Post").unwrap();
738
739 clock.advance(std::time::Duration::from_secs(300));
741 assert_eq!(clock.now_secs(), t0 + 300);
742
743 policy.evaluate(&context, "Post").unwrap();
745 let refreshed_expires = policy.cache.read().get("user1:Post").unwrap().expires_at;
746 assert_eq!(
747 refreshed_expires,
748 t0 + 600,
749 "At exact TTL boundary, cache must expire and refresh (< not <=)"
750 );
751 }
752
753 #[test]
756 fn test_cache_result_stores_with_300s_default_ttl() {
757 use crate::utils::clock::ManualClock;
761 let clock = std::sync::Arc::new(ManualClock::new());
762 let t0 = clock.now_secs();
763
764 let policy =
765 CompiledRLSPolicy::new_with_clock(std::collections::HashMap::new(), None, clock);
766
767 let result = Some(WhereClause::Field {
768 path: vec!["author_id".to_string()],
769 operator: crate::db::WhereOperator::Eq,
770 value: serde_json::json!("user_x"),
771 });
772
773 policy.cache_result("user_x:Post", &result);
774
775 let cache = policy.cache.read();
776 let entry = cache.get("user_x:Post").expect("cache_result must insert entry");
777 assert_eq!(entry.expires_at, t0 + 300, "cache_result must use 300s TTL");
778 assert_eq!(entry.result, result, "cache_result must store the provided result");
779 }
780
781 #[test]
782 fn test_cache_result_stores_none_result() {
783 use crate::utils::clock::ManualClock;
785 let policy = CompiledRLSPolicy::new_with_clock(
786 std::collections::HashMap::new(),
787 None,
788 std::sync::Arc::new(ManualClock::new()),
789 );
790
791 policy.cache_result("user1:Post", &None);
792
793 let cache = policy.cache.read();
794 let entry = cache.get("user1:Post").expect("cache_result must store even None result");
795 assert!(entry.result.is_none(), "cached None result must remain None");
796 }
797
798 #[test]
801 fn test_extract_user_value_id_field() {
802 let ctx = make_context("user_abc_123", vec![], None);
804 assert_eq!(
805 extract_user_value("id", &ctx),
806 Some(serde_json::json!("user_abc_123")),
807 "'id' must return the actual user_id"
808 );
809 }
810
811 #[test]
812 fn test_extract_user_value_user_id_alias() {
813 let ctx = make_context("user_abc_123", vec![], None);
814 assert_eq!(
815 extract_user_value("user_id", &ctx),
816 Some(serde_json::json!("user_abc_123")),
817 "'user_id' must return the same user_id as 'id'"
818 );
819 }
820
821 #[test]
822 fn test_extract_user_value_tenant_id_present() {
823 let ctx = make_context("u1", vec![], Some("tenant_xyz"));
824 assert_eq!(
825 extract_user_value("tenant_id", &ctx),
826 Some(serde_json::json!("tenant_xyz")),
827 "'tenant_id' must return the tenant id when present"
828 );
829 }
830
831 #[test]
832 fn test_extract_user_value_tenant_id_absent() {
833 let ctx = make_context("u1", vec![], None);
834 assert_eq!(
835 extract_user_value("tenant_id", &ctx),
836 None,
837 "'tenant_id' must return None when absent, not Some(null)"
838 );
839 }
840
841 #[test]
842 fn test_extract_user_value_roles_field() {
843 let ctx = make_context("u1", vec!["editor", "viewer"], None);
844 assert_eq!(
845 extract_user_value("roles", &ctx),
846 Some(serde_json::json!(["editor", "viewer"])),
847 "'roles' must return the full roles array"
848 );
849 }
850
851 #[test]
852 fn test_extract_user_value_custom_attribute() {
853 let mut attrs = HashMap::new();
854 attrs.insert("department".to_string(), serde_json::json!("engineering"));
855 let ctx = SecurityContext {
856 user_id: "u1".to_string(),
857 roles: vec![],
858 tenant_id: None,
859 scopes: vec![],
860 attributes: attrs,
861 request_id: "r1".to_string(),
862 ip_address: None,
863 authenticated_at: chrono::Utc::now(),
864 expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
865 issuer: None,
866 audience: None,
867 };
868 assert_eq!(
869 extract_user_value("department", &ctx),
870 Some(serde_json::json!("engineering")),
871 "Custom attribute must be returned by name"
872 );
873 }
874
875 #[test]
876 fn test_extract_user_value_unknown_returns_none() {
877 let ctx = make_context("u1", vec![], None);
878 assert_eq!(
879 extract_user_value("nonexistent_field", &ctx),
880 None,
881 "Unknown field must return None, not Some(null)"
882 );
883 }
884
885 #[test]
888 fn test_user_id_propagated_to_rls_where_clause() {
889 let policy = policy_with_rule(RLSRule {
892 name: "owner_only".to_string(),
893 expression: "user.id == object.author_id".to_string(),
894 cacheable: false,
895 cache_ttl_seconds: None,
896 });
897
898 let context = make_context("specific_user_42", vec!["viewer"], None);
899 let result = policy.evaluate(&context, "Post").unwrap();
900
901 let clause = result.expect("non-admin user must receive an RLS filter").into_where_clause();
902 match clause {
903 WhereClause::Field { value, .. } => {
904 assert_eq!(
905 value,
906 serde_json::json!("specific_user_42"),
907 "RLS WhereClause must embed the actual user_id, not null"
908 );
909 },
910 other => panic!("Expected Field clause, got {other:?}"),
911 }
912 }
913}