pharia_common/
iam.rs

1//! **IAM** is short for **I**dentity **A**ccess **M**anagement. This module contains opinionated
2//! adapters to connect to the internal Pharia IAM solution.
3
4use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use serde::{Deserialize, Serialize};
9
10/// URL of IAM in our production environment
11pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13/// Client forPharia **I**dentity **A**ccess **M**anagement. Authenticate and authorize users.
14pub struct IamClient {
15    /// Environment specific URL to Pharia IAM. E.g. <https://pharia-iam.product.pharia.com>
16    base_url: String,
17    /// Used for sending the http requests. We are using `ClientWithMiddleware` to allow for VCR
18    /// recording in tests.
19    http_client: ClientWithMiddleware,
20}
21
22impl IamClient {
23    pub fn new(base_url: String) -> Self {
24        let client = Client::builder().use_rustls_tls().build().expect(
25            "Must be able to initialize TLS backend and resolver must be able to load system \
26            configuration.",
27        );
28
29        let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
30
31        Self {
32            base_url,
33            http_client,
34        }
35    }
36
37    #[cfg(test)]
38    pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
39        let cassette_does_exist = path_to_cassette.is_file();
40        let vcr_mode = if cassette_does_exist {
41            reqwest_vcr::VCRMode::Replay
42        } else {
43            reqwest_vcr::VCRMode::Record
44        };
45
46        let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
47            .unwrap()
48            .with_mode(vcr_mode)
49            .with_modify_request(|request| {
50                if let Some(header) = request.headers.get_mut("authorization") {
51                    *header = vec!["TOKEN_REMOVED".to_owned()];
52                }
53            });
54
55        IamClient::with_middleware(base_url, middleware)
56    }
57
58    #[cfg(test)]
59    fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
60        let client = Client::builder().use_rustls_tls().build().expect(
61            "Must be able to initialize TLS backend and resolver must be able to load system \
62            configuration.",
63        );
64
65        let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
66
67        IamClient {
68            base_url,
69            http_client,
70        }
71    }
72
73    /// One stop shop for both authentication and authorization.
74    pub async fn check_user<'a>(
75        &self,
76        token: impl Display,
77        permissions: &'a [Permission<'a>],
78    ) -> Result<UserInfoAndPermissions, CheckUserError> {
79        let request_body = CheckUserRequestBody { permissions };
80
81        let response = self
82            .http_client
83            .post(format!("{base_url}/check_user", base_url = self.base_url))
84            .bearer_auth(token)
85            .json(&request_body)
86            .send()
87            .await
88            .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
89
90        // A long standing quirk of the HTTP standard: Unauthorized 401 actually means
91        // "unauthenticated". We consider this a domain specific logic error, rather than a runtime
92        // error, which should be fixed with retry. Therfore we categorize this error differently
93        // the other connection errors
94        if response.status() == StatusCode::UNAUTHORIZED {
95            return Err(CheckUserError::Unauthenticated);
96        }
97
98        if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
99            use anyhow::anyhow;
100            eprintln!("{}", response.text().await.unwrap());
101            return Err(CheckUserError::ConnectionError(anyhow!(
102                "Unprocessable entity"
103            )));
104        }
105
106        // Map all other thing to ConnectionError
107        response
108            .error_for_status_ref()
109            .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
110
111        let user_info = response
112            .json()
113            .await
114            .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
115
116        Ok(user_info)
117    }
118}
119
120/// Body of the the IAM `/check_user` route. The token is not passed in the body but in the
121/// authorization header.
122#[derive(Serialize)]
123struct CheckUserRequestBody<'a> {
124    /// A list of permissions to query for the specific user.
125    permissions: &'a [Permission<'a>],
126}
127
128/// Returned by [`IamClient::check_user`]. Contains information describing the user as well as the
129/// union of the queried permissions and the privileges of the user.
130#[derive(Deserialize, PartialEq, Eq, Debug)]
131pub struct UserInfoAndPermissions {
132    /// Unique ID of the User
133    sub: String,
134    /// Email of the user. `None` for Service users
135    email: Option<String>,
136    /// May be `None` for Service Users
137    email_verified: Option<bool>,
138    /// List of requested permissions, which are privieleges of the User Service. They are in the
139    /// same order as in the query
140    permissions: Vec<Permission<'static>>,
141}
142
143/// An error returned by [`IamClient::check_user`]. Note that this does **not** include
144/// unauthorized. To check for authorization inspect the permissions of [`UserInfoAndPermissions`]
145#[derive(thiserror::Error, Debug)]
146pub enum CheckUserError {
147    #[error("User is Unauthenticated. Token is invalid")]
148    Unauthenticated,
149    #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
150    ConnectionError(#[source] anyhow::Error),
151}
152
153#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
154#[serde(tag = "permission")]
155pub enum Permission<'a> {
156    AssistantAccess,
157    NuminousAccess,
158    /// The kernel uses this permission to authorize skill execution
159    KernelAccess,
160    /// Used by inference to decide wether a user is authorized to perform any kind of inference
161    /// requests.
162    ExecuteJob,
163    /// Is this user allowed to use this model? "*" Can be used as a model name in order to indicate
164    /// access to all models.
165    AccessModel {
166        model: Cow<'a, str>,
167    },
168    HasRelation {
169        relation: Cow<'a, str>,
170        object: Cow<'a, str>,
171    },
172}
173
174#[cfg(test)]
175mod tests {
176    use dotenvy::dotenv;
177    use std::{env, path::PathBuf};
178
179    use super::{
180        CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
181    };
182
183    #[tokio::test]
184    async fn valid_user_token() {
185        // We are using cassets to record the request. This makes the test easy to execute even
186        // without a connection to Pharia. Additionally it allows us to execute the test even
187        // without the specific token of the user who recorded it at hand.
188        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
189        cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
190
191        // Given a client
192        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
193
194        // When sending a check user request with a valid token
195        let response = client.check_user(token(), &[]).await.unwrap();
196
197        // Then we recevie an answer, identifying the user
198        let expected = UserInfoAndPermissions {
199            sub: "295355180126307110".to_owned(),
200            email: Some("markus.klein@aleph-alpha.com".to_owned()),
201            email_verified: Some(true),
202            permissions: vec![],
203        };
204        assert_eq!(expected, response);
205    }
206
207    #[tokio::test]
208    async fn invalid_user_token() {
209        // We are using cassets to record the request. This makes the test easy to execute even
210        // without a connection to Pharia. Additionally it allows us to execute the test even
211        // without the specific token of the user who recorded it at hand.
212        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
213        cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
214
215        // Given an invalid Pharia User Token
216        let token = "I-AM-AN-INVALID-TOKEN";
217        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
218
219        // When sending a check user request
220        let result = client.check_user(token, &[]).await;
221
222        // Then the user is unauthenticated
223        assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
224    }
225
226    #[tokio::test]
227    async fn asking_for_permissions() {
228        // We are using cassets to record the request. This makes the test easy to execute even
229        // without a connection to Pharia. Additionally it allows us to execute the test even
230        // without the specific token of the user who recorded it at hand.
231        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
232        cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
233
234        // Given a client
235        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
236        let permissions = [
237            Permission::KernelAccess,
238            Permission::ExecuteJob,
239            Permission::AssistantAccess,
240            Permission::NuminousAccess,
241            Permission::AccessModel { model: "*".into() },
242        ];
243
244        // When sending a check user request with a token authorized for all permission it is
245        // asking for.
246        let response = client.check_user(token(), &permissions).await.unwrap();
247
248        // Then we recevie an answer, identifying the user and all the permissions are visible
249        // in the answer.
250        let expected = UserInfoAndPermissions {
251            sub: "295355180126307110".to_owned(),
252            email: Some("markus.klein@aleph-alpha.com".to_owned()),
253            email_verified: Some(true),
254            // It seems the IAM backend maintains order. So this assertion works.
255            permissions: permissions.to_vec(),
256        };
257        assert_eq!(expected, response);
258    }
259
260    #[tokio::test]
261    async fn asking_for_permissions_as_service() {
262        // We are using cassets to record the request. This makes the test easy to execute even
263        // without a connection to Pharia. Additionally it allows us to execute the test even
264        // without the specific token of the user who recorded it at hand.
265        let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
266        cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
267
268        // Given a client
269        let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
270        let permissions = [Permission::AssistantAccess, Permission::NuminousAccess];
271
272        // When sending a check user request with a token authorized for all permission it is
273        // asking for.
274        let response = client
275            .check_user(service_token(), &permissions)
276            .await
277            .unwrap();
278
279        // Then we recevie an answer, identifying the user and all the permissions are visible
280        // in the answer.
281        let expected = UserInfoAndPermissions {
282            sub: "336362361919115278".to_owned(),
283            email: None,
284            email_verified: None,
285            // It seems the IAM backend maintains order. So this assertion works.
286            permissions: [].to_vec(), // permissions.to_vec(),
287        };
288        assert_eq!(expected, response);
289    }
290
291    /// Service token used for recording cassettes
292    ///
293    /// Credentials: pharia-internal-rs-test
294    /// The user (developers) token from the environment
295    fn service_token() -> String {
296        _ = dotenv();
297        env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
298    }
299
300    /// The user (developers) token from the environment
301    fn token() -> String {
302        _ = dotenv();
303        env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
304    }
305}