1use async_trait::async_trait;
18use chrono::{DateTime, Utc};
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21use tokio::sync::RwLock;
22use uuid::Uuid;
23use validator::Validate;
24
25use crate::auth::AuthenticatedInfo;
26use crate::config::Config;
27use crate::identity::{IdentityProviderError, backend::IdentityBackend, types::*};
28use crate::keystone::ServiceState;
29use crate::plugin_manager::PluginManagerApi;
30use crate::resource::{ResourceApi, error::ResourceProviderError};
31
32pub struct IdentityService {
34 backend_driver: Arc<dyn IdentityBackend>,
35 caching: bool,
38 user_id_domain_id_cache: RwLock<HashMap<String, String>>,
42}
43
44impl IdentityService {
45 pub fn new<P: PluginManagerApi>(
46 config: &Config,
47 plugin_manager: &P,
48 ) -> Result<Self, IdentityProviderError> {
49 let backend_driver = plugin_manager
50 .get_identity_backend(config.identity.driver.clone())?
51 .clone();
52 Ok(Self {
53 backend_driver,
54 caching: config.identity.caching,
55 user_id_domain_id_cache: HashMap::new().into(),
56 })
57 }
58
59 pub fn from_driver<I: IdentityBackend + 'static>(driver: I) -> Self {
60 Self {
61 backend_driver: Arc::new(driver),
62 caching: false,
63 user_id_domain_id_cache: HashMap::new().into(),
64 }
65 }
66}
67
68#[async_trait]
69impl IdentityApi for IdentityService {
70 #[tracing::instrument(skip(self, state))]
71 async fn add_user_to_group<'a>(
72 &self,
73 state: &ServiceState,
74 user_id: &'a str,
75 group_id: &'a str,
76 ) -> Result<(), IdentityProviderError> {
77 self.backend_driver
78 .add_user_to_group(state, user_id, group_id)
79 .await
80 }
81
82 #[tracing::instrument(skip(self, state))]
83 async fn add_user_to_group_expiring<'a>(
84 &self,
85 state: &ServiceState,
86 user_id: &'a str,
87 group_id: &'a str,
88 idp_id: &'a str,
89 ) -> Result<(), IdentityProviderError> {
90 self.backend_driver
91 .add_user_to_group_expiring(state, user_id, group_id, idp_id)
92 .await
93 }
94
95 #[tracing::instrument(skip(self, state))]
96 async fn add_users_to_groups<'a>(
97 &self,
98 state: &ServiceState,
99 memberships: Vec<(&'a str, &'a str)>,
100 ) -> Result<(), IdentityProviderError> {
101 self.backend_driver
102 .add_users_to_groups(state, memberships)
103 .await
104 }
105
106 #[tracing::instrument(skip(self, state))]
107 async fn add_users_to_groups_expiring<'a>(
108 &self,
109 state: &ServiceState,
110 memberships: Vec<(&'a str, &'a str)>,
111 idp_id: &'a str,
112 ) -> Result<(), IdentityProviderError> {
113 self.backend_driver
114 .add_users_to_groups_expiring(state, memberships, idp_id)
115 .await
116 }
117
118 #[tracing::instrument(skip(self, state, auth))]
120 async fn authenticate_by_password(
121 &self,
122 state: &ServiceState,
123 auth: &UserPasswordAuthRequest,
124 ) -> Result<AuthenticatedInfo, IdentityProviderError> {
125 let mut auth = auth.clone();
126 if auth.id.is_none() {
127 if auth.name.is_none() {
128 return Err(IdentityProviderError::UserIdOrNameWithDomain);
129 }
130
131 if let Some(ref mut domain) = auth.domain {
132 if let Some(dname) = &domain.name {
133 let d = state
134 .provider
135 .get_resource_provider()
136 .find_domain_by_name(state, dname)
137 .await?
138 .ok_or(ResourceProviderError::DomainNotFound(dname.clone()))?;
139 domain.id = Some(d.id);
140 } else if domain.id.is_none() {
141 return Err(IdentityProviderError::UserIdOrNameWithDomain);
142 }
143 } else {
144 return Err(IdentityProviderError::UserIdOrNameWithDomain);
145 }
146 }
147
148 self.backend_driver
149 .authenticate_by_password(state, &auth)
150 .await
151 }
152
153 #[tracing::instrument(skip(self, state))]
155 async fn create_group(
156 &self,
157 state: &ServiceState,
158 group: GroupCreate,
159 ) -> Result<Group, IdentityProviderError> {
160 let mut res = group;
161 if res.id.is_none() {
162 res.id = Some(Uuid::new_v4().simple().to_string());
163 }
164 self.backend_driver.create_group(state, res).await
165 }
166
167 #[tracing::instrument(skip(self, state))]
169 async fn create_service_account(
170 &self,
171 state: &ServiceState,
172 sa: ServiceAccountCreate,
173 ) -> Result<ServiceAccount, IdentityProviderError> {
174 let mut mod_sa = sa;
175 if mod_sa.id.is_none() {
176 mod_sa.id = Some(Uuid::new_v4().simple().to_string());
177 }
178 if mod_sa.enabled.is_none() {
179 mod_sa.enabled = Some(true);
180 }
181 mod_sa.validate()?;
182 self.backend_driver
183 .create_service_account(state, mod_sa)
184 .await
185 }
186
187 #[tracing::instrument(skip(self, state))]
189 async fn create_user(
190 &self,
191 state: &ServiceState,
192 user: UserCreate,
193 ) -> Result<UserResponse, IdentityProviderError> {
194 let mut mod_user = user;
195 if mod_user.id.is_none() {
196 mod_user.id = Some(Uuid::new_v4().simple().to_string());
197 }
198 if mod_user.enabled.is_none() {
199 mod_user.enabled = Some(true);
200 }
201 mod_user.validate()?;
202 self.backend_driver.create_user(state, mod_user).await
203 }
204
205 #[tracing::instrument(skip(self, state))]
207 async fn delete_group<'a>(
208 &self,
209 state: &ServiceState,
210 group_id: &'a str,
211 ) -> Result<(), IdentityProviderError> {
212 self.backend_driver.delete_group(state, group_id).await
213 }
214
215 #[tracing::instrument(skip(self, state))]
217 async fn delete_user<'a>(
218 &self,
219 state: &ServiceState,
220 user_id: &'a str,
221 ) -> Result<(), IdentityProviderError> {
222 self.backend_driver.delete_user(state, user_id).await?;
223 if self.caching {
224 self.user_id_domain_id_cache.write().await.remove(user_id);
225 }
226 Ok(())
227 }
228
229 #[tracing::instrument(skip(self, state))]
231 async fn get_service_account<'a>(
232 &self,
233 state: &ServiceState,
234 user_id: &'a str,
235 ) -> Result<Option<ServiceAccount>, IdentityProviderError> {
236 self.backend_driver
237 .get_service_account(state, user_id)
238 .await
239 }
240
241 #[tracing::instrument(skip(self, state))]
243 async fn get_user<'a>(
244 &self,
245 state: &ServiceState,
246 user_id: &'a str,
247 ) -> Result<Option<UserResponse>, IdentityProviderError> {
248 let user = self.backend_driver.get_user(state, user_id).await?;
249 if self.caching
250 && let Some(user) = &user
251 {
252 self.user_id_domain_id_cache
253 .write()
254 .await
255 .insert(user_id.to_string(), user.domain_id.clone());
256 }
257 Ok(user)
258 }
259
260 async fn get_user_domain_id<'a>(
268 &self,
269 state: &ServiceState,
270 user_id: &'a str,
271 ) -> Result<String, IdentityProviderError> {
272 if self.caching {
273 if let Some(domain_id) = self.user_id_domain_id_cache.read().await.get(user_id) {
274 return Ok(domain_id.clone());
275 } else {
276 let domain_id = self
277 .backend_driver
278 .get_user_domain_id(state, user_id)
279 .await?;
280 self.user_id_domain_id_cache
281 .write()
282 .await
283 .insert(user_id.to_string(), domain_id.clone());
284 return Ok(domain_id);
285 }
286 } else {
287 Ok(self
288 .backend_driver
289 .get_user_domain_id(state, user_id)
290 .await?)
291 }
292 }
293
294 #[tracing::instrument(skip(self, state))]
296 async fn find_federated_user<'a>(
297 &self,
298 state: &ServiceState,
299 idp_id: &'a str,
300 unique_id: &'a str,
301 ) -> Result<Option<UserResponse>, IdentityProviderError> {
302 self.backend_driver
303 .find_federated_user(state, idp_id, unique_id)
304 .await
305 }
306
307 #[tracing::instrument(skip(self, state))]
309 async fn list_users(
310 &self,
311 state: &ServiceState,
312 params: &UserListParameters,
313 ) -> Result<Vec<UserResponse>, IdentityProviderError> {
314 self.backend_driver.list_users(state, params).await
315 }
316
317 #[tracing::instrument(skip(self, state))]
319 async fn list_groups(
320 &self,
321 state: &ServiceState,
322 params: &GroupListParameters,
323 ) -> Result<Vec<Group>, IdentityProviderError> {
324 self.backend_driver.list_groups(state, params).await
325 }
326
327 #[tracing::instrument(skip(self, state))]
329 async fn get_group<'a>(
330 &self,
331 state: &ServiceState,
332 group_id: &'a str,
333 ) -> Result<Option<Group>, IdentityProviderError> {
334 self.backend_driver.get_group(state, group_id).await
335 }
336
337 #[tracing::instrument(skip(self, state))]
339 async fn list_groups_of_user<'a>(
340 &self,
341 state: &ServiceState,
342 user_id: &'a str,
343 ) -> Result<Vec<Group>, IdentityProviderError> {
344 self.backend_driver
345 .list_groups_of_user(state, user_id)
346 .await
347 }
348
349 #[tracing::instrument(skip(self, state))]
350 async fn remove_user_from_group<'a>(
351 &self,
352 state: &ServiceState,
353 user_id: &'a str,
354 group_id: &'a str,
355 ) -> Result<(), IdentityProviderError> {
356 self.backend_driver
357 .remove_user_from_group(state, user_id, group_id)
358 .await
359 }
360
361 #[tracing::instrument(skip(self, state))]
362 async fn remove_user_from_group_expiring<'a>(
363 &self,
364 state: &ServiceState,
365 user_id: &'a str,
366 group_id: &'a str,
367 idp_id: &'a str,
368 ) -> Result<(), IdentityProviderError> {
369 self.backend_driver
370 .remove_user_from_group_expiring(state, user_id, group_id, idp_id)
371 .await
372 }
373
374 #[tracing::instrument(skip(self, state))]
375 async fn remove_user_from_groups<'a>(
376 &self,
377 state: &ServiceState,
378 user_id: &'a str,
379 group_ids: HashSet<&'a str>,
380 ) -> Result<(), IdentityProviderError> {
381 self.backend_driver
382 .remove_user_from_groups(state, user_id, group_ids)
383 .await
384 }
385
386 #[tracing::instrument(skip(self, state))]
387 async fn remove_user_from_groups_expiring<'a>(
388 &self,
389 state: &ServiceState,
390 user_id: &'a str,
391 group_ids: HashSet<&'a str>,
392 idp_id: &'a str,
393 ) -> Result<(), IdentityProviderError> {
394 self.backend_driver
395 .remove_user_from_groups_expiring(state, user_id, group_ids, idp_id)
396 .await
397 }
398
399 #[tracing::instrument(skip(self, state))]
400 async fn set_user_groups<'a>(
401 &self,
402 state: &ServiceState,
403 user_id: &'a str,
404 group_ids: HashSet<&'a str>,
405 ) -> Result<(), IdentityProviderError> {
406 self.backend_driver
407 .set_user_groups(state, user_id, group_ids)
408 .await
409 }
410
411 #[tracing::instrument(skip(self, state))]
412 async fn set_user_groups_expiring<'a>(
413 &self,
414 state: &ServiceState,
415 user_id: &'a str,
416 group_ids: HashSet<&'a str>,
417 idp_id: &'a str,
418 last_verified: Option<&'a DateTime<Utc>>,
419 ) -> Result<(), IdentityProviderError> {
420 self.backend_driver
421 .set_user_groups_expiring(state, user_id, group_ids, idp_id, last_verified)
422 .await
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::identity::backend::MockIdentityBackend;
430 use crate::identity::types::user::{UserCreateBuilder, UserResponseBuilder};
431 use crate::tests::get_mocked_state;
432
433 #[tokio::test]
434 async fn test_create_user() {
435 let state = get_mocked_state(None, None);
436 let mut backend = MockIdentityBackend::default();
437 backend.expect_create_user().returning(|_, _| {
438 Ok(UserResponseBuilder::default()
439 .id("id")
440 .domain_id("domain_id")
441 .enabled(true)
442 .name("name")
443 .build()
444 .unwrap())
445 });
446 let provider = IdentityService::from_driver(backend);
447
448 assert_eq!(
449 provider
450 .create_user(
451 &state,
452 UserCreateBuilder::default()
453 .name("uname")
454 .domain_id("did")
455 .build()
456 .unwrap()
457 )
458 .await
459 .unwrap(),
460 UserResponseBuilder::default()
461 .domain_id("domain_id")
462 .enabled(true)
463 .id("id")
464 .name("name")
465 .build()
466 .unwrap()
467 );
468 }
469
470 #[tokio::test]
471 async fn test_get_user() {
472 let state = get_mocked_state(None, None);
473 let mut backend = MockIdentityBackend::default();
474 backend
475 .expect_get_user()
476 .withf(|_, uid: &'_ str| uid == "uid")
477 .returning(|_, _| {
478 Ok(Some(
479 UserResponseBuilder::default()
480 .id("id")
481 .domain_id("domain_id")
482 .enabled(true)
483 .name("name")
484 .build()
485 .unwrap(),
486 ))
487 });
488 let provider = IdentityService::from_driver(backend);
489
490 assert_eq!(
491 provider
492 .get_user(&state, "uid")
493 .await
494 .unwrap()
495 .expect("user should be there"),
496 UserResponseBuilder::default()
497 .domain_id("domain_id")
498 .enabled(true)
499 .id("id")
500 .name("name")
501 .build()
502 .unwrap(),
503 );
504 }
505
506 #[tokio::test]
507 async fn test_get_user_domain_id() {
508 let state = get_mocked_state(None, None);
509 let mut backend = MockIdentityBackend::default();
510 backend
511 .expect_get_user_domain_id()
512 .withf(|_, uid: &'_ str| uid == "uid")
513 .times(2) .returning(|_, _| Ok("did".into()));
515 backend
516 .expect_get_user_domain_id()
517 .withf(|_, uid: &'_ str| uid == "missing")
518 .returning(|_, _| Err(IdentityProviderError::UserNotFound("missing".into())));
519 let mut provider = IdentityService::from_driver(backend);
520 provider.caching = true;
521
522 assert_eq!(
523 provider.get_user_domain_id(&state, "uid").await.unwrap(),
524 "did"
525 );
526 assert_eq!(
527 provider.get_user_domain_id(&state, "uid").await.unwrap(),
528 "did",
529 "second time data extracted from cache"
530 );
531 assert!(
532 provider
533 .get_user_domain_id(&state, "missing")
534 .await
535 .is_err()
536 );
537 provider.caching = false;
538 assert_eq!(
539 provider.get_user_domain_id(&state, "uid").await.unwrap(),
540 "did",
541 "third time backend is again triggered causing total of 2 invocations"
542 );
543 }
544
545 #[tokio::test]
546 async fn test_delete_user() {
547 let state = get_mocked_state(None, None);
548 let mut backend = MockIdentityBackend::default();
549 backend
550 .expect_delete_user()
551 .withf(|_, uid: &'_ str| uid == "uid")
552 .returning(|_, _| Ok(()));
553 let provider = IdentityService::from_driver(backend);
554
555 assert!(provider.delete_user(&state, "uid").await.is_ok());
556 }
557}