1use super::context::{TenantContext, TenantId};
4use crate::auth::AuthFramework;
5use crate::config::AuthConfig;
6use crate::errors::{AuthError, Result};
7use dashmap::DashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use tracing::{debug, error, info, warn};
11
12#[derive(Debug, Clone)]
14pub enum TenantRegistryError {
15 TenantNotFound(String),
17
18 TenantAlreadyExists(String),
20
21 InvalidConfiguration(String),
23
24 TenantInactive(String),
26
27 InternalError(String),
29}
30
31impl std::fmt::Display for TenantRegistryError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::TenantNotFound(id) => write!(f, "Tenant not found: {}", id),
35 Self::TenantAlreadyExists(id) => write!(f, "Tenant already exists: {}", id),
36 Self::InvalidConfiguration(msg) => write!(f, "Invalid configuration: {}", msg),
37 Self::TenantInactive(id) => write!(f, "Tenant is inactive: {}", id),
38 Self::InternalError(msg) => write!(f, "Internal error: {}", msg),
39 }
40 }
41}
42
43impl std::error::Error for TenantRegistryError {}
44
45impl From<TenantRegistryError> for AuthError {
46 fn from(err: TenantRegistryError) -> Self {
47 AuthError::internal(err.to_string())
48 }
49}
50
51#[derive(Clone)]
57pub struct TenantRegistry {
58 default_config: Arc<RwLock<AuthConfig>>,
60
61 frameworks: Arc<DashMap<TenantId, Arc<RwLock<AuthFramework>>>>,
63
64 tenants: Arc<DashMap<TenantId, TenantContext>>,
66}
67
68impl TenantRegistry {
69 pub fn new(default_config: AuthConfig) -> Self {
71 Self {
72 default_config: Arc::new(RwLock::new(default_config)),
73 frameworks: Arc::new(DashMap::new()),
74 tenants: Arc::new(DashMap::new()),
75 }
76 }
77
78 pub async fn register_tenant(
80 &self,
81 tenant_context: TenantContext,
82 config: Option<AuthConfig>,
83 ) -> Result<Arc<RwLock<AuthFramework>>> {
84 if !tenant_context.active {
86 warn!(
87 "Attempted to register inactive tenant: {}",
88 tenant_context.id
89 );
90 return Err(AuthError::internal(
91 TenantRegistryError::TenantInactive(tenant_context.id.to_string()).to_string(),
92 ));
93 }
94
95 let tenant_id = tenant_context.id.clone();
96
97 if self.tenants.contains_key(&tenant_id) {
99 error!("Tenant already registered: {}", tenant_id);
100 return Err(AuthError::internal(
101 TenantRegistryError::TenantAlreadyExists(tenant_id.to_string()).to_string(),
102 ));
103 }
104
105 let mut auth_config = if let Some(cfg) = config {
107 cfg
108 } else {
109 self.default_config.read().await.clone()
110 };
111
112 auth_config.method_configs.insert(
115 "tenant_id".to_string(),
116 serde_json::json!(tenant_id.as_str()),
117 );
118
119 let mut framework = AuthFramework::new(auth_config);
121
122 if let Err(e) = framework.initialize().await {
124 error!(
125 "Failed to initialize AuthFramework for tenant {}: {}",
126 tenant_id, e
127 );
128 return Err(e);
129 }
130
131 let framework = Arc::new(RwLock::new(framework));
132
133 self.frameworks.insert(tenant_id.clone(), framework.clone());
135 self.tenants
136 .insert(tenant_id.clone(), tenant_context.clone());
137
138 info!(
139 "Tenant registered: {} ({})",
140 tenant_id, tenant_context.metadata.name
141 );
142
143 Ok(framework)
144 }
145
146 pub fn get_tenant_framework(&self, tenant_id: &TenantId) -> Result<Arc<RwLock<AuthFramework>>> {
148 let tenant_ref = self.tenants.get(tenant_id).ok_or_else(|| {
149 AuthError::internal(
150 TenantRegistryError::TenantNotFound(tenant_id.to_string()).to_string(),
151 )
152 })?;
153
154 if !tenant_ref.active {
156 debug!("Attempted to access inactive tenant: {}", tenant_id);
157 return Err(AuthError::internal(
158 TenantRegistryError::TenantInactive(tenant_id.to_string()).to_string(),
159 ));
160 }
161
162 self.frameworks
163 .get(tenant_id)
164 .map(|f| f.clone())
165 .ok_or_else(|| {
166 error!("Framework not found for tenant: {}", tenant_id);
167 AuthError::internal(
168 TenantRegistryError::InternalError(format!(
169 "Framework not found for tenant: {}",
170 tenant_id
171 ))
172 .to_string(),
173 )
174 })
175 }
176
177 pub fn get_tenant_context(&self, tenant_id: &TenantId) -> Result<TenantContext> {
179 self.tenants
180 .get(tenant_id)
181 .map(|t| t.clone())
182 .ok_or_else(|| {
183 AuthError::internal(
184 TenantRegistryError::TenantNotFound(tenant_id.to_string()).to_string(),
185 )
186 })
187 }
188
189 pub async fn deactivate_tenant(&self, tenant_id: &TenantId) -> Result<()> {
191 if let Some(mut tenant) = self.tenants.get_mut(tenant_id) {
192 tenant.deactivate();
193 info!("Tenant deactivated: {}", tenant_id);
194 Ok(())
195 } else {
196 Err(AuthError::internal(
197 TenantRegistryError::TenantNotFound(tenant_id.to_string()).to_string(),
198 ))
199 }
200 }
201
202 pub async fn activate_tenant(&self, tenant_id: &TenantId) -> Result<()> {
204 if let Some(mut tenant) = self.tenants.get_mut(tenant_id) {
205 tenant.activate();
206 info!("Tenant activated: {}", tenant_id);
207 Ok(())
208 } else {
209 Err(AuthError::internal(
210 TenantRegistryError::TenantNotFound(tenant_id.to_string()).to_string(),
211 ))
212 }
213 }
214
215 pub async fn remove_tenant(&self, tenant_id: &TenantId) -> Result<()> {
220 self.frameworks.remove(tenant_id);
221 self.tenants.remove(tenant_id);
222 info!("Tenant removed from registry: {}", tenant_id);
223 Ok(())
224 }
225
226 pub async fn list_tenant_ids(&self) -> Vec<TenantId> {
228 self.tenants.iter().map(|t| t.id.clone()).collect()
229 }
230
231 pub async fn list_active_tenants(&self) -> Vec<TenantId> {
233 self.tenants
234 .iter()
235 .filter(|t| t.active)
236 .map(|t| t.id.clone())
237 .collect()
238 }
239
240 pub async fn tenant_count(&self) -> usize {
242 self.tenants.len()
243 }
244
245 pub async fn active_tenant_count(&self) -> usize {
247 self.tenants.iter().filter(|t| t.active).count()
248 }
249
250 pub async fn set_default_config(&self, config: AuthConfig) {
252 let mut default = self.default_config.write().await;
253 *default = config;
254 }
255
256 pub async fn get_default_config(&self) -> AuthConfig {
258 self.default_config.read().await.clone()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 fn create_test_config() -> AuthConfig {
268 AuthConfig {
269 secret: Some("Xk9#mP3$vQ7!nR2@wL8&jT5*cY1%fB4^z6".to_string()),
271 ..AuthConfig::default()
272 }
273 }
274
275 #[tokio::test]
276 async fn test_register_and_get_tenant() {
277 let registry = TenantRegistry::new(create_test_config());
278 let context = TenantContext::with_name("test-tenant", "Test Tenant").unwrap();
279
280 let result = registry.register_tenant(context, None).await;
281 assert!(
282 result.is_ok(),
283 "Failed to register tenant: {:?}",
284 result.err()
285 );
286
287 let tenant_id = TenantId::new("test-tenant");
288 let framework = registry.get_tenant_framework(&tenant_id);
289 assert!(
290 framework.is_ok(),
291 "Failed to get tenant framework: {:?}",
292 framework.err()
293 );
294 }
295
296 #[tokio::test]
297 async fn test_duplicate_tenant_registration() {
298 let registry = TenantRegistry::new(create_test_config());
299 let context = TenantContext::with_name("test", "Test").unwrap();
300
301 let _ = registry.register_tenant(context.clone(), None).await;
302 let result = registry.register_tenant(context, None).await;
303
304 assert!(result.is_err());
305 }
306
307 #[tokio::test]
308 async fn test_tenant_activation_deactivation() {
309 let registry = TenantRegistry::new(create_test_config());
310 let context = TenantContext::with_name("test", "Test").unwrap();
311 let tenant_id = context.id.clone();
312
313 let result = registry.register_tenant(context, None).await;
314 assert!(
315 result.is_ok(),
316 "Failed to register tenant: {:?}",
317 result.err()
318 );
319
320 let deactivate_result = registry.deactivate_tenant(&tenant_id).await;
322 assert!(
323 deactivate_result.is_ok(),
324 "Failed to deactivate tenant: {:?}",
325 deactivate_result.err()
326 );
327
328 let result = registry.get_tenant_framework(&tenant_id);
330 assert!(
331 result.is_err(),
332 "Should not be able to access deactivated tenant"
333 );
334
335 let activate_result = registry.activate_tenant(&tenant_id).await;
337 assert!(
338 activate_result.is_ok(),
339 "Failed to activate tenant: {:?}",
340 activate_result.err()
341 );
342
343 let result = registry.get_tenant_framework(&tenant_id);
345 assert!(
346 result.is_ok(),
347 "Should be able to access reactivated tenant: {:?}",
348 result.err()
349 );
350 }
351
352 #[tokio::test]
353 async fn test_list_tenants() {
354 let registry = TenantRegistry::new(create_test_config());
355
356 let c1 = TenantContext::with_name("tenant1", "Tenant 1").unwrap();
357 let c2 = TenantContext::with_name("tenant2", "Tenant 2").unwrap();
358
359 let r1 = registry.register_tenant(c1, None).await;
360 let r2 = registry.register_tenant(c2, None).await;
361
362 assert!(r1.is_ok(), "Failed to register tenant1: {:?}", r1.err());
363 assert!(r2.is_ok(), "Failed to register tenant2: {:?}", r2.err());
364
365 let count = registry.tenant_count().await;
366 assert_eq!(count, 2, "Expected 2 tenants, got {}", count);
367
368 let ids = registry.list_tenant_ids().await;
369 assert_eq!(ids.len(), 2, "Expected 2 tenant IDs, got {}", ids.len());
370 }
371
372 #[tokio::test]
374 async fn test_tenant_creation_with_minimal_config() {
375 let registry = TenantRegistry::new(create_test_config());
376 let context = TenantContext::with_name("minimal", "Minimal Tenant").unwrap();
377
378 let result = registry.register_tenant(context, None).await;
380 match result {
381 Ok(_) => {
382 let ids = registry.list_tenant_ids().await;
383 assert!(ids.contains(&TenantId::new("minimal")));
384 }
385 Err(e) => {
386 panic!("Failed to register tenant with test config: {}", e);
387 }
388 }
389 }
390
391 #[tokio::test]
393 async fn test_concurrent_tenant_registration() {
394 let registry = std::sync::Arc::new(TenantRegistry::new(create_test_config()));
395 let mut handles = vec![];
396
397 for i in 0..5 {
398 let reg = registry.clone();
399 let handle = tokio::spawn(async move {
400 let id = format!("tenant-{}", i);
401 let context = TenantContext::with_name(id, format!("Tenant {}", i)).unwrap();
402 reg.register_tenant(context, None).await
403 });
404 handles.push(handle);
405 }
406
407 for handle in handles {
408 let result = handle.await.unwrap();
409 assert!(
410 result.is_ok(),
411 "Concurrent registration failed: {:?}",
412 result.err()
413 );
414 }
415
416 let count = registry.tenant_count().await;
417 assert_eq!(
418 count, 5,
419 "Expected 5 tenants after concurrent registration, got {}",
420 count
421 );
422 }
423
424 #[tokio::test]
426 async fn test_tenant_data_isolation() {
427 let registry = TenantRegistry::new(create_test_config());
428
429 let c1 = TenantContext::with_name("tenant-a", "Tenant A").unwrap();
430 let c2 = TenantContext::with_name("tenant-b", "Tenant B").unwrap();
431
432 let r1 = registry.register_tenant(c1, None).await;
433 let r2 = registry.register_tenant(c2, None).await;
434
435 assert!(r1.is_ok());
436 assert!(r2.is_ok());
437
438 let deactivate = registry.deactivate_tenant(&TenantId::new("tenant-a")).await;
440 assert!(deactivate.is_ok());
441
442 let result = registry.get_tenant_framework(&TenantId::new("tenant-b"));
444 assert!(result.is_ok(), "Tenant B should still be accessible");
445
446 let result = registry.get_tenant_framework(&TenantId::new("tenant-a"));
448 assert!(
449 result.is_err(),
450 "Tenant A should not be accessible when deactivated"
451 );
452 }
453}