1use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use serde::{Deserialize, Serialize};
9
10pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13pub const IAM_STAGE_URL: &str = "https://pharia-iam.stage.product.pharia.com";
14
15#[derive(Clone, Debug)]
17pub struct IamClient {
18 base_url: String,
20 http_client: ClientWithMiddleware,
23}
24
25impl IamClient {
26 pub fn new(base_url: String) -> Self {
28 let client = Client::builder().use_rustls_tls().build().expect(
29 "Must be able to initialize TLS backend and resolver must be able to load system \
30 configuration.",
31 );
32
33 let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
34
35 Self {
36 base_url,
37 http_client,
38 }
39 }
40
41 #[cfg(test)]
42 pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
43 let cassette_does_exist = path_to_cassette.is_file();
44 let vcr_mode = if cassette_does_exist {
45 reqwest_vcr::VCRMode::Replay
46 } else {
47 reqwest_vcr::VCRMode::Record
48 };
49
50 let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
51 .unwrap()
52 .with_mode(vcr_mode)
53 .with_modify_request(|request| {
54 if let Some(header) = request.headers.get_mut("authorization") {
55 *header = vec!["TOKEN_REMOVED".to_owned()];
56 }
57 });
58
59 IamClient::with_middleware(base_url, middleware)
60 }
61
62 #[cfg(test)]
63 fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
64 let client = Client::builder().use_rustls_tls().build().expect(
65 "Must be able to initialize TLS backend and resolver must be able to load system \
66 configuration.",
67 );
68
69 let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
70
71 IamClient {
72 base_url,
73 http_client,
74 }
75 }
76
77 pub async fn check_user<'a>(
87 &self,
88 token: impl Display,
89 permissions: &'a [Permission<'a>],
90 ) -> Result<UserInfoAndPermissions, CheckUserError> {
91 let request_body = CheckUserRequestBody { permissions };
92
93 let response = self
94 .http_client
95 .post(format!("{base_url}/check_user", base_url = self.base_url))
96 .bearer_auth(token)
97 .json(&request_body)
98 .send()
99 .await
100 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
101
102 if response.status() == StatusCode::UNAUTHORIZED {
107 return Err(CheckUserError::Unauthenticated);
108 }
109
110 if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
111 use anyhow::anyhow;
112 eprintln!("{}", response.text().await.unwrap());
113 return Err(CheckUserError::ConnectionError(anyhow!(
114 "Unprocessable entity"
115 )));
116 }
117
118 response
120 .error_for_status_ref()
121 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
122
123 let user_info = response
124 .json()
125 .await
126 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
127
128 Ok(user_info)
129 }
130
131 pub async fn authorize<'a>(
153 &self,
154 token: impl Display,
155 permissions: &'a [Permission<'a>],
156 ) -> Result<UserInfoAndPermissions, AuthorizationError> {
157 let user_info = self.check_user(token, permissions).await?;
158 if user_info.permissions == permissions {
159 Ok(user_info)
160 } else {
161 Err(AuthorizationError::Unauthorized)
162 }
163 }
164}
165
166#[derive(Serialize)]
169struct CheckUserRequestBody<'a> {
170 permissions: &'a [Permission<'a>],
172}
173
174#[derive(Deserialize, PartialEq, Eq, Debug)]
177pub struct UserInfoAndPermissions {
178 pub sub: String,
180 pub email: Option<String>,
182 pub email_verified: Option<bool>,
184 pub permissions: Vec<Permission<'static>>,
187}
188
189#[derive(thiserror::Error, Debug)]
192pub enum CheckUserError {
193 #[error("User is Unauthenticated. Token is invalid")]
194 Unauthenticated,
195 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
196 ConnectionError(#[source] anyhow::Error),
197}
198
199#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
200#[serde(tag = "permission")]
201pub enum Permission<'a> {
202 AccessAssistant,
203 NuminousAccess,
204 KernelAccess,
206 ExecuteJobs,
209 AccessModel {
212 model: Cow<'a, str>,
213 },
214 HasRelation {
215 relation: Cow<'a, str>,
216 object: Cow<'a, str>,
217 },
218}
219
220#[derive(thiserror::Error, Debug)]
221pub enum AuthorizationError {
222 #[error("User is Unauthenticated. Token is invalid")]
223 Unauthenticated,
224 #[error("Unauthorized")]
225 Unauthorized,
226 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
227 ConnectionError(#[source] anyhow::Error),
228}
229
230impl From<CheckUserError> for AuthorizationError {
231 fn from(err: CheckUserError) -> Self {
232 match err {
233 CheckUserError::Unauthenticated => AuthorizationError::Unauthenticated,
234 CheckUserError::ConnectionError(err) => AuthorizationError::ConnectionError(err),
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use dotenvy::dotenv;
242 use std::{borrow::Cow, env, path::PathBuf};
243
244 use crate::iam::IAM_STAGE_URL;
245
246 use super::{
247 CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
248 };
249
250 #[tokio::test]
251 async fn valid_user_token() {
252 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
256 cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
257
258 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
260
261 let response = client.check_user(token(), &[]).await.unwrap();
263
264 let expected = UserInfoAndPermissions {
266 sub: "295355180126307110".to_owned(),
267 email: Some("markus.klein@aleph-alpha.com".to_owned()),
268 email_verified: Some(true),
269 permissions: vec![],
270 };
271 assert_eq!(expected, response);
272 }
273
274 #[tokio::test]
275 async fn invalid_user_token() {
276 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
280 cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
281
282 let token = "I-AM-AN-INVALID-TOKEN";
284 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
285
286 let result = client.check_user(token, &[]).await;
288
289 assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
291 }
292
293 #[tokio::test]
294 async fn asking_for_permissions() {
295 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
299 cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
300
301 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
303 let permissions = [
304 Permission::KernelAccess,
305 Permission::ExecuteJobs,
306 Permission::AccessAssistant,
307 Permission::NuminousAccess,
308 Permission::AccessModel { model: "*".into() },
309 ];
310
311 let response = client.check_user(token(), &permissions).await.unwrap();
314
315 let expected = UserInfoAndPermissions {
318 sub: "295355180126307110".to_owned(),
319 email: Some("markus.klein@aleph-alpha.com".to_owned()),
320 email_verified: Some(true),
321 permissions: permissions.to_vec(),
323 };
324 assert_eq!(expected, response);
325 }
326
327 #[tokio::test]
328 async fn authorize() {
329 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
333 cassette_path.push("tests/cassettes/authorize.vcr.json");
334
335 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
337 let permissions = [
338 Permission::KernelAccess,
339 Permission::ExecuteJobs,
340 Permission::AccessAssistant,
341 Permission::NuminousAccess,
342 Permission::AccessModel { model: "*".into() },
343 ];
344
345 let response = client.authorize(token(), &permissions).await.unwrap();
348
349 let expected = UserInfoAndPermissions {
352 sub: "295355180126307110".to_owned(),
353 email: Some("markus.klein@aleph-alpha.com".to_owned()),
354 email_verified: Some(true),
355 permissions: permissions.to_vec(),
357 };
358 assert_eq!(expected, response);
359 }
360
361 #[tokio::test]
362 async fn asking_for_permissions_as_service() {
363 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
367 cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
368
369 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
371 let permissions = [Permission::AccessAssistant, Permission::NuminousAccess];
372
373 let response = client
376 .check_user(service_token(), &permissions)
377 .await
378 .unwrap();
379
380 let expected = UserInfoAndPermissions {
383 sub: "336362361919115278".to_owned(),
384 email: None,
385 email_verified: None,
386 permissions: [].to_vec(), };
389 assert_eq!(expected, response);
390 }
391
392 #[tokio::test]
396 async fn verify_predefined_permissions() {
397 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
398 cassette_path.push("tests/cassettes/verify_predefined_permissions.vcr.json");
399
400 let client = IamClient::with_vcr(IAM_STAGE_URL.to_owned(), cassette_path);
402 let permissions = [
403 Permission::AccessAssistant,
404 Permission::ExecuteJobs,
405 Permission::KernelAccess,
406 Permission::NuminousAccess,
407 Permission::AccessModel {
408 model: Cow::Borrowed("*"),
409 },
410 ];
411
412 let result = client
415 .authorize(stage_non_admin_token(), &permissions)
416 .await;
417
418 eprintln!("{:?}", result);
421 assert!(result.is_ok());
422 }
423
424 fn service_token() -> String {
429 _ = dotenv();
430 env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
431 }
432
433 fn token() -> String {
435 _ = dotenv();
436 env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
437 }
438
439 fn stage_non_admin_token() -> String {
441 _ = dotenv();
442 env::var("PHARIA_STAGE_NON_ADMIN").unwrap_or_else(|_| "DUMMY".to_owned())
443 }
444}