1use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use serde::{Deserialize, Serialize};
9
10pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13pub struct IamClient {
15 base_url: String,
17 http_client: ClientWithMiddleware,
20}
21
22impl IamClient {
23 pub fn new(base_url: String) -> Self {
25 let client = Client::builder().use_rustls_tls().build().expect(
26 "Must be able to initialize TLS backend and resolver must be able to load system \
27 configuration.",
28 );
29
30 let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
31
32 Self {
33 base_url,
34 http_client,
35 }
36 }
37
38 #[cfg(test)]
39 pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
40 let cassette_does_exist = path_to_cassette.is_file();
41 let vcr_mode = if cassette_does_exist {
42 reqwest_vcr::VCRMode::Replay
43 } else {
44 reqwest_vcr::VCRMode::Record
45 };
46
47 let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
48 .unwrap()
49 .with_mode(vcr_mode)
50 .with_modify_request(|request| {
51 if let Some(header) = request.headers.get_mut("authorization") {
52 *header = vec!["TOKEN_REMOVED".to_owned()];
53 }
54 });
55
56 IamClient::with_middleware(base_url, middleware)
57 }
58
59 #[cfg(test)]
60 fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
61 let client = Client::builder().use_rustls_tls().build().expect(
62 "Must be able to initialize TLS backend and resolver must be able to load system \
63 configuration.",
64 );
65
66 let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
67
68 IamClient {
69 base_url,
70 http_client,
71 }
72 }
73
74 pub async fn check_user<'a>(
96 &self,
97 token: impl Display,
98 permissions: &'a [Permission<'a>],
99 ) -> Result<UserInfoAndPermissions, CheckUserError> {
100 let request_body = CheckUserRequestBody { permissions };
101
102 let response = self
103 .http_client
104 .post(format!("{base_url}/check_user", base_url = self.base_url))
105 .bearer_auth(token)
106 .json(&request_body)
107 .send()
108 .await
109 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
110
111 if response.status() == StatusCode::UNAUTHORIZED {
116 return Err(CheckUserError::Unauthenticated);
117 }
118
119 if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
120 use anyhow::anyhow;
121 eprintln!("{}", response.text().await.unwrap());
122 return Err(CheckUserError::ConnectionError(anyhow!(
123 "Unprocessable entity"
124 )));
125 }
126
127 response
129 .error_for_status_ref()
130 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
131
132 let user_info = response
133 .json()
134 .await
135 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
136
137 Ok(user_info)
138 }
139}
140
141#[derive(Serialize)]
144struct CheckUserRequestBody<'a> {
145 permissions: &'a [Permission<'a>],
147}
148
149#[derive(Deserialize, PartialEq, Eq, Debug)]
152pub struct UserInfoAndPermissions {
153 pub sub: String,
155 pub email: Option<String>,
157 pub email_verified: Option<bool>,
159 pub permissions: Vec<Permission<'static>>,
162}
163
164#[derive(thiserror::Error, Debug)]
167pub enum CheckUserError {
168 #[error("User is Unauthenticated. Token is invalid")]
169 Unauthenticated,
170 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
171 ConnectionError(#[source] anyhow::Error),
172}
173
174#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
175#[serde(tag = "permission")]
176pub enum Permission<'a> {
177 AssistantAccess,
178 NuminousAccess,
179 KernelAccess,
181 ExecuteJob,
184 AccessModel {
187 model: Cow<'a, str>,
188 },
189 HasRelation {
190 relation: Cow<'a, str>,
191 object: Cow<'a, str>,
192 },
193}
194
195#[cfg(test)]
196mod tests {
197 use dotenvy::dotenv;
198 use std::{env, path::PathBuf};
199
200 use super::{
201 CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
202 };
203
204 #[tokio::test]
205 async fn valid_user_token() {
206 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
210 cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
211
212 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
214
215 let response = client.check_user(token(), &[]).await.unwrap();
217
218 let expected = UserInfoAndPermissions {
220 sub: "295355180126307110".to_owned(),
221 email: Some("markus.klein@aleph-alpha.com".to_owned()),
222 email_verified: Some(true),
223 permissions: vec![],
224 };
225 assert_eq!(expected, response);
226 }
227
228 #[tokio::test]
229 async fn invalid_user_token() {
230 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
234 cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
235
236 let token = "I-AM-AN-INVALID-TOKEN";
238 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
239
240 let result = client.check_user(token, &[]).await;
242
243 assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
245 }
246
247 #[tokio::test]
248 async fn asking_for_permissions() {
249 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
253 cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
254
255 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
257 let permissions = [
258 Permission::KernelAccess,
259 Permission::ExecuteJob,
260 Permission::AssistantAccess,
261 Permission::NuminousAccess,
262 Permission::AccessModel { model: "*".into() },
263 ];
264
265 let response = client.check_user(token(), &permissions).await.unwrap();
268
269 let expected = UserInfoAndPermissions {
272 sub: "295355180126307110".to_owned(),
273 email: Some("markus.klein@aleph-alpha.com".to_owned()),
274 email_verified: Some(true),
275 permissions: permissions.to_vec(),
277 };
278 assert_eq!(expected, response);
279 }
280
281 #[tokio::test]
282 async fn asking_for_permissions_as_service() {
283 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
287 cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
288
289 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
291 let permissions = [Permission::AssistantAccess, Permission::NuminousAccess];
292
293 let response = client
296 .check_user(service_token(), &permissions)
297 .await
298 .unwrap();
299
300 let expected = UserInfoAndPermissions {
303 sub: "336362361919115278".to_owned(),
304 email: None,
305 email_verified: None,
306 permissions: [].to_vec(), };
309 assert_eq!(expected, response);
310 }
311
312 fn service_token() -> String {
317 _ = dotenv();
318 env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
319 }
320
321 fn token() -> String {
323 _ = dotenv();
324 env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
325 }
326}