1use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware};
8use serde::{Deserialize, Serialize};
9
10pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13pub const IAM_STAGE_URL: &str = "https://pharia-iam.stage.product.pharia.com";
15
16pub struct IamClientBuilder {
17 base_url: String,
18 client_builder: ClientBuilder,
19}
20
21impl IamClientBuilder {
22 pub fn new(base_url: String) -> Self {
25 let client = Client::builder().use_rustls_tls().build().expect(
26 "Must be able to initialize TLS backend and resolver must be able to load system \
27 configuration.",
28 );
29
30 let client_builder = ClientBuilder::new(client);
31 IamClientBuilder {
32 base_url,
33 client_builder,
34 }
35 }
36
37 pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
39 self.client_builder = self.client_builder.with(middleware);
40 self
41 }
42
43 #[cfg(feature = "opentelemetry")]
46 pub fn with_opentelemetry(self) -> Self {
47 let middleware = reqwest_tracing::TracingMiddleware::default();
48 self.with_middleware(middleware)
49 }
50
51 #[cfg(test)]
52 pub fn with_vcr(self, path_to_cassette: std::path::PathBuf) -> Self {
56 let cassette_does_exist = path_to_cassette.is_file();
57 let vcr_mode = if cassette_does_exist {
58 reqwest_vcr::VCRMode::Replay
59 } else {
60 reqwest_vcr::VCRMode::Record
61 };
62
63 let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
64 .unwrap()
65 .with_mode(vcr_mode)
66 .with_modify_request(|request| {
67 if let Some(header) = request.headers.get_mut("authorization") {
68 *header = vec!["TOKEN_REMOVED".to_owned()];
69 }
70 });
71
72 self.with_middleware(middleware)
73 }
74
75 pub fn build(self) -> IamClient {
77 let client = self.client_builder.build();
78 IamClient {
79 base_url: self.base_url,
80 http_client: client,
81 }
82 }
83}
84
85#[derive(Clone, Debug)]
87pub struct IamClient {
88 base_url: String,
90 http_client: ClientWithMiddleware,
93}
94
95impl IamClient {
96 pub fn builder(base_url: String) -> IamClientBuilder {
98 IamClientBuilder::new(base_url)
99 }
100
101 pub fn new(base_url: String) -> Self {
103 Self::builder(base_url).build()
104 }
105
106 #[cfg(test)]
107 pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
108 Self::builder(base_url).with_vcr(path_to_cassette).build()
109 }
110
111 pub async fn check_user<'a>(
121 &self,
122 token: impl Display,
123 permissions: &'a [Permission<'a>],
124 ) -> Result<UserInfoAndPermissions, CheckUserError> {
125 let request_body = CheckUserRequestBody { permissions };
126
127 let response = self
128 .http_client
129 .post(format!("{base_url}/check_user", base_url = self.base_url))
130 .bearer_auth(token)
131 .json(&request_body)
132 .send()
133 .await
134 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
135
136 if response.status() == StatusCode::UNAUTHORIZED {
141 return Err(CheckUserError::Unauthenticated);
142 }
143
144 if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
145 use anyhow::anyhow;
146 eprintln!("{}", response.text().await.unwrap());
147 return Err(CheckUserError::ConnectionError(anyhow!(
148 "Unprocessable entity"
149 )));
150 }
151
152 response
154 .error_for_status_ref()
155 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
156
157 let user_info = response
158 .json()
159 .await
160 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
161
162 Ok(user_info)
163 }
164
165 pub async fn authorize<'a>(
187 &self,
188 token: impl Display,
189 permissions: &'a [Permission<'a>],
190 ) -> Result<UserInfoAndPermissions, AuthorizationError> {
191 let user_info = self.check_user(token, permissions).await?;
192 if user_info.permissions == permissions {
193 Ok(user_info)
194 } else {
195 Err(AuthorizationError::Unauthorized)
196 }
197 }
198}
199
200#[derive(Serialize)]
203struct CheckUserRequestBody<'a> {
204 permissions: &'a [Permission<'a>],
206}
207
208#[derive(Deserialize, PartialEq, Eq, Debug)]
211pub struct UserInfoAndPermissions {
212 pub sub: String,
214 pub email: Option<String>,
216 pub email_verified: Option<bool>,
218 pub permissions: Vec<Permission<'static>>,
221}
222
223#[derive(thiserror::Error, Debug)]
226pub enum CheckUserError {
227 #[error("User is Unauthenticated. Token is invalid")]
228 Unauthenticated,
229 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
230 ConnectionError(#[source] anyhow::Error),
231}
232
233#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
234#[serde(tag = "permission")]
235pub enum Permission<'a> {
236 AccessAssistant,
237 NuminousAccess,
238 KernelAccess,
240 ExecuteJobs,
243 AccessModel {
246 model: Cow<'a, str>,
247 },
248 HasRelation {
249 relation: Cow<'a, str>,
250 object: Cow<'a, str>,
251 },
252}
253
254#[derive(thiserror::Error, Debug)]
255pub enum AuthorizationError {
256 #[error("User is Unauthenticated. Token is invalid")]
257 Unauthenticated,
258 #[error("Unauthorized")]
259 Unauthorized,
260 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
261 ConnectionError(#[source] anyhow::Error),
262}
263
264impl From<CheckUserError> for AuthorizationError {
265 fn from(err: CheckUserError) -> Self {
266 match err {
267 CheckUserError::Unauthenticated => AuthorizationError::Unauthenticated,
268 CheckUserError::ConnectionError(err) => AuthorizationError::ConnectionError(err),
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use dotenvy::dotenv;
276 use std::{borrow::Cow, env, path::PathBuf};
277
278 use crate::iam::IAM_STAGE_URL;
279
280 use super::{
281 CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
282 };
283
284 #[tokio::test]
285 async fn valid_user_token() {
286 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
290 cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
291
292 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
294
295 let response = client.check_user(token(), &[]).await.unwrap();
297
298 let expected = UserInfoAndPermissions {
300 sub: "295355180126307110".to_owned(),
301 email: Some("markus.klein@aleph-alpha.com".to_owned()),
302 email_verified: Some(true),
303 permissions: vec![],
304 };
305 assert_eq!(expected, response);
306 }
307
308 #[tokio::test]
309 async fn invalid_user_token() {
310 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
314 cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
315
316 let token = "I-AM-AN-INVALID-TOKEN";
318 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
319
320 let result = client.check_user(token, &[]).await;
322
323 assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
325 }
326
327 #[tokio::test]
328 async fn asking_for_permissions() {
329 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
333 cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
334
335 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
337 let permissions = [
338 Permission::KernelAccess,
339 Permission::ExecuteJobs,
340 Permission::AccessAssistant,
341 Permission::NuminousAccess,
342 Permission::AccessModel { model: "*".into() },
343 ];
344
345 let response = client.check_user(token(), &permissions).await.unwrap();
348
349 let expected = UserInfoAndPermissions {
352 sub: "295355180126307110".to_owned(),
353 email: Some("markus.klein@aleph-alpha.com".to_owned()),
354 email_verified: Some(true),
355 permissions: permissions.to_vec(),
357 };
358 assert_eq!(expected, response);
359 }
360
361 #[tokio::test]
362 async fn authorize() {
363 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
367 cassette_path.push("tests/cassettes/authorize.vcr.json");
368
369 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
371 let permissions = [
372 Permission::KernelAccess,
373 Permission::ExecuteJobs,
374 Permission::AccessAssistant,
375 Permission::NuminousAccess,
376 Permission::AccessModel { model: "*".into() },
377 ];
378
379 let response = client.authorize(token(), &permissions).await.unwrap();
382
383 let expected = UserInfoAndPermissions {
386 sub: "295355180126307110".to_owned(),
387 email: Some("markus.klein@aleph-alpha.com".to_owned()),
388 email_verified: Some(true),
389 permissions: permissions.to_vec(),
391 };
392 assert_eq!(expected, response);
393 }
394
395 #[tokio::test]
396 async fn asking_for_permissions_as_service() {
397 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
401 cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
402
403 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
405 let permissions = [Permission::AccessAssistant, Permission::NuminousAccess];
406
407 let response = client
410 .check_user(service_token(), &permissions)
411 .await
412 .unwrap();
413
414 let expected = UserInfoAndPermissions {
417 sub: "336362361919115278".to_owned(),
418 email: None,
419 email_verified: None,
420 permissions: [].to_vec(), };
423 assert_eq!(expected, response);
424 }
425
426 #[tokio::test]
430 async fn verify_predefined_permissions() {
431 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
432 cassette_path.push("tests/cassettes/verify_predefined_permissions.vcr.json");
433
434 let client = IamClient::with_vcr(IAM_STAGE_URL.to_owned(), cassette_path);
436 let permissions = [
437 Permission::AccessAssistant,
438 Permission::ExecuteJobs,
439 Permission::KernelAccess,
440 Permission::NuminousAccess,
441 Permission::AccessModel {
442 model: Cow::Borrowed("*"),
443 },
444 ];
445
446 let result = client
449 .authorize(stage_non_admin_token(), &permissions)
450 .await;
451
452 eprintln!("{:?}", result);
455 assert!(result.is_ok());
456 }
457
458 fn service_token() -> String {
463 _ = dotenv();
464 env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
465 }
466
467 fn token() -> String {
469 _ = dotenv();
470 env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
471 }
472
473 fn stage_non_admin_token() -> String {
475 _ = dotenv();
476 env::var("PHARIA_STAGE_NON_ADMIN").unwrap_or_else(|_| "DUMMY".to_owned())
477 }
478}