use crate::{
AccessMode, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbBackend, DbErr,
ExecResult, IsolationLevel, QueryResult, Statement, StatementBuilder, TransactionError,
TransactionSession, TransactionTrait,
};
use crate::{
TransactionOptions,
rbac::{
PermissionRequest, RbacEngine, RbacError, RbacPermissionsByResources,
RbacResourcesAndPermissions, RbacRoleHierarchyList, RbacRolesAndRanks,
RbacUserRolePermissions, ResourceRequest,
entity::{role::RoleId, user::UserId},
},
};
use std::{
pin::Pin,
sync::{Arc, RwLock},
};
use tracing::instrument;
#[derive(Debug, Clone)]
#[cfg_attr(docsrs, doc(cfg(feature = "rbac")))]
pub struct RestrictedConnection {
pub(crate) user_id: UserId,
pub(crate) conn: DatabaseConnection,
}
#[derive(Debug)]
pub struct RestrictedTransaction {
user_id: UserId,
conn: DatabaseTransaction,
rbac: RbacEngineMount,
}
#[derive(Debug, Default, Clone)]
pub(crate) struct RbacEngineMount {
inner: Arc<RwLock<Option<RbacEngine>>>,
}
impl ConnectionTrait for RestrictedConnection {
fn get_database_backend(&self) -> DbBackend {
self.conn.get_database_backend()
}
fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {stmt}"
)))
}
fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
self.user_can_run(stmt)?;
self.conn.execute(stmt)
}
fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {sql}"
)))
}
fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {stmt}"
)))
}
fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
self.user_can_run(stmt)?;
self.conn.query_one(stmt)
}
fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {stmt}"
)))
}
fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
self.user_can_run(stmt)?;
self.conn.query_all(stmt)
}
}
impl ConnectionTrait for RestrictedTransaction {
fn get_database_backend(&self) -> DbBackend {
self.conn.get_database_backend()
}
fn execute_raw(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {stmt}"
)))
}
fn execute<S: StatementBuilder>(&self, stmt: &S) -> Result<ExecResult, DbErr> {
self.user_can_run(stmt)?;
self.conn.execute(stmt)
}
fn execute_unprepared(&self, sql: &str) -> Result<ExecResult, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {sql}"
)))
}
fn query_one_raw(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {stmt}"
)))
}
fn query_one<S: StatementBuilder>(&self, stmt: &S) -> Result<Option<QueryResult>, DbErr> {
self.user_can_run(stmt)?;
self.conn.query_one(stmt)
}
fn query_all_raw(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
Err(DbErr::RbacError(format!(
"Raw query is not supported: {stmt}"
)))
}
fn query_all<S: StatementBuilder>(&self, stmt: &S) -> Result<Vec<QueryResult>, DbErr> {
self.user_can_run(stmt)?;
self.conn.query_all(stmt)
}
}
impl RestrictedConnection {
pub fn user_id(&self) -> UserId {
self.user_id
}
pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
self.conn.rbac.user_can_run(self.user_id, stmt)
}
pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
where
P: Into<PermissionRequest>,
R: Into<ResourceRequest>,
{
self.conn.rbac.user_can(self.user_id, permission, resource)
}
pub fn current_user_role_permissions(&self) -> Result<RbacUserRolePermissions, DbErr> {
self.conn.rbac.user_role_permissions(self.user_id)
}
pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
self.conn.rbac.roles_and_ranks()
}
pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
self.conn.rbac.resources_and_permissions()
}
pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
self.conn.rbac.role_hierarchy_edges(role_id)
}
pub fn role_permissions_by_resources(
&self,
role_id: RoleId,
) -> Result<RbacPermissionsByResources, DbErr> {
self.conn.rbac.role_permissions_by_resources(role_id)
}
}
impl RestrictedTransaction {
pub fn user_id(&self) -> UserId {
self.user_id
}
pub fn user_can_run<S: StatementBuilder>(&self, stmt: &S) -> Result<(), DbErr> {
self.rbac.user_can_run(self.user_id, stmt)
}
pub fn user_can<P, R>(&self, permission: P, resource: R) -> Result<bool, DbErr>
where
P: Into<PermissionRequest>,
R: Into<ResourceRequest>,
{
self.rbac.user_can(self.user_id, permission, resource)
}
}
impl TransactionTrait for RestrictedConnection {
type Transaction = RestrictedTransaction;
#[instrument(level = "trace")]
fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
Ok(RestrictedTransaction {
user_id: self.user_id,
conn: self.conn.begin()?,
rbac: self.conn.rbac.clone(),
})
}
#[instrument(level = "trace")]
fn begin_with_config(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<RestrictedTransaction, DbErr> {
Ok(RestrictedTransaction {
user_id: self.user_id,
conn: self.conn.begin_with_config(isolation_level, access_mode)?,
rbac: self.conn.rbac.clone(),
})
}
#[instrument(level = "trace")]
fn begin_with_options(
&self,
options: TransactionOptions,
) -> Result<RestrictedTransaction, DbErr> {
Ok(RestrictedTransaction {
user_id: self.user_id,
conn: self.conn.begin_with_options(options)?,
rbac: self.conn.rbac.clone(),
})
}
#[instrument(level = "trace", skip(callback))]
fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
E: std::fmt::Display + std::fmt::Debug,
{
let transaction = self.begin().map_err(TransactionError::Connection)?;
transaction.run(callback)
}
#[instrument(level = "trace", skip(callback))]
fn transaction_with_config<F, T, E>(
&self,
callback: F,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
E: std::fmt::Display + std::fmt::Debug,
{
let transaction = self
.begin_with_config(isolation_level, access_mode)
.map_err(TransactionError::Connection)?;
transaction.run(callback)
}
}
impl TransactionTrait for RestrictedTransaction {
type Transaction = RestrictedTransaction;
#[instrument(level = "trace")]
fn begin(&self) -> Result<RestrictedTransaction, DbErr> {
Ok(RestrictedTransaction {
user_id: self.user_id,
conn: self.conn.begin()?,
rbac: self.rbac.clone(),
})
}
#[instrument(level = "trace")]
fn begin_with_config(
&self,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<RestrictedTransaction, DbErr> {
Ok(RestrictedTransaction {
user_id: self.user_id,
conn: self.conn.begin_with_config(isolation_level, access_mode)?,
rbac: self.rbac.clone(),
})
}
#[instrument(level = "trace")]
fn begin_with_options(
&self,
options: TransactionOptions,
) -> Result<RestrictedTransaction, DbErr> {
Ok(RestrictedTransaction {
user_id: self.user_id,
conn: self.conn.begin_with_options(options)?,
rbac: self.rbac.clone(),
})
}
#[instrument(level = "trace", skip(callback))]
fn transaction<F, T, E>(&self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
E: std::fmt::Display + std::fmt::Debug,
{
let transaction = self.begin().map_err(TransactionError::Connection)?;
transaction.run(callback)
}
#[instrument(level = "trace", skip(callback))]
fn transaction_with_config<F, T, E>(
&self,
callback: F,
isolation_level: Option<IsolationLevel>,
access_mode: Option<AccessMode>,
) -> Result<T, TransactionError<E>>
where
F: for<'c> FnOnce(&'c RestrictedTransaction) -> Result<T, E>,
E: std::fmt::Display + std::fmt::Debug,
{
let transaction = self
.begin_with_config(isolation_level, access_mode)
.map_err(TransactionError::Connection)?;
transaction.run(callback)
}
}
impl TransactionSession for RestrictedTransaction {
fn commit(self) -> Result<(), DbErr> {
self.commit()
}
fn rollback(self) -> Result<(), DbErr> {
self.rollback()
}
}
impl RestrictedTransaction {
#[instrument(level = "trace", skip(callback))]
fn run<F, T, E>(self, callback: F) -> Result<T, TransactionError<E>>
where
F: for<'b> FnOnce(&'b RestrictedTransaction) -> Result<T, E>,
E: std::fmt::Display + std::fmt::Debug,
{
let res = callback(&self).map_err(TransactionError::Transaction);
if res.is_ok() {
self.commit().map_err(TransactionError::Connection)?;
} else {
self.rollback().map_err(TransactionError::Connection)?;
}
res
}
#[instrument(level = "trace")]
pub fn commit(self) -> Result<(), DbErr> {
self.conn.commit()
}
#[instrument(level = "trace")]
pub fn rollback(self) -> Result<(), DbErr> {
self.conn.rollback()
}
}
impl RbacEngineMount {
pub fn is_some(&self) -> bool {
let engine = self.inner.read().expect("RBAC Engine died");
engine.is_some()
}
pub fn replace(&self, engine: RbacEngine) {
let mut inner = self.inner.write().expect("RBAC Engine died");
*inner = Some(engine);
}
pub fn user_can<P, R>(&self, user_id: UserId, permission: P, resource: R) -> Result<bool, DbErr>
where
P: Into<PermissionRequest>,
R: Into<ResourceRequest>,
{
let permission = permission.into();
let resource = resource.into();
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
engine
.user_can(user_id, permission, resource)
.map_err(map_err)
}
pub fn user_can_run<S: StatementBuilder>(
&self,
user_id: UserId,
stmt: &S,
) -> Result<(), DbErr> {
let audit = match stmt.audit() {
Ok(audit) => audit,
Err(err) => return Err(DbErr::RbacError(err.to_string())),
};
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
for request in audit.requests {
let permission = || PermissionRequest {
action: request.access_type.as_str().to_owned(),
};
let resource = || ResourceRequest {
schema: request.schema_table.0.as_ref().map(|s| s.1.to_string()),
table: request.schema_table.1.to_string(),
};
if !engine
.user_can(user_id, permission(), resource())
.map_err(map_err)?
{
return Err(DbErr::AccessDenied {
permission: permission().action.to_owned(),
resource: resource().to_string(),
});
}
}
Ok(())
}
pub fn user_role_permissions(&self, user_id: UserId) -> Result<RbacUserRolePermissions, DbErr> {
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
engine
.get_user_role_permissions(user_id)
.map_err(|err| DbErr::RbacError(err.to_string()))
}
pub fn roles_and_ranks(&self) -> Result<RbacRolesAndRanks, DbErr> {
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
engine
.get_roles_and_ranks()
.map_err(|err| DbErr::RbacError(err.to_string()))
}
pub fn resources_and_permissions(&self) -> Result<RbacResourcesAndPermissions, DbErr> {
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
Ok(engine.list_resources_and_permissions())
}
pub fn role_hierarchy_edges(&self, role_id: RoleId) -> Result<RbacRoleHierarchyList, DbErr> {
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
Ok(engine.list_role_hierarchy_edges(role_id))
}
pub fn role_permissions_by_resources(
&self,
role_id: RoleId,
) -> Result<RbacPermissionsByResources, DbErr> {
let holder = self.inner.read().expect("RBAC Engine died");
let engine = holder.as_ref().expect("RBAC Engine not set");
engine
.list_role_permissions_by_resources(role_id)
.map_err(|err| DbErr::RbacError(err.to_string()))
}
}
fn map_err(err: RbacError) -> DbErr {
DbErr::RbacError(err.to_string())
}