helios_persistence/strategy/
database_per_tenant.rs1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13
14use crate::tenant::TenantId;
15
16use super::{TenantResolution, TenantResolver, TenantValidationError};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DatabasePerTenantConfig {
34 #[serde(default = "default_connection_template")]
44 pub connection_template: String,
45
46 #[serde(default = "default_true")]
51 pub pool_per_tenant: bool,
52
53 #[serde(default)]
58 pub max_pools: Option<usize>,
59
60 #[serde(default = "default_min_connections")]
62 pub min_connections_per_pool: u32,
63
64 #[serde(default = "default_max_connections")]
66 pub max_connections_per_pool: u32,
67
68 #[serde(default = "default_idle_timeout")]
70 pub idle_timeout_secs: u64,
71
72 #[serde(default = "default_max_tenant_id_length")]
74 pub max_tenant_id_length: usize,
75
76 #[serde(default = "default_tenant_id_pattern")]
78 pub tenant_id_pattern: String,
79
80 #[serde(default)]
82 pub auto_create_database: bool,
83
84 #[serde(default = "default_database_prefix")]
86 pub database_prefix: String,
87
88 #[serde(default)]
90 pub database_suffix: String,
91
92 #[serde(default = "default_host")]
94 pub default_host: String,
95
96 #[serde(default = "default_port")]
98 pub default_port: u16,
99
100 #[serde(default = "default_system_database")]
102 pub system_database: String,
103}
104
105fn default_connection_template() -> String {
106 "postgres://{user}:{password}@{host}:{port}/{tenant}_db".to_string()
107}
108
109fn default_true() -> bool {
110 true
111}
112
113fn default_min_connections() -> u32 {
114 1
115}
116
117fn default_max_connections() -> u32 {
118 10
119}
120
121fn default_idle_timeout() -> u64 {
122 300 }
124
125fn default_max_tenant_id_length() -> usize {
126 32
127}
128
129fn default_tenant_id_pattern() -> String {
130 r"^[a-zA-Z][a-zA-Z0-9_]*$".to_string()
131}
132
133fn default_database_prefix() -> String {
134 "tenant_".to_string()
135}
136
137fn default_host() -> String {
138 "localhost".to_string()
139}
140
141fn default_port() -> u16 {
142 5432
143}
144
145fn default_system_database() -> String {
146 "helios_system".to_string()
147}
148
149impl Default for DatabasePerTenantConfig {
150 fn default() -> Self {
151 Self {
152 connection_template: default_connection_template(),
153 pool_per_tenant: true,
154 max_pools: Some(100),
155 min_connections_per_pool: default_min_connections(),
156 max_connections_per_pool: default_max_connections(),
157 idle_timeout_secs: default_idle_timeout(),
158 max_tenant_id_length: default_max_tenant_id_length(),
159 tenant_id_pattern: default_tenant_id_pattern(),
160 auto_create_database: false,
161 database_prefix: default_database_prefix(),
162 database_suffix: String::new(),
163 default_host: default_host(),
164 default_port: default_port(),
165 system_database: default_system_database(),
166 }
167 }
168}
169
170impl DatabasePerTenantConfig {
171 pub fn new() -> Self {
173 Self::default()
174 }
175
176 pub fn with_connection_template(mut self, template: impl Into<String>) -> Self {
178 self.connection_template = template.into();
179 self
180 }
181
182 pub fn with_auto_create(mut self) -> Self {
184 self.auto_create_database = true;
185 self
186 }
187
188 pub fn with_max_pools(mut self, max: usize) -> Self {
190 self.max_pools = Some(max);
191 self
192 }
193
194 pub fn with_database_prefix(mut self, prefix: impl Into<String>) -> Self {
196 self.database_prefix = prefix.into();
197 self
198 }
199
200 pub fn without_pool_per_tenant(mut self) -> Self {
202 self.pool_per_tenant = false;
203 self
204 }
205}
206
207#[derive(Debug)]
245pub struct DatabasePerTenantStrategy {
246 config: DatabasePerTenantConfig,
247 tenant_pattern: regex::Regex,
248 pool_access_times: Arc<RwLock<HashMap<String, Instant>>>,
250}
251
252impl Clone for DatabasePerTenantStrategy {
253 fn clone(&self) -> Self {
254 Self {
255 config: self.config.clone(),
256 tenant_pattern: regex::Regex::new(&self.config.tenant_id_pattern)
257 .expect("pattern was valid in original"),
258 pool_access_times: Arc::clone(&self.pool_access_times),
259 }
260 }
261}
262
263impl DatabasePerTenantStrategy {
264 pub fn new(config: DatabasePerTenantConfig) -> Result<Self, regex::Error> {
266 let tenant_pattern = regex::Regex::new(&config.tenant_id_pattern)?;
267 Ok(Self {
268 config,
269 tenant_pattern,
270 pool_access_times: Arc::new(RwLock::new(HashMap::new())),
271 })
272 }
273
274 pub fn config(&self) -> &DatabasePerTenantConfig {
276 &self.config
277 }
278
279 pub fn database_name(&self, tenant_id: &TenantId) -> String {
281 let sanitized = self.sanitize_tenant_id(tenant_id);
282 format!(
283 "{}{}{}",
284 self.config.database_prefix, sanitized, self.config.database_suffix
285 )
286 }
287
288 pub fn connection_string(&self, tenant_id: &TenantId, user: &str, password: &str) -> String {
290 self.connection_string_with_host(tenant_id, user, password, None)
291 }
292
293 pub fn connection_string_with_host(
295 &self,
296 tenant_id: &TenantId,
297 user: &str,
298 password: &str,
299 host: Option<&str>,
300 ) -> String {
301 let sanitized = self.sanitize_tenant_id(tenant_id);
302 let db_name = self.database_name(tenant_id);
303 let host = host.unwrap_or(&self.config.default_host);
304
305 self.config
306 .connection_template
307 .replace("{tenant}", &sanitized)
308 .replace("{tenant_hash}", &self.hash_tenant_id(tenant_id))
309 .replace("{host}", host)
310 .replace("{port}", &self.config.default_port.to_string())
311 .replace("{user}", user)
312 .replace("{password}", password)
313 .replace("{database}", &db_name)
314 }
315
316 pub fn create_database_sql(&self, tenant_id: &TenantId) -> String {
318 let db_name = self.database_name(tenant_id);
319 format!(
320 "CREATE DATABASE {} WITH ENCODING 'UTF8'",
321 self.quote_identifier(&db_name)
322 )
323 }
324
325 pub fn drop_database_sql(&self, tenant_id: &TenantId) -> String {
327 let db_name = self.database_name(tenant_id);
328 format!(
329 "DROP DATABASE IF EXISTS {}",
330 self.quote_identifier(&db_name)
331 )
332 }
333
334 pub fn database_exists_sql(&self, tenant_id: &TenantId) -> String {
336 let db_name = self.database_name(tenant_id);
337 format!(
338 "SELECT 1 FROM pg_database WHERE datname = '{}'",
339 self.escape_sql_string(&db_name)
340 )
341 }
342
343 pub fn record_pool_access(&self, tenant_id: &TenantId) {
345 let mut times = self.pool_access_times.write();
346 times.insert(tenant_id.as_str().to_string(), Instant::now());
347 }
348
349 pub fn tenants_to_evict(&self) -> Vec<String> {
351 let times = self.pool_access_times.read();
352 let max_pools = self.config.max_pools.unwrap_or(usize::MAX);
353
354 if times.len() <= max_pools {
355 return Vec::new();
356 }
357
358 let mut entries: Vec<_> = times.iter().collect();
359 entries.sort_by_key(|(_, time)| *time);
360
361 let to_evict = times.len() - max_pools;
362 entries
363 .into_iter()
364 .take(to_evict)
365 .map(|(id, _)| id.clone())
366 .collect()
367 }
368
369 pub fn remove_pool_tracking(&self, tenant_id: &str) {
371 let mut times = self.pool_access_times.write();
372 times.remove(tenant_id);
373 }
374
375 pub fn idle_tenants(&self) -> Vec<String> {
377 let times = self.pool_access_times.read();
378 let timeout = Duration::from_secs(self.config.idle_timeout_secs);
379 let now = Instant::now();
380
381 times
382 .iter()
383 .filter(|(_, last_access)| now.duration_since(**last_access) > timeout)
384 .map(|(id, _)| id.clone())
385 .collect()
386 }
387
388 fn sanitize_tenant_id(&self, tenant_id: &TenantId) -> String {
390 let id = tenant_id.as_str();
391
392 let sanitized = id.replace(['/', '-'], "_");
394
395 if sanitized.len() > self.config.max_tenant_id_length {
397 self.hash_tenant_id(tenant_id)
399 } else {
400 sanitized.to_lowercase()
401 }
402 }
403
404 fn hash_tenant_id(&self, tenant_id: &TenantId) -> String {
406 use std::collections::hash_map::DefaultHasher;
407 use std::hash::{Hash, Hasher};
408
409 let mut hasher = DefaultHasher::new();
410 tenant_id.as_str().hash(&mut hasher);
411 format!("t_{:016x}", hasher.finish())
412 }
413
414 fn quote_identifier(&self, id: &str) -> String {
416 format!("\"{}\"", id.replace('"', "\"\""))
417 }
418
419 fn escape_sql_string(&self, s: &str) -> String {
421 s.replace('\'', "''")
422 }
423}
424
425impl TenantResolver for DatabasePerTenantStrategy {
426 fn resolve(&self, tenant_id: &TenantId) -> TenantResolution {
427 self.record_pool_access(tenant_id);
428 TenantResolution::Database {
429 connection: self.database_name(tenant_id),
430 }
431 }
432
433 fn validate(&self, tenant_id: &TenantId) -> Result<(), TenantValidationError> {
434 let id = tenant_id.as_str();
435
436 let base_name = id.split('/').next().unwrap_or(id);
439
440 if !self.tenant_pattern.is_match(base_name) {
441 return Err(TenantValidationError {
442 tenant_id: id.to_string(),
443 reason: format!(
444 "tenant ID does not match required pattern for database names: {}",
445 self.config.tenant_id_pattern
446 ),
447 });
448 }
449
450 let sanitized = self.sanitize_tenant_id(tenant_id);
452 if sanitized.len() > 63 {
453 return Err(TenantValidationError {
455 tenant_id: id.to_string(),
456 reason: "sanitized tenant ID would exceed database name length limit (63 chars)"
457 .to_string(),
458 });
459 }
460
461 Ok(())
462 }
463
464 fn system_tenant(&self) -> TenantResolution {
465 TenantResolution::Database {
466 connection: self.config.system_database.clone(),
467 }
468 }
469}
470
471#[derive(Debug)]
475#[allow(dead_code)]
476pub struct TenantDatabaseManager {
477 strategy: DatabasePerTenantStrategy,
478 admin_user: String,
479 admin_password: String,
480}
481
482#[allow(dead_code)]
483impl TenantDatabaseManager {
484 pub fn new(
486 strategy: DatabasePerTenantStrategy,
487 admin_user: impl Into<String>,
488 admin_password: impl Into<String>,
489 ) -> Self {
490 Self {
491 strategy,
492 admin_user: admin_user.into(),
493 admin_password: admin_password.into(),
494 }
495 }
496
497 pub fn admin_connection_string(&self) -> String {
499 self.strategy
500 .config
501 .connection_template
502 .replace("{tenant}", "system")
503 .replace("{tenant_hash}", "system")
504 .replace("{host}", &self.strategy.config.default_host)
505 .replace("{port}", &self.strategy.config.default_port.to_string())
506 .replace("{user}", &self.admin_user)
507 .replace("{password}", &self.admin_password)
508 .replace("{database}", &self.strategy.config.system_database)
509 }
510
511 pub fn provision_tenant_sql(&self, tenant_id: &TenantId) -> Vec<String> {
513 let db_name = self.strategy.database_name(tenant_id);
514 let quoted_db = self.strategy.quote_identifier(&db_name);
515
516 vec![
517 format!("CREATE DATABASE {} WITH ENCODING 'UTF8'", quoted_db),
519 ]
521 }
522
523 pub fn deprovision_tenant_sql(&self, tenant_id: &TenantId) -> Vec<String> {
525 let db_name = self.strategy.database_name(tenant_id);
526 let quoted_db = self.strategy.quote_identifier(&db_name);
527
528 vec![
529 format!(
531 "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}'",
532 self.strategy.escape_sql_string(&db_name)
533 ),
534 format!("DROP DATABASE IF EXISTS {}", quoted_db),
536 ]
537 }
538
539 pub fn list_tenant_databases_sql(&self) -> String {
541 let prefix = &self.strategy.config.database_prefix;
542 format!(
543 "SELECT datname FROM pg_database WHERE datname LIKE '{}%' ORDER BY datname",
544 self.strategy.escape_sql_string(prefix)
545 )
546 }
547
548 pub fn database_stats_sql(&self, tenant_id: &TenantId) -> String {
550 let db_name = self.strategy.database_name(tenant_id);
551 format!(
552 r#"
553 SELECT
554 pg_database_size('{}') as size_bytes,
555 (SELECT count(*) FROM pg_stat_activity WHERE datname = '{}') as active_connections
556 "#,
557 self.strategy.escape_sql_string(&db_name),
558 self.strategy.escape_sql_string(&db_name)
559 )
560 }
561}
562
563#[derive(Debug, Clone)]
565#[allow(dead_code)]
566pub struct TenantDatabaseInfo {
567 pub tenant_id: String,
569 pub database_name: String,
571 pub size_bytes: Option<u64>,
573 pub active_connections: Option<u32>,
575 pub last_access: Option<Instant>,
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_database_per_tenant_config_default() {
585 let config = DatabasePerTenantConfig::default();
586 assert!(config.pool_per_tenant);
587 assert_eq!(config.max_pools, Some(100));
588 assert_eq!(config.database_prefix, "tenant_");
589 }
590
591 #[test]
592 fn test_database_per_tenant_config_builder() {
593 let config = DatabasePerTenantConfig::new()
594 .with_max_pools(50)
595 .with_database_prefix("db_")
596 .with_auto_create();
597
598 assert_eq!(config.max_pools, Some(50));
599 assert_eq!(config.database_prefix, "db_");
600 assert!(config.auto_create_database);
601 }
602
603 #[test]
604 fn test_database_per_tenant_strategy_creation() {
605 let config = DatabasePerTenantConfig::default();
606 let strategy = DatabasePerTenantStrategy::new(config).unwrap();
607 assert_eq!(strategy.config().database_prefix, "tenant_");
608 }
609
610 #[test]
611 fn test_database_name_generation() {
612 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
613
614 let db_name = strategy.database_name(&TenantId::new("acme"));
615 assert_eq!(db_name, "tenant_acme");
616
617 let db_name = strategy.database_name(&TenantId::new("acme/research"));
619 assert_eq!(db_name, "tenant_acme_research");
620 }
621
622 #[test]
623 fn test_tenant_resolution() {
624 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
625 let resolution = strategy.resolve(&TenantId::new("acme"));
626
627 match resolution {
628 TenantResolution::Database { connection } => {
629 assert_eq!(connection, "tenant_acme");
630 }
631 _ => panic!("expected Database resolution"),
632 }
633 }
634
635 #[test]
636 fn test_tenant_validation_valid() {
637 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
638 assert!(strategy.validate(&TenantId::new("acme")).is_ok());
639 assert!(strategy.validate(&TenantId::new("Acme123")).is_ok());
640 assert!(strategy.validate(&TenantId::new("tenant_one")).is_ok());
641 }
642
643 #[test]
644 fn test_tenant_validation_invalid_pattern() {
645 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
646 let result = strategy.validate(&TenantId::new("123acme"));
648 assert!(result.is_err());
649 }
650
651 #[test]
652 fn test_connection_string_generation() {
653 let config = DatabasePerTenantConfig {
654 connection_template: "postgres://{user}:{password}@{host}:{port}/{tenant}_db"
655 .to_string(),
656 default_host: "db.example.com".to_string(),
657 default_port: 5432,
658 ..Default::default()
659 };
660 let strategy = DatabasePerTenantStrategy::new(config).unwrap();
661
662 let conn = strategy.connection_string(&TenantId::new("acme"), "admin", "secret");
663 assert!(conn.contains("admin:secret"));
664 assert!(conn.contains("db.example.com:5432"));
665 assert!(conn.contains("acme_db"));
666 }
667
668 #[test]
669 fn test_create_database_sql() {
670 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
671 let sql = strategy.create_database_sql(&TenantId::new("acme"));
672 assert!(sql.contains("CREATE DATABASE"));
673 assert!(sql.contains("tenant_acme"));
674 }
675
676 #[test]
677 fn test_drop_database_sql() {
678 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
679 let sql = strategy.drop_database_sql(&TenantId::new("acme"));
680 assert!(sql.contains("DROP DATABASE IF EXISTS"));
681 assert!(sql.contains("tenant_acme"));
682 }
683
684 #[test]
685 fn test_system_tenant_resolution() {
686 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
687 let resolution = strategy.system_tenant();
688
689 match resolution {
690 TenantResolution::Database { connection } => {
691 assert_eq!(connection, "helios_system");
692 }
693 _ => panic!("expected Database resolution"),
694 }
695 }
696
697 #[test]
698 fn test_pool_access_tracking() {
699 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
700
701 strategy.record_pool_access(&TenantId::new("tenant1"));
702 strategy.record_pool_access(&TenantId::new("tenant2"));
703
704 let times = strategy.pool_access_times.read();
706 assert!(times.contains_key("tenant1"));
707 assert!(times.contains_key("tenant2"));
708 }
709
710 #[test]
711 fn test_tenants_to_evict() {
712 let config = DatabasePerTenantConfig {
713 max_pools: Some(2),
714 ..Default::default()
715 };
716 let strategy = DatabasePerTenantStrategy::new(config).unwrap();
717
718 strategy.record_pool_access(&TenantId::new("tenant1"));
720 std::thread::sleep(std::time::Duration::from_millis(10));
721 strategy.record_pool_access(&TenantId::new("tenant2"));
722 std::thread::sleep(std::time::Duration::from_millis(10));
723 strategy.record_pool_access(&TenantId::new("tenant3"));
724
725 let to_evict = strategy.tenants_to_evict();
726 assert_eq!(to_evict.len(), 1);
727 assert_eq!(to_evict[0], "tenant1"); }
729
730 #[test]
731 fn test_tenant_database_manager() {
732 let strategy = DatabasePerTenantStrategy::new(DatabasePerTenantConfig::default()).unwrap();
733 let manager = TenantDatabaseManager::new(strategy, "admin", "password");
734
735 let provision_sql = manager.provision_tenant_sql(&TenantId::new("newcorp"));
736 assert!(!provision_sql.is_empty());
737 assert!(provision_sql[0].contains("CREATE DATABASE"));
738
739 let deprovision_sql = manager.deprovision_tenant_sql(&TenantId::new("oldcorp"));
740 assert!(deprovision_sql.len() >= 2);
741 assert!(deprovision_sql.iter().any(|s| s.contains("DROP DATABASE")));
742 }
743
744 #[test]
745 fn test_long_tenant_id_hashing() {
746 let config = DatabasePerTenantConfig {
747 max_tenant_id_length: 10,
748 ..Default::default()
749 };
750 let strategy = DatabasePerTenantStrategy::new(config).unwrap();
751
752 let long_id = TenantId::new("this_is_a_very_long_tenant_identifier");
754 let db_name = strategy.database_name(&long_id);
755
756 assert!(db_name.starts_with("tenant_t_"));
758 assert!(db_name.len() <= 64); }
760}