Skip to main content

kimberlite_rbac/
enforcement.rs

1//! Policy enforcement logic.
2//!
3//! Enforces access control policies at query time.
4
5use crate::policy::{AccessPolicy, RowFilter};
6use thiserror::Error;
7use tracing::{error, info, warn};
8
9/// Error type for policy enforcement.
10#[derive(Debug, Error)]
11pub enum EnforcementError {
12    /// Access denied by policy.
13    #[error("Access denied: {reason}")]
14    AccessDenied { reason: String },
15
16    /// Insufficient permissions for operation.
17    #[error("Insufficient permissions: {operation} requires {required_permission}")]
18    InsufficientPermissions {
19        operation: String,
20        required_permission: String,
21    },
22
23    /// Policy evaluation error.
24    #[error("Policy evaluation failed: {0}")]
25    PolicyEvaluationFailed(String),
26}
27
28/// Result type for enforcement operations.
29pub type Result<T> = std::result::Result<T, EnforcementError>;
30
31/// Policy enforcement engine.
32///
33/// Enforces access control policies at query time:
34/// - Stream-level access control
35/// - Column filtering (field-level security)
36/// - Row-level security (RLS)
37/// - Audit logging
38pub struct PolicyEnforcer {
39    /// Current access policy.
40    policy: AccessPolicy,
41
42    /// Whether to log access attempts.
43    audit_enabled: bool,
44}
45
46impl PolicyEnforcer {
47    /// Creates a new policy enforcer.
48    pub fn new(policy: AccessPolicy) -> Self {
49        Self {
50            policy,
51            audit_enabled: true,
52        }
53    }
54
55    /// Disables audit logging (for testing).
56    pub fn without_audit(mut self) -> Self {
57        self.audit_enabled = false;
58        self
59    }
60
61    /// Enforces stream-level access control.
62    ///
63    /// Returns `Ok(())` if access is allowed, `Err` otherwise.
64    ///
65    /// **Audit:** Logs all access attempts.
66    pub fn enforce_stream_access(&self, stream_name: &str) -> Result<()> {
67        let allowed = self.policy.allows_stream(stream_name);
68
69        if self.audit_enabled {
70            if allowed {
71                info!(
72                    stream = %stream_name,
73                    role = ?self.policy.role,
74                    "Stream access granted"
75                );
76            } else {
77                warn!(
78                    stream = %stream_name,
79                    role = ?self.policy.role,
80                    "Stream access denied"
81                );
82            }
83        }
84
85        if allowed {
86            Ok(())
87        } else {
88            Err(EnforcementError::AccessDenied {
89                reason: format!("Access to stream '{stream_name}' denied by policy"),
90            })
91        }
92    }
93
94    /// Filters columns based on policy.
95    ///
96    /// Removes unauthorized columns from the query result.
97    ///
98    /// # Arguments
99    ///
100    /// * `columns` - List of column names requested by the query
101    ///
102    /// # Returns
103    ///
104    /// Filtered list of column names that the policy allows.
105    ///
106    /// **Audit:** Logs denied columns (if any).
107    pub fn filter_columns(&self, columns: &[String]) -> Vec<String> {
108        let allowed: Vec<String> = columns
109            .iter()
110            .filter(|col| self.policy.allows_column(col))
111            .cloned()
112            .collect();
113
114        if self.audit_enabled {
115            let denied: Vec<&String> = columns
116                .iter()
117                .filter(|col| !self.policy.allows_column(col))
118                .collect();
119
120            if !denied.is_empty() {
121                warn!(
122                    role = ?self.policy.role,
123                    denied_columns = ?denied,
124                    "Columns filtered by policy"
125                );
126            }
127        }
128
129        allowed
130    }
131
132    /// Returns row-level security filters to inject into the query.
133    ///
134    /// These filters are added as WHERE clauses to restrict rows
135    /// visible to the user.
136    ///
137    /// # Examples
138    ///
139    /// For a User role with `tenant_id=42`:
140    /// ```sql
141    /// WHERE tenant_id = 42
142    /// ```
143    ///
144    /// For multiple filters:
145    /// ```sql
146    /// WHERE tenant_id = 42 AND status = 'active'
147    /// ```
148    pub fn row_filters(&self) -> &[RowFilter] {
149        self.policy.row_filters()
150    }
151
152    /// Generates SQL WHERE clause from row filters.
153    ///
154    /// # Returns
155    ///
156    /// SQL WHERE clause (without "WHERE" keyword), or empty string if no filters.
157    ///
158    /// # Errors
159    ///
160    /// Returns [`EnforcementError::PolicyEvaluationFailed`] if a filter value
161    /// fails SQL literal validation (e.g., contains SQL injection attempts).
162    ///
163    /// # Examples
164    ///
165    /// ```
166    /// use kimberlite_rbac::enforcement::PolicyEnforcer;
167    /// use kimberlite_rbac::policy::{AccessPolicy, RowFilter, RowFilterOperator, StandardPolicies};
168    /// use kimberlite_types::TenantId;
169    ///
170    /// let policy = StandardPolicies::user(TenantId::new(42));
171    /// let enforcer = PolicyEnforcer::new(policy).without_audit();
172    ///
173    /// let where_clause = enforcer.generate_where_clause().unwrap();
174    /// assert_eq!(where_clause, "tenant_id = 42");
175    /// ```
176    pub fn generate_where_clause(&self) -> Result<String> {
177        let filters = self.row_filters();
178
179        if filters.is_empty() {
180            return Ok(String::new());
181        }
182
183        let mut parts = Vec::with_capacity(filters.len());
184        for f in filters {
185            validate_sql_literal(&f.value)?;
186            let op = f.operator.to_sql();
187            parts.push(format!("{} {op} {}", f.column, f.value));
188        }
189
190        Ok(parts.join(" AND "))
191    }
192
193    /// Enforces policy for a complete query.
194    ///
195    /// Validates:
196    /// 1. Stream access is allowed
197    /// 2. Columns are filtered
198    /// 3. Row filters are applied
199    ///
200    /// Returns filtered columns and WHERE clause.
201    pub fn enforce_query(
202        &self,
203        stream_name: &str,
204        requested_columns: &[String],
205    ) -> Result<(Vec<String>, String)> {
206        // 1. Check stream access
207        self.enforce_stream_access(stream_name)?;
208
209        // 2. Filter columns
210        let allowed_columns = self.filter_columns(requested_columns);
211
212        if allowed_columns.is_empty() {
213            return Err(EnforcementError::AccessDenied {
214                reason: "No authorized columns in query".to_string(),
215            });
216        }
217
218        // 3. Generate row filters (validates SQL literals)
219        let where_clause = self.generate_where_clause()?;
220
221        if self.audit_enabled {
222            info!(
223                stream = %stream_name,
224                role = ?self.policy.role,
225                columns = ?allowed_columns,
226                where_clause = %where_clause,
227                "Query access granted"
228            );
229        }
230
231        Ok((allowed_columns, where_clause))
232    }
233
234    /// Returns the current policy.
235    pub fn policy(&self) -> &AccessPolicy {
236        &self.policy
237    }
238}
239
240/// Validates that a value is a safe SQL literal.
241///
242/// Accepts: integers, booleans (`true`/`false`), `NULL`, and simple quoted
243/// strings (single-quoted, no embedded quotes or backslashes).
244///
245/// Rejects everything else to prevent SQL injection via row filter values.
246fn validate_sql_literal(value: &str) -> Result<()> {
247    // Integer literals (including negative)
248    if value.parse::<i64>().is_ok() {
249        return Ok(());
250    }
251
252    // Boolean literals
253    if value.eq_ignore_ascii_case("true") || value.eq_ignore_ascii_case("false") {
254        return Ok(());
255    }
256
257    // NULL literal
258    if value.eq_ignore_ascii_case("null") {
259        return Ok(());
260    }
261
262    // Simple single-quoted string: 'content' with no embedded quotes or backslashes
263    if value.len() >= 2
264        && value.starts_with('\'')
265        && value.ends_with('\'')
266        && !value[1..value.len() - 1].contains('\'')
267        && !value[1..value.len() - 1].contains('\\')
268    {
269        return Ok(());
270    }
271
272    Err(EnforcementError::PolicyEvaluationFailed(format!(
273        "Invalid SQL literal in row filter: {value:?}"
274    )))
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::policy::{RowFilter, RowFilterOperator, StandardPolicies};
281    use crate::roles::Role;
282    use kimberlite_types::TenantId;
283
284    #[test]
285    fn test_enforce_stream_access_allowed() {
286        let policy = StandardPolicies::admin();
287        let enforcer = PolicyEnforcer::new(policy).without_audit();
288
289        assert!(enforcer.enforce_stream_access("any_stream").is_ok());
290    }
291
292    #[test]
293    fn test_enforce_stream_access_denied() {
294        let policy = StandardPolicies::auditor();
295        let enforcer = PolicyEnforcer::new(policy).without_audit();
296
297        // Auditor can only access audit_* streams
298        assert!(enforcer.enforce_stream_access("audit_log").is_ok());
299        assert!(enforcer.enforce_stream_access("patient_records").is_err());
300    }
301
302    #[test]
303    fn test_filter_columns() {
304        let policy = AccessPolicy::new(Role::Analyst)
305            .allow_column("*")
306            .deny_column("ssn");
307
308        let enforcer = PolicyEnforcer::new(policy).without_audit();
309
310        let requested = vec!["name".to_string(), "email".to_string(), "ssn".to_string()];
311
312        let allowed = enforcer.filter_columns(&requested);
313
314        assert_eq!(allowed.len(), 2);
315        assert!(allowed.contains(&"name".to_string()));
316        assert!(allowed.contains(&"email".to_string()));
317        assert!(!allowed.contains(&"ssn".to_string()));
318    }
319
320    #[test]
321    fn test_generate_where_clause_single_filter() {
322        let tenant_id = TenantId::new(42);
323        let policy = StandardPolicies::user(tenant_id);
324        let enforcer = PolicyEnforcer::new(policy).without_audit();
325
326        let where_clause = enforcer.generate_where_clause().unwrap();
327        assert_eq!(where_clause, "tenant_id = 42");
328    }
329
330    #[test]
331    fn test_generate_where_clause_multiple_filters() {
332        let policy = AccessPolicy::new(Role::User)
333            .allow_stream("*")
334            .allow_column("*")
335            .with_row_filter(RowFilter::new("tenant_id", RowFilterOperator::Eq, "42"))
336            .with_row_filter(RowFilter::new("status", RowFilterOperator::Eq, "'active'"));
337
338        let enforcer = PolicyEnforcer::new(policy).without_audit();
339
340        let where_clause = enforcer.generate_where_clause().unwrap();
341        assert_eq!(where_clause, "tenant_id = 42 AND status = 'active'");
342    }
343
344    #[test]
345    fn test_generate_where_clause_no_filters() {
346        let policy = StandardPolicies::admin();
347        let enforcer = PolicyEnforcer::new(policy).without_audit();
348
349        let where_clause = enforcer.generate_where_clause().unwrap();
350        assert_eq!(where_clause, "");
351    }
352
353    #[test]
354    fn test_generate_where_clause_rejects_injection() {
355        let policy = AccessPolicy::new(Role::User)
356            .allow_stream("*")
357            .allow_column("*")
358            .with_row_filter(RowFilter::new(
359                "tenant_id",
360                RowFilterOperator::Eq,
361                "1; DROP TABLE users",
362            ));
363
364        let enforcer = PolicyEnforcer::new(policy).without_audit();
365        let result = enforcer.generate_where_clause();
366        assert!(result.is_err());
367    }
368
369    #[test]
370    fn test_enforce_query_full_flow() {
371        let policy = AccessPolicy::new(Role::User)
372            .with_tenant(TenantId::new(42))
373            .allow_stream("patient_*")
374            .allow_column("*")
375            .deny_column("ssn")
376            .with_row_filter(RowFilter::new("tenant_id", RowFilterOperator::Eq, "42"));
377
378        let enforcer = PolicyEnforcer::new(policy).without_audit();
379
380        let requested_columns = vec!["name".to_string(), "email".to_string(), "ssn".to_string()];
381
382        let (allowed_columns, where_clause) = enforcer
383            .enforce_query("patient_records", &requested_columns)
384            .unwrap();
385
386        assert_eq!(allowed_columns.len(), 2);
387        assert!(allowed_columns.contains(&"name".to_string()));
388        assert!(allowed_columns.contains(&"email".to_string()));
389        assert!(!allowed_columns.contains(&"ssn".to_string()));
390
391        assert_eq!(where_clause, "tenant_id = 42");
392    }
393
394    #[test]
395    fn test_enforce_query_stream_denied() {
396        let policy = StandardPolicies::auditor();
397        let enforcer = PolicyEnforcer::new(policy).without_audit();
398
399        let columns = vec!["name".to_string()];
400        let result = enforcer.enforce_query("patient_records", &columns);
401
402        assert!(result.is_err());
403        match result {
404            Err(EnforcementError::AccessDenied { reason }) => {
405                assert!(reason.contains("patient_records"));
406            }
407            _ => panic!("Expected AccessDenied error"),
408        }
409    }
410
411    #[test]
412    fn test_enforce_query_no_authorized_columns() {
413        let policy = AccessPolicy::new(Role::User)
414            .allow_stream("*")
415            .allow_column("public_*"); // Only public columns allowed
416
417        let enforcer = PolicyEnforcer::new(policy).without_audit();
418
419        let requested = vec!["private_ssn".to_string(), "private_address".to_string()];
420
421        let result = enforcer.enforce_query("patient_records", &requested);
422
423        assert!(result.is_err());
424        match result {
425            Err(EnforcementError::AccessDenied { reason }) => {
426                assert!(reason.contains("No authorized columns"));
427            }
428            _ => panic!("Expected AccessDenied error"),
429        }
430    }
431}