use serde_json::json;
use crate::db::where_clause::{WhereClause, WhereOperator};
#[derive(Debug, Clone)]
pub struct TenantEnforcer {
org_id: Option<String>,
require_tenant: bool,
}
impl TenantEnforcer {
pub const fn new(org_id: Option<String>) -> Self {
Self {
org_id,
require_tenant: false,
}
}
pub const fn with_requirement(org_id: Option<String>, require_tenant: bool) -> Self {
Self {
org_id,
require_tenant,
}
}
pub const fn is_tenant_scoped(&self) -> bool {
self.org_id.is_some()
}
pub fn get_org_id(&self) -> Option<&str> {
self.org_id.as_deref()
}
pub fn enforce_tenant_scope(
&self,
where_clause: Option<&WhereClause>,
) -> Result<Option<WhereClause>, String> {
if self.require_tenant && self.org_id.is_none() {
return Err("Request must be tenant-scoped (missing org_id)".to_string());
}
let Some(org_id) = &self.org_id else {
return Ok(where_clause.cloned());
};
let org_id_filter = WhereClause::Field {
path: vec!["org_id".to_string()],
operator: WhereOperator::Eq,
value: json!(org_id),
};
let enforced_clause = match where_clause {
None => org_id_filter,
Some(user_clause) => WhereClause::And(vec![user_clause.clone(), org_id_filter]),
};
Ok(Some(enforced_clause))
}
pub fn enforce_tenant_scope_sql(&self, sql: &str) -> Result<String, String> {
if self.require_tenant && self.org_id.is_none() {
return Err("Request must be tenant-scoped (missing org_id)".to_string());
}
let Some(org_id) = &self.org_id else {
return Ok(sql.to_string());
};
let escaped_org_id = org_id.replace('\'', "''");
let sql_upper = sql.to_uppercase();
let enforced_sql = if sql_upper.contains("WHERE") {
format!("{sql} AND org_id = '{escaped_org_id}'")
} else if sql_upper.contains("GROUP BY") {
let parts: Vec<&str> = sql.splitn(2, "GROUP BY").collect();
if parts.len() == 2 {
format!("{} WHERE org_id = '{}' GROUP BY {}", parts[0], escaped_org_id, parts[1])
} else {
sql.to_string()
}
} else if sql_upper.contains("ORDER BY") {
let parts: Vec<&str> = sql.splitn(2, "ORDER BY").collect();
if parts.len() == 2 {
format!("{} WHERE org_id = '{}' ORDER BY {}", parts[0], escaped_org_id, parts[1])
} else {
sql.to_string()
}
} else {
format!("{sql} WHERE org_id = '{escaped_org_id}'")
};
Ok(enforced_sql)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_tenant_enforcer_with_org_id() {
let enforcer = TenantEnforcer::new(Some("org-123".to_string()));
assert!(enforcer.is_tenant_scoped());
assert_eq!(enforcer.get_org_id(), Some("org-123"));
}
#[test]
fn test_tenant_enforcer_without_org_id() {
let enforcer = TenantEnforcer::new(None);
assert!(!enforcer.is_tenant_scoped());
assert_eq!(enforcer.get_org_id(), None);
}
#[test]
fn test_enforce_tenant_scope_with_no_where_clause() {
let enforcer = TenantEnforcer::new(Some("org-123".to_string()));
let result = enforcer.enforce_tenant_scope(None);
let enforced =
result.unwrap_or_else(|e| panic!("expected Ok for enforce_tenant_scope: {e}"));
assert!(enforced.is_some());
if let Some(WhereClause::Field {
path,
operator,
value,
}) = enforced
{
assert_eq!(path, vec!["org_id".to_string()]);
assert_eq!(operator, WhereOperator::Eq);
assert_eq!(value, json!("org-123"));
} else {
panic!("Expected Field clause");
}
}
#[test]
fn test_enforce_tenant_scope_with_existing_where_clause() {
let enforcer = TenantEnforcer::new(Some("org-123".to_string()));
let user_clause = WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: json!("active"),
};
let result = enforcer.enforce_tenant_scope(Some(&user_clause));
let enforced = result
.unwrap_or_else(|e| panic!("expected Ok for enforce_tenant_scope with clause: {e}"));
assert!(enforced.is_some());
if let Some(WhereClause::And(clauses)) = enforced {
assert_eq!(clauses.len(), 2);
} else {
panic!("Expected And clause");
}
}
#[test]
fn test_enforce_tenant_scope_sql_without_where() {
let enforcer = TenantEnforcer::new(Some("org-123".to_string()));
let sql = "SELECT * FROM users";
let result = enforcer.enforce_tenant_scope_sql(sql);
let enforced = result.unwrap_or_else(|e| panic!("expected Ok for SQL without WHERE: {e}"));
assert!(enforced.contains("WHERE org_id = 'org-123'"));
}
#[test]
fn test_enforce_tenant_scope_sql_with_where() {
let enforcer = TenantEnforcer::new(Some("org-123".to_string()));
let sql = "SELECT * FROM users WHERE status = 'active'";
let result = enforcer.enforce_tenant_scope_sql(sql);
let enforced = result.unwrap_or_else(|e| panic!("expected Ok for SQL with WHERE: {e}"));
assert!(enforced.contains("WHERE status = 'active'"));
assert!(enforced.contains("AND org_id = 'org-123'"));
}
#[test]
fn test_enforce_tenant_scope_sql_with_group_by() {
let enforcer = TenantEnforcer::new(Some("org-123".to_string()));
let sql = "SELECT status, COUNT(*) as count FROM users GROUP BY status";
let result = enforcer.enforce_tenant_scope_sql(sql);
let enforced = result.unwrap_or_else(|e| panic!("expected Ok for SQL with GROUP BY: {e}"));
assert!(enforced.contains("WHERE org_id = 'org-123'"));
assert!(enforced.contains("GROUP BY"));
}
#[test]
fn test_enforce_tenant_scope_without_org_id() {
let enforcer = TenantEnforcer::new(None);
let user_clause = WhereClause::Field {
path: vec!["status".to_string()],
operator: WhereOperator::Eq,
value: json!("active"),
};
let result = enforcer.enforce_tenant_scope(Some(&user_clause));
let enforced = result
.unwrap_or_else(|e| panic!("expected Ok for enforce_tenant_scope without org_id: {e}"));
assert!(matches!(enforced, Some(WhereClause::Field { .. })));
}
#[test]
fn test_enforce_tenant_scope_sql_injection_prevention() {
let enforcer = TenantEnforcer::new(Some("'; DROP TABLE users; --".to_string()));
let sql = "SELECT * FROM users";
let result = enforcer.enforce_tenant_scope_sql(sql);
let enforced = result.unwrap_or_else(|e| panic!("expected Ok for SQL injection test: {e}"));
assert!(enforced.contains("''"), "Single quotes in org_id must be escaped (doubled)");
assert!(
enforced.contains("WHERE org_id = '''"),
"The escaped single quote should keep the value inside the string literal"
);
}
#[test]
fn test_require_tenant_fails_without_org_id() {
let enforcer = TenantEnforcer::with_requirement(None, true);
let result = enforcer.enforce_tenant_scope(None);
let err = result.expect_err("expected Err when tenant required but org_id absent");
assert_eq!(err, "Request must be tenant-scoped (missing org_id)");
}
#[test]
fn test_require_tenant_succeeds_with_org_id() {
let enforcer = TenantEnforcer::with_requirement(Some("org-123".to_string()), true);
let result = enforcer.enforce_tenant_scope(None);
result.unwrap_or_else(|e| {
panic!("expected Ok for enforce_tenant_scope with org_id present: {e}")
});
}
}