prax_query/tenant/
strategy.rs

1//! Tenant isolation strategies.
2
3use std::collections::HashSet;
4
5/// The isolation strategy for multi-tenancy.
6#[derive(Debug, Clone)]
7pub enum IsolationStrategy {
8    /// Row-level security: all tenants share tables, filtered by column.
9    RowLevel(RowLevelConfig),
10    /// Schema-based: each tenant has their own schema.
11    Schema(SchemaConfig),
12    /// Database-based: each tenant has their own database.
13    Database(DatabaseConfig),
14    /// Hybrid: combination of strategies (e.g., schema + row-level).
15    Hybrid(Box<IsolationStrategy>, Box<IsolationStrategy>),
16}
17
18impl IsolationStrategy {
19    /// Create a row-level isolation strategy.
20    pub fn row_level(column: impl Into<String>) -> Self {
21        Self::RowLevel(RowLevelConfig::new(column))
22    }
23
24    /// Create a schema-based isolation strategy.
25    pub fn schema_based() -> Self {
26        Self::Schema(SchemaConfig::default())
27    }
28
29    /// Create a database-based isolation strategy.
30    pub fn database_based() -> Self {
31        Self::Database(DatabaseConfig::default())
32    }
33
34    /// Check if this is row-level isolation.
35    pub fn is_row_level(&self) -> bool {
36        matches!(self, Self::RowLevel(_))
37    }
38
39    /// Check if this is schema-based isolation.
40    pub fn is_schema_based(&self) -> bool {
41        matches!(self, Self::Schema(_))
42    }
43
44    /// Check if this is database-based isolation.
45    pub fn is_database_based(&self) -> bool {
46        matches!(self, Self::Database(_))
47    }
48
49    /// Get the row-level config if applicable.
50    pub fn row_level_config(&self) -> Option<&RowLevelConfig> {
51        match self {
52            Self::RowLevel(config) => Some(config),
53            Self::Hybrid(a, b) => a.row_level_config().or_else(|| b.row_level_config()),
54            _ => None,
55        }
56    }
57
58    /// Get the schema config if applicable.
59    pub fn schema_config(&self) -> Option<&SchemaConfig> {
60        match self {
61            Self::Schema(config) => Some(config),
62            Self::Hybrid(a, b) => a.schema_config().or_else(|| b.schema_config()),
63            _ => None,
64        }
65    }
66
67    /// Get the database config if applicable.
68    pub fn database_config(&self) -> Option<&DatabaseConfig> {
69        match self {
70            Self::Database(config) => Some(config),
71            Self::Hybrid(a, b) => a.database_config().or_else(|| b.database_config()),
72            _ => None,
73        }
74    }
75}
76
77/// Configuration for row-level tenant isolation.
78#[derive(Debug, Clone)]
79pub struct RowLevelConfig {
80    /// The column name that stores the tenant ID.
81    pub column: String,
82    /// The column type (for type-safe comparisons).
83    pub column_type: ColumnType,
84    /// Tables that should be excluded from tenant filtering.
85    pub excluded_tables: HashSet<String>,
86    /// Tables that are shared across all tenants.
87    pub shared_tables: HashSet<String>,
88    /// Whether to automatically add tenant_id to INSERT statements.
89    pub auto_insert: bool,
90    /// Whether to validate tenant_id on UPDATE/DELETE.
91    pub validate_writes: bool,
92    /// Whether to use database-level RLS (PostgreSQL).
93    pub use_database_rls: bool,
94}
95
96impl RowLevelConfig {
97    /// Create a new row-level config with the given column name.
98    pub fn new(column: impl Into<String>) -> Self {
99        Self {
100            column: column.into(),
101            column_type: ColumnType::String,
102            excluded_tables: HashSet::new(),
103            shared_tables: HashSet::new(),
104            auto_insert: true,
105            validate_writes: true,
106            use_database_rls: false,
107        }
108    }
109
110    /// Set the column type.
111    pub fn with_column_type(mut self, column_type: ColumnType) -> Self {
112        self.column_type = column_type;
113        self
114    }
115
116    /// Exclude a table from tenant filtering.
117    pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
118        self.excluded_tables.insert(table.into());
119        self
120    }
121
122    /// Mark a table as shared (no tenant filtering).
123    pub fn shared_table(mut self, table: impl Into<String>) -> Self {
124        self.shared_tables.insert(table.into());
125        self
126    }
127
128    /// Disable automatic tenant_id insertion.
129    pub fn without_auto_insert(mut self) -> Self {
130        self.auto_insert = false;
131        self
132    }
133
134    /// Disable write validation.
135    pub fn without_write_validation(mut self) -> Self {
136        self.validate_writes = false;
137        self
138    }
139
140    /// Enable PostgreSQL database-level RLS.
141    pub fn with_database_rls(mut self) -> Self {
142        self.use_database_rls = true;
143        self
144    }
145
146    /// Check if a table should be filtered.
147    pub fn should_filter(&self, table: &str) -> bool {
148        !self.excluded_tables.contains(table) && !self.shared_tables.contains(table)
149    }
150}
151
152/// The type of the tenant column.
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
154pub enum ColumnType {
155    /// String/VARCHAR/TEXT column.
156    #[default]
157    String,
158    /// UUID column.
159    Uuid,
160    /// Integer column.
161    Integer,
162    /// BigInt column.
163    BigInt,
164}
165
166impl ColumnType {
167    /// Get the SQL placeholder for this column type.
168    pub fn placeholder(&self, index: usize) -> String {
169        format!("${}", index)
170    }
171
172    /// Format a value for this column type.
173    pub fn format_value(&self, value: &str) -> String {
174        match self {
175            Self::String => format!("'{}'", value.replace('\'', "''")),
176            Self::Uuid => format!("'{}'::uuid", value),
177            Self::Integer | Self::BigInt => value.to_string(),
178        }
179    }
180}
181
182/// Configuration for schema-based tenant isolation.
183#[derive(Debug, Clone, Default)]
184pub struct SchemaConfig {
185    /// Prefix for tenant schema names (e.g., "tenant_" -> "tenant_acme").
186    pub schema_prefix: Option<String>,
187    /// Suffix for tenant schema names.
188    pub schema_suffix: Option<String>,
189    /// Name of the shared schema for common tables.
190    pub shared_schema: Option<String>,
191    /// Whether to create schemas automatically.
192    pub auto_create: bool,
193    /// Default schema for new tenants.
194    pub default_schema: Option<String>,
195    /// Schema search path format.
196    pub search_path_format: SearchPathFormat,
197}
198
199impl SchemaConfig {
200    /// Set the schema prefix.
201    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
202        self.schema_prefix = Some(prefix.into());
203        self
204    }
205
206    /// Set the schema suffix.
207    pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
208        self.schema_suffix = Some(suffix.into());
209        self
210    }
211
212    /// Set the shared schema name.
213    pub fn with_shared_schema(mut self, schema: impl Into<String>) -> Self {
214        self.shared_schema = Some(schema.into());
215        self
216    }
217
218    /// Enable auto-creation of schemas.
219    pub fn with_auto_create(mut self) -> Self {
220        self.auto_create = true;
221        self
222    }
223
224    /// Set the default schema.
225    pub fn with_default_schema(mut self, schema: impl Into<String>) -> Self {
226        self.default_schema = Some(schema.into());
227        self
228    }
229
230    /// Set the search path format.
231    pub fn with_search_path(mut self, format: SearchPathFormat) -> Self {
232        self.search_path_format = format;
233        self
234    }
235
236    /// Generate the schema name for a tenant.
237    pub fn schema_name(&self, tenant_id: &str) -> String {
238        let mut name = String::new();
239        if let Some(prefix) = &self.schema_prefix {
240            name.push_str(prefix);
241        }
242        name.push_str(tenant_id);
243        if let Some(suffix) = &self.schema_suffix {
244            name.push_str(suffix);
245        }
246        name
247    }
248
249    /// Generate the search_path SQL for a tenant.
250    pub fn search_path(&self, tenant_id: &str) -> String {
251        let tenant_schema = self.schema_name(tenant_id);
252        match self.search_path_format {
253            SearchPathFormat::TenantOnly => {
254                format!("SET search_path TO {}", tenant_schema)
255            }
256            SearchPathFormat::TenantFirst => {
257                if let Some(shared) = &self.shared_schema {
258                    format!("SET search_path TO {}, {}", tenant_schema, shared)
259                } else {
260                    format!("SET search_path TO {}, public", tenant_schema)
261                }
262            }
263            SearchPathFormat::SharedFirst => {
264                if let Some(shared) = &self.shared_schema {
265                    format!("SET search_path TO {}, {}", shared, tenant_schema)
266                } else {
267                    format!("SET search_path TO public, {}", tenant_schema)
268                }
269            }
270        }
271    }
272}
273
274/// Format for the schema search path.
275#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
276pub enum SearchPathFormat {
277    /// Only include the tenant schema.
278    TenantOnly,
279    /// Tenant schema first, then shared.
280    #[default]
281    TenantFirst,
282    /// Shared schema first, then tenant.
283    SharedFirst,
284}
285
286/// Configuration for database-based tenant isolation.
287#[derive(Debug, Clone, Default)]
288pub struct DatabaseConfig {
289    /// Prefix for tenant database names.
290    pub database_prefix: Option<String>,
291    /// Suffix for tenant database names.
292    pub database_suffix: Option<String>,
293    /// Whether to create databases automatically.
294    pub auto_create: bool,
295    /// Template database for new tenant databases.
296    pub template_database: Option<String>,
297    /// Connection pool size per tenant.
298    pub pool_size_per_tenant: usize,
299    /// Maximum number of tenant connections to keep.
300    pub max_tenant_connections: usize,
301}
302
303impl DatabaseConfig {
304    /// Set the database prefix.
305    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
306        self.database_prefix = Some(prefix.into());
307        self
308    }
309
310    /// Set the database suffix.
311    pub fn with_suffix(mut self, suffix: impl Into<String>) -> Self {
312        self.database_suffix = Some(suffix.into());
313        self
314    }
315
316    /// Enable auto-creation of databases.
317    pub fn with_auto_create(mut self) -> Self {
318        self.auto_create = true;
319        self
320    }
321
322    /// Set the template database.
323    pub fn with_template(mut self, template: impl Into<String>) -> Self {
324        self.template_database = Some(template.into());
325        self
326    }
327
328    /// Set the pool size per tenant.
329    pub fn with_pool_size(mut self, size: usize) -> Self {
330        self.pool_size_per_tenant = size;
331        self
332    }
333
334    /// Set the maximum tenant connections.
335    pub fn with_max_connections(mut self, max: usize) -> Self {
336        self.max_tenant_connections = max;
337        self
338    }
339
340    /// Generate the database name for a tenant.
341    pub fn database_name(&self, tenant_id: &str) -> String {
342        let mut name = String::new();
343        if let Some(prefix) = &self.database_prefix {
344            name.push_str(prefix);
345        }
346        name.push_str(tenant_id);
347        if let Some(suffix) = &self.database_suffix {
348            name.push_str(suffix);
349        }
350        name
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_row_level_config() {
360        let config = RowLevelConfig::new("tenant_id")
361            .with_column_type(ColumnType::Uuid)
362            .exclude_table("audit_logs")
363            .shared_table("plans");
364
365        assert_eq!(config.column, "tenant_id");
366        assert_eq!(config.column_type, ColumnType::Uuid);
367        assert!(config.should_filter("users"));
368        assert!(!config.should_filter("audit_logs"));
369        assert!(!config.should_filter("plans"));
370    }
371
372    #[test]
373    fn test_schema_config() {
374        let config = SchemaConfig::default()
375            .with_prefix("tenant_")
376            .with_shared_schema("shared");
377
378        assert_eq!(config.schema_name("acme"), "tenant_acme");
379        assert!(config.search_path("acme").contains("tenant_acme"));
380        assert!(config.search_path("acme").contains("shared"));
381    }
382
383    #[test]
384    fn test_database_config() {
385        let config = DatabaseConfig::default()
386            .with_prefix("prax_")
387            .with_suffix("_db");
388
389        assert_eq!(config.database_name("acme"), "prax_acme_db");
390    }
391
392    #[test]
393    fn test_column_type_format() {
394        assert_eq!(ColumnType::String.format_value("test"), "'test'");
395        assert_eq!(
396            ColumnType::Uuid.format_value("123e4567-e89b-12d3-a456-426614174000"),
397            "'123e4567-e89b-12d3-a456-426614174000'::uuid"
398        );
399        assert_eq!(ColumnType::Integer.format_value("42"), "42");
400    }
401}