use std::{borrow::Cow, fmt::Display};
use reqwest::{Client, StatusCode};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware};
use serde::{Deserialize, Serialize};
pub const IAM_PRODUCTION_URL: &str = "https://pharia-iam.product.pharia.com";
pub const IAM_STAGE_URL: &str = "https://pharia-iam.stage.product.pharia.com";
pub struct IamClientBuilder {
base_url: String,
client_builder: ClientBuilder,
}
impl IamClientBuilder {
pub fn new(base_url: String) -> Self {
let client = Client::builder().use_rustls_tls().build().expect(
"Must be able to initialize TLS backend and resolver must be able to load system \
configuration.",
);
let client_builder = ClientBuilder::new(client);
IamClientBuilder {
base_url,
client_builder,
}
}
pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
self.client_builder = self.client_builder.with(middleware);
self
}
#[cfg(feature = "opentelemetry")]
pub fn with_opentelemetry(self) -> Self {
let middleware = reqwest_tracing::TracingMiddleware::default();
self.with_middleware(middleware)
}
#[cfg(test)]
pub fn with_vcr(self, path_to_cassette: std::path::PathBuf) -> Self {
let cassette_does_exist = path_to_cassette.is_file();
let vcr_mode = if cassette_does_exist {
reqwest_vcr::VCRMode::Replay
} else {
reqwest_vcr::VCRMode::Record
};
let middleware = reqwest_vcr::VCRMiddleware::try_from(path_to_cassette)
.unwrap()
.with_mode(vcr_mode)
.with_modify_request(|request| {
if let Some(header) = request.headers.get_mut("authorization") {
*header = vec!["TOKEN_REMOVED".to_owned()];
}
});
self.with_middleware(middleware)
}
pub fn build(self) -> IamClient {
let client = self.client_builder.build();
IamClient {
base_url: self.base_url,
http_client: client,
}
}
}
#[derive(Clone, Debug)]
pub struct IamClient {
base_url: String,
http_client: ClientWithMiddleware,
}
impl IamClient {
pub fn builder(base_url: String) -> IamClientBuilder {
IamClientBuilder::new(base_url)
}
pub fn new(base_url: String) -> Self {
Self::builder(base_url).build()
}
#[cfg(test)]
pub fn with_vcr(base_url: String, path_to_cassette: std::path::PathBuf) -> Self {
Self::builder(base_url).with_vcr(path_to_cassette).build()
}
pub async fn check_user<'a>(
&self,
token: impl Display,
permissions: &'a [Permission<'a>],
) -> Result<UserInfoAndPermissions, CheckUserError> {
let request_body = CheckUserRequestBody { permissions };
let response = self
.http_client
.post(format!("{base_url}/check_user", base_url = self.base_url))
.bearer_auth(token)
.json(&request_body)
.send()
.await
.map_err(|e| CheckUserError::ConnectionError(e.into()))?;
if response.status() == StatusCode::UNAUTHORIZED {
return Err(CheckUserError::Unauthenticated);
}
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
use anyhow::anyhow;
eprintln!("{}", response.text().await.unwrap());
return Err(CheckUserError::ConnectionError(anyhow!(
"Unprocessable entity"
)));
}
response
.error_for_status_ref()
.map_err(|e| CheckUserError::ConnectionError(e.into()))?;
let user_info = response
.json()
.await
.map_err(|e| CheckUserError::ConnectionError(e.into()))?;
Ok(user_info)
}
pub async fn authorize<'a>(
&self,
token: impl Display,
permissions: &'a [Permission<'a>],
) -> Result<UserInfoAndPermissions, AuthorizationError> {
let user_info = self.check_user(token, permissions).await?;
if user_info.permissions == permissions {
Ok(user_info)
} else {
Err(AuthorizationError::Unauthorized)
}
}
}
#[derive(Serialize)]
struct CheckUserRequestBody<'a> {
permissions: &'a [Permission<'a>],
}
#[derive(Deserialize, PartialEq, Eq, Debug)]
pub struct UserInfoAndPermissions {
pub sub: String,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub permissions: Vec<Permission<'static>>,
}
#[derive(thiserror::Error, Debug)]
pub enum CheckUserError {
#[error("User is Unauthenticated. Token is invalid")]
Unauthenticated,
#[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
ConnectionError(#[source] anyhow::Error),
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone, Hash)]
#[serde(tag = "permission")]
pub enum Permission<'a> {
AccessAssistant,
NuminousAccess,
KernelAccess,
ExecuteJobs,
AccessModel {
model: Cow<'a, str>,
},
HasRelation {
relation: Cow<'a, str>,
object: Cow<'a, str>,
},
}
#[derive(thiserror::Error, Debug)]
pub enum AuthorizationError {
#[error("User is Unauthenticated. Token is invalid")]
Unauthenticated,
#[error("Unauthorized")]
Unauthorized,
#[error("User could not be authenticated due to connectivity issue:\n{0:#}")]
ConnectionError(#[source] anyhow::Error),
}
impl From<CheckUserError> for AuthorizationError {
fn from(err: CheckUserError) -> Self {
match err {
CheckUserError::Unauthenticated => AuthorizationError::Unauthenticated,
CheckUserError::ConnectionError(err) => AuthorizationError::ConnectionError(err),
}
}
}
#[cfg(test)]
mod tests {
use dotenvy::dotenv;
use std::{borrow::Cow, env, path::PathBuf};
use crate::iam::IAM_STAGE_URL;
use super::{
CheckUserError, IAM_PRODUCTION_URL, IamClient, Permission, UserInfoAndPermissions,
};
#[tokio::test]
async fn valid_user_token() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/valid_user_token.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let response = client.check_user(token(), &[]).await.unwrap();
let expected = UserInfoAndPermissions {
sub: "295355180126307110".to_owned(),
email: Some("markus.klein@aleph-alpha.com".to_owned()),
email_verified: Some(true),
permissions: vec![],
};
assert_eq!(expected, response);
}
#[tokio::test]
async fn invalid_user_token() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/invalid_user_token.vcr.json");
let token = "I-AM-AN-INVALID-TOKEN";
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let result = client.check_user(token, &[]).await;
assert!(matches!(result, Err(CheckUserError::Unauthenticated)))
}
#[tokio::test]
async fn asking_for_permissions() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/asking_for_permissions.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let permissions = [
Permission::KernelAccess,
Permission::ExecuteJobs,
Permission::AccessAssistant,
Permission::NuminousAccess,
Permission::AccessModel { model: "*".into() },
];
let response = client.check_user(token(), &permissions).await.unwrap();
let expected = UserInfoAndPermissions {
sub: "295355180126307110".to_owned(),
email: Some("markus.klein@aleph-alpha.com".to_owned()),
email_verified: Some(true),
permissions: permissions.to_vec(),
};
assert_eq!(expected, response);
}
#[tokio::test]
async fn authorize() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/authorize.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let permissions = [
Permission::KernelAccess,
Permission::ExecuteJobs,
Permission::AccessAssistant,
Permission::NuminousAccess,
Permission::AccessModel { model: "*".into() },
];
let response = client.authorize(token(), &permissions).await.unwrap();
let expected = UserInfoAndPermissions {
sub: "295355180126307110".to_owned(),
email: Some("markus.klein@aleph-alpha.com".to_owned()),
email_verified: Some(true),
permissions: permissions.to_vec(),
};
assert_eq!(expected, response);
}
#[tokio::test]
async fn asking_for_permissions_as_service() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/asking_for_permissions_as_service.vcr.json");
let client = IamClient::with_vcr(IAM_PRODUCTION_URL.to_owned(), cassette_path);
let permissions = [Permission::AccessAssistant, Permission::NuminousAccess];
let response = client
.check_user(service_token(), &permissions)
.await
.unwrap();
let expected = UserInfoAndPermissions {
sub: "336362361919115278".to_owned(),
email: None,
email_verified: None,
permissions: [].to_vec(), };
assert_eq!(expected, response);
}
#[tokio::test]
async fn verify_predefined_permissions() {
let mut cassette_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
cassette_path.push("tests/cassettes/verify_predefined_permissions.vcr.json");
let client = IamClient::with_vcr(IAM_STAGE_URL.to_owned(), cassette_path);
let permissions = [
Permission::AccessAssistant,
Permission::ExecuteJobs,
Permission::KernelAccess,
Permission::NuminousAccess,
Permission::AccessModel {
model: Cow::Borrowed("*"),
},
];
let result = client
.authorize(stage_non_admin_token(), &permissions)
.await;
eprintln!("{:?}", result);
assert!(result.is_ok());
}
fn service_token() -> String {
_ = dotenv();
env::var("PHARIA_AI_SERVICE_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
}
fn token() -> String {
_ = dotenv();
env::var("PHARIA_AI_TOKEN").unwrap_or_else(|_| "DUMMY".to_owned())
}
fn stage_non_admin_token() -> String {
_ = dotenv();
env::var("PHARIA_STAGE_NON_ADMIN").unwrap_or_else(|_| "DUMMY".to_owned())
}
}