Skip to main content

openstack_keystone_core/identity/
service.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5//     http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12//
13// SPDX-License-Identifier: Apache-2.0
14
15//! # Identity provider
16
17use async_trait::async_trait;
18use chrono::{DateTime, Utc};
19use std::collections::{HashMap, HashSet};
20use std::sync::Arc;
21use tokio::sync::RwLock;
22use uuid::Uuid;
23use validator::Validate;
24
25use crate::auth::AuthenticatedInfo;
26use crate::config::Config;
27use crate::identity::{IdentityProviderError, backend::IdentityBackend, types::*};
28use crate::keystone::ServiceState;
29use crate::plugin_manager::PluginManagerApi;
30use crate::resource::{ResourceApi, error::ResourceProviderError};
31
32/// Identity provider.
33pub struct IdentityService {
34    backend_driver: Arc<dyn IdentityBackend>,
35    /// Caching flag. When enabled certain data can be cached (i.e. `domain_id`
36    /// by `user_id`).
37    caching: bool,
38    /// Internal cache of `user_id` to `domain_id` mappings. This information if
39    /// fully static and can never change (well, except with a direct SQL
40    /// update).
41    user_id_domain_id_cache: RwLock<HashMap<String, String>>,
42}
43
44impl IdentityService {
45    pub fn new<P: PluginManagerApi>(
46        config: &Config,
47        plugin_manager: &P,
48    ) -> Result<Self, IdentityProviderError> {
49        let backend_driver = plugin_manager
50            .get_identity_backend(config.identity.driver.clone())?
51            .clone();
52        Ok(Self {
53            backend_driver,
54            caching: config.identity.caching,
55            user_id_domain_id_cache: HashMap::new().into(),
56        })
57    }
58
59    pub fn from_driver<I: IdentityBackend + 'static>(driver: I) -> Self {
60        Self {
61            backend_driver: Arc::new(driver),
62            caching: false,
63            user_id_domain_id_cache: HashMap::new().into(),
64        }
65    }
66}
67
68#[async_trait]
69impl IdentityApi for IdentityService {
70    #[tracing::instrument(skip(self, state))]
71    async fn add_user_to_group<'a>(
72        &self,
73        state: &ServiceState,
74        user_id: &'a str,
75        group_id: &'a str,
76    ) -> Result<(), IdentityProviderError> {
77        self.backend_driver
78            .add_user_to_group(state, user_id, group_id)
79            .await
80    }
81
82    #[tracing::instrument(skip(self, state))]
83    async fn add_user_to_group_expiring<'a>(
84        &self,
85        state: &ServiceState,
86        user_id: &'a str,
87        group_id: &'a str,
88        idp_id: &'a str,
89    ) -> Result<(), IdentityProviderError> {
90        self.backend_driver
91            .add_user_to_group_expiring(state, user_id, group_id, idp_id)
92            .await
93    }
94
95    #[tracing::instrument(skip(self, state))]
96    async fn add_users_to_groups<'a>(
97        &self,
98        state: &ServiceState,
99        memberships: Vec<(&'a str, &'a str)>,
100    ) -> Result<(), IdentityProviderError> {
101        self.backend_driver
102            .add_users_to_groups(state, memberships)
103            .await
104    }
105
106    #[tracing::instrument(skip(self, state))]
107    async fn add_users_to_groups_expiring<'a>(
108        &self,
109        state: &ServiceState,
110        memberships: Vec<(&'a str, &'a str)>,
111        idp_id: &'a str,
112    ) -> Result<(), IdentityProviderError> {
113        self.backend_driver
114            .add_users_to_groups_expiring(state, memberships, idp_id)
115            .await
116    }
117
118    /// Authenticate user with the password auth method.
119    #[tracing::instrument(skip(self, state, auth))]
120    async fn authenticate_by_password(
121        &self,
122        state: &ServiceState,
123        auth: &UserPasswordAuthRequest,
124    ) -> Result<AuthenticatedInfo, IdentityProviderError> {
125        let mut auth = auth.clone();
126        if auth.id.is_none() {
127            if auth.name.is_none() {
128                return Err(IdentityProviderError::UserIdOrNameWithDomain);
129            }
130
131            if let Some(ref mut domain) = auth.domain {
132                if let Some(dname) = &domain.name {
133                    let d = state
134                        .provider
135                        .get_resource_provider()
136                        .find_domain_by_name(state, dname)
137                        .await?
138                        .ok_or(ResourceProviderError::DomainNotFound(dname.clone()))?;
139                    domain.id = Some(d.id);
140                } else if domain.id.is_none() {
141                    return Err(IdentityProviderError::UserIdOrNameWithDomain);
142                }
143            } else {
144                return Err(IdentityProviderError::UserIdOrNameWithDomain);
145            }
146        }
147
148        self.backend_driver
149            .authenticate_by_password(state, &auth)
150            .await
151    }
152
153    /// Create group.
154    #[tracing::instrument(skip(self, state))]
155    async fn create_group(
156        &self,
157        state: &ServiceState,
158        group: GroupCreate,
159    ) -> Result<Group, IdentityProviderError> {
160        let mut res = group;
161        if res.id.is_none() {
162            res.id = Some(Uuid::new_v4().simple().to_string());
163        }
164        self.backend_driver.create_group(state, res).await
165    }
166
167    /// Create service account.
168    #[tracing::instrument(skip(self, state))]
169    async fn create_service_account(
170        &self,
171        state: &ServiceState,
172        sa: ServiceAccountCreate,
173    ) -> Result<ServiceAccount, IdentityProviderError> {
174        let mut mod_sa = sa;
175        if mod_sa.id.is_none() {
176            mod_sa.id = Some(Uuid::new_v4().simple().to_string());
177        }
178        if mod_sa.enabled.is_none() {
179            mod_sa.enabled = Some(true);
180        }
181        mod_sa.validate()?;
182        self.backend_driver
183            .create_service_account(state, mod_sa)
184            .await
185    }
186
187    /// Create user.
188    #[tracing::instrument(skip(self, state))]
189    async fn create_user(
190        &self,
191        state: &ServiceState,
192        user: UserCreate,
193    ) -> Result<UserResponse, IdentityProviderError> {
194        let mut mod_user = user;
195        if mod_user.id.is_none() {
196            mod_user.id = Some(Uuid::new_v4().simple().to_string());
197        }
198        if mod_user.enabled.is_none() {
199            mod_user.enabled = Some(true);
200        }
201        mod_user.validate()?;
202        self.backend_driver.create_user(state, mod_user).await
203    }
204
205    /// Delete group.
206    #[tracing::instrument(skip(self, state))]
207    async fn delete_group<'a>(
208        &self,
209        state: &ServiceState,
210        group_id: &'a str,
211    ) -> Result<(), IdentityProviderError> {
212        self.backend_driver.delete_group(state, group_id).await
213    }
214
215    /// Delete user.
216    #[tracing::instrument(skip(self, state))]
217    async fn delete_user<'a>(
218        &self,
219        state: &ServiceState,
220        user_id: &'a str,
221    ) -> Result<(), IdentityProviderError> {
222        self.backend_driver.delete_user(state, user_id).await?;
223        if self.caching {
224            self.user_id_domain_id_cache.write().await.remove(user_id);
225        }
226        Ok(())
227    }
228
229    /// Get a service account by ID.
230    #[tracing::instrument(skip(self, state))]
231    async fn get_service_account<'a>(
232        &self,
233        state: &ServiceState,
234        user_id: &'a str,
235    ) -> Result<Option<ServiceAccount>, IdentityProviderError> {
236        self.backend_driver
237            .get_service_account(state, user_id)
238            .await
239    }
240
241    /// Get single user.
242    #[tracing::instrument(skip(self, state))]
243    async fn get_user<'a>(
244        &self,
245        state: &ServiceState,
246        user_id: &'a str,
247    ) -> Result<Option<UserResponse>, IdentityProviderError> {
248        let user = self.backend_driver.get_user(state, user_id).await?;
249        if self.caching
250            && let Some(user) = &user
251        {
252            self.user_id_domain_id_cache
253                .write()
254                .await
255                .insert(user_id.to_string(), user.domain_id.clone());
256        }
257        Ok(user)
258    }
259
260    /// Get `domain_id` of a user.
261    ///
262    /// When the caching is enabled check for the cached value there. When no
263    /// data is present for the key - invoke the backend driver and place
264    /// the new value into the cache. Other operations (`get_user`,
265    /// `delete_user`) update the cache with `delete_user` purging the value
266    /// from the cache.
267    async fn get_user_domain_id<'a>(
268        &self,
269        state: &ServiceState,
270        user_id: &'a str,
271    ) -> Result<String, IdentityProviderError> {
272        if self.caching {
273            if let Some(domain_id) = self.user_id_domain_id_cache.read().await.get(user_id) {
274                return Ok(domain_id.clone());
275            } else {
276                let domain_id = self
277                    .backend_driver
278                    .get_user_domain_id(state, user_id)
279                    .await?;
280                self.user_id_domain_id_cache
281                    .write()
282                    .await
283                    .insert(user_id.to_string(), domain_id.clone());
284                return Ok(domain_id);
285            }
286        } else {
287            Ok(self
288                .backend_driver
289                .get_user_domain_id(state, user_id)
290                .await?)
291        }
292    }
293
294    /// Find federated user by `idp_id` and `unique_id`.
295    #[tracing::instrument(skip(self, state))]
296    async fn find_federated_user<'a>(
297        &self,
298        state: &ServiceState,
299        idp_id: &'a str,
300        unique_id: &'a str,
301    ) -> Result<Option<UserResponse>, IdentityProviderError> {
302        self.backend_driver
303            .find_federated_user(state, idp_id, unique_id)
304            .await
305    }
306
307    /// List users.
308    #[tracing::instrument(skip(self, state))]
309    async fn list_users(
310        &self,
311        state: &ServiceState,
312        params: &UserListParameters,
313    ) -> Result<Vec<UserResponse>, IdentityProviderError> {
314        self.backend_driver.list_users(state, params).await
315    }
316
317    /// List groups.
318    #[tracing::instrument(skip(self, state))]
319    async fn list_groups(
320        &self,
321        state: &ServiceState,
322        params: &GroupListParameters,
323    ) -> Result<Vec<Group>, IdentityProviderError> {
324        self.backend_driver.list_groups(state, params).await
325    }
326
327    /// Get single group.
328    #[tracing::instrument(skip(self, state))]
329    async fn get_group<'a>(
330        &self,
331        state: &ServiceState,
332        group_id: &'a str,
333    ) -> Result<Option<Group>, IdentityProviderError> {
334        self.backend_driver.get_group(state, group_id).await
335    }
336
337    /// List groups a user is a member of.
338    #[tracing::instrument(skip(self, state))]
339    async fn list_groups_of_user<'a>(
340        &self,
341        state: &ServiceState,
342        user_id: &'a str,
343    ) -> Result<Vec<Group>, IdentityProviderError> {
344        self.backend_driver
345            .list_groups_of_user(state, user_id)
346            .await
347    }
348
349    #[tracing::instrument(skip(self, state))]
350    async fn remove_user_from_group<'a>(
351        &self,
352        state: &ServiceState,
353        user_id: &'a str,
354        group_id: &'a str,
355    ) -> Result<(), IdentityProviderError> {
356        self.backend_driver
357            .remove_user_from_group(state, user_id, group_id)
358            .await
359    }
360
361    #[tracing::instrument(skip(self, state))]
362    async fn remove_user_from_group_expiring<'a>(
363        &self,
364        state: &ServiceState,
365        user_id: &'a str,
366        group_id: &'a str,
367        idp_id: &'a str,
368    ) -> Result<(), IdentityProviderError> {
369        self.backend_driver
370            .remove_user_from_group_expiring(state, user_id, group_id, idp_id)
371            .await
372    }
373
374    #[tracing::instrument(skip(self, state))]
375    async fn remove_user_from_groups<'a>(
376        &self,
377        state: &ServiceState,
378        user_id: &'a str,
379        group_ids: HashSet<&'a str>,
380    ) -> Result<(), IdentityProviderError> {
381        self.backend_driver
382            .remove_user_from_groups(state, user_id, group_ids)
383            .await
384    }
385
386    #[tracing::instrument(skip(self, state))]
387    async fn remove_user_from_groups_expiring<'a>(
388        &self,
389        state: &ServiceState,
390        user_id: &'a str,
391        group_ids: HashSet<&'a str>,
392        idp_id: &'a str,
393    ) -> Result<(), IdentityProviderError> {
394        self.backend_driver
395            .remove_user_from_groups_expiring(state, user_id, group_ids, idp_id)
396            .await
397    }
398
399    #[tracing::instrument(skip(self, state))]
400    async fn set_user_groups<'a>(
401        &self,
402        state: &ServiceState,
403        user_id: &'a str,
404        group_ids: HashSet<&'a str>,
405    ) -> Result<(), IdentityProviderError> {
406        self.backend_driver
407            .set_user_groups(state, user_id, group_ids)
408            .await
409    }
410
411    #[tracing::instrument(skip(self, state))]
412    async fn set_user_groups_expiring<'a>(
413        &self,
414        state: &ServiceState,
415        user_id: &'a str,
416        group_ids: HashSet<&'a str>,
417        idp_id: &'a str,
418        last_verified: Option<&'a DateTime<Utc>>,
419    ) -> Result<(), IdentityProviderError> {
420        self.backend_driver
421            .set_user_groups_expiring(state, user_id, group_ids, idp_id, last_verified)
422            .await
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::identity::backend::MockIdentityBackend;
430    use crate::identity::types::user::{UserCreateBuilder, UserResponseBuilder};
431    use crate::tests::get_mocked_state;
432
433    #[tokio::test]
434    async fn test_create_user() {
435        let state = get_mocked_state(None, None);
436        let mut backend = MockIdentityBackend::default();
437        backend.expect_create_user().returning(|_, _| {
438            Ok(UserResponseBuilder::default()
439                .id("id")
440                .domain_id("domain_id")
441                .enabled(true)
442                .name("name")
443                .build()
444                .unwrap())
445        });
446        let provider = IdentityService::from_driver(backend);
447
448        assert_eq!(
449            provider
450                .create_user(
451                    &state,
452                    UserCreateBuilder::default()
453                        .name("uname")
454                        .domain_id("did")
455                        .build()
456                        .unwrap()
457                )
458                .await
459                .unwrap(),
460            UserResponseBuilder::default()
461                .domain_id("domain_id")
462                .enabled(true)
463                .id("id")
464                .name("name")
465                .build()
466                .unwrap()
467        );
468    }
469
470    #[tokio::test]
471    async fn test_get_user() {
472        let state = get_mocked_state(None, None);
473        let mut backend = MockIdentityBackend::default();
474        backend
475            .expect_get_user()
476            .withf(|_, uid: &'_ str| uid == "uid")
477            .returning(|_, _| {
478                Ok(Some(
479                    UserResponseBuilder::default()
480                        .id("id")
481                        .domain_id("domain_id")
482                        .enabled(true)
483                        .name("name")
484                        .build()
485                        .unwrap(),
486                ))
487            });
488        let provider = IdentityService::from_driver(backend);
489
490        assert_eq!(
491            provider
492                .get_user(&state, "uid")
493                .await
494                .unwrap()
495                .expect("user should be there"),
496            UserResponseBuilder::default()
497                .domain_id("domain_id")
498                .enabled(true)
499                .id("id")
500                .name("name")
501                .build()
502                .unwrap(),
503        );
504    }
505
506    #[tokio::test]
507    async fn test_get_user_domain_id() {
508        let state = get_mocked_state(None, None);
509        let mut backend = MockIdentityBackend::default();
510        backend
511            .expect_get_user_domain_id()
512            .withf(|_, uid: &'_ str| uid == "uid")
513            .times(2) // only 2 times
514            .returning(|_, _| Ok("did".into()));
515        backend
516            .expect_get_user_domain_id()
517            .withf(|_, uid: &'_ str| uid == "missing")
518            .returning(|_, _| Err(IdentityProviderError::UserNotFound("missing".into())));
519        let mut provider = IdentityService::from_driver(backend);
520        provider.caching = true;
521
522        assert_eq!(
523            provider.get_user_domain_id(&state, "uid").await.unwrap(),
524            "did"
525        );
526        assert_eq!(
527            provider.get_user_domain_id(&state, "uid").await.unwrap(),
528            "did",
529            "second time data extracted from cache"
530        );
531        assert!(
532            provider
533                .get_user_domain_id(&state, "missing")
534                .await
535                .is_err()
536        );
537        provider.caching = false;
538        assert_eq!(
539            provider.get_user_domain_id(&state, "uid").await.unwrap(),
540            "did",
541            "third time backend is again triggered causing total of 2 invocations"
542        );
543    }
544
545    #[tokio::test]
546    async fn test_delete_user() {
547        let state = get_mocked_state(None, None);
548        let mut backend = MockIdentityBackend::default();
549        backend
550            .expect_delete_user()
551            .withf(|_, uid: &'_ str| uid == "uid")
552            .returning(|_, _| Ok(()));
553        let provider = IdentityService::from_driver(backend);
554
555        assert!(provider.delete_user(&state, "uid").await.is_ok());
556    }
557}