1use crate::database::{QueryBuilder, Scope};
13use crate::tenant::context::current_tenant;
14use sea_orm::{ColumnTrait, EntityTrait};
15
16pub struct TenantScope<C: ColumnTrait>(pub C);
38
39impl<E, C> Scope<E> for TenantScope<C>
40where
41 E: EntityTrait,
42 E::Model: Send + Sync,
43 C: ColumnTrait,
44{
45 fn apply(self, query: QueryBuilder<E>) -> QueryBuilder<E> {
46 let ctx = current_tenant().expect(
47 "TenantScope used outside TenantMiddleware scope — ensure this route is behind TenantMiddleware",
48 );
49 query.filter(self.0.eq(ctx.id))
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use crate::database::Scope as ScopeTrait;
57 use crate::tenant::context::{current_tenant, tenant_scope, with_tenant_scope};
58 use crate::tenant::TenantContext;
59 use sea_orm::{DbBackend, QueryTrait, Value};
60
61 mod post {
64 use sea_orm::entity::prelude::*;
65
66 #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
67 #[sea_orm(table_name = "posts")]
68 pub struct Model {
69 #[sea_orm(primary_key)]
70 pub id: i64,
71 pub tenant_id: i64,
72 pub title: String,
73 }
74
75 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
76 pub enum Relation {}
77
78 impl ActiveModelBehavior for ActiveModel {}
79 }
80
81 fn make_tenant(id: i64) -> TenantContext {
82 TenantContext {
83 id,
84 slug: format!("tenant-{id}"),
85 name: format!("Tenant {id}"),
86 plan: None,
87 #[cfg(feature = "stripe")]
88 subscription: None,
89 }
90 }
91
92 fn statement_from_query(query: QueryBuilder<post::Entity>) -> sea_orm::Statement {
93 query.into_select().build(DbBackend::Sqlite)
94 }
95
96 #[tokio::test]
98 async fn tenant_scope_apply_adds_tenant_id_filter() {
99 let ctx = tenant_scope();
100 {
101 let mut guard = ctx.write().await;
102 *guard = Some(make_tenant(42));
103 }
104
105 let stmt = with_tenant_scope(ctx, async {
106 let builder: QueryBuilder<post::Entity> = QueryBuilder::new();
107 let scoped = TenantScope(post::Column::TenantId).apply(builder);
108 statement_from_query(scoped)
109 })
110 .await;
111
112 assert!(
113 stmt.sql.contains("tenant_id"),
114 "Expected SQL to contain tenant_id filter, got: {}",
115 stmt.sql
116 );
117 assert!(
119 stmt.sql.contains("WHERE"),
120 "Expected SQL to have WHERE clause, got: {}",
121 stmt.sql
122 );
123 let values = stmt.values.expect("expected bound values");
125 assert!(
126 values
127 .0
128 .iter()
129 .any(|v| matches!(v, Value::BigInt(Some(42)))),
130 "Expected bound value 42i64, got: {:?}",
131 values.0
132 );
133 }
134
135 #[test]
137 #[should_panic(expected = "TenantScope used outside TenantMiddleware scope")]
138 fn tenant_scope_panics_outside_middleware_scope() {
139 let builder: QueryBuilder<post::Entity> = QueryBuilder::new();
140 TenantScope(post::Column::TenantId).apply(builder);
141 }
142
143 #[tokio::test]
145 async fn tenant_scope_is_generic_over_column_type() {
146 let ctx = tenant_scope();
147 {
148 let mut guard = ctx.write().await;
149 *guard = Some(make_tenant(7));
150 }
151
152 let stmt = with_tenant_scope(ctx, async {
154 let builder: QueryBuilder<post::Entity> = QueryBuilder::new();
155 let scoped = TenantScope(post::Column::Id).apply(builder);
156 statement_from_query(scoped)
157 })
158 .await;
159
160 assert!(
161 stmt.sql.contains("\"id\""),
162 "Expected SQL to filter on id column, got: {}",
163 stmt.sql
164 );
165 let values = stmt.values.expect("expected bound values");
166 assert!(
167 values.0.iter().any(|v| matches!(v, Value::BigInt(Some(7)))),
168 "Expected bound value 7i64, got: {:?}",
169 values.0
170 );
171 }
172
173 #[tokio::test(flavor = "multi_thread")]
175 async fn concurrent_tasks_get_isolated_tenant_scopes() {
176 let ctx1 = tenant_scope();
177 let ctx2 = tenant_scope();
178 {
179 let mut g1 = ctx1.write().await;
180 *g1 = Some(make_tenant(100));
181 }
182 {
183 let mut g2 = ctx2.write().await;
184 *g2 = Some(make_tenant(200));
185 }
186
187 let (result1, result2) = tokio::join!(
188 with_tenant_scope(ctx1, async { current_tenant().map(|t| t.id) }),
189 with_tenant_scope(ctx2, async { current_tenant().map(|t| t.id) }),
190 );
191
192 assert_eq!(result1, Some(100), "Task 1 should see tenant 100");
193 assert_eq!(result2, Some(200), "Task 2 should see tenant 200");
194 }
195}