1use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use serde::{Deserialize, Serialize};
9
10pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13#[derive(Clone, Debug)]
15pub struct IamClient {
16 base_url: String,
18 http_client: ClientWithMiddleware,
21}
22
23impl IamClient {
24 pub fn new(base_url: String) -> Self {
26 let client = Client::builder().use_rustls_tls().build().expect(
27 "Must be able to initialize TLS backend and resolver must be able to load system \
28 configuration.",
29 );
30
31 let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
32
33 Self {
34 base_url,
35 http_client,
36 }
37 }
38
39 #[cfg(test)]
40 pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
41 let cassette_does_exist = path_to_cassette.is_file();
42 let vcr_mode = if cassette_does_exist {
43 reqwest_vcr::VCRMode::Replay
44 } else {
45 reqwest_vcr::VCRMode::Record
46 };
47
48 let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
49 .unwrap()
50 .with_mode(vcr_mode)
51 .with_modify_request(|request| {
52 if let Some(header) = request.headers.get_mut("authorization") {
53 *header = vec!["TOKEN_REMOVED".to_owned()];
54 }
55 });
56
57 IamClient::with_middleware(base_url, middleware)
58 }
59
60 #[cfg(test)]
61 fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
62 let client = Client::builder().use_rustls_tls().build().expect(
63 "Must be able to initialize TLS backend and resolver must be able to load system \
64 configuration.",
65 );
66
67 let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
68
69 IamClient {
70 base_url,
71 http_client,
72 }
73 }
74
75 pub async fn check_user<'a>(
97 &self,
98 token: impl Display,
99 permissions: &'a [Permission<'a>],
100 ) -> Result<UserInfoAndPermissions, CheckUserError> {
101 let request_body = CheckUserRequestBody { permissions };
102
103 let response = self
104 .http_client
105 .post(format!("{base_url}/check_user", base_url = self.base_url))
106 .bearer_auth(token)
107 .json(&request_body)
108 .send()
109 .await
110 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
111
112 if response.status() == StatusCode::UNAUTHORIZED {
117 return Err(CheckUserError::Unauthenticated);
118 }
119
120 if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
121 use anyhow::anyhow;
122 eprintln!("{}", response.text().await.unwrap());
123 return Err(CheckUserError::ConnectionError(anyhow!(
124 "Unprocessable entity"
125 )));
126 }
127
128 response
130 .error_for_status_ref()
131 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
132
133 let user_info = response
134 .json()
135 .await
136 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
137
138 Ok(user_info)
139 }
140}
141
142#[derive(Serialize)]
145struct CheckUserRequestBody<'a> {
146 permissions: &'a [Permission<'a>],
148}
149
150#[derive(Deserialize, PartialEq, Eq, Debug)]
153pub struct UserInfoAndPermissions {
154 pub sub: String,
156 pub email: Option<String>,
158 pub email_verified: Option<bool>,
160 pub permissions: Vec<Permission<'static>>,
163}
164
165#[derive(thiserror::Error, Debug)]
168pub enum CheckUserError {
169 #[error("User is Unauthenticated. Token is invalid")]
170 Unauthenticated,
171 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
172 ConnectionError(#[source] anyhow::Error),
173}
174
175#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
176#[serde(tag = "permission")]
177pub enum Permission<'a> {
178 AccessAssistant,
179 AccessNuminous,
180 KernelAccess,
182 ExecuteJob,
185 AccessModel {
188 model: Cow<'a, str>,
189 },
190 HasRelation {
191 relation: Cow<'a, str>,
192 object: Cow<'a, str>,
193 },
194}
195
196#[cfg(test)]
197mod tests {
198 use dotenvy::dotenv;
199 use std::{env, path::PathBuf};
200
201 use super::{
202 CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
203 };
204
205 #[tokio::test]
206 async fn valid_user_token() {
207 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
211 cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
212
213 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
215
216 let response = client.check_user(token(), &[]).await.unwrap();
218
219 let expected = UserInfoAndPermissions {
221 sub: "295355180126307110".to_owned(),
222 email: Some("markus.klein@aleph-alpha.com".to_owned()),
223 email_verified: Some(true),
224 permissions: vec![],
225 };
226 assert_eq!(expected, response);
227 }
228
229 #[tokio::test]
230 async fn invalid_user_token() {
231 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
235 cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
236
237 let token = "I-AM-AN-INVALID-TOKEN";
239 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
240
241 let result = client.check_user(token, &[]).await;
243
244 assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
246 }
247
248 #[tokio::test]
249 async fn asking_for_permissions() {
250 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
254 cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
255
256 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
258 let permissions = [
259 Permission::KernelAccess,
260 Permission::ExecuteJob,
261 Permission::AccessAssistant,
262 Permission::AccessNuminous,
263 Permission::AccessModel { model: "*".into() },
264 ];
265
266 let response = client.check_user(token(), &permissions).await.unwrap();
269
270 let expected = UserInfoAndPermissions {
273 sub: "295355180126307110".to_owned(),
274 email: Some("markus.klein@aleph-alpha.com".to_owned()),
275 email_verified: Some(true),
276 permissions: permissions.to_vec(),
278 };
279 assert_eq!(expected, response);
280 }
281
282 #[tokio::test]
283 async fn asking_for_permissions_as_service() {
284 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
288 cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
289
290 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
292 let permissions = [Permission::AccessAssistant, Permission::AccessNuminous];
293
294 let response = client
297 .check_user(service_token(), &permissions)
298 .await
299 .unwrap();
300
301 let expected = UserInfoAndPermissions {
304 sub: "336362361919115278".to_owned(),
305 email: None,
306 email_verified: None,
307 permissions: [].to_vec(), };
310 assert_eq!(expected, response);
311 }
312
313 fn service_token() -> String {
318 _ = dotenv();
319 env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
320 }
321
322 fn token() -> String {
324 _ = dotenv();
325 env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
326 }
327}