Skip to main content

helios_persistence/strategy/
shared_schema.rs

1//! Shared schema tenancy strategy.
2//!
3//! In this strategy, all tenants share the same database tables with a
4//! `tenant_id` column used to filter data. This is the simplest and most
5//! common approach for multi-tenant applications.
6
7use serde::{Deserialize, Serialize};
8
9use crate::tenant::TenantId;
10
11use super::{TenantResolution, TenantResolver, TenantValidationError};
12
13/// Configuration for shared schema tenancy.
14///
15/// # Example
16///
17/// ```
18/// use helios_persistence::strategy::SharedSchemaConfig;
19///
20/// let config = SharedSchemaConfig {
21///     use_row_level_security: true,
22///     tenant_column: "tenant_id".to_string(),
23///     index_tenant_first: true,
24///     ..Default::default()
25/// };
26/// ```
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SharedSchemaConfig {
29    /// Whether to use Row-Level Security (PostgreSQL only).
30    ///
31    /// When enabled, the database enforces tenant isolation via RLS policies,
32    /// providing an additional layer of protection against application bugs.
33    #[serde(default)]
34    pub use_row_level_security: bool,
35
36    /// The name of the tenant ID column in tables.
37    #[serde(default = "default_tenant_column")]
38    pub tenant_column: String,
39
40    /// Whether to put tenant_id first in composite indexes.
41    ///
42    /// When true (recommended), indexes are created as (tenant_id, ...)
43    /// which improves query performance for tenant-filtered queries.
44    #[serde(default = "default_true")]
45    pub index_tenant_first: bool,
46
47    /// Maximum length for tenant IDs.
48    #[serde(default = "default_max_tenant_id_length")]
49    pub max_tenant_id_length: usize,
50
51    /// Allowed characters in tenant IDs (regex pattern).
52    #[serde(default = "default_tenant_id_pattern")]
53    pub tenant_id_pattern: String,
54
55    /// Whether to hash long tenant IDs.
56    ///
57    /// If a tenant ID exceeds `max_tenant_id_length`, it will be hashed
58    /// to a shorter value. The mapping is stored for reverse lookup.
59    #[serde(default)]
60    pub hash_long_ids: bool,
61}
62
63fn default_tenant_column() -> String {
64    "tenant_id".to_string()
65}
66
67fn default_true() -> bool {
68    true
69}
70
71fn default_max_tenant_id_length() -> usize {
72    64
73}
74
75fn default_tenant_id_pattern() -> String {
76    r"^[a-zA-Z0-9_\-/]+$".to_string()
77}
78
79impl Default for SharedSchemaConfig {
80    fn default() -> Self {
81        Self {
82            use_row_level_security: false,
83            tenant_column: default_tenant_column(),
84            index_tenant_first: true,
85            max_tenant_id_length: default_max_tenant_id_length(),
86            tenant_id_pattern: default_tenant_id_pattern(),
87            hash_long_ids: false,
88        }
89    }
90}
91
92impl SharedSchemaConfig {
93    /// Creates a new configuration with defaults.
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    /// Enables Row-Level Security.
99    pub fn with_rls(mut self) -> Self {
100        self.use_row_level_security = true;
101        self
102    }
103
104    /// Sets the tenant column name.
105    pub fn with_tenant_column(mut self, column: impl Into<String>) -> Self {
106        self.tenant_column = column.into();
107        self
108    }
109}
110
111/// Shared schema tenancy strategy implementation.
112///
113/// This strategy uses a single database schema with a `tenant_id` column
114/// on all tables to isolate tenant data.
115///
116/// # Query Modification
117///
118/// All queries are modified to include a tenant filter:
119///
120/// ```sql
121/// -- Original query
122/// SELECT * FROM patient WHERE id = '123';
123///
124/// -- Modified query
125/// SELECT * FROM patient WHERE tenant_id = 'acme' AND id = '123';
126/// ```
127///
128/// # Index Strategy
129///
130/// For optimal performance, indexes should have `tenant_id` as the leading column:
131///
132/// ```sql
133/// CREATE INDEX idx_patient_tenant_id ON patient (tenant_id, id);
134/// CREATE INDEX idx_patient_tenant_name ON patient (tenant_id, family_name, given_name);
135/// ```
136///
137/// # Row-Level Security (PostgreSQL)
138///
139/// When RLS is enabled, additional protection is provided at the database level:
140///
141/// ```sql
142/// -- Enable RLS on table
143/// ALTER TABLE patient ENABLE ROW LEVEL SECURITY;
144///
145/// -- Create policy
146/// CREATE POLICY tenant_isolation ON patient
147///     USING (tenant_id = current_setting('app.current_tenant'));
148/// ```
149#[derive(Debug, Clone)]
150pub struct SharedSchemaStrategy {
151    config: SharedSchemaConfig,
152    tenant_pattern: regex::Regex,
153}
154
155impl SharedSchemaStrategy {
156    /// Creates a new shared schema strategy with the given configuration.
157    pub fn new(config: SharedSchemaConfig) -> Result<Self, regex::Error> {
158        let tenant_pattern = regex::Regex::new(&config.tenant_id_pattern)?;
159        Ok(Self {
160            config,
161            tenant_pattern,
162        })
163    }
164
165    /// Returns the configuration.
166    pub fn config(&self) -> &SharedSchemaConfig {
167        &self.config
168    }
169
170    /// Returns the tenant column name.
171    pub fn tenant_column(&self) -> &str {
172        &self.config.tenant_column
173    }
174
175    /// Returns whether RLS is enabled.
176    pub fn uses_rls(&self) -> bool {
177        self.config.use_row_level_security
178    }
179
180    /// Generates SQL for setting the current tenant (for RLS).
181    ///
182    /// This should be executed at the beginning of each request/transaction.
183    pub fn set_tenant_sql(&self, tenant_id: &TenantId) -> String {
184        format!(
185            "SET LOCAL app.current_tenant = '{}'",
186            self.escape_sql_string(tenant_id.as_str())
187        )
188    }
189
190    /// Generates SQL for clearing the current tenant.
191    pub fn clear_tenant_sql(&self) -> String {
192        "RESET app.current_tenant".to_string()
193    }
194
195    /// Generates a WHERE clause fragment for tenant filtering.
196    pub fn tenant_filter_sql(&self, table_alias: Option<&str>) -> String {
197        match table_alias {
198            Some(alias) => format!("{}.{} = $tenant_id", alias, self.config.tenant_column),
199            None => format!("{} = $tenant_id", self.config.tenant_column),
200        }
201    }
202
203    /// Escapes a string for safe inclusion in SQL.
204    fn escape_sql_string(&self, s: &str) -> String {
205        s.replace('\'', "''")
206    }
207
208    /// Normalizes a tenant ID (handles hashing if needed).
209    fn normalize_tenant_id(&self, tenant_id: &TenantId) -> String {
210        let id = tenant_id.as_str();
211
212        if self.config.hash_long_ids && id.len() > self.config.max_tenant_id_length {
213            // Use a simple hash for long IDs
214            use std::collections::hash_map::DefaultHasher;
215            use std::hash::{Hash, Hasher};
216
217            let mut hasher = DefaultHasher::new();
218            id.hash(&mut hasher);
219            format!("h_{:016x}", hasher.finish())
220        } else {
221            id.to_string()
222        }
223    }
224}
225
226impl TenantResolver for SharedSchemaStrategy {
227    fn resolve(&self, tenant_id: &TenantId) -> TenantResolution {
228        TenantResolution::SharedSchema {
229            tenant_id: self.normalize_tenant_id(tenant_id),
230        }
231    }
232
233    fn validate(&self, tenant_id: &TenantId) -> Result<(), TenantValidationError> {
234        let id = tenant_id.as_str();
235
236        // Check length
237        if !self.config.hash_long_ids && id.len() > self.config.max_tenant_id_length {
238            return Err(TenantValidationError {
239                tenant_id: id.to_string(),
240                reason: format!(
241                    "tenant ID exceeds maximum length of {} characters",
242                    self.config.max_tenant_id_length
243                ),
244            });
245        }
246
247        // Check pattern
248        if !self.tenant_pattern.is_match(id) {
249            return Err(TenantValidationError {
250                tenant_id: id.to_string(),
251                reason: format!(
252                    "tenant ID does not match required pattern: {}",
253                    self.config.tenant_id_pattern
254                ),
255            });
256        }
257
258        Ok(())
259    }
260
261    fn system_tenant(&self) -> TenantResolution {
262        TenantResolution::SharedSchema {
263            tenant_id: crate::tenant::SYSTEM_TENANT.to_string(),
264        }
265    }
266}
267
268/// Builder for creating table DDL with tenant support.
269#[derive(Debug)]
270#[allow(dead_code)]
271pub struct TenantAwareTableBuilder {
272    table_name: String,
273    tenant_column: String,
274    columns: Vec<ColumnDef>,
275    indexes: Vec<IndexDef>,
276    use_rls: bool,
277}
278
279#[derive(Debug)]
280#[allow(dead_code)]
281struct ColumnDef {
282    name: String,
283    data_type: String,
284    nullable: bool,
285}
286
287#[derive(Debug)]
288#[allow(dead_code)]
289struct IndexDef {
290    name: String,
291    columns: Vec<String>,
292    unique: bool,
293}
294
295#[allow(dead_code)]
296impl TenantAwareTableBuilder {
297    /// Creates a new table builder.
298    pub fn new(table_name: impl Into<String>, config: &SharedSchemaConfig) -> Self {
299        Self {
300            table_name: table_name.into(),
301            tenant_column: config.tenant_column.clone(),
302            columns: Vec::new(),
303            indexes: Vec::new(),
304            use_rls: config.use_row_level_security,
305        }
306    }
307
308    /// Adds a column to the table.
309    pub fn column(
310        mut self,
311        name: impl Into<String>,
312        data_type: impl Into<String>,
313        nullable: bool,
314    ) -> Self {
315        self.columns.push(ColumnDef {
316            name: name.into(),
317            data_type: data_type.into(),
318            nullable,
319        });
320        self
321    }
322
323    /// Adds an index (tenant_id will be prepended automatically).
324    pub fn index(mut self, name: impl Into<String>, columns: Vec<&str>, unique: bool) -> Self {
325        self.indexes.push(IndexDef {
326            name: name.into(),
327            columns: columns.into_iter().map(String::from).collect(),
328            unique,
329        });
330        self
331    }
332
333    /// Generates PostgreSQL DDL for the table.
334    pub fn to_postgres_ddl(&self) -> String {
335        let mut ddl = String::new();
336
337        // CREATE TABLE
338        ddl.push_str(&format!(
339            "CREATE TABLE IF NOT EXISTS {} (\n",
340            self.table_name
341        ));
342        ddl.push_str(&format!(
343            "    {} VARCHAR(64) NOT NULL,\n",
344            self.tenant_column
345        ));
346
347        for col in &self.columns {
348            let null_str = if col.nullable { "" } else { " NOT NULL" };
349            ddl.push_str(&format!(
350                "    {} {}{},\n",
351                col.name, col.data_type, null_str
352            ));
353        }
354
355        // Remove trailing comma and close
356        ddl.truncate(ddl.len() - 2);
357        ddl.push_str("\n);\n\n");
358
359        // CREATE INDEXES (with tenant_id first)
360        for idx in &self.indexes {
361            let unique_str = if idx.unique { "UNIQUE " } else { "" };
362            let columns: Vec<_> = std::iter::once(self.tenant_column.as_str())
363                .chain(idx.columns.iter().map(|s| s.as_str()))
364                .collect();
365            ddl.push_str(&format!(
366                "CREATE {}INDEX IF NOT EXISTS {} ON {} ({});\n",
367                unique_str,
368                idx.name,
369                self.table_name,
370                columns.join(", ")
371            ));
372        }
373
374        // RLS if enabled
375        if self.use_rls {
376            ddl.push_str(&format!(
377                "\nALTER TABLE {} ENABLE ROW LEVEL SECURITY;\n",
378                self.table_name
379            ));
380            ddl.push_str(&format!(
381                "CREATE POLICY tenant_isolation ON {} USING ({} = current_setting('app.current_tenant'));\n",
382                self.table_name, self.tenant_column
383            ));
384        }
385
386        ddl
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn test_shared_schema_config_default() {
396        let config = SharedSchemaConfig::default();
397        assert_eq!(config.tenant_column, "tenant_id");
398        assert!(!config.use_row_level_security);
399        assert!(config.index_tenant_first);
400    }
401
402    #[test]
403    fn test_shared_schema_config_builder() {
404        let config = SharedSchemaConfig::new()
405            .with_rls()
406            .with_tenant_column("org_id");
407
408        assert!(config.use_row_level_security);
409        assert_eq!(config.tenant_column, "org_id");
410    }
411
412    #[test]
413    fn test_shared_schema_strategy_creation() {
414        let config = SharedSchemaConfig::default();
415        let strategy = SharedSchemaStrategy::new(config).unwrap();
416        assert_eq!(strategy.tenant_column(), "tenant_id");
417    }
418
419    #[test]
420    fn test_tenant_resolution() {
421        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
422        let resolution = strategy.resolve(&TenantId::new("acme"));
423
424        match resolution {
425            TenantResolution::SharedSchema { tenant_id } => {
426                assert_eq!(tenant_id, "acme");
427            }
428            _ => panic!("expected SharedSchema resolution"),
429        }
430    }
431
432    #[test]
433    fn test_tenant_validation_valid() {
434        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
435        assert!(strategy.validate(&TenantId::new("acme")).is_ok());
436        assert!(strategy.validate(&TenantId::new("acme/research")).is_ok());
437        assert!(strategy.validate(&TenantId::new("tenant_123")).is_ok());
438    }
439
440    #[test]
441    fn test_tenant_validation_invalid_pattern() {
442        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
443        let result = strategy.validate(&TenantId::new("tenant with spaces"));
444        assert!(result.is_err());
445    }
446
447    #[test]
448    fn test_tenant_validation_too_long() {
449        let config = SharedSchemaConfig {
450            max_tenant_id_length: 10,
451            ..Default::default()
452        };
453        let strategy = SharedSchemaStrategy::new(config).unwrap();
454        let result = strategy.validate(&TenantId::new("this-is-a-very-long-tenant-id"));
455        assert!(result.is_err());
456    }
457
458    #[test]
459    fn test_set_tenant_sql() {
460        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
461        let sql = strategy.set_tenant_sql(&TenantId::new("acme"));
462        assert_eq!(sql, "SET LOCAL app.current_tenant = 'acme'");
463    }
464
465    #[test]
466    fn test_set_tenant_sql_escapes() {
467        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
468        let sql = strategy.set_tenant_sql(&TenantId::new("o'brien"));
469        assert_eq!(sql, "SET LOCAL app.current_tenant = 'o''brien'");
470    }
471
472    #[test]
473    fn test_tenant_filter_sql() {
474        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
475
476        let filter = strategy.tenant_filter_sql(None);
477        assert_eq!(filter, "tenant_id = $tenant_id");
478
479        let filter_aliased = strategy.tenant_filter_sql(Some("p"));
480        assert_eq!(filter_aliased, "p.tenant_id = $tenant_id");
481    }
482
483    #[test]
484    fn test_table_builder() {
485        let config = SharedSchemaConfig::default();
486        let ddl = TenantAwareTableBuilder::new("patient", &config)
487            .column("id", "VARCHAR(64)", false)
488            .column("family_name", "TEXT", true)
489            .index("idx_patient_id", vec!["id"], true)
490            .to_postgres_ddl();
491
492        assert!(ddl.contains("CREATE TABLE IF NOT EXISTS patient"));
493        assert!(ddl.contains("tenant_id VARCHAR(64) NOT NULL"));
494        assert!(ddl.contains("id VARCHAR(64) NOT NULL"));
495        assert!(ddl.contains("CREATE UNIQUE INDEX"));
496        assert!(ddl.contains("(tenant_id, id)"));
497    }
498
499    #[test]
500    fn test_table_builder_with_rls() {
501        let config = SharedSchemaConfig::new().with_rls();
502        let ddl = TenantAwareTableBuilder::new("patient", &config)
503            .column("id", "VARCHAR(64)", false)
504            .to_postgres_ddl();
505
506        assert!(ddl.contains("ENABLE ROW LEVEL SECURITY"));
507        assert!(ddl.contains("CREATE POLICY tenant_isolation"));
508    }
509
510    #[test]
511    fn test_system_tenant_resolution() {
512        let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
513        let resolution = strategy.system_tenant();
514
515        match resolution {
516            TenantResolution::SharedSchema { tenant_id } => {
517                assert_eq!(tenant_id, crate::tenant::SYSTEM_TENANT);
518            }
519            _ => panic!("expected SharedSchema resolution"),
520        }
521    }
522}