use std::{borrow::Cow, fmt::Display};
use reqwest::{Client, StatusCode};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use serde::{Deserialize, Serialize};
pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
pub struct IamClient {
base_url: String,
http_client: ClientWithMiddleware,
}
impl IamClient {
pub fn new(base_url: String) -> Self {
let client = Client::builder().use_rustls_tls().build().expect(
"Must be able to initialize TLS backend and resolver must be able to load system \
configuration.",
);
let http_client: ClientWithMiddleware = ClientBuilder::new(client).build();
Self {
base_url,
http_client,
}
}
#[cfg(test)]
pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
let cassette_does_exist = path_to_cassette.is_file();
let vcr_mode = if cassette_does_exist {
reqwest_vcr::VCRMode::Replay
} else {
reqwest_vcr::VCRMode::Record
};
let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
.unwrap()
.with_mode(vcr_mode)
.with_modify_request(|request| {
if let Some(header) = request.headers.get_mut("authorization") {
*header = vec!["TOKEN_REMOVED".to_owned()];
}
});
IamClient::with_middleware(base_url, middleware)
}
#[cfg(test)]
fn with_middleware(base_url: String, middleware: impl reqwest_middleware::Middleware) -> Self {
let client = Client::builder().use_rustls_tls().build().expect(
"Must be able to initialize TLS backend and resolver must be able to load system \
configuration.",
);
let http_client: ClientWithMiddleware = ClientBuilder::new(client).with(middleware).build();
IamClient {
base_url,
http_client,
}
}
pub async fn check_user<'a>(
&self,
token: impl Display,
permissions: &'a [Permission<'a>],
) -> Result<UserInfoAndPermissions, CheckUserError> {
let request_body = CheckUserRequestBody { permissions };
let response = self
.http_client
.post(format!("{base_url}/check_user", base_url = self.base_url))
.bearer_auth(token)
.json(&request_body)
.send()
.await
.map_err(|e| CheckUserError::ConnectionError(e.into()))?;
if response.status() == StatusCode::UNAUTHORIZED {
return Err(CheckUserError::Unauthenticated);
}
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
use anyhow::anyhow;
eprintln!("{}", response.text().await.unwrap());
return Err(CheckUserError::ConnectionError(anyhow!(
"Unprocessable entity"
)));
}
response
.error_for_status_ref()
.map_err(|e| CheckUserError::ConnectionError(e.into()))?;
let user_info = response
.json()
.await
.map_err(|e| CheckUserError::ConnectionError(e.into()))?;
Ok(user_info)
}
}
#[derive(Serialize)]
struct CheckUserRequestBody<'a> {
permissions: &'a [Permission<'a>],
}
#[derive(Deserialize, PartialEq, Eq, Debug)]
pub struct UserInfoAndPermissions {
sub: String,
email: Option<String>,
email_verified: Option<bool>,
permissions: Vec<Permission<'static>>,
}
#[derive(thiserror::Error, Debug)]
pub enum CheckUserError {
#[error("User is Unauthenticated. Token is invalid")]
Unauthenticated,
#[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
ConnectionError(#[source] anyhow::Error),
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
#[serde(tag = "permission")]
pub enum Permission<'a> {
AssistantAccess,
NuminousAccess,
KernelAccess,
ExecuteJob,
AccessModel {
model: Cow<'a, str>,
},
HasRelation {
relation: Cow<'a, str>,
object: Cow<'a, str>,
},
}
#[cfg(test)]
mod tests {
use dotenvy::dotenv;
use std::{env, path::PathBuf};
use super::{
CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
};
#[tokio::test]
async fn valid_user_token() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let response = client.check_user(token(), &[]).await.unwrap();
let expected = UserInfoAndPermissions {
sub: "295355180126307110".to_owned(),
email: Some("markus.klein@aleph-alpha.com".to_owned()),
email_verified: Some(true),
permissions: vec![],
};
assert_eq!(expected, response);
}
#[tokio::test]
async fn invalid_user_token() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
let token = "I-AM-AN-INVALID-TOKEN";
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let result = client.check_user(token, &[]).await;
assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
}
#[tokio::test]
async fn asking_for_permissions() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let permissions = [
Permission::KernelAccess,
Permission::ExecuteJob,
Permission::AssistantAccess,
Permission::NuminousAccess,
Permission::AccessModel { model: "*".into() },
];
let response = client.check_user(token(), &permissions).await.unwrap();
let expected = UserInfoAndPermissions {
sub: "295355180126307110".to_owned(),
email: Some("markus.klein@aleph-alpha.com".to_owned()),
email_verified: Some(true),
permissions: permissions.to_vec(),
};
assert_eq!(expected, response);
}
#[tokio::test]
async fn asking_for_permissions_as_service() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let permissions = [Permission::AssistantAccess, Permission::NuminousAccess];
let response = client
.check_user(service_token(), &permissions)
.await
.unwrap();
let expected = UserInfoAndPermissions {
sub: "336362361919115278".to_owned(),
email: None,
email_verified: None,
permissions: [].to_vec(), };
assert_eq!(expected, response);
}
fn service_token() -> String {
_ = dotenv();
env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
}
fn token() -> String {
_ = dotenv();
env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
}
}