Skip to main content

modkit_db/secure/
select.rs

1use sea_orm::{
2    ColumnTrait, ConnectionTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder,
3    QuerySelect, sea_query::Expr,
4};
5use std::marker::PhantomData;
6
7use crate::secure::cond::build_scope_condition;
8use crate::secure::error::ScopeError;
9use crate::secure::{AccessScope, ScopableEntity};
10
11/// Typestate marker: query has not yet been scoped.
12/// Cannot execute queries in this state.
13#[derive(Debug, Clone, Copy)]
14pub struct Unscoped;
15
16/// Typestate marker: query has been scoped with access control.
17/// Can now execute queries safely.
18#[derive(Debug, Clone, Copy)]
19pub struct Scoped;
20
21/// A type-safe wrapper around `SeaORM`'s `Select` that enforces scoping.
22///
23/// This wrapper uses the typestate pattern to ensure that queries cannot
24/// be executed without first applying access control via `.scope_with()`.
25///
26/// # Type Parameters
27/// - `E`: The `SeaORM` entity type
28/// - `S`: The typestate (`Unscoped` or `Scoped`)
29///
30/// # Example
31/// ```rust,ignore
32/// use modkit_db::secure::{AccessScope, SecureEntityExt};
33///
34/// let scope = AccessScope::tenants_only(vec![tenant_id]);
35/// let users = user::Entity::find()
36///     .secure()           // Returns SecureSelect<E, Unscoped>
37///     .scope_with(&scope)? // Returns SecureSelect<E, Scoped>
38///     .all(conn)          // Now can execute
39///     .await?;
40/// ```
41#[must_use]
42#[derive(Clone, Debug)]
43pub struct SecureSelect<E: EntityTrait, S> {
44    pub(crate) inner: sea_orm::Select<E>,
45    pub(crate) _state: PhantomData<S>,
46}
47
48/// Extension trait to convert a regular `SeaORM` `Select` into a `SecureSelect`.
49pub trait SecureEntityExt<E: EntityTrait>: Sized {
50    /// Convert this select query into a secure (unscoped) select.
51    /// You must call `.scope_with()` before executing the query.
52    fn secure(self) -> SecureSelect<E, Unscoped>;
53}
54
55impl<E> SecureEntityExt<E> for sea_orm::Select<E>
56where
57    E: EntityTrait,
58{
59    fn secure(self) -> SecureSelect<E, Unscoped> {
60        SecureSelect {
61            inner: self,
62            _state: PhantomData,
63        }
64    }
65}
66
67// Methods available only on Unscoped queries
68impl<E> SecureSelect<E, Unscoped>
69where
70    E: ScopableEntity + EntityTrait,
71    E::Column: ColumnTrait + Copy,
72{
73    /// Apply access control scope to this query, transitioning to the `Scoped` state.
74    ///
75    /// This applies the implicit policy:
76    /// - Empty scope → deny all
77    /// - Tenants only → filter by tenant
78    /// - Resources only → filter by resource IDs
79    /// - Both → AND them together
80    ///
81    pub fn scope_with(self, scope: &AccessScope) -> SecureSelect<E, Scoped> {
82        let cond = build_scope_condition::<E>(scope);
83        SecureSelect {
84            inner: self.inner.filter(cond),
85            _state: PhantomData,
86        }
87    }
88}
89
90// Methods available only on Scoped queries
91impl<E> SecureSelect<E, Scoped>
92where
93    E: EntityTrait,
94{
95    /// Execute the query and return all matching results.
96    ///
97    /// # Errors
98    /// Returns `ScopeError::Db` if the database query fails.
99    #[allow(clippy::disallowed_methods)]
100    pub async fn all<C>(self, conn: &C) -> Result<Vec<E::Model>, ScopeError>
101    where
102        C: ConnectionTrait + Send + Sync,
103    {
104        Ok(self.inner.all(conn).await?)
105    }
106
107    /// Execute the query and return at most one result.
108    ///
109    /// # Errors
110    /// Returns `ScopeError::Db` if the database query fails.
111    #[allow(clippy::disallowed_methods)]
112    pub async fn one<C>(self, conn: &C) -> Result<Option<E::Model>, ScopeError>
113    where
114        C: ConnectionTrait + Send + Sync,
115    {
116        Ok(self.inner.one(conn).await?)
117    }
118
119    /// Execute the query and return the number of matching results.
120    ///
121    /// # Errors
122    /// Returns `ScopeError::Db` if the database query fails.
123    #[allow(clippy::disallowed_methods)]
124    pub async fn count<C>(self, conn: &C) -> Result<u64, ScopeError>
125    where
126        C: ConnectionTrait + Send + Sync,
127        E::Model: sea_orm::FromQueryResult + Send + Sync,
128    {
129        Ok(self.inner.count(conn).await?)
130    }
131
132    // Note: count() uses SeaORM's `PaginatorTrait::count` internally.
133
134    // Note: For pagination, use `into_inner().paginate()` due to complex lifetime bounds
135
136    /// Add an additional filter for a specific resource ID.
137    ///
138    /// This is useful when you want to further narrow a scoped query
139    /// to a single resource.
140    ///
141    /// # Example
142    /// ```ignore
143    /// let user = User::find()
144    ///     .secure()
145    ///     .scope_with(&scope)?
146    ///     .and_id(user_id)
147    ///     .one(conn)
148    ///     .await?;
149    /// ```
150    ///
151    /// # Errors
152    /// Returns `ScopeError::Invalid` if the entity doesn't have a resource column.
153    pub fn and_id(self, id: uuid::Uuid) -> Result<Self, ScopeError>
154    where
155        E: ScopableEntity,
156        E::Column: ColumnTrait + Copy,
157    {
158        let resource_col = E::resource_col().ok_or(ScopeError::Invalid(
159            "Entity must have a resource_col to use and_id()",
160        ))?;
161        let cond = sea_orm::Condition::all().add(Expr::col(resource_col).eq(id));
162        Ok(self.filter(cond))
163    }
164
165    /// Unwrap the inner `SeaORM` `Select` for advanced use cases.
166    ///
167    /// This is an escape hatch if you need to add additional filters,
168    /// joins, or ordering after scoping has been applied.
169    ///
170    /// # Safety
171    /// The caller must ensure they don't remove or bypass the security
172    /// conditions that were applied during `.scope_with()`.
173    #[must_use]
174    pub fn into_inner(self) -> sea_orm::Select<E> {
175        self.inner
176    }
177}
178
179// Allow further chaining on Scoped queries before execution
180impl<E> SecureSelect<E, Scoped>
181where
182    E: EntityTrait,
183{
184    /// Add additional filters to the scoped query.
185    /// The scope conditions remain in place.
186    pub fn filter(mut self, filter: sea_orm::Condition) -> Self {
187        self.inner = QueryFilter::filter(self.inner, filter);
188        self
189    }
190
191    /// Add ordering to the scoped query.
192    pub fn order_by<C>(mut self, col: C, order: sea_orm::Order) -> Self
193    where
194        C: sea_orm::IntoSimpleExpr,
195    {
196        self.inner = QueryOrder::order_by(self.inner, col, order);
197        self
198    }
199
200    /// Add a limit to the scoped query.
201    pub fn limit(mut self, limit: u64) -> Self {
202        self.inner = QuerySelect::limit(self.inner, limit);
203        self
204    }
205
206    /// Add an offset to the scoped query.
207    pub fn offset(mut self, offset: u64) -> Self {
208        self.inner = QuerySelect::offset(self.inner, offset);
209        self
210    }
211
212    /// Apply scoping for a joined entity.
213    ///
214    /// This is useful when you need to filter by tenant on a joined table.
215    ///
216    /// # Example
217    /// ```ignore
218    /// // Select orders, ensuring both Order and Customer match tenant scope
219    /// Order::find()
220    ///     .secure()
221    ///     .scope_with(&scope)?
222    ///     .and_scope_for::<customer::Entity>(&scope)
223    ///     .all(conn)
224    ///     .await?
225    /// ```
226    pub fn and_scope_for<J>(mut self, scope: &AccessScope) -> Self
227    where
228        J: ScopableEntity + EntityTrait,
229        J::Column: ColumnTrait + Copy,
230    {
231        if !scope.tenant_ids().is_empty()
232            && let Some(tcol) = J::tenant_col()
233        {
234            let condition = sea_orm::Condition::all()
235                .add(Expr::col((J::default(), tcol)).is_in(scope.tenant_ids().to_vec()));
236            self.inner = QueryFilter::filter(self.inner, condition);
237        }
238        self
239    }
240
241    /// Apply scoping via EXISTS subquery on a related entity.
242    ///
243    /// This is particularly useful when the base entity doesn't have a tenant column
244    /// but is related to one that does.
245    ///
246    /// # Note
247    /// This is a simplified version that filters by tenant on the joined entity.
248    /// For complex join predicates, use `into_inner()` and build custom EXISTS clauses.
249    ///
250    /// # Example
251    /// ```ignore
252    /// // Find settings that exist in a tenant-scoped relationship
253    /// GlobalSetting::find()
254    ///     .secure()
255    ///     .scope_with(&AccessScope::resources_only(vec![]))?
256    ///     .scope_via_exists::<TenantSetting>(&scope)
257    ///     .all(conn)
258    ///     .await?
259    /// ```
260    pub fn scope_via_exists<J>(mut self, scope: &AccessScope) -> Self
261    where
262        J: ScopableEntity + EntityTrait,
263        J::Column: ColumnTrait + Copy,
264    {
265        if !scope.tenant_ids().is_empty()
266            && let Some(tcol) = J::tenant_col()
267        {
268            // Build EXISTS clause with tenant filter on joined entity
269            use sea_orm::sea_query::Query;
270
271            let mut sub = Query::select();
272            sub.expr(Expr::value(1))
273                .from(J::default())
274                .cond_where(Expr::col((J::default(), tcol)).is_in(scope.tenant_ids().to_vec()));
275
276            self.inner =
277                QueryFilter::filter(self.inner, sea_orm::Condition::all().add(Expr::exists(sub)));
278        }
279        self
280    }
281}
282
283#[cfg(test)]
284#[cfg_attr(coverage_nightly, coverage(off))]
285mod tests {
286    use super::*;
287
288    // Note: Full integration tests with real SeaORM entities should be written
289    // in application code where actual entities are available.
290    // The typestate pattern is enforced at compile time.
291    //
292    // See USAGE_EXAMPLE.md for complete usage patterns.
293
294    #[test]
295    fn test_typestate_markers_exist() {
296        // This test verifies the typestate markers compile
297        // The actual enforcement happens at compile time
298        let unscoped: std::marker::PhantomData<Unscoped> = std::marker::PhantomData;
299        let scoped: std::marker::PhantomData<Scoped> = std::marker::PhantomData;
300        // Use the variables to avoid unused warnings
301        let _ = (unscoped, scoped);
302    }
303}