1use std::{borrow::Cow, fmt::Display};
5
6use reqwest::{Client, StatusCode};
7use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
8use serde::{Deserialize, Serialize};
9
10pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
12
13pub struct IamClient {
15 base_url: String,
17 http_client: ClientWithMiddleware,
20}
21
22impl IamClient {
23 pub fn new(base_url: String) -> Self {
24 let client = Client::builder().use_rustls_tls().build().expect(
25 "Must be able to initialize TLS backend and resolver must be able to load system \
26 configuration.",
27 );
28
29 let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
30
31 Self {
32 base_url,
33 http_client,
34 }
35 }
36
37 #[cfg(test)]
38 pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
39 let cassette_does_exist = path_to_cassette.is_file();
40 let vcr_mode = if cassette_does_exist {
41 reqwest_vcr::VCRMode::Replay
42 } else {
43 reqwest_vcr::VCRMode::Record
44 };
45
46 let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
47 .unwrap()
48 .with_mode(vcr_mode)
49 .with_modify_request(|request| {
50 if let Some(header) = request.headers.get_mut("authorization") {
51 *header = vec!["TOKEN_REMOVED".to_owned()];
52 }
53 });
54
55 IamClient::with_middleware(base_url, middleware)
56 }
57
58 #[cfg(test)]
59 fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
60 let client = Client::builder().use_rustls_tls().build().expect(
61 "Must be able to initialize TLS backend and resolver must be able to load system \
62 configuration.",
63 );
64
65 let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
66
67 IamClient {
68 base_url,
69 http_client,
70 }
71 }
72
73 pub async fn check_user<'a>(
75 &self,
76 token: impl Display,
77 permissions: &'a [Permission<'a>],
78 ) -> Result<UserInfoAndPermissions, CheckUserError> {
79 let request_body = CheckUserRequestBody { permissions };
80
81 let response = self
82 .http_client
83 .post(format!("{base_url}/check_user", base_url = self.base_url))
84 .bearer_auth(token)
85 .json(&request_body)
86 .send()
87 .await
88 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
89
90 if response.status() == StatusCode::UNAUTHORIZED {
95 return Err(CheckUserError::Unauthenticated);
96 }
97
98 if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
99 use anyhow::anyhow;
100 eprintln!("{}", response.text().await.unwrap());
101 return Err(CheckUserError::ConnectionError(anyhow!(
102 "Unprocessable entity"
103 )));
104 }
105
106 response
108 .error_for_status_ref()
109 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
110
111 let user_info = response
112 .json()
113 .await
114 .map_err(|e| CheckUserError::ConnectionError(e.into()))?;
115
116 Ok(user_info)
117 }
118}
119
120#[derive(Serialize)]
123struct CheckUserRequestBody<'a> {
124 permissions: &'a [Permission<'a>],
126}
127
128#[derive(Deserialize, PartialEq, Eq, Debug)]
131pub struct UserInfoAndPermissions {
132 sub: String,
134 email: Option<String>,
136 email_verified: Option<bool>,
138 permissions: Vec<Permission<'static>>,
141}
142
143#[derive(thiserror::Error, Debug)]
146pub enum CheckUserError {
147 #[error("User is Unauthenticated. Token is invalid")]
148 Unauthenticated,
149 #[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
150 ConnectionError(#[source] anyhow::Error),
151}
152
153#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
154#[serde(tag = "permission")]
155pub enum Permission<'a> {
156 AssistantAccess,
157 NuminousAccess,
158 KernelAccess,
160 ExecuteJob,
163 AccessModel {
166 model: Cow<'a, str>,
167 },
168 HasRelation {
169 relation: Cow<'a, str>,
170 object: Cow<'a, str>,
171 },
172}
173
174#[cfg(test)]
175mod tests {
176 use dotenvy::dotenv;
177 use std::{env, path::PathBuf};
178
179 use super::{
180 CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
181 };
182
183 #[tokio::test]
184 async fn valid_user_token() {
185 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
189 cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
190
191 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
193
194 let response = client.check_user(token(), &[]).await.unwrap();
196
197 let expected = UserInfoAndPermissions {
199 sub: "295355180126307110".to_owned(),
200 email: Some("markus.klein@aleph-alpha.com".to_owned()),
201 email_verified: Some(true),
202 permissions: vec![],
203 };
204 assert_eq!(expected, response);
205 }
206
207 #[tokio::test]
208 async fn invalid_user_token() {
209 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
213 cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
214
215 let token = "I-AM-AN-INVALID-TOKEN";
217 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
218
219 let result = client.check_user(token, &[]).await;
221
222 assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
224 }
225
226 #[tokio::test]
227 async fn asking_for_permissions() {
228 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
232 cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
233
234 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
236 let permissions = [
237 Permission::KernelAccess,
238 Permission::ExecuteJob,
239 Permission::AssistantAccess,
240 Permission::NuminousAccess,
241 Permission::AccessModel { model: "*".into() },
242 ];
243
244 let response = client.check_user(token(), &permissions).await.unwrap();
247
248 let expected = UserInfoAndPermissions {
251 sub: "295355180126307110".to_owned(),
252 email: Some("markus.klein@aleph-alpha.com".to_owned()),
253 email_verified: Some(true),
254 permissions: permissions.to_vec(),
256 };
257 assert_eq!(expected, response);
258 }
259
260 #[tokio::test]
261 async fn asking_for_permissions_as_service() {
262 let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
266 cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
267
268 let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
270 let permissions = [Permission::AssistantAccess, Permission::NuminousAccess];
271
272 let response = client
275 .check_user(service_token(), &permissions)
276 .await
277 .unwrap();
278
279 let expected = UserInfoAndPermissions {
282 sub: "336362361919115278".to_owned(),
283 email: None,
284 email_verified: None,
285 permissions: [].to_vec(), };
288 assert_eq!(expected, response);
289 }
290
291 fn service_token() -> String {
296 _ = dotenv();
297 env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
298 }
299
300 fn token() -> String {
302 _ = dotenv();
303 env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
304 }
305}