1use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use serde::{Deserialize, Serialize};
9
10pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13#[derive(Clone, Debug)]
15pub struct IamClient {
16 base_url: String,
18 http_client: ClientWithMiddleware,
21}
22
23impl IamClient {
24 pub fn new(base_url: String) -> Self {
26 let client = Client::builder().use_rustls_tls().build().expect(
27 "Must be able to initialize TLS backend and resolver must be able to load system \
28 configuration.",
29 );
30
31 let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
32
33 Self {
34 base_url,
35 http_client,
36 }
37 }
38
39 #[cfg(test)]
40 pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
41 let cassette_does_exist = path_to_cassette.is_file();
42 let vcr_mode = if cassette_does_exist {
43 reqwest_vcr::VCRMode::Replay
44 } else {
45 reqwest_vcr::VCRMode::Record
46 };
47
48 let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
49 .unwrap()
50 .with_mode(vcr_mode)
51 .with_modify_request(|request| {
52 if let Some(header) = request.headers.get_mut("authorization") {
53 *header = vec!["TOKEN_REMOVED".to_owned()];
54 }
55 });
56
57 IamClient::with_middleware(base_url, middleware)
58 }
59
60 #[cfg(test)]
61 fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
62 let client = Client::builder().use_rustls_tls().build().expect(
63 "Must be able to initialize TLS backend and resolver must be able to load system \
64 configuration.",
65 );
66
67 let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
68
69 IamClient {
70 base_url,
71 http_client,
72 }
73 }
74
75 pub async fn check_user<'a>(
99 &self,
100 token: impl Display,
101 permissions: &'a [Permission<'a>],
102 ) -> Result<UserInfoAndPermissions, CheckUserError> {
103 let request_body = CheckUserRequestBody { permissions };
104
105 let response = self
106 .http_client
107 .post(format!("{base_url}/check_user", base_url = self.base_url))
108 .bearer_auth(token)
109 .json(&request_body)
110 .send()
111 .await
112 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
113
114 if response.status() == StatusCode::UNAUTHORIZED {
119 return Err(CheckUserError::Unauthenticated);
120 }
121
122 if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
123 use anyhow::anyhow;
124 eprintln!("{}", response.text().await.unwrap());
125 return Err(CheckUserError::ConnectionError(anyhow!(
126 "Unprocessable entity"
127 )));
128 }
129
130 response
132 .error_for_status_ref()
133 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
134
135 let user_info = response
136 .json()
137 .await
138 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
139
140 Ok(user_info)
141 }
142
143 pub async fn authorize<'a>(
144 &self,
145 token: impl Display,
146 permissions: &'a [Permission<'a>],
147 ) -> Result<UserInfoAndPermissions, AuthorizationError> {
148 let user_info = self.check_user(token, permissions).await?;
149 if user_info.permissions == permissions {
150 Ok(user_info)
151 } else {
152 Err(AuthorizationError::Unauthorized)
153 }
154 }
155}
156
157#[derive(Serialize)]
160struct CheckUserRequestBody<'a> {
161 permissions: &'a [Permission<'a>],
163}
164
165#[derive(Deserialize, PartialEq, Eq, Debug)]
168pub struct UserInfoAndPermissions {
169 pub sub: String,
171 pub email: Option<String>,
173 pub email_verified: Option<bool>,
175 pub permissions: Vec<Permission<'static>>,
178}
179
180#[derive(thiserror::Error, Debug)]
183pub enum CheckUserError {
184 #[error("User is Unauthenticated. Token is invalid")]
185 Unauthenticated,
186 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
187 ConnectionError(#[source] anyhow::Error),
188}
189
190#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
191#[serde(tag = "permission")]
192pub enum Permission<'a> {
193 AccessAssistant,
194 NuminousAccess,
195 KernelAccess,
197 ExecuteJob,
200 AccessModel {
203 model: Cow<'a, str>,
204 },
205 HasRelation {
206 relation: Cow<'a, str>,
207 object: Cow<'a, str>,
208 },
209}
210
211#[derive(thiserror::Error, Debug)]
212pub enum AuthorizationError {
213 #[error("User is Unauthenticated. Token is invalid")]
214 Unauthenticated,
215 #[error("Unauthorized")]
216 Unauthorized,
217 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
218 ConnectionError(#[source] anyhow::Error),
219}
220
221impl From<CheckUserError> for AuthorizationError {
222 fn from(err: CheckUserError) -> Self {
223 match err {
224 CheckUserError::Unauthenticated => AuthorizationError::Unauthenticated,
225 CheckUserError::ConnectionError(err) => AuthorizationError::ConnectionError(err),
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use dotenvy::dotenv;
233 use std::{env, path::PathBuf};
234
235 use super::{
236 CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
237 };
238
239 #[tokio::test]
240 async fn valid_user_token() {
241 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
245 cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
246
247 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
249
250 let response = client.check_user(token(), &[]).await.unwrap();
252
253 let expected = UserInfoAndPermissions {
255 sub: "295355180126307110".to_owned(),
256 email: Some("markus.klein@aleph-alpha.com".to_owned()),
257 email_verified: Some(true),
258 permissions: vec![],
259 };
260 assert_eq!(expected, response);
261 }
262
263 #[tokio::test]
264 async fn invalid_user_token() {
265 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
269 cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
270
271 let token = "I-AM-AN-INVALID-TOKEN";
273 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
274
275 let result = client.check_user(token, &[]).await;
277
278 assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
280 }
281
282 #[tokio::test]
283 async fn asking_for_permissions() {
284 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
288 cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
289
290 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
292 let permissions = [
293 Permission::KernelAccess,
294 Permission::ExecuteJob,
295 Permission::AccessAssistant,
296 Permission::NuminousAccess,
297 Permission::AccessModel { model: "*".into() },
298 ];
299
300 let response = client.check_user(token(), &permissions).await.unwrap();
303
304 let expected = UserInfoAndPermissions {
307 sub: "295355180126307110".to_owned(),
308 email: Some("markus.klein@aleph-alpha.com".to_owned()),
309 email_verified: Some(true),
310 permissions: permissions.to_vec(),
312 };
313 assert_eq!(expected, response);
314 }
315
316 #[tokio::test]
317 async fn authorize() {
318 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
322 cassette_path.push("tests/cassettes/authorize.vcr.json");
323
324 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
326 let permissions = [
327 Permission::KernelAccess,
328 Permission::ExecuteJob,
329 Permission::AccessAssistant,
330 Permission::NuminousAccess,
331 Permission::AccessModel { model: "*".into() },
332 ];
333
334 let response = client.authorize(token(), &permissions).await.unwrap();
337
338 let expected = UserInfoAndPermissions {
341 sub: "295355180126307110".to_owned(),
342 email: Some("markus.klein@aleph-alpha.com".to_owned()),
343 email_verified: Some(true),
344 permissions: permissions.to_vec(),
346 };
347 assert_eq!(expected, response);
348 }
349
350 #[tokio::test]
351 async fn asking_for_permissions_as_service() {
352 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
356 cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
357
358 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
360 let permissions = [Permission::AccessAssistant, Permission::NuminousAccess];
361
362 let response = client
365 .check_user(service_token(), &permissions)
366 .await
367 .unwrap();
368
369 let expected = UserInfoAndPermissions {
372 sub: "336362361919115278".to_owned(),
373 email: None,
374 email_verified: None,
375 permissions: [].to_vec(), };
378 assert_eq!(expected, response);
379 }
380
381 fn service_token() -> String {
386 _ = dotenv();
387 env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
388 }
389
390 fn token() -> String {
392 _ = dotenv();
393 env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
394 }
395}