use std::collections::BTreeSet;
use crate::ast::Action;
use crate::rls::SuperAdminToken;
use super::ident::normalize_column_name;
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
#[serde(rename_all = "lowercase")]
pub enum AccessOperation {
Read,
Create,
Update,
Delete,
}
impl AccessOperation {
pub fn required_for_action(action: Action) -> Option<&'static [Self]> {
match action {
Action::Get
| Action::Cnt
| Action::Export
| Action::With
| Action::Search
| Action::Scroll => Some(&[Self::Read]),
Action::Add => Some(&[Self::Create]),
Action::Set | Action::Put | Action::Over => Some(&[Self::Update]),
Action::Upsert => Some(&[Self::Create, Self::Update]),
Action::Del => Some(&[Self::Delete]),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AccessContext {
pub subject_id: Option<String>,
pub tenant_id: Option<String>,
pub roles: BTreeSet<String>,
pub scopes: BTreeSet<String>,
bypass: bool,
}
impl AccessContext {
pub fn anonymous() -> Self {
Self {
subject_id: None,
tenant_id: None,
roles: BTreeSet::new(),
scopes: BTreeSet::new(),
bypass: false,
}
}
pub fn subject(subject_id: impl Into<String>) -> Self {
Self {
subject_id: Some(subject_id.into()),
..Self::anonymous()
}
}
pub fn super_admin(_token: SuperAdminToken) -> Self {
Self {
bypass: true,
..Self::anonymous()
}
}
pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn with_role(mut self, role: impl Into<String>) -> Self {
self.roles.insert(role.into());
self
}
pub fn with_roles<I, S>(mut self, roles: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.roles.extend(roles.into_iter().map(Into::into));
self
}
pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.insert(scope.into());
self
}
pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.scopes.extend(scopes.into_iter().map(Into::into));
self
}
pub fn bypasses_access(&self) -> bool {
self.bypass
}
pub(super) fn has_any_role(&self, required: &BTreeSet<String>) -> bool {
required.is_empty() || required.iter().any(|role| self.roles.contains(role))
}
pub(super) fn has_all_scopes(&self, required: &BTreeSet<String>) -> bool {
required.is_subset(&self.scopes)
}
}
impl Default for AccessContext {
fn default() -> Self {
Self::anonymous()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AccessDecision {
Allow,
Deny,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ColumnRule {
#[default]
Any,
DenyAll,
Only(BTreeSet<String>),
Except(BTreeSet<String>),
}
impl ColumnRule {
pub fn only<I, S>(columns: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self::Only(columns.into_iter().map(normalize_column_name).collect())
}
pub fn except<I, S>(columns: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self::Except(columns.into_iter().map(normalize_column_name).collect())
}
pub fn is_restrictive(&self) -> bool {
!matches!(self, Self::Any)
}
pub fn allows(&self, column: &str) -> bool {
let normalized = normalize_column_name(column);
match self {
Self::Any => true,
Self::DenyAll => false,
Self::Only(columns) => columns.contains(&normalized),
Self::Except(columns) => !columns.contains(&normalized),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct TableAccessPolicy {
#[serde(default)]
pub operations: BTreeSet<AccessOperation>,
#[serde(default)]
pub denied_operations: BTreeSet<AccessOperation>,
#[serde(default)]
pub read_columns: ColumnRule,
#[serde(default)]
pub write_columns: ColumnRule,
#[serde(default)]
pub returning_columns: ColumnRule,
#[serde(default)]
pub require_any_role: BTreeSet<String>,
#[serde(default)]
pub require_scopes: BTreeSet<String>,
}
impl TableAccessPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn allow_operations<I>(mut self, operations: I) -> Self
where
I: IntoIterator<Item = AccessOperation>,
{
self.operations.extend(operations);
self
}
pub fn deny_operations<I>(mut self, operations: I) -> Self
where
I: IntoIterator<Item = AccessOperation>,
{
self.denied_operations.extend(operations);
self
}
pub fn read_columns(mut self, rule: ColumnRule) -> Self {
self.read_columns = rule;
self
}
pub fn write_columns(mut self, rule: ColumnRule) -> Self {
self.write_columns = rule;
self
}
pub fn returning_columns(mut self, rule: ColumnRule) -> Self {
self.returning_columns = rule;
self
}
pub fn require_any_role<I, S>(mut self, roles: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.require_any_role
.extend(roles.into_iter().map(Into::into));
self
}
pub fn require_scopes<I, S>(mut self, scopes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.require_scopes
.extend(scopes.into_iter().map(Into::into));
self
}
pub(super) fn allows_operation(&self, operation: AccessOperation) -> bool {
self.operations.contains(&operation) && !self.denied_operations.contains(&operation)
}
}
impl Default for TableAccessPolicy {
fn default() -> Self {
Self {
operations: BTreeSet::new(),
denied_operations: BTreeSet::new(),
read_columns: ColumnRule::Any,
write_columns: ColumnRule::Any,
returning_columns: ColumnRule::Any,
require_any_role: BTreeSet::new(),
require_scopes: BTreeSet::new(),
}
}
}