use sea_orm::{
ColumnTrait, EntityTrait, ModelTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect,
Related, sea_query::Expr,
};
use std::sync::Arc;
use crate::secure::cond::build_scope_condition;
use crate::secure::error::ScopeError;
use crate::secure::{AccessScope, DBRunner, DBRunnerInternal, ScopableEntity, SeaOrmRunner};
#[derive(Debug, Clone, Copy)]
pub struct Unscoped;
#[derive(Debug, Clone)]
pub struct Scoped {
scope: Arc<AccessScope>,
}
#[must_use]
#[derive(Clone, Debug)]
pub struct SecureSelect<E: EntityTrait, S> {
pub(crate) inner: sea_orm::Select<E>,
pub(crate) state: S,
}
#[must_use]
#[derive(Clone, Debug)]
pub struct SecureSelectTwo<E: EntityTrait, F: EntityTrait, S> {
pub(crate) inner: sea_orm::SelectTwo<E, F>,
pub(crate) state: S,
}
#[must_use]
#[derive(Clone, Debug)]
pub struct SecureSelectTwoMany<E: EntityTrait, F: EntityTrait, S> {
pub(crate) inner: sea_orm::SelectTwoMany<E, F>,
pub(crate) state: S,
}
pub trait SecureEntityExt<E: EntityTrait>: Sized {
fn secure(self) -> SecureSelect<E, Unscoped>;
}
impl<E> SecureEntityExt<E> for sea_orm::Select<E>
where
E: EntityTrait,
{
fn secure(self) -> SecureSelect<E, Unscoped> {
SecureSelect {
inner: self,
state: Unscoped,
}
}
}
impl<E> SecureSelect<E, Unscoped>
where
E: ScopableEntity + EntityTrait,
E::Column: ColumnTrait + Copy,
{
pub fn scope_with(self, scope: &AccessScope) -> SecureSelect<E, Scoped> {
let cond = build_scope_condition::<E>(scope);
SecureSelect {
inner: self.inner.filter(cond),
state: Scoped {
scope: Arc::new(scope.clone()),
},
}
}
pub fn scope_with_arc(self, scope: Arc<AccessScope>) -> SecureSelect<E, Scoped> {
let cond = build_scope_condition::<E>(&scope);
SecureSelect {
inner: self.inner.filter(cond),
state: Scoped { scope },
}
}
}
impl<E> SecureSelect<E, Scoped>
where
E: EntityTrait,
{
#[allow(clippy::disallowed_methods)]
pub async fn all(self, runner: &impl DBRunner) -> Result<Vec<E::Model>, ScopeError> {
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(self.inner.all(db).await?),
SeaOrmRunner::Tx(tx) => Ok(self.inner.all(tx).await?),
}
}
#[allow(clippy::disallowed_methods)]
pub async fn one(self, runner: &impl DBRunner) -> Result<Option<E::Model>, ScopeError> {
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(self.inner.one(db).await?),
SeaOrmRunner::Tx(tx) => Ok(self.inner.one(tx).await?),
}
}
#[allow(clippy::disallowed_methods)]
pub async fn count(self, runner: &impl DBRunner) -> Result<u64, ScopeError>
where
E::Model: sea_orm::FromQueryResult + Send + Sync,
{
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(self.inner.count(db).await?),
SeaOrmRunner::Tx(tx) => Ok(self.inner.count(tx).await?),
}
}
pub fn and_id(self, id: uuid::Uuid) -> Result<Self, ScopeError>
where
E: ScopableEntity,
E::Column: ColumnTrait + Copy,
{
let resource_col = E::resource_col().ok_or(ScopeError::Invalid(
"Entity must have a resource_col to use and_id()",
))?;
let cond = sea_orm::Condition::all().add(Expr::col(resource_col).eq(id));
Ok(self.filter(cond))
}
}
impl<E> SecureSelect<E, Scoped>
where
E: EntityTrait,
{
pub fn filter(mut self, filter: sea_orm::Condition) -> Self {
self.inner = QueryFilter::filter(self.inner, filter);
self
}
pub fn order_by<C>(mut self, col: C, order: sea_orm::Order) -> Self
where
C: sea_orm::IntoSimpleExpr,
{
self.inner = QueryOrder::order_by(self.inner, col, order);
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.inner = QuerySelect::limit(self.inner, limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.inner = QuerySelect::offset(self.inner, offset);
self
}
pub fn and_scope_for<J>(mut self, scope: &AccessScope) -> Self
where
J: ScopableEntity + EntityTrait,
J::Column: ColumnTrait + Copy,
{
let cond = build_scope_condition::<J>(scope);
self.inner = QueryFilter::filter(self.inner, cond);
self
}
pub fn scope_via_exists<J>(mut self, scope: &AccessScope) -> Self
where
J: ScopableEntity + EntityTrait,
J::Column: ColumnTrait + Copy,
{
use sea_orm::sea_query::Query;
let cond = build_scope_condition::<J>(scope);
let mut sub = Query::select();
sub.expr(Expr::value(1)).from(J::default()).cond_where(cond);
self.inner =
QueryFilter::filter(self.inner, sea_orm::Condition::all().add(Expr::exists(sub)));
self
}
#[allow(clippy::disallowed_methods)]
pub async fn project_all<T, C, F>(self, runner: &C, project: F) -> Result<Vec<T>, ScopeError>
where
T: sea_orm::FromQueryResult + Send + Sync,
C: DBRunner,
F: FnOnce(sea_orm::Select<E>) -> sea_orm::Selector<sea_orm::SelectModel<T>>,
{
let selector = project(self.inner);
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(selector.all(db).await?),
SeaOrmRunner::Tx(tx) => Ok(selector.all(tx).await?),
}
}
#[must_use]
pub fn into_inner(self) -> sea_orm::Select<E> {
self.inner
}
}
fn apply_related_scope<R>(scope: &AccessScope) -> Option<sea_orm::Condition>
where
R: ScopableEntity + EntityTrait,
R::Column: ColumnTrait + Copy,
{
if scope.is_unconstrained() {
return None;
}
Some(build_scope_condition::<R>(scope))
}
impl<E> SecureSelect<E, Scoped>
where
E: EntityTrait,
{
#[must_use]
pub fn scope(&self) -> &AccessScope {
&self.state.scope
}
#[must_use]
pub fn scope_arc(&self) -> Arc<AccessScope> {
Arc::clone(&self.state.scope)
}
pub fn find_also_related<R>(self, r: R) -> SecureSelectTwo<E, R, Scoped>
where
R: ScopableEntity + EntityTrait,
R::Column: ColumnTrait + Copy,
E: Related<R>,
{
let select_two = self.inner.find_also_related(r);
let select_two = if let Some(cond) = apply_related_scope::<R>(&self.state.scope) {
QueryFilter::filter(select_two, cond)
} else {
select_two
};
SecureSelectTwo {
inner: select_two,
state: self.state,
}
}
pub fn find_with_related<R>(self, r: R) -> SecureSelectTwoMany<E, R, Scoped>
where
R: ScopableEntity + EntityTrait,
R::Column: ColumnTrait + Copy,
E: Related<R>,
{
let select_two_many = self.inner.find_with_related(r);
let select_two_many = if let Some(cond) = apply_related_scope::<R>(&self.state.scope) {
QueryFilter::filter(select_two_many, cond)
} else {
select_two_many
};
SecureSelectTwoMany {
inner: select_two_many,
state: self.state,
}
}
}
impl<E, F> SecureSelectTwo<E, F, Scoped>
where
E: EntityTrait,
F: EntityTrait,
{
#[must_use]
pub fn scope(&self) -> &AccessScope {
&self.state.scope
}
#[must_use]
pub fn scope_arc(&self) -> Arc<AccessScope> {
Arc::clone(&self.state.scope)
}
#[allow(clippy::disallowed_methods)]
pub async fn all(
self,
runner: &impl DBRunner,
) -> Result<Vec<(E::Model, Option<F::Model>)>, ScopeError> {
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(self.inner.all(db).await?),
SeaOrmRunner::Tx(tx) => Ok(self.inner.all(tx).await?),
}
}
#[allow(clippy::disallowed_methods)]
pub async fn one(
self,
runner: &impl DBRunner,
) -> Result<Option<(E::Model, Option<F::Model>)>, ScopeError> {
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(self.inner.one(db).await?),
SeaOrmRunner::Tx(tx) => Ok(self.inner.one(tx).await?),
}
}
pub fn filter(mut self, filter: sea_orm::Condition) -> Self {
self.inner = QueryFilter::filter(self.inner, filter);
self
}
pub fn order_by<C>(mut self, col: C, order: sea_orm::Order) -> Self
where
C: sea_orm::IntoSimpleExpr,
{
self.inner = QueryOrder::order_by(self.inner, col, order);
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.inner = QuerySelect::limit(self.inner, limit);
self
}
#[must_use]
pub fn into_inner(self) -> sea_orm::SelectTwo<E, F> {
self.inner
}
}
impl<E, F> SecureSelectTwoMany<E, F, Scoped>
where
E: EntityTrait,
F: EntityTrait,
{
#[must_use]
pub fn scope(&self) -> &AccessScope {
&self.state.scope
}
#[must_use]
pub fn scope_arc(&self) -> Arc<AccessScope> {
Arc::clone(&self.state.scope)
}
#[allow(clippy::disallowed_methods)]
pub async fn all(
self,
runner: &impl DBRunner,
) -> Result<Vec<(E::Model, Vec<F::Model>)>, ScopeError> {
match DBRunnerInternal::as_seaorm(runner) {
SeaOrmRunner::Conn(db) => Ok(self.inner.all(db).await?),
SeaOrmRunner::Tx(tx) => Ok(self.inner.all(tx).await?),
}
}
pub fn filter(mut self, filter: sea_orm::Condition) -> Self {
self.inner = QueryFilter::filter(self.inner, filter);
self
}
pub fn order_by<C>(mut self, col: C, order: sea_orm::Order) -> Self
where
C: sea_orm::IntoSimpleExpr,
{
self.inner = QueryOrder::order_by(self.inner, col, order);
self
}
#[must_use]
pub fn into_inner(self) -> sea_orm::SelectTwoMany<E, F> {
self.inner
}
}
pub trait SecureFindRelatedExt: ModelTrait {
fn secure_find_related<R>(&self, r: R, scope: &AccessScope) -> SecureSelect<R, Scoped>
where
R: ScopableEntity + EntityTrait,
R::Column: ColumnTrait + Copy,
Self::Entity: Related<R>;
}
impl<M> SecureFindRelatedExt for M
where
M: ModelTrait,
{
fn secure_find_related<R>(&self, r: R, scope: &AccessScope) -> SecureSelect<R, Scoped>
where
R: ScopableEntity + EntityTrait,
R::Column: ColumnTrait + Copy,
Self::Entity: Related<R>,
{
let select = self.find_related(r);
select.secure().scope_with(scope)
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use modkit_security::pep_properties;
#[test]
fn test_typestate_markers_exist() {
let unscoped = Unscoped;
assert!(std::mem::size_of_val(&unscoped) == 0);
let scope = AccessScope::default();
let scoped = Scoped {
scope: Arc::new(scope),
};
assert!(!scoped.scope.has_property(pep_properties::OWNER_TENANT_ID)); }
#[test]
fn test_scoped_state_holds_scope() {
let tenant_id = uuid::Uuid::new_v4();
let scope = AccessScope::for_tenants(vec![tenant_id]);
let scoped = Scoped {
scope: Arc::new(scope),
};
assert!(scoped.scope.has_property(pep_properties::OWNER_TENANT_ID));
assert_eq!(
scoped
.scope
.all_values_for(pep_properties::OWNER_TENANT_ID)
.len(),
1
);
assert!(
scoped
.scope
.all_uuid_values_for(pep_properties::OWNER_TENANT_ID)
.contains(&tenant_id)
);
}
#[test]
fn test_scoped_state_is_cloneable() {
let scope = AccessScope::for_tenants(vec![uuid::Uuid::new_v4()]);
let scoped = Scoped {
scope: Arc::new(scope),
};
let cloned = scoped.clone();
assert!(Arc::ptr_eq(&scoped.scope, &cloned.scope));
}
}