Skip to main content

pharia_common/
iam.rs

1//! **IAM** is short for **I**dentity **A**ccess **M**anagement. This module contains opinionated
2//! adapters to connect to the internal Pharia IAM solution.
3
4use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware};
8use serde::{Deserialize, Serialize};
9
10/// URL of IAM in our production environment
11pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13/// URL of IAM in our staging environment
14pub const IAM_STAGE_URL: &str = "https://pharia-iam.stage.product.pharia.com";
15
16pub struct IamClientBuilder {
17    base_url: String,
18    client_builder: ClientBuilder,
19}
20
21impl IamClientBuilder {
22    /// Create a builder for the IAM client. Use this builder if you want to attach middleware, for
23    /// testing, opentelemetry or any other purpose.
24    pub fn new(base_url: String) -> Self {
25        let client = Client::builder().use_rustls_tls().build().expect(
26            "Must be able to initialize TLS backend and resolver must be able to load system \
27            configuration.",
28        );
29
30        let client_builder = ClientBuilder::new(client);
31        IamClientBuilder {
32            base_url,
33            client_builder,
34        }
35    }
36
37    /// Attach a middleware to the client
38    pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
39        self.client_builder = self.client_builder.with(middleware);
40        self
41    }
42
43    /// Inject a middleware into the client extracting the open telemetry context from the current
44    /// tracing context and propagating the open telemetry request headers.
45    #[cfg(feature = "opentelemetry")]
46    pub fn with_opentelemetry(self) -> Self {
47        let middleware = reqwest_tracing::TracingMiddleware::default();
48        self.with_middleware(middleware)
49    }
50
51    #[cfg(test)]
52    /// Register a VCR middleware which would replay the response from the cassette, rather than
53    /// making an actual request. If the cassette does not exist, an actual request will be made
54    /// and the cassette will be recorded.
55    pub fn with_vcr(self, path_to_cassette: std::path::PathBuf) -> Self {
56        let cassette_does_exist = path_to_cassette.is_file();
57        let vcr_mode = if cassette_does_exist {
58            reqwest_vcr::VCRMode::Replay
59        } else {
60            reqwest_vcr::VCRMode::Record
61        };
62
63        let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
64            .unwrap()
65            .with_mode(vcr_mode)
66            .with_modify_request(|request| {
67                if let Some(header) = request.headers.get_mut("authorization") {
68                    *header = vec!["TOKEN_REMOVED".to_owned()];
69                }
70            });
71
72        self.with_middleware(middleware)
73    }
74
75    /// Construct the IAM client
76    pub fn build(self) -> IamClient {
77        let client = self.client_builder.build();
78        IamClient {
79            base_url: self.base_url,
80            http_client: client,
81        }
82    }
83}
84
85/// Client forPharia **I**dentity **A**ccess **M**anagement. Authenticate and authorize users.
86#[derive(Clone, Debug)]
87pub struct IamClient {
88    /// Environment specific URL to Pharia IAM. E.g. <https://pharia-iam.product.pharia.com>
89    base_url: String,
90    /// Used for sending the http requests. We are using `ClientWithMiddleware` to allow for VCR
91    /// recording in tests.
92    http_client: ClientWithMiddleware,
93}
94
95impl IamClient {
96    /// Use this instead of [`IamClient::new`] if you want to have additional middleware.
97    pub fn builder(base_url: String) -> IamClientBuilder {
98        IamClientBuilder::new(base_url)
99    }
100
101    /// Construct a new client using the respective IAM instance. E.g. [`IAM_PRODUCTION_URL`]
102    pub fn new(base_url: String) -> Self {
103        Self::builder(base_url).build()
104    }
105
106    #[cfg(test)]
107    pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
108        Self::builder(base_url).with_vcr(path_to_cassette).build()
109    }
110
111    /// One stop shop for both authentication and asking a set of permissions. While this method
112    /// returns a subset of permissions to which matches the privileges of the user it does not
113    /// perform the authorization check. Call `authorize`
114    ///
115    /// # Parameters
116    ///
117    /// * `token`: Service or user token used for authentication.
118    /// * `permissions`: A list of all permissions you are interested in. The response will contain
119    ///   the subset of these permissions which are privileges the user has.
120    pub async fn check_user<'a>(
121        &self,
122        token: impl Display,
123        permissions: &'a [Permission<'a>],
124    ) -> Result<UserInfoAndPermissions, CheckUserError> {
125        let request_body = CheckUserRequestBody { permissions };
126
127        let response = self
128            .http_client
129            .post(format!("{base_url}/check_user", base_url = self.base_url))
130            .bearer_auth(token)
131            .json(&request_body)
132            .send()
133            .await
134            .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
135
136        // A long standing quirk of the HTTP standard: Unauthorized 401 actually means
137        // "unauthenticated". We consider this a domain specific logic error, rather than a runtime
138        // error, which should be fixed with retry. Therfore we categorize this error differently
139        // the other connection errors
140        if response.status() == StatusCode::UNAUTHORIZED {
141            return Err(CheckUserError::Unauthenticated);
142        }
143
144        if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
145            use anyhow::anyhow;
146            eprintln!("{}", response.text().await.unwrap());
147            return Err(CheckUserError::ConnectionError(anyhow!(
148                "Unprocessable entity"
149            )));
150        }
151
152        // Map all other thing to ConnectionError
153        response
154            .error_for_status_ref()
155            .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
156
157        let user_info = response
158            .json()
159            .await
160            .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
161
162        Ok(user_info)
163    }
164
165    /// Same as `check_user` but also performs the authorization check and fails if the user is not
166    /// authorized.
167    ///
168    /// # Parameters
169    ///
170    /// * `token`: Service or user token used for authentication.
171    /// * `permissions`: A list of all permissions you are interested in. The response will contain
172    ///   the subset of these permissions which are privileges the user has.
173    ///
174    /// Example: Check if the user has the `AccessAssistant` permission.
175    ///
176    /// ```
177    /// use pharia_common::{Permission, IamClient, AuthorizationError, IAM_PRODUCTION_URL};
178    ///
179    /// pub async fn authorize(token: &str) -> Result<(), AuthorizationError> {
180    ///     let iam = IamClient::new(IAM_PRODUCTION_URL.to_owned());
181    ///     let permissions = [Permission::AccessAssistant];
182    ///     let user_info = iam.authorize(token, &permissions).await?;
183    ///     Ok(())
184    /// }
185    /// ```
186    pub async fn authorize<'a>(
187        &self,
188        token: impl Display,
189        permissions: &'a [Permission<'a>],
190    ) -> Result<UserInfoAndPermissions, AuthorizationError> {
191        let user_info = self.check_user(token, permissions).await?;
192        if user_info.permissions == permissions {
193            Ok(user_info)
194        } else {
195            Err(AuthorizationError::Unauthorized)
196        }
197    }
198}
199
200/// Body of the the IAM `/check_user` route. The token is not passed in the body but in the
201/// authorization header.
202#[derive(Serialize)]
203struct CheckUserRequestBody<'a> {
204    /// A list of permissions to query for the specific user.
205    permissions: &'a [Permission<'a>],
206}
207
208/// Returned by [`IamClient::check_user`]. Contains information describing the user as well as the
209/// union of the queried permissions and the privileges of the user.
210#[derive(Deserialize, PartialEq, Eq, Debug)]
211pub struct UserInfoAndPermissions {
212    /// Unique ID of the User
213    pub sub: String,
214    /// Email of the user. `None` for Service users
215    pub email: Option<String>,
216    /// May be `None` for Service Users
217    pub email_verified: Option<bool>,
218    /// List of requested permissions, which are privieleges of the User Service. They are in the
219    /// same order as in the query
220    pub permissions: Vec<Permission<'static>>,
221}
222
223/// An error returned by [`IamClient::check_user`]. Note that this does **not** include
224/// unauthorized. To check for authorization inspect the permissions of [`UserInfoAndPermissions`]
225#[derive(thiserror::Error, Debug)]
226pub enum CheckUserError {
227    #[error("User is Unauthenticated. Token is invalid")]
228    Unauthenticated,
229    #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
230    ConnectionError(#[source] anyhow::Error),
231}
232
233#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
234#[serde(tag = "permission")]
235pub enum Permission<'a> {
236    AccessAssistant,
237    NuminousAccess,
238    /// The kernel uses this permission to authorize skill execution
239    KernelAccess,
240    /// Used by inference to decide wether a user is authorized to perform any kind of inference
241    /// requests.
242    ExecuteJobs,
243    /// Is this user allowed to use this model? "*" Can be used as a model name in order to indicate
244    /// access to all models.
245    AccessModel {
246        model: Cow<'a, str>,
247    },
248    HasRelation {
249        relation: Cow<'a, str>,
250        object: Cow<'a, str>,
251    },
252}
253
254#[derive(thiserror::Error, Debug)]
255pub enum AuthorizationError {
256    #[error("User is Unauthenticated. Token is invalid")]
257    Unauthenticated,
258    #[error("Unauthorized")]
259    Unauthorized,
260    #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
261    ConnectionError(#[source] anyhow::Error),
262}
263
264impl From<CheckUserError> for AuthorizationError {
265    fn from(err: CheckUserError) -> Self {
266        match err {
267            CheckUserError::Unauthenticated => AuthorizationError::Unauthenticated,
268            CheckUserError::ConnectionError(err) => AuthorizationError::ConnectionError(err),
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use dotenvy::dotenv;
276    use std::{borrow::Cow, env, path::PathBuf};
277
278    use crate::iam::IAM_STAGE_URL;
279
280    use super::{
281        CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
282    };
283
284    #[tokio::test]
285    async fn valid_user_token() {
286        // We are using cassets to record the request. This makes the test easy to execute even
287        // without a connection to Pharia. Additionally it allows us to execute the test even
288        // without the specific token of the user who recorded it at hand.
289        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
290        cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
291
292        // Given a client
293        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
294
295        // When sending a check user request with a valid token
296        let response = client.check_user(token(), &[]).await.unwrap();
297
298        // Then we recevie an answer, identifying the user
299        let expected = UserInfoAndPermissions {
300            sub: "295355180126307110".to_owned(),
301            email: Some("markus.klein@aleph-alpha.com".to_owned()),
302            email_verified: Some(true),
303            permissions: vec![],
304        };
305        assert_eq!(expected, response);
306    }
307
308    #[tokio::test]
309    async fn invalid_user_token() {
310        // We are using cassets to record the request. This makes the test easy to execute even
311        // without a connection to Pharia. Additionally it allows us to execute the test even
312        // without the specific token of the user who recorded it at hand.
313        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
314        cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
315
316        // Given an invalid Pharia User Token
317        let token = "I-AM-AN-INVALID-TOKEN";
318        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
319
320        // When sending a check user request
321        let result = client.check_user(token, &[]).await;
322
323        // Then the user is unauthenticated
324        assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
325    }
326
327    #[tokio::test]
328    async fn asking_for_permissions() {
329        // We are using cassets to record the request. This makes the test easy to execute even
330        // without a connection to Pharia. Additionally it allows us to execute the test even
331        // without the specific token of the user who recorded it at hand.
332        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
333        cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
334
335        // Given a client
336        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
337        let permissions = [
338            Permission::KernelAccess,
339            Permission::ExecuteJobs,
340            Permission::AccessAssistant,
341            Permission::NuminousAccess,
342            Permission::AccessModel { model: "*".into() },
343        ];
344
345        // When sending a check user request with a token authorized for all permission it is
346        // asking for.
347        let response = client.check_user(token(), &permissions).await.unwrap();
348
349        // Then we recevie an answer, identifying the user and all the permissions are visible
350        // in the answer.
351        let expected = UserInfoAndPermissions {
352            sub: "295355180126307110".to_owned(),
353            email: Some("markus.klein@aleph-alpha.com".to_owned()),
354            email_verified: Some(true),
355            // It seems the IAM backend maintains order. So this assertion works.
356            permissions: permissions.to_vec(),
357        };
358        assert_eq!(expected, response);
359    }
360
361    #[tokio::test]
362    async fn authorize() {
363        // We are using cassets to record the request. This makes the test easy to execute even
364        // without a connection to Pharia. Additionally it allows us to execute the test even
365        // without the specific token of the user who recorded it at hand.
366        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
367        cassette_path.push("tests/cassettes/authorize.vcr.json");
368
369        // Given a client
370        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
371        let permissions = [
372            Permission::KernelAccess,
373            Permission::ExecuteJobs,
374            Permission::AccessAssistant,
375            Permission::NuminousAccess,
376            Permission::AccessModel { model: "*".into() },
377        ];
378
379        // When sending a check user request with a token authorized for all permission it is
380        // asking for.
381        let response = client.authorize(token(), &permissions).await.unwrap();
382
383        // Then we recevie an answer, identifying the user and all the permissions are visible
384        // in the answer.
385        let expected = UserInfoAndPermissions {
386            sub: "295355180126307110".to_owned(),
387            email: Some("markus.klein@aleph-alpha.com".to_owned()),
388            email_verified: Some(true),
389            // It seems the IAM backend maintains order. So this assertion works.
390            permissions: permissions.to_vec(),
391        };
392        assert_eq!(expected, response);
393    }
394
395    #[tokio::test]
396    async fn asking_for_permissions_as_service() {
397        // We are using cassets to record the request. This makes the test easy to execute even
398        // without a connection to Pharia. Additionally it allows us to execute the test even
399        // without the specific token of the user who recorded it at hand.
400        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
401        cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
402
403        // Given a client
404        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
405        let permissions = [Permission::AccessAssistant, Permission::NuminousAccess];
406
407        // When sending a check user request with a token authorized for all permission it is
408        // asking for.
409        let response = client
410            .check_user(service_token(), &permissions)
411            .await
412            .unwrap();
413
414        // Then we recevie an answer, identifying the user and all the permissions are visible
415        // in the answer.
416        let expected = UserInfoAndPermissions {
417            sub: "336362361919115278".to_owned(),
418            email: None,
419            email_verified: None,
420            // It seems the IAM backend maintains order. So this assertion works.
421            permissions: [].to_vec(), // permissions.to_vec(),
422        };
423        assert_eq!(expected, response);
424    }
425
426    /// The [`Permission`]s enum is not exhaustive. If only testing as admin you get every, even
427    /// made up ones, mirrored. So we want to have a test to verify that permissions do exist, by
428    /// authorizing for them, with a
429    #[tokio::test]
430    async fn verify_predefined_permissions() {
431        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
432        cassette_path.push("tests/cassettes/verify_predefined_permissions.vcr.json");
433
434        // Given a client
435        let client = IamClient::with_vcr(IAM_STAGE_URL.to_owned(), cassette_path);
436        let permissions = [
437            Permission::AccessAssistant,
438            Permission::ExecuteJobs,
439            Permission::KernelAccess,
440            Permission::NuminousAccess,
441            Permission::AccessModel {
442                model: Cow::Borrowed("*"),
443            },
444        ];
445
446        // When sending a check user request with a token authorized for all permission it is
447        // asking for.
448        let result = client
449            .authorize(stage_non_admin_token(), &permissions)
450            .await;
451
452        // Then we recevie an answer, identifying the user and all the permissions are visible
453        // in the answer.
454        eprintln!("{:?}", result);
455        assert!(result.is_ok());
456    }
457
458    /// Service token used for recording cassettes
459    ///
460    /// Credentials: pharia-internal-rs-test
461    /// The user (developers) token from the environment
462    fn service_token() -> String {
463        _ = dotenv();
464        env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
465    }
466
467    /// The user (developers) token from the environment
468    fn token() -> String {
469        _ = dotenv();
470        env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
471    }
472
473    /// The user (developers) token from the environment
474    fn stage_non_admin_token() -> String {
475        _ = dotenv();
476        env::var("PHARIA_STAGE_NON_ADMIN").unwrap_or_else(|_| "DUMMY".to_owned())
477    }
478}