helios_persistence/strategy/
shared_schema.rs1use serde::{Deserialize, Serialize};
8
9use crate::tenant::TenantId;
10
11use super::{TenantResolution, TenantResolver, TenantValidationError};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct SharedSchemaConfig {
29 #[serde(default)]
34 pub use_row_level_security: bool,
35
36 #[serde(default = "default_tenant_column")]
38 pub tenant_column: String,
39
40 #[serde(default = "default_true")]
45 pub index_tenant_first: bool,
46
47 #[serde(default = "default_max_tenant_id_length")]
49 pub max_tenant_id_length: usize,
50
51 #[serde(default = "default_tenant_id_pattern")]
53 pub tenant_id_pattern: String,
54
55 #[serde(default)]
60 pub hash_long_ids: bool,
61}
62
63fn default_tenant_column() -> String {
64 "tenant_id".to_string()
65}
66
67fn default_true() -> bool {
68 true
69}
70
71fn default_max_tenant_id_length() -> usize {
72 64
73}
74
75fn default_tenant_id_pattern() -> String {
76 r"^[a-zA-Z0-9_\-/]+$".to_string()
77}
78
79impl Default for SharedSchemaConfig {
80 fn default() -> Self {
81 Self {
82 use_row_level_security: false,
83 tenant_column: default_tenant_column(),
84 index_tenant_first: true,
85 max_tenant_id_length: default_max_tenant_id_length(),
86 tenant_id_pattern: default_tenant_id_pattern(),
87 hash_long_ids: false,
88 }
89 }
90}
91
92impl SharedSchemaConfig {
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn with_rls(mut self) -> Self {
100 self.use_row_level_security = true;
101 self
102 }
103
104 pub fn with_tenant_column(mut self, column: impl Into<String>) -> Self {
106 self.tenant_column = column.into();
107 self
108 }
109}
110
111#[derive(Debug, Clone)]
150pub struct SharedSchemaStrategy {
151 config: SharedSchemaConfig,
152 tenant_pattern: regex::Regex,
153}
154
155impl SharedSchemaStrategy {
156 pub fn new(config: SharedSchemaConfig) -> Result<Self, regex::Error> {
158 let tenant_pattern = regex::Regex::new(&config.tenant_id_pattern)?;
159 Ok(Self {
160 config,
161 tenant_pattern,
162 })
163 }
164
165 pub fn config(&self) -> &SharedSchemaConfig {
167 &self.config
168 }
169
170 pub fn tenant_column(&self) -> &str {
172 &self.config.tenant_column
173 }
174
175 pub fn uses_rls(&self) -> bool {
177 self.config.use_row_level_security
178 }
179
180 pub fn set_tenant_sql(&self, tenant_id: &TenantId) -> String {
184 format!(
185 "SET LOCAL app.current_tenant = '{}'",
186 self.escape_sql_string(tenant_id.as_str())
187 )
188 }
189
190 pub fn clear_tenant_sql(&self) -> String {
192 "RESET app.current_tenant".to_string()
193 }
194
195 pub fn tenant_filter_sql(&self, table_alias: Option<&str>) -> String {
197 match table_alias {
198 Some(alias) => format!("{}.{} = $tenant_id", alias, self.config.tenant_column),
199 None => format!("{} = $tenant_id", self.config.tenant_column),
200 }
201 }
202
203 fn escape_sql_string(&self, s: &str) -> String {
205 s.replace('\'', "''")
206 }
207
208 fn normalize_tenant_id(&self, tenant_id: &TenantId) -> String {
210 let id = tenant_id.as_str();
211
212 if self.config.hash_long_ids && id.len() > self.config.max_tenant_id_length {
213 use std::collections::hash_map::DefaultHasher;
215 use std::hash::{Hash, Hasher};
216
217 let mut hasher = DefaultHasher::new();
218 id.hash(&mut hasher);
219 format!("h_{:016x}", hasher.finish())
220 } else {
221 id.to_string()
222 }
223 }
224}
225
226impl TenantResolver for SharedSchemaStrategy {
227 fn resolve(&self, tenant_id: &TenantId) -> TenantResolution {
228 TenantResolution::SharedSchema {
229 tenant_id: self.normalize_tenant_id(tenant_id),
230 }
231 }
232
233 fn validate(&self, tenant_id: &TenantId) -> Result<(), TenantValidationError> {
234 let id = tenant_id.as_str();
235
236 if !self.config.hash_long_ids && id.len() > self.config.max_tenant_id_length {
238 return Err(TenantValidationError {
239 tenant_id: id.to_string(),
240 reason: format!(
241 "tenant ID exceeds maximum length of {} characters",
242 self.config.max_tenant_id_length
243 ),
244 });
245 }
246
247 if !self.tenant_pattern.is_match(id) {
249 return Err(TenantValidationError {
250 tenant_id: id.to_string(),
251 reason: format!(
252 "tenant ID does not match required pattern: {}",
253 self.config.tenant_id_pattern
254 ),
255 });
256 }
257
258 Ok(())
259 }
260
261 fn system_tenant(&self) -> TenantResolution {
262 TenantResolution::SharedSchema {
263 tenant_id: crate::tenant::SYSTEM_TENANT.to_string(),
264 }
265 }
266}
267
268#[derive(Debug)]
270#[allow(dead_code)]
271pub struct TenantAwareTableBuilder {
272 table_name: String,
273 tenant_column: String,
274 columns: Vec<ColumnDef>,
275 indexes: Vec<IndexDef>,
276 use_rls: bool,
277}
278
279#[derive(Debug)]
280#[allow(dead_code)]
281struct ColumnDef {
282 name: String,
283 data_type: String,
284 nullable: bool,
285}
286
287#[derive(Debug)]
288#[allow(dead_code)]
289struct IndexDef {
290 name: String,
291 columns: Vec<String>,
292 unique: bool,
293}
294
295#[allow(dead_code)]
296impl TenantAwareTableBuilder {
297 pub fn new(table_name: impl Into<String>, config: &SharedSchemaConfig) -> Self {
299 Self {
300 table_name: table_name.into(),
301 tenant_column: config.tenant_column.clone(),
302 columns: Vec::new(),
303 indexes: Vec::new(),
304 use_rls: config.use_row_level_security,
305 }
306 }
307
308 pub fn column(
310 mut self,
311 name: impl Into<String>,
312 data_type: impl Into<String>,
313 nullable: bool,
314 ) -> Self {
315 self.columns.push(ColumnDef {
316 name: name.into(),
317 data_type: data_type.into(),
318 nullable,
319 });
320 self
321 }
322
323 pub fn index(mut self, name: impl Into<String>, columns: Vec<&str>, unique: bool) -> Self {
325 self.indexes.push(IndexDef {
326 name: name.into(),
327 columns: columns.into_iter().map(String::from).collect(),
328 unique,
329 });
330 self
331 }
332
333 pub fn to_postgres_ddl(&self) -> String {
335 let mut ddl = String::new();
336
337 ddl.push_str(&format!(
339 "CREATE TABLE IF NOT EXISTS {} (\n",
340 self.table_name
341 ));
342 ddl.push_str(&format!(
343 " {} VARCHAR(64) NOT NULL,\n",
344 self.tenant_column
345 ));
346
347 for col in &self.columns {
348 let null_str = if col.nullable { "" } else { " NOT NULL" };
349 ddl.push_str(&format!(
350 " {} {}{},\n",
351 col.name, col.data_type, null_str
352 ));
353 }
354
355 ddl.truncate(ddl.len() - 2);
357 ddl.push_str("\n);\n\n");
358
359 for idx in &self.indexes {
361 let unique_str = if idx.unique { "UNIQUE " } else { "" };
362 let columns: Vec<_> = std::iter::once(self.tenant_column.as_str())
363 .chain(idx.columns.iter().map(|s| s.as_str()))
364 .collect();
365 ddl.push_str(&format!(
366 "CREATE {}INDEX IF NOT EXISTS {} ON {} ({});\n",
367 unique_str,
368 idx.name,
369 self.table_name,
370 columns.join(", ")
371 ));
372 }
373
374 if self.use_rls {
376 ddl.push_str(&format!(
377 "\nALTER TABLE {} ENABLE ROW LEVEL SECURITY;\n",
378 self.table_name
379 ));
380 ddl.push_str(&format!(
381 "CREATE POLICY tenant_isolation ON {} USING ({} = current_setting('app.current_tenant'));\n",
382 self.table_name, self.tenant_column
383 ));
384 }
385
386 ddl
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_shared_schema_config_default() {
396 let config = SharedSchemaConfig::default();
397 assert_eq!(config.tenant_column, "tenant_id");
398 assert!(!config.use_row_level_security);
399 assert!(config.index_tenant_first);
400 }
401
402 #[test]
403 fn test_shared_schema_config_builder() {
404 let config = SharedSchemaConfig::new()
405 .with_rls()
406 .with_tenant_column("org_id");
407
408 assert!(config.use_row_level_security);
409 assert_eq!(config.tenant_column, "org_id");
410 }
411
412 #[test]
413 fn test_shared_schema_strategy_creation() {
414 let config = SharedSchemaConfig::default();
415 let strategy = SharedSchemaStrategy::new(config).unwrap();
416 assert_eq!(strategy.tenant_column(), "tenant_id");
417 }
418
419 #[test]
420 fn test_tenant_resolution() {
421 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
422 let resolution = strategy.resolve(&TenantId::new("acme"));
423
424 match resolution {
425 TenantResolution::SharedSchema { tenant_id } => {
426 assert_eq!(tenant_id, "acme");
427 }
428 _ => panic!("expected SharedSchema resolution"),
429 }
430 }
431
432 #[test]
433 fn test_tenant_validation_valid() {
434 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
435 assert!(strategy.validate(&TenantId::new("acme")).is_ok());
436 assert!(strategy.validate(&TenantId::new("acme/research")).is_ok());
437 assert!(strategy.validate(&TenantId::new("tenant_123")).is_ok());
438 }
439
440 #[test]
441 fn test_tenant_validation_invalid_pattern() {
442 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
443 let result = strategy.validate(&TenantId::new("tenant with spaces"));
444 assert!(result.is_err());
445 }
446
447 #[test]
448 fn test_tenant_validation_too_long() {
449 let config = SharedSchemaConfig {
450 max_tenant_id_length: 10,
451 ..Default::default()
452 };
453 let strategy = SharedSchemaStrategy::new(config).unwrap();
454 let result = strategy.validate(&TenantId::new("this-is-a-very-long-tenant-id"));
455 assert!(result.is_err());
456 }
457
458 #[test]
459 fn test_set_tenant_sql() {
460 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
461 let sql = strategy.set_tenant_sql(&TenantId::new("acme"));
462 assert_eq!(sql, "SET LOCAL app.current_tenant = 'acme'");
463 }
464
465 #[test]
466 fn test_set_tenant_sql_escapes() {
467 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
468 let sql = strategy.set_tenant_sql(&TenantId::new("o'brien"));
469 assert_eq!(sql, "SET LOCAL app.current_tenant = 'o''brien'");
470 }
471
472 #[test]
473 fn test_tenant_filter_sql() {
474 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
475
476 let filter = strategy.tenant_filter_sql(None);
477 assert_eq!(filter, "tenant_id = $tenant_id");
478
479 let filter_aliased = strategy.tenant_filter_sql(Some("p"));
480 assert_eq!(filter_aliased, "p.tenant_id = $tenant_id");
481 }
482
483 #[test]
484 fn test_table_builder() {
485 let config = SharedSchemaConfig::default();
486 let ddl = TenantAwareTableBuilder::new("patient", &config)
487 .column("id", "VARCHAR(64)", false)
488 .column("family_name", "TEXT", true)
489 .index("idx_patient_id", vec!["id"], true)
490 .to_postgres_ddl();
491
492 assert!(ddl.contains("CREATE TABLE IF NOT EXISTS patient"));
493 assert!(ddl.contains("tenant_id VARCHAR(64) NOT NULL"));
494 assert!(ddl.contains("id VARCHAR(64) NOT NULL"));
495 assert!(ddl.contains("CREATE UNIQUE INDEX"));
496 assert!(ddl.contains("(tenant_id, id)"));
497 }
498
499 #[test]
500 fn test_table_builder_with_rls() {
501 let config = SharedSchemaConfig::new().with_rls();
502 let ddl = TenantAwareTableBuilder::new("patient", &config)
503 .column("id", "VARCHAR(64)", false)
504 .to_postgres_ddl();
505
506 assert!(ddl.contains("ENABLE ROW LEVEL SECURITY"));
507 assert!(ddl.contains("CREATE POLICY tenant_isolation"));
508 }
509
510 #[test]
511 fn test_system_tenant_resolution() {
512 let strategy = SharedSchemaStrategy::new(SharedSchemaConfig::default()).unwrap();
513 let resolution = strategy.system_tenant();
514
515 match resolution {
516 TenantResolution::SharedSchema { tenant_id } => {
517 assert_eq!(tenant_id, crate::tenant::SYSTEM_TENANT);
518 }
519 _ => panic!("expected SharedSchema resolution"),
520 }
521 }
522}