kimberlite_rbac/
enforcement.rs1use crate::policy::{AccessPolicy, RowFilter};
6use thiserror::Error;
7use tracing::{error, info, warn};
8
9#[derive(Debug, Error)]
11pub enum EnforcementError {
12 #[error("Access denied: {reason}")]
14 AccessDenied { reason: String },
15
16 #[error("Insufficient permissions: {operation} requires {required_permission}")]
18 InsufficientPermissions {
19 operation: String,
20 required_permission: String,
21 },
22
23 #[error("Policy evaluation failed: {0}")]
25 PolicyEvaluationFailed(String),
26}
27
28pub type Result<T> = std::result::Result<T, EnforcementError>;
30
31pub struct PolicyEnforcer {
39 policy: AccessPolicy,
41
42 audit_enabled: bool,
44}
45
46impl PolicyEnforcer {
47 pub fn new(policy: AccessPolicy) -> Self {
49 Self {
50 policy,
51 audit_enabled: true,
52 }
53 }
54
55 pub fn without_audit(mut self) -> Self {
57 self.audit_enabled = false;
58 self
59 }
60
61 pub fn enforce_stream_access(&self, stream_name: &str) -> Result<()> {
67 let allowed = self.policy.allows_stream(stream_name);
68
69 if self.audit_enabled {
70 if allowed {
71 info!(
72 stream = %stream_name,
73 role = ?self.policy.role,
74 "Stream access granted"
75 );
76 } else {
77 warn!(
78 stream = %stream_name,
79 role = ?self.policy.role,
80 "Stream access denied"
81 );
82 }
83 }
84
85 if allowed {
86 Ok(())
87 } else {
88 Err(EnforcementError::AccessDenied {
89 reason: format!("Access to stream '{stream_name}' denied by policy"),
90 })
91 }
92 }
93
94 pub fn filter_columns(&self, columns: &[String]) -> Vec<String> {
108 let allowed: Vec<String> = columns
109 .iter()
110 .filter(|col| self.policy.allows_column(col))
111 .cloned()
112 .collect();
113
114 if self.audit_enabled {
115 let denied: Vec<&String> = columns
116 .iter()
117 .filter(|col| !self.policy.allows_column(col))
118 .collect();
119
120 if !denied.is_empty() {
121 warn!(
122 role = ?self.policy.role,
123 denied_columns = ?denied,
124 "Columns filtered by policy"
125 );
126 }
127 }
128
129 allowed
130 }
131
132 pub fn row_filters(&self) -> &[RowFilter] {
149 self.policy.row_filters()
150 }
151
152 pub fn generate_where_clause(&self) -> Result<String> {
177 let filters = self.row_filters();
178
179 if filters.is_empty() {
180 return Ok(String::new());
181 }
182
183 let mut parts = Vec::with_capacity(filters.len());
184 for f in filters {
185 validate_sql_literal(&f.value)?;
186 let op = f.operator.to_sql();
187 parts.push(format!("{} {op} {}", f.column, f.value));
188 }
189
190 Ok(parts.join(" AND "))
191 }
192
193 pub fn enforce_query(
202 &self,
203 stream_name: &str,
204 requested_columns: &[String],
205 ) -> Result<(Vec<String>, String)> {
206 self.enforce_stream_access(stream_name)?;
208
209 let allowed_columns = self.filter_columns(requested_columns);
211
212 if allowed_columns.is_empty() {
213 return Err(EnforcementError::AccessDenied {
214 reason: "No authorized columns in query".to_string(),
215 });
216 }
217
218 let where_clause = self.generate_where_clause()?;
220
221 if self.audit_enabled {
222 info!(
223 stream = %stream_name,
224 role = ?self.policy.role,
225 columns = ?allowed_columns,
226 where_clause = %where_clause,
227 "Query access granted"
228 );
229 }
230
231 Ok((allowed_columns, where_clause))
232 }
233
234 pub fn policy(&self) -> &AccessPolicy {
236 &self.policy
237 }
238}
239
240fn validate_sql_literal(value: &str) -> Result<()> {
247 if value.parse::<i64>().is_ok() {
249 return Ok(());
250 }
251
252 if value.eq_ignore_ascii_case("true") || value.eq_ignore_ascii_case("false") {
254 return Ok(());
255 }
256
257 if value.eq_ignore_ascii_case("null") {
259 return Ok(());
260 }
261
262 if value.len() >= 2
264 && value.starts_with('\'')
265 && value.ends_with('\'')
266 && !value[1..value.len() - 1].contains('\'')
267 && !value[1..value.len() - 1].contains('\\')
268 {
269 return Ok(());
270 }
271
272 Err(EnforcementError::PolicyEvaluationFailed(format!(
273 "Invalid SQL literal in row filter: {value:?}"
274 )))
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::policy::{RowFilter, RowFilterOperator, StandardPolicies};
281 use crate::roles::Role;
282 use kimberlite_types::TenantId;
283
284 #[test]
285 fn test_enforce_stream_access_allowed() {
286 let policy = StandardPolicies::admin();
287 let enforcer = PolicyEnforcer::new(policy).without_audit();
288
289 assert!(enforcer.enforce_stream_access("any_stream").is_ok());
290 }
291
292 #[test]
293 fn test_enforce_stream_access_denied() {
294 let policy = StandardPolicies::auditor();
295 let enforcer = PolicyEnforcer::new(policy).without_audit();
296
297 assert!(enforcer.enforce_stream_access("audit_log").is_ok());
299 assert!(enforcer.enforce_stream_access("patient_records").is_err());
300 }
301
302 #[test]
303 fn test_filter_columns() {
304 let policy = AccessPolicy::new(Role::Analyst)
305 .allow_column("*")
306 .deny_column("ssn");
307
308 let enforcer = PolicyEnforcer::new(policy).without_audit();
309
310 let requested = vec!["name".to_string(), "email".to_string(), "ssn".to_string()];
311
312 let allowed = enforcer.filter_columns(&requested);
313
314 assert_eq!(allowed.len(), 2);
315 assert!(allowed.contains(&"name".to_string()));
316 assert!(allowed.contains(&"email".to_string()));
317 assert!(!allowed.contains(&"ssn".to_string()));
318 }
319
320 #[test]
321 fn test_generate_where_clause_single_filter() {
322 let tenant_id = TenantId::new(42);
323 let policy = StandardPolicies::user(tenant_id);
324 let enforcer = PolicyEnforcer::new(policy).without_audit();
325
326 let where_clause = enforcer.generate_where_clause().unwrap();
327 assert_eq!(where_clause, "tenant_id = 42");
328 }
329
330 #[test]
331 fn test_generate_where_clause_multiple_filters() {
332 let policy = AccessPolicy::new(Role::User)
333 .allow_stream("*")
334 .allow_column("*")
335 .with_row_filter(RowFilter::new("tenant_id", RowFilterOperator::Eq, "42"))
336 .with_row_filter(RowFilter::new("status", RowFilterOperator::Eq, "'active'"));
337
338 let enforcer = PolicyEnforcer::new(policy).without_audit();
339
340 let where_clause = enforcer.generate_where_clause().unwrap();
341 assert_eq!(where_clause, "tenant_id = 42 AND status = 'active'");
342 }
343
344 #[test]
345 fn test_generate_where_clause_no_filters() {
346 let policy = StandardPolicies::admin();
347 let enforcer = PolicyEnforcer::new(policy).without_audit();
348
349 let where_clause = enforcer.generate_where_clause().unwrap();
350 assert_eq!(where_clause, "");
351 }
352
353 #[test]
354 fn test_generate_where_clause_rejects_injection() {
355 let policy = AccessPolicy::new(Role::User)
356 .allow_stream("*")
357 .allow_column("*")
358 .with_row_filter(RowFilter::new(
359 "tenant_id",
360 RowFilterOperator::Eq,
361 "1; DROP TABLE users",
362 ));
363
364 let enforcer = PolicyEnforcer::new(policy).without_audit();
365 let result = enforcer.generate_where_clause();
366 assert!(result.is_err());
367 }
368
369 #[test]
370 fn test_enforce_query_full_flow() {
371 let policy = AccessPolicy::new(Role::User)
372 .with_tenant(TenantId::new(42))
373 .allow_stream("patient_*")
374 .allow_column("*")
375 .deny_column("ssn")
376 .with_row_filter(RowFilter::new("tenant_id", RowFilterOperator::Eq, "42"));
377
378 let enforcer = PolicyEnforcer::new(policy).without_audit();
379
380 let requested_columns = vec!["name".to_string(), "email".to_string(), "ssn".to_string()];
381
382 let (allowed_columns, where_clause) = enforcer
383 .enforce_query("patient_records", &requested_columns)
384 .unwrap();
385
386 assert_eq!(allowed_columns.len(), 2);
387 assert!(allowed_columns.contains(&"name".to_string()));
388 assert!(allowed_columns.contains(&"email".to_string()));
389 assert!(!allowed_columns.contains(&"ssn".to_string()));
390
391 assert_eq!(where_clause, "tenant_id = 42");
392 }
393
394 #[test]
395 fn test_enforce_query_stream_denied() {
396 let policy = StandardPolicies::auditor();
397 let enforcer = PolicyEnforcer::new(policy).without_audit();
398
399 let columns = vec!["name".to_string()];
400 let result = enforcer.enforce_query("patient_records", &columns);
401
402 assert!(result.is_err());
403 match result {
404 Err(EnforcementError::AccessDenied { reason }) => {
405 assert!(reason.contains("patient_records"));
406 }
407 _ => panic!("Expected AccessDenied error"),
408 }
409 }
410
411 #[test]
412 fn test_enforce_query_no_authorized_columns() {
413 let policy = AccessPolicy::new(Role::User)
414 .allow_stream("*")
415 .allow_column("public_*"); let enforcer = PolicyEnforcer::new(policy).without_audit();
418
419 let requested = vec!["private_ssn".to_string(), "private_address".to_string()];
420
421 let result = enforcer.enforce_query("patient_records", &requested);
422
423 assert!(result.is_err());
424 match result {
425 Err(EnforcementError::AccessDenied { reason }) => {
426 assert!(reason.contains("No authorized columns"));
427 }
428 _ => panic!("Expected AccessDenied error"),
429 }
430 }
431}