use crate::policy::{AccessPolicy, RowFilter};
use thiserror::Error;
use tracing::{error, info, warn};
#[derive(Debug, Error)]
pub enum EnforcementError {
#[error("Access denied: {reason}")]
AccessDenied { reason: String },
#[error("Insufficient permissions: {operation} requires {required_permission}")]
InsufficientPermissions {
operation: String,
required_permission: String,
},
#[error("Policy evaluation failed: {0}")]
PolicyEvaluationFailed(String),
}
pub type Result<T> = std::result::Result<T, EnforcementError>;
pub struct PolicyEnforcer {
policy: AccessPolicy,
audit_enabled: bool,
}
impl PolicyEnforcer {
pub fn new(policy: AccessPolicy) -> Self {
Self {
policy,
audit_enabled: true,
}
}
pub fn without_audit(mut self) -> Self {
self.audit_enabled = false;
self
}
pub fn enforce_stream_access(&self, stream_name: &str) -> Result<()> {
let allowed = self.policy.allows_stream(stream_name);
if self.audit_enabled {
if allowed {
info!(
stream = %stream_name,
role = ?self.policy.role,
"Stream access granted"
);
} else {
warn!(
stream = %stream_name,
role = ?self.policy.role,
"Stream access denied"
);
}
}
if allowed {
Ok(())
} else {
Err(EnforcementError::AccessDenied {
reason: format!("Access to stream '{stream_name}' denied by policy"),
})
}
}
pub fn filter_columns(&self, columns: &[String]) -> Vec<String> {
let allowed: Vec<String> = columns
.iter()
.filter(|col| self.policy.allows_column(col))
.cloned()
.collect();
if self.audit_enabled {
let denied: Vec<&String> = columns
.iter()
.filter(|col| !self.policy.allows_column(col))
.collect();
if !denied.is_empty() {
warn!(
role = ?self.policy.role,
denied_columns = ?denied,
"Columns filtered by policy"
);
}
}
allowed
}
pub fn row_filters(&self) -> &[RowFilter] {
self.policy.row_filters()
}
pub fn generate_where_clause(&self) -> Result<String> {
let filters = self.row_filters();
if filters.is_empty() {
return Ok(String::new());
}
let mut parts = Vec::with_capacity(filters.len());
for f in filters {
validate_sql_literal(&f.value)?;
let op = f.operator.to_sql();
parts.push(format!("{} {op} {}", f.column, f.value));
}
Ok(parts.join(" AND "))
}
pub fn enforce_query(
&self,
stream_name: &str,
requested_columns: &[String],
) -> Result<(Vec<String>, String)> {
self.enforce_stream_access(stream_name)?;
let allowed_columns = self.filter_columns(requested_columns);
if allowed_columns.is_empty() {
return Err(EnforcementError::AccessDenied {
reason: "No authorized columns in query".to_string(),
});
}
let where_clause = self.generate_where_clause()?;
if self.audit_enabled {
info!(
stream = %stream_name,
role = ?self.policy.role,
columns = ?allowed_columns,
where_clause = %where_clause,
"Query access granted"
);
}
Ok((allowed_columns, where_clause))
}
pub fn policy(&self) -> &AccessPolicy {
&self.policy
}
}
fn validate_sql_literal(value: &str) -> Result<()> {
if value.parse::<i64>().is_ok() {
return Ok(());
}
if value.eq_ignore_ascii_case("true") || value.eq_ignore_ascii_case("false") {
return Ok(());
}
if value.eq_ignore_ascii_case("null") {
return Ok(());
}
if value.len() >= 2
&& value.starts_with('\'')
&& value.ends_with('\'')
&& !value[1..value.len() - 1].contains('\'')
&& !value[1..value.len() - 1].contains('\\')
{
return Ok(());
}
Err(EnforcementError::PolicyEvaluationFailed(format!(
"Invalid SQL literal in row filter: {value:?}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::policy::{RowFilter, RowFilterOperator, StandardPolicies};
use crate::roles::Role;
use kimberlite_types::TenantId;
#[test]
fn test_enforce_stream_access_allowed() {
let policy = StandardPolicies::admin();
let enforcer = PolicyEnforcer::new(policy).without_audit();
assert!(enforcer.enforce_stream_access("any_stream").is_ok());
}
#[test]
fn test_enforce_stream_access_denied() {
let policy = StandardPolicies::auditor();
let enforcer = PolicyEnforcer::new(policy).without_audit();
assert!(enforcer.enforce_stream_access("audit_log").is_ok());
assert!(enforcer.enforce_stream_access("patient_records").is_err());
}
#[test]
fn test_filter_columns() {
let policy = AccessPolicy::new(Role::Analyst)
.allow_column("*")
.deny_column("ssn");
let enforcer = PolicyEnforcer::new(policy).without_audit();
let requested = vec!["name".to_string(), "email".to_string(), "ssn".to_string()];
let allowed = enforcer.filter_columns(&requested);
assert_eq!(allowed.len(), 2);
assert!(allowed.contains(&"name".to_string()));
assert!(allowed.contains(&"email".to_string()));
assert!(!allowed.contains(&"ssn".to_string()));
}
#[test]
fn test_generate_where_clause_single_filter() {
let tenant_id = TenantId::new(42);
let policy = StandardPolicies::user(tenant_id);
let enforcer = PolicyEnforcer::new(policy).without_audit();
let where_clause = enforcer.generate_where_clause().unwrap();
assert_eq!(where_clause, "tenant_id = 42");
}
#[test]
fn test_generate_where_clause_multiple_filters() {
let policy = AccessPolicy::new(Role::User)
.allow_stream("*")
.allow_column("*")
.with_row_filter(RowFilter::new("tenant_id", RowFilterOperator::Eq, "42"))
.with_row_filter(RowFilter::new("status", RowFilterOperator::Eq, "'active'"));
let enforcer = PolicyEnforcer::new(policy).without_audit();
let where_clause = enforcer.generate_where_clause().unwrap();
assert_eq!(where_clause, "tenant_id = 42 AND status = 'active'");
}
#[test]
fn test_generate_where_clause_no_filters() {
let policy = StandardPolicies::admin();
let enforcer = PolicyEnforcer::new(policy).without_audit();
let where_clause = enforcer.generate_where_clause().unwrap();
assert_eq!(where_clause, "");
}
#[test]
fn test_generate_where_clause_rejects_injection() {
let policy = AccessPolicy::new(Role::User)
.allow_stream("*")
.allow_column("*")
.with_row_filter(RowFilter::new(
"tenant_id",
RowFilterOperator::Eq,
"1; DROP TABLE users",
));
let enforcer = PolicyEnforcer::new(policy).without_audit();
let result = enforcer.generate_where_clause();
assert!(result.is_err());
}
#[test]
fn test_enforce_query_full_flow() {
let policy = AccessPolicy::new(Role::User)
.with_tenant(TenantId::new(42))
.allow_stream("patient_*")
.allow_column("*")
.deny_column("ssn")
.with_row_filter(RowFilter::new("tenant_id", RowFilterOperator::Eq, "42"));
let enforcer = PolicyEnforcer::new(policy).without_audit();
let requested_columns = vec!["name".to_string(), "email".to_string(), "ssn".to_string()];
let (allowed_columns, where_clause) = enforcer
.enforce_query("patient_records", &requested_columns)
.unwrap();
assert_eq!(allowed_columns.len(), 2);
assert!(allowed_columns.contains(&"name".to_string()));
assert!(allowed_columns.contains(&"email".to_string()));
assert!(!allowed_columns.contains(&"ssn".to_string()));
assert_eq!(where_clause, "tenant_id = 42");
}
#[test]
fn test_enforce_query_stream_denied() {
let policy = StandardPolicies::auditor();
let enforcer = PolicyEnforcer::new(policy).without_audit();
let columns = vec!["name".to_string()];
let result = enforcer.enforce_query("patient_records", &columns);
assert!(result.is_err());
match result {
Err(EnforcementError::AccessDenied { reason }) => {
assert!(reason.contains("patient_records"));
}
_ => panic!("Expected AccessDenied error"),
}
}
#[test]
fn test_enforce_query_no_authorized_columns() {
let policy = AccessPolicy::new(Role::User)
.allow_stream("*")
.allow_column("public_*");
let enforcer = PolicyEnforcer::new(policy).without_audit();
let requested = vec!["private_ssn".to_string(), "private_address".to_string()];
let result = enforcer.enforce_query("patient_records", &requested);
assert!(result.is_err());
match result {
Err(EnforcementError::AccessDenied { reason }) => {
assert!(reason.contains("No authorized columns"));
}
_ => panic!("Expected AccessDenied error"),
}
}
}